flags_rs/
lib.rs

1use std::collections::HashMap;
2use std::env;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use log::__private_api::loc;
11use reqwest::header::{HeaderMap, HeaderValue};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15pub mod cache;
16pub mod flag;
17mod tests;
18
19use crate::cache::{Cache, CacheSystem, MemoryCache};
20use crate::flag::{Details, FeatureFlag};
21
22const BASE_URL: &str = "https://api.flags.gg";
23const MAX_RETRIES: u32 = 3;
24
25#[derive(Debug, Clone)]
26pub struct Auth {
27    pub project_id: String,
28    pub agent_id: String,
29    pub environment_id: String,
30}
31
32pub struct Flag<'a> {
33    name: String,
34    client: &'a Client,
35}
36
37#[derive(Debug, Error)]
38pub enum FlagError {
39    #[error("HTTP error: {0}")]
40    HttpError(#[from] reqwest::Error),
41
42    #[error("Cache error: {0}")]
43    CacheError(String),
44
45    #[error("Missing authentication: {0}")]
46    AuthError(String),
47
48    #[error("API error: {0}")]
49    ApiError(String),
50}
51
52#[derive(Debug)]
53struct CircuitState {
54    is_open: bool,
55    failure_count: u32,
56    last_failure: Option<DateTime<Utc>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ApiResponse {
61    #[serde(rename = "intervalAllowed")]
62    interval_allowed: i32,
63    flags: Vec<flag::FeatureFlag>,
64}
65
66pub struct Client {
67    base_url: String,
68    http_client: reqwest::Client,
69    cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
70    max_retries: u32,
71    circuit_state: RwLock<CircuitState>,
72    auth: Option<Auth>,
73}
74
75impl Client {
76    pub fn builder() -> ClientBuilder {
77        ClientBuilder::new()
78    }
79
80    pub fn debug_info(&self) -> String {
81        format!(
82            "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
83            self.base_url, self.max_retries, self.auth
84        )
85    }
86
87    pub fn is(&self, name: &str) -> Flag {
88        Flag {
89            name: name.to_string(),
90            client: self,
91        }
92    }
93
94    pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
95        // Check if cache needs refresh before listing
96        {
97            let cache = self.cache.read().await;
98            if cache.should_refresh_cache().await {
99                drop(cache); // Release the read lock before acquiring write lock
100                if let Err(e) = self.refetch().await {
101                    error!("Failed to refetch flags for list: {}", e);
102                    // Continue with potentially stale cache data if refetch fails
103                }
104            }
105        }
106
107        let cache = self.cache.read().await;
108        cache.get_all().await
109            .map_err(|e| FlagError::CacheError(e.to_string()))
110    }
111
112    async fn is_enabled(&self, name: &str) -> bool {
113        let name = name.to_lowercase();
114
115        // Check if cache needs refresh
116        {
117            let cache = self.cache.read().await;
118            if cache.should_refresh_cache().await {
119                drop(cache); // Release the read lock before acquiring write lock
120                if let Err(e) = self.refetch().await {
121                    error!("Failed to refetch flags: {}", e);
122                    // If refetch fails, continue to check the potentially stale cache.
123                }
124            }
125        }
126
127        // Check cache (which now contains combined API and local flags with overrides)
128        let cache = self.cache.read().await;
129        match cache.get(&name).await {
130            Ok((enabled, exists)) => {
131                if exists {
132                    enabled
133                } else {
134                    false
135                }
136            }
137            Err(_) => false, // Treat cache errors as flag not found
138        }
139    }
140
141    async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
142        let auth = match &self.auth {
143            Some(auth) => auth,
144            None => return Err(FlagError::AuthError("Authentication is required".to_string())),
145        };
146
147        if auth.project_id.is_empty() {
148            return Err(FlagError::AuthError("Project ID is required".to_string()));
149        }
150        if auth.agent_id.is_empty() {
151            return Err(FlagError::AuthError("Agent ID is required".to_string()));
152        }
153        if auth.environment_id.is_empty() {
154            return Err(FlagError::AuthError("Environment ID is required".to_string()));
155        }
156
157        let mut headers = HeaderMap::new();
158        headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
159        headers.insert("Accept", HeaderValue::from_static("application/json"));
160        headers.insert("Content-Type", HeaderValue::from_static("application/json"));
161        headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id).unwrap());
162        headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id).unwrap());
163        headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id).unwrap());
164
165        let url = format!("{}/flags", self.base_url);
166        let response = self.http_client
167            .get(&url)
168            .headers(headers)
169            .send()
170            .await?;
171
172        if !response.status().is_success() {
173            return Err(FlagError::ApiError(format!(
174                "Unexpected status code: {}",
175                response.status()
176            )));
177        }
178
179        let api_resp = response.json::<ApiResponse>().await?;
180        Ok(api_resp)
181    }
182
183    async fn refetch(&self) -> Result<(), FlagError> {
184        let mut circuit_state = self.circuit_state.write().await;
185
186        if circuit_state.is_open {
187            if let Some(last_failure) = circuit_state.last_failure {
188                let now = Utc::now();
189                // Keep the circuit open for a bit after failure
190                if (now - last_failure).num_seconds() < 10 { // You can adjust this duration
191                    warn!("Circuit breaker is open, skipping refetch.");
192                    return Ok(());
193                }
194            }
195            // If enough time has passed, attempt to close the circuit
196            warn!("Attempting to close circuit breaker.");
197            circuit_state.is_open = false;
198            circuit_state.failure_count = 0;
199        }
200        drop(circuit_state); // Release the write lock
201
202        let api_resp = match self.fetch_flags().await {
203            Ok(resp) => {
204                let mut circuit_state = self.circuit_state.write().await;
205                circuit_state.failure_count = 0; // Reset failure count on success
206                resp
207            }
208            Err(e) => {
209                let mut circuit_state = self.circuit_state.write().await;
210                circuit_state.failure_count += 1;
211                circuit_state.last_failure = Some(Utc::now());
212                if circuit_state.failure_count >= self.max_retries {
213                    circuit_state.is_open = true;
214                    error!("Refetch failed after {} retries, opening circuit breaker: {}", self.max_retries, e);
215                } else {
216                    warn!("Refetch failed (attempt {}/{}), retrying: {}", circuit_state.failure_count, self.max_retries, e);
217                }
218                drop(circuit_state); // Release the write lock
219                // If fetching fails, we should still attempt to use local flags and potentially old cache data
220                let local_flags = build_local(); // Build local flags even on API failure
221                let mut cache = self.cache.write().await;
222                // Attempt to refresh cache with only local flags if API failed
223                cache.refresh(&local_flags, 60).await // Use a default interval if API interval is not available
224                    .map_err(|e| FlagError::CacheError(e.to_string()))?;
225                return Err(e); // Propagate the error
226            }
227        };
228
229        let mut api_flags: Vec<flag::FeatureFlag> = api_resp.flags
230            .into_iter()
231            .map(|f| flag::FeatureFlag {
232                enabled: f.enabled,
233                details: flag::Details {
234                    name: f.details.name.to_lowercase(),
235                    id: f.details.id,
236                },
237            })
238            .collect();
239
240        let local_flags = build_local();
241
242        // Combine API flags and local flags, with local overriding API
243        let mut combined_flags = Vec::new();
244        let mut local_flags_map: HashMap<String, FeatureFlag> = local_flags.into_iter().map(|f| (f.details.name.clone(), f)).collect();
245
246        for api_flag in api_flags.drain(..) {
247            if let Some(local_flag) = local_flags_map.remove(&api_flag.details.name) {
248                // Local flag with the same name exists, use the local one
249                combined_flags.push(local_flag);
250            } else {
251                // No local flag with the same name, use the API one
252                combined_flags.push(api_flag);
253            }
254        }
255
256        // Add any remaining local flags that didn't have a corresponding API flag
257        combined_flags.extend(local_flags_map.into_values());
258
259
260        let mut cache = self.cache.write().await;
261        cache.refresh(&combined_flags, api_resp.interval_allowed).await
262            .map_err(|e| FlagError::CacheError(e.to_string()))?;
263
264        Ok(())
265    }
266}
267
268impl Clone for Client {
269    fn clone(&self) -> Self {
270        Client {
271            base_url: self.base_url.clone(),
272            http_client: self.http_client.clone(),
273            cache: Arc::clone(&self.cache),
274            max_retries: self.max_retries,
275            circuit_state: RwLock::new(CircuitState{
276                is_open: self.circuit_state.blocking_read().is_open,
277                failure_count: self.circuit_state.blocking_read().failure_count,
278                last_failure: self.circuit_state.blocking_read().last_failure,
279            }),
280            auth: self.auth.clone(),
281        }
282    }
283}
284
285impl<'a> Flag<'a> {
286    pub async fn enabled(&self) -> bool {
287        self.client.is_enabled(&self.name).await
288    }
289}
290
291pub struct ClientBuilder {
292    base_url: String,
293    max_retries: u32,
294    auth: Option<Auth>,
295    use_memory_cache: bool,
296    file_name: Option<String>,
297}
298
299impl ClientBuilder {
300    fn new() -> Self {
301        Self {
302            base_url: BASE_URL.to_string(),
303            max_retries: MAX_RETRIES,
304            auth: None,
305            use_memory_cache: false,
306            file_name: None,
307        }
308    }
309
310    pub fn with_base_url(mut self, base_url: &str) -> Self {
311        self.base_url = base_url.to_string();
312        self
313    }
314
315    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
316        self.max_retries = max_retries;
317        self
318    }
319
320    pub fn with_auth(mut self, auth: Auth) -> Self {
321        self.auth = Some(auth);
322        self
323    }
324
325    pub fn with_file_name(mut self, file_name: &str) -> Self {
326        self.file_name = Some(file_name.to_string());
327        self
328    }
329
330    pub fn with_memory_cache(mut self) -> Self {
331        self.use_memory_cache = true;
332        self
333    }
334
335    pub fn build(self) -> Client {
336        let cache: Box<dyn Cache + Send + Sync> = if self.use_memory_cache {
337            Box::new(MemoryCache::new())
338        } else {
339            #[cfg(feature = "rusqlite")]
340            {
341                if let Some(file_name) = self.file_name {
342                    Box::new(cache::SqliteCache::new(&file_name))
343                } else {
344                    Box::new(MemoryCache::new())
345                }
346            }
347            #[cfg(not(feature = "rusqlite"))]
348            {
349                Box::new(MemoryCache::new())
350            }
351        };
352
353        Client {
354            base_url: self.base_url,
355            http_client: reqwest::Client::builder()
356                .timeout(Duration::from_secs(10))
357                .build()
358                .unwrap(),
359            cache: Arc::new(RwLock::new(cache)),
360            max_retries: self.max_retries,
361            circuit_state: RwLock::new(CircuitState {
362                is_open: false,
363                failure_count: 0,
364                last_failure: None,
365            }),
366            auth: self.auth,
367        }
368    }
369}
370
371fn build_local() -> Vec<FeatureFlag> {
372    let mut result = Vec::new();
373
374    for (key, value) in env::vars() {
375        if !key.starts_with("FLAGS_") {
376            continue;
377        }
378
379        let enabled = value == "true";
380        let flag_name_env = key.trim_start_matches("FLAGS_").to_string();
381        let flag_name_lower = flag_name_env.to_lowercase();
382
383        // Create a FeatureFlag for the flag name as it appears in the environment variable (lowercase)
384        result.push(FeatureFlag {
385            enabled,
386            details: Details {
387                name: flag_name_lower.clone(),
388                id: format!("local_{}", flag_name_lower), // Using a simple identifier for local flags
389            },
390        });
391
392        // Optionally, create FeatureFlags for common variations (hyphens and spaces)
393        if flag_name_lower.contains('_') {
394            let flag_name_hyphenated = flag_name_lower.replace('_', "-");
395            result.push(FeatureFlag {
396                enabled,
397                details: Details {
398                    name: flag_name_hyphenated.clone(),
399                    id: format!("local_{}", flag_name_hyphenated),
400                },
401            });
402        }
403
404        if flag_name_lower.contains('_') || flag_name_lower.contains('-') {
405            let flag_name_spaced = flag_name_lower.replace('_', " ").replace('-', " ");
406            result.push(FeatureFlag {
407                enabled,
408                details: Details {
409                    name: flag_name_spaced.clone(),
410                    id: format!("local_{}", flag_name_spaced),
411                },
412            });
413        }
414
415    }
416
417    result
418}
419