flags_rs/
lib.rs

1use std::collections::HashMap;
2use std::env;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5use tokio::sync::RwLock;
6use std::time::Duration;
7
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use reqwest::header::{HeaderMap, HeaderValue};
11use serde::Deserialize;
12use thiserror::Error;
13
14pub mod cache;
15pub mod flag;
16mod tests;
17
18#[cfg(feature = "tower-middleware")]
19pub mod middleware;
20
21#[cfg(all(test, feature = "tower-middleware"))]
22mod middleware_tests;
23
24use crate::cache::{Cache, MemoryCache};
25use crate::flag::{Details, FeatureFlag};
26
27const BASE_URL: &str = "https://api.flags.gg";
28const MAX_RETRIES: u32 = 3;
29
30#[derive(Debug, Clone)]
31pub struct Auth {
32    pub project_id: String,
33    pub agent_id: String,
34    pub environment_id: String,
35}
36
37pub struct Flag<'a> {
38    name: String,
39    client: &'a Client,
40}
41
42#[derive(Debug, Error)]
43pub enum FlagError {
44    #[error("HTTP error: {0}")]
45    HttpError(#[from] reqwest::Error),
46
47    #[error("Cache error: {0}")]
48    CacheError(String),
49
50    #[error("Missing authentication: {0}")]
51    AuthError(String),
52
53    #[error("API error: {0}")]
54    ApiError(String),
55    
56    #[error("Builder error: {0}")]
57    BuilderError(String),
58}
59
60#[derive(Debug)]
61struct CircuitState {
62    is_open: bool,
63    failure_count: u32,
64    last_failure: Option<DateTime<Utc>>,
65}
66
67#[derive(Debug, Deserialize)]
68struct ApiResponse {
69    #[serde(rename = "intervalAllowed")]
70    interval_allowed: i32,
71    flags: Vec<flag::FeatureFlag>,
72}
73
74pub type ErrorCallback = Arc<dyn Fn(&FlagError) + Send + Sync>;
75
76pub struct Client {
77    base_url: String,
78    http_client: reqwest::Client,
79    cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
80    max_retries: u32,
81    circuit_state: Arc<RwLock<CircuitState>>,
82    auth: Option<Auth>,
83    refresh_in_progress: Arc<AtomicBool>,
84    error_callback: Option<ErrorCallback>,
85}
86
87impl Client {
88    pub fn builder() -> ClientBuilder {
89        ClientBuilder::new()
90    }
91    
92    fn handle_error(&self, error: &FlagError) {
93        if let Some(ref callback) = self.error_callback {
94            callback(error);
95        }
96    }
97
98    pub fn debug_info(&self) -> String {
99        format!(
100            "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
101            self.base_url, self.max_retries, self.auth
102        )
103    }
104
105    pub fn is(&self, name: &str) -> Flag {
106        Flag {
107            name: name.to_string(),
108            client: self,
109        }
110    }
111    
112    /// Get the enabled status of multiple flags at once.
113    /// This is more efficient than checking flags individually as it only
114    /// requires a single cache lock and potential refresh.
115    /// 
116    /// # Example
117    /// ```no_run
118    /// # use flags_rs::Client;
119    /// # async fn example(client: &Client) {
120    /// let flags = client.get_multiple(&["feature-1", "feature-2", "feature-3"]).await;
121    /// for (name, enabled) in flags {
122    ///     println!("{}: {}", name, enabled);
123    /// }
124    /// # }
125    /// ```
126    pub async fn get_multiple(&self, names: &[&str]) -> HashMap<String, bool> {
127        // Ensure cache is refreshed if needed (only once for all flags)
128        if self.cache.read().await.should_refresh_cache().await {
129            if self.refresh_in_progress.compare_exchange(
130                false, 
131                true, 
132                Ordering::SeqCst, 
133                Ordering::SeqCst
134            ).is_ok() {
135                if let Err(e) = self.refetch().await {
136                    error!("Failed to refetch flags for batch operation: {}", e);
137                    self.handle_error(&e);
138                }
139                self.refresh_in_progress.store(false, Ordering::SeqCst);
140            }
141        }
142
143        // Now get all flags with a single cache lock
144        let cache = self.cache.read().await;
145        let mut results = HashMap::with_capacity(names.len());
146        
147        for &name in names {
148            let normalized_name = name.to_lowercase();
149            match cache.get(&normalized_name).await {
150                Ok((enabled, exists)) => {
151                    results.insert(name.to_string(), exists && enabled);
152                }
153                Err(_) => {
154                    results.insert(name.to_string(), false);
155                }
156            }
157        }
158        
159        results
160    }
161    
162    /// Check if all of the specified flags are enabled.
163    /// 
164    /// # Example
165    /// ```no_run
166    /// # use flags_rs::Client;
167    /// # async fn example(client: &Client) {
168    /// if client.all_enabled(&["feature-1", "feature-2"]).await {
169    ///     // Both features are enabled
170    /// }
171    /// # }
172    /// ```
173    pub async fn all_enabled(&self, names: &[&str]) -> bool {
174        if names.is_empty() {
175            return true;
176        }
177        
178        let flags = self.get_multiple(names).await;
179        names.iter().all(|&name| flags.get(name).copied().unwrap_or(false))
180    }
181    
182    /// Check if any of the specified flags are enabled.
183    /// 
184    /// # Example
185    /// ```no_run
186    /// # use flags_rs::Client;
187    /// # async fn example(client: &Client) {
188    /// if client.any_enabled(&["premium-feature", "beta-feature"]).await {
189    ///     // At least one feature is enabled
190    /// }
191    /// # }
192    /// ```
193    pub async fn any_enabled(&self, names: &[&str]) -> bool {
194        if names.is_empty() {
195            return false;
196        }
197        
198        let flags = self.get_multiple(names).await;
199        names.iter().any(|&name| flags.get(name).copied().unwrap_or(false))
200    }
201
202    pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
203        // Check if cache needs refresh and ensure only one refresh happens
204        if self.cache.read().await.should_refresh_cache().await {
205            // Try to acquire the refresh lock
206            if self.refresh_in_progress.compare_exchange(
207                false, 
208                true, 
209                Ordering::SeqCst, 
210                Ordering::SeqCst
211            ).is_ok() {
212                // We got the lock, perform the refresh
213                if let Err(e) = self.refetch().await {
214                    error!("Failed to refetch flags for list: {}", e);
215                    self.handle_error(&e);
216                }
217                // Release the refresh lock
218                self.refresh_in_progress.store(false, Ordering::SeqCst);
219            }
220            // If we didn't get the lock, another thread is refreshing
221        }
222
223        let cache = self.cache.read().await;
224        cache.get_all().await
225            .map_err(|e| FlagError::CacheError(e.to_string()))
226    }
227
228    async fn is_enabled(&self, name: &str) -> bool {
229        let name = name.to_lowercase();
230
231        // Check if cache needs refresh and ensure only one refresh happens
232        if self.cache.read().await.should_refresh_cache().await {
233            // Try to acquire the refresh lock
234            if self.refresh_in_progress.compare_exchange(
235                false, 
236                true, 
237                Ordering::SeqCst, 
238                Ordering::SeqCst
239            ).is_ok() {
240                // We got the lock, perform the refresh
241                if let Err(e) = self.refetch().await {
242                    error!("Failed to refetch flags: {}", e);
243                    self.handle_error(&e);
244                }
245                // Release the refresh lock
246                self.refresh_in_progress.store(false, Ordering::SeqCst);
247            }
248            // If we didn't get the lock, another thread is refreshing
249        }
250
251        // Check cache (which now contains combined API and local flags with overrides)
252        let cache = self.cache.read().await;
253        match cache.get(&name).await {
254            Ok((enabled, exists)) => {
255                if exists {
256                    enabled
257                } else {
258                    false
259                }
260            }
261            Err(_) => false, // Treat cache errors as flag not found
262        }
263    }
264
265    async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
266        let auth = match &self.auth {
267            Some(auth) => auth,
268            None => return Err(FlagError::AuthError("Authentication is required".to_string())),
269        };
270
271        let mut headers = HeaderMap::new();
272        headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
273        headers.insert("Accept", HeaderValue::from_static("application/json"));
274        headers.insert("Content-Type", HeaderValue::from_static("application/json"));
275        headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id)
276            .map_err(|_| FlagError::AuthError(format!("Invalid project ID: {}", auth.project_id)))?);
277        headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id)
278            .map_err(|_| FlagError::AuthError(format!("Invalid agent ID: {}", auth.agent_id)))?);
279        headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id)
280            .map_err(|_| FlagError::AuthError(format!("Invalid environment ID: {}", auth.environment_id)))?);
281
282        let url = format!("{}/flags", self.base_url);
283        let response = self.http_client
284            .get(&url)
285            .headers(headers)
286            .send()
287            .await?;
288
289        if !response.status().is_success() {
290            return Err(FlagError::ApiError(format!(
291                "Unexpected status code: {}",
292                response.status()
293            )));
294        }
295
296        let api_resp = response.json::<ApiResponse>().await?;
297        Ok(api_resp)
298    }
299
300    async fn refetch(&self) -> Result<(), FlagError> {
301        let mut circuit_state = self.circuit_state.write().await;
302
303        if circuit_state.is_open {
304            if let Some(last_failure) = circuit_state.last_failure {
305                let now = Utc::now();
306                // Keep the circuit open for a bit after failure
307                if (now - last_failure).num_seconds() < 10 { // You can adjust this duration
308                    warn!("Circuit breaker is open, skipping refetch.");
309                    return Ok(());
310                }
311            }
312            // If enough time has passed, attempt to close the circuit
313            warn!("Attempting to close circuit breaker.");
314            circuit_state.is_open = false;
315            circuit_state.failure_count = 0;
316        }
317        drop(circuit_state); // Release the write lock
318
319        let api_resp = match self.fetch_flags().await {
320            Ok(resp) => {
321                let mut circuit_state = self.circuit_state.write().await;
322                circuit_state.failure_count = 0; // Reset failure count on success
323                resp
324            }
325            Err(e) => {
326                let mut circuit_state = self.circuit_state.write().await;
327                circuit_state.failure_count += 1;
328                circuit_state.last_failure = Some(Utc::now());
329                if circuit_state.failure_count >= self.max_retries {
330                    circuit_state.is_open = true;
331                    error!("Refetch failed after {} retries, opening circuit breaker: {}", self.max_retries, e);
332                } else {
333                    warn!("Refetch failed (attempt {}/{}), retrying: {}", circuit_state.failure_count, self.max_retries, e);
334                }
335                self.handle_error(&e);
336                drop(circuit_state); // Release the write lock
337                // If fetching fails, we should still attempt to use local flags and potentially old cache data
338                let local_flags = build_local(); // Build local flags even on API failure
339                let mut cache = self.cache.write().await;
340                // Attempt to refresh cache with only local flags if API failed
341                cache.refresh(&local_flags, 60).await // Use a default interval if API interval is not available
342                    .map_err(|e| FlagError::CacheError(e.to_string()))?;
343                return Err(e); // Propagate the error
344            }
345        };
346
347        let mut api_flags: Vec<flag::FeatureFlag> = api_resp.flags
348            .into_iter()
349            .map(|f| flag::FeatureFlag {
350                enabled: f.enabled,
351                details: flag::Details {
352                    name: f.details.name.to_lowercase(),
353                    id: f.details.id,
354                },
355            })
356            .collect();
357
358        let local_flags = build_local();
359
360        // Combine API flags and local flags, with local overriding API
361        let mut combined_flags = Vec::new();
362        let mut local_flags_map: HashMap<String, FeatureFlag> = local_flags.into_iter().map(|f| (f.details.name.clone(), f)).collect();
363
364        for api_flag in api_flags.drain(..) {
365            if let Some(local_flag) = local_flags_map.remove(&api_flag.details.name) {
366                // Local flag with the same name exists, use the local one
367                combined_flags.push(local_flag);
368            } else {
369                // No local flag with the same name, use the API one
370                combined_flags.push(api_flag);
371            }
372        }
373
374        // Add any remaining local flags that didn't have a corresponding API flag
375        combined_flags.extend(local_flags_map.into_values());
376
377
378        let mut cache = self.cache.write().await;
379        cache.refresh(&combined_flags, api_resp.interval_allowed).await
380            .map_err(|e| FlagError::CacheError(e.to_string()))?;
381
382        Ok(())
383    }
384}
385
386impl Clone for Client {
387    fn clone(&self) -> Self {
388        Client {
389            base_url: self.base_url.clone(),
390            http_client: self.http_client.clone(),
391            cache: Arc::clone(&self.cache),
392            max_retries: self.max_retries,
393            circuit_state: Arc::clone(&self.circuit_state),
394            auth: self.auth.clone(),
395            refresh_in_progress: Arc::clone(&self.refresh_in_progress),
396            error_callback: self.error_callback.clone(),
397        }
398    }
399}
400
401impl<'a> Flag<'a> {
402    pub async fn enabled(&self) -> bool {
403        self.client.is_enabled(&self.name).await
404    }
405}
406
407pub struct ClientBuilder {
408    base_url: String,
409    max_retries: u32,
410    auth: Option<Auth>,
411    use_memory_cache: bool,
412    file_name: Option<String>,
413    error_callback: Option<ErrorCallback>,
414}
415
416impl ClientBuilder {
417    fn new() -> Self {
418        Self {
419            base_url: BASE_URL.to_string(),
420            max_retries: MAX_RETRIES,
421            auth: None,
422            use_memory_cache: false,
423            file_name: None,
424            error_callback: None,
425        }
426    }
427    
428    /// Set a callback function that will be called whenever an error occurs.
429    /// This is useful for logging, monitoring, or custom error handling.
430    /// 
431    /// # Example
432    /// ```no_run
433    /// # use flags_rs::{Client, FlagError};
434    /// let client = Client::builder()
435    ///     .with_error_callback(|error| {
436    ///         eprintln!("Flag error occurred: {}", error);
437    ///         // Send to monitoring service, etc.
438    ///     })
439    ///     .build();
440    /// ```
441    pub fn with_error_callback<F>(mut self, callback: F) -> Self 
442    where
443        F: Fn(&FlagError) + Send + Sync + 'static,
444    {
445        self.error_callback = Some(Arc::new(callback));
446        self
447    }
448
449    pub fn with_base_url(mut self, base_url: &str) -> Self {
450        self.base_url = base_url.to_string();
451        self
452    }
453
454    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
455        self.max_retries = max_retries;
456        self
457    }
458
459    pub fn with_auth(mut self, auth: Auth) -> Self {
460        self.auth = Some(auth);
461        self
462    }
463
464    pub fn with_file_name(mut self, file_name: &str) -> Self {
465        self.file_name = Some(file_name.to_string());
466        self
467    }
468
469    pub fn with_memory_cache(mut self) -> Self {
470        self.use_memory_cache = true;
471        self
472    }
473
474    pub fn build(self) -> Result<Client, FlagError> {
475        // Validate auth if provided
476        if let Some(ref auth) = self.auth {
477            if auth.project_id.trim().is_empty() {
478                return Err(FlagError::BuilderError("Project ID cannot be empty".to_string()));
479            }
480            if auth.agent_id.trim().is_empty() {
481                return Err(FlagError::BuilderError("Agent ID cannot be empty".to_string()));
482            }
483            if auth.environment_id.trim().is_empty() {
484                return Err(FlagError::BuilderError("Environment ID cannot be empty".to_string()));
485            }
486        }
487
488        // Validate base URL
489        if self.base_url.trim().is_empty() {
490            return Err(FlagError::BuilderError("Base URL cannot be empty".to_string()));
491        }
492
493        // Validate max retries is reasonable
494        if self.max_retries > 10 {
495            return Err(FlagError::BuilderError("Max retries cannot exceed 10".to_string()));
496        }
497
498        let cache: Box<dyn Cache + Send + Sync> = Box::new(MemoryCache::new());
499
500        let http_client = reqwest::Client::builder()
501            .timeout(Duration::from_secs(10))
502            .build()
503            .map_err(|e| FlagError::BuilderError(format!("Failed to build HTTP client: {}", e)))?;
504
505        Ok(Client {
506            base_url: self.base_url,
507            http_client,
508            cache: Arc::new(RwLock::new(cache)),
509            max_retries: self.max_retries,
510            circuit_state: Arc::new(RwLock::new(CircuitState {
511                is_open: false,
512                failure_count: 0,
513                last_failure: None,
514            })),
515            auth: self.auth,
516            refresh_in_progress: Arc::new(AtomicBool::new(false)),
517            error_callback: self.error_callback,
518        })
519    }
520}
521
522fn build_local() -> Vec<FeatureFlag> {
523    let mut result = Vec::new();
524
525    for (key, value) in env::vars() {
526        if !key.starts_with("FLAGS_") {
527            continue;
528        }
529
530        let enabled = value == "true";
531        let flag_name_env = key.trim_start_matches("FLAGS_").to_string();
532        let flag_name_lower = flag_name_env.to_lowercase();
533
534        // Create a FeatureFlag for the flag name as it appears in the environment variable (lowercase)
535        result.push(FeatureFlag {
536            enabled,
537            details: Details {
538                name: flag_name_lower.clone(),
539                id: format!("local_{}", flag_name_lower), // Using a simple identifier for local flags
540            },
541        });
542
543        // Optionally, create FeatureFlags for common variations (hyphens and spaces)
544        if flag_name_lower.contains('_') {
545            let flag_name_hyphenated = flag_name_lower.replace('_', "-");
546            result.push(FeatureFlag {
547                enabled,
548                details: Details {
549                    name: flag_name_hyphenated.clone(),
550                    id: format!("local_{}", flag_name_hyphenated),
551                },
552            });
553        }
554
555        if flag_name_lower.contains('_') || flag_name_lower.contains('-') {
556            let flag_name_spaced = flag_name_lower.replace('_', " ").replace('-', " ");
557            result.push(FeatureFlag {
558                enabled,
559                details: Details {
560                    name: flag_name_spaced.clone(),
561                    id: format!("local_{}", flag_name_spaced),
562                },
563            });
564        }
565
566    }
567
568    result
569}
570