flags_rs/
lib.rs

1// src/lib.rs
2use std::collections::HashMap;
3use std::env;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use log::__private_api::loc;
11use reqwest::header::{HeaderMap, HeaderValue};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15pub mod cache;
16pub mod flag;
17mod tests;
18
19use crate::cache::{Cache, CacheSystem, MemoryCache};
20use crate::flag::{Details, FeatureFlag};
21
22const BASE_URL: &str = "https://api.flags.gg";
23const MAX_RETRIES: u32 = 3;
24
25#[derive(Debug, Clone)]
26pub struct Auth {
27    pub project_id: String,
28    pub agent_id: String,
29    pub environment_id: String,
30}
31
32pub struct Flag<'a> {
33    name: String,
34    client: &'a Client,
35}
36
37#[derive(Debug, Error)]
38pub enum FlagError {
39    #[error("HTTP error: {0}")]
40    HttpError(#[from] reqwest::Error),
41
42    #[error("Cache error: {0}")]
43    CacheError(String),
44
45    #[error("Missing authentication: {0}")]
46    AuthError(String),
47
48    #[error("API error: {0}")]
49    ApiError(String),
50}
51
52#[derive(Debug)]
53struct CircuitState {
54    is_open: bool,
55    failure_count: u32,
56    last_failure: Option<DateTime<Utc>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ApiResponse {
61    #[serde(rename = "intervalAllowed")]
62    interval_allowed: i32,
63    flags: Vec<flag::FeatureFlag>,
64}
65
66pub struct Client {
67    base_url: String,
68    http_client: reqwest::Client,
69    cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
70    max_retries: u32,
71    circuit_state: RwLock<CircuitState>,
72    auth: Option<Auth>,
73}
74
75impl Client {
76    pub fn builder() -> ClientBuilder {
77        ClientBuilder::new()
78    }
79
80    pub fn debug_info(&self) -> String {
81        format!(
82            "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
83            self.base_url, self.max_retries, self.auth
84        )
85    }
86
87    pub fn is(&self, name: &str) -> Flag {
88        Flag {
89            name: name.to_string(),
90            client: self,
91        }
92    }
93
94    pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
95        // Check if cache needs refresh before listing
96        {
97            let cache = self.cache.read().unwrap();
98            if cache.should_refresh_cache().await {
99                drop(cache); // Release the read lock before acquiring write lock
100                if let Err(e) = self.refetch().await {
101                    error!("Failed to refetch flags for list: {}", e);
102                    // Continue with potentially stale cache data if refetch fails
103                }
104            }
105        }
106
107        let cache = self.cache.read().unwrap();
108        cache.get_all().await
109            .map_err(|e| FlagError::CacheError(e.to_string()))
110    }
111
112    async fn is_enabled(&self, name: &str) -> bool {
113        let name = name.to_lowercase();
114
115        // Check if cache needs refresh
116        {
117            let cache = self.cache.read().unwrap();
118            if cache.should_refresh_cache().await {
119                drop(cache); // Release the read lock before acquiring write lock
120                if let Err(e) = self.refetch().await {
121                    error!("Failed to refetch flags: {}", e);
122                    // If refetch fails, continue to check the potentially stale cache.
123                }
124            }
125        }
126
127        // Check cache (which now contains combined API and local flags with overrides)
128        let cache = self.cache.read().unwrap();
129        match cache.get(&name).await {
130            Ok((enabled, exists)) => {
131                if exists {
132                    enabled
133                } else {
134                    false
135                }
136            }
137            Err(_) => false, // Treat cache errors as flag not found
138        }
139    }
140
141    async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
142        let auth = match &self.auth {
143            Some(auth) => auth,
144            None => return Err(FlagError::AuthError("Authentication is required".to_string())),
145        };
146
147        if auth.project_id.is_empty() {
148            return Err(FlagError::AuthError("Project ID is required".to_string()));
149        }
150        if auth.agent_id.is_empty() {
151            return Err(FlagError::AuthError("Agent ID is required".to_string()));
152        }
153        if auth.environment_id.is_empty() {
154            return Err(FlagError::AuthError("Environment ID is required".to_string()));
155        }
156
157        let mut headers = HeaderMap::new();
158        headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
159        headers.insert("Accept", HeaderValue::from_static("application/json"));
160        headers.insert("Content-Type", HeaderValue::from_static("application/json"));
161        headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id).unwrap());
162        headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id).unwrap());
163        headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id).unwrap());
164
165        let url = format!("{}/flags", self.base_url);
166        let response = self.http_client
167            .get(&url)
168            .headers(headers)
169            .send()
170            .await?;
171
172        if !response.status().is_success() {
173            return Err(FlagError::ApiError(format!(
174                "Unexpected status code: {}",
175                response.status()
176            )));
177        }
178
179        let api_resp = response.json::<ApiResponse>().await?;
180        Ok(api_resp)
181    }
182
183    async fn refetch(&self) -> Result<(), FlagError> {
184        let mut circuit_state = self.circuit_state.write().unwrap();
185
186        if circuit_state.is_open {
187            if let Some(last_failure) = circuit_state.last_failure {
188                let now = Utc::now();
189                // Keep the circuit open for a bit after failure
190                if (now - last_failure).num_seconds() < 10 { // You can adjust this duration
191                    warn!("Circuit breaker is open, skipping refetch.");
192                    return Ok(());
193                }
194            }
195            // If enough time has passed, attempt to close the circuit
196            warn!("Attempting to close circuit breaker.");
197            circuit_state.is_open = false;
198            circuit_state.failure_count = 0;
199        }
200        drop(circuit_state); // Release the write lock
201
202        let api_resp = match self.fetch_flags().await {
203            Ok(resp) => {
204                let mut circuit_state = self.circuit_state.write().unwrap();
205                circuit_state.failure_count = 0; // Reset failure count on success
206                resp
207            }
208            Err(e) => {
209                let mut circuit_state = self.circuit_state.write().unwrap();
210                circuit_state.failure_count += 1;
211                circuit_state.last_failure = Some(Utc::now());
212                if circuit_state.failure_count >= self.max_retries {
213                    circuit_state.is_open = true;
214                    error!("Refetch failed after {} retries, opening circuit breaker: {}", self.max_retries, e);
215                } else {
216                    warn!("Refetch failed (attempt {}/{}), retrying: {}", circuit_state.failure_count, self.max_retries, e);
217                }
218                drop(circuit_state); // Release the write lock
219                // If fetching fails, we should still attempt to use local flags and potentially old cache data
220                let local_flags = build_local(); // Build local flags even on API failure
221                let mut cache = self.cache.write().unwrap();
222                // Attempt to refresh cache with only local flags if API failed
223                cache.refresh(&local_flags, 60).await // Use a default interval if API interval is not available
224                    .map_err(|e| FlagError::CacheError(e.to_string()))?;
225                return Err(e); // Propagate the error
226            }
227        };
228
229        let mut api_flags: Vec<flag::FeatureFlag> = api_resp.flags
230            .into_iter()
231            .map(|f| flag::FeatureFlag {
232                enabled: f.enabled,
233                details: flag::Details {
234                    name: f.details.name.to_lowercase(),
235                    id: f.details.id,
236                },
237            })
238            .collect();
239
240        let local_flags = build_local();
241
242        // Combine API flags and local flags, with local overriding API
243        let mut combined_flags = Vec::new();
244        let mut local_flags_map: HashMap<String, FeatureFlag> = local_flags.into_iter().map(|f| (f.details.name.clone(), f)).collect();
245
246        for api_flag in api_flags.drain(..) {
247            if let Some(local_flag) = local_flags_map.remove(&api_flag.details.name) {
248                // Local flag with the same name exists, use the local one
249                combined_flags.push(local_flag);
250            } else {
251                // No local flag with the same name, use the API one
252                combined_flags.push(api_flag);
253            }
254        }
255
256        // Add any remaining local flags that didn't have a corresponding API flag
257        combined_flags.extend(local_flags_map.into_values());
258
259
260        let mut cache = self.cache.write().unwrap();
261        cache.refresh(&combined_flags, api_resp.interval_allowed).await
262            .map_err(|e| FlagError::CacheError(e.to_string()))?;
263
264        Ok(())
265    }
266}
267
268impl<'a> Flag<'a> {
269    pub async fn enabled(&self) -> bool {
270        self.client.is_enabled(&self.name).await
271    }
272}
273
274pub struct ClientBuilder {
275    base_url: String,
276    max_retries: u32,
277    auth: Option<Auth>,
278    use_memory_cache: bool,
279    file_name: Option<String>,
280}
281
282impl ClientBuilder {
283    fn new() -> Self {
284        Self {
285            base_url: BASE_URL.to_string(),
286            max_retries: MAX_RETRIES,
287            auth: None,
288            use_memory_cache: false,
289            file_name: None,
290        }
291    }
292
293    pub fn with_base_url(mut self, base_url: &str) -> Self {
294        self.base_url = base_url.to_string();
295        self
296    }
297
298    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
299        self.max_retries = max_retries;
300        self
301    }
302
303    pub fn with_auth(mut self, auth: Auth) -> Self {
304        self.auth = Some(auth);
305        self
306    }
307
308    pub fn with_file_name(mut self, file_name: &str) -> Self {
309        self.file_name = Some(file_name.to_string());
310        self
311    }
312
313    pub fn with_memory_cache(mut self) -> Self {
314        self.use_memory_cache = true;
315        self
316    }
317
318    pub fn build(self) -> Client {
319        let cache: Box<dyn Cache + Send + Sync> = if self.use_memory_cache {
320            Box::new(MemoryCache::new())
321        } else {
322            #[cfg(feature = "rusqlite")]
323            {
324                if let Some(file_name) = self.file_name {
325                    Box::new(cache::SqliteCache::new(&file_name))
326                } else {
327                    Box::new(MemoryCache::new())
328                }
329            }
330            #[cfg(not(feature = "rusqlite"))]
331            {
332                Box::new(MemoryCache::new())
333            }
334        };
335
336        Client {
337            base_url: self.base_url,
338            http_client: reqwest::Client::builder()
339                .timeout(Duration::from_secs(10))
340                .build()
341                .unwrap(),
342            cache: Arc::new(RwLock::new(cache)),
343            max_retries: self.max_retries,
344            circuit_state: RwLock::new(CircuitState {
345                is_open: false,
346                failure_count: 0,
347                last_failure: None,
348            }),
349            auth: self.auth,
350        }
351    }
352}
353
354fn build_local() -> Vec<FeatureFlag> {
355    let mut result = Vec::new();
356
357    for (key, value) in env::vars() {
358        if !key.starts_with("FLAGS_") {
359            continue;
360        }
361
362        let enabled = value == "true";
363        let flag_name_env = key.trim_start_matches("FLAGS_").to_string();
364        let flag_name_lower = flag_name_env.to_lowercase();
365
366        // Create a FeatureFlag for the flag name as it appears in the environment variable (lowercase)
367        result.push(FeatureFlag {
368            enabled,
369            details: Details {
370                name: flag_name_lower.clone(),
371                id: format!("local_{}", flag_name_lower), // Using a simple identifier for local flags
372            },
373        });
374
375        // Optionally, create FeatureFlags for common variations (hyphens and spaces)
376        if flag_name_lower.contains('_') {
377            let flag_name_hyphenated = flag_name_lower.replace('_', "-");
378            result.push(FeatureFlag {
379                enabled,
380                details: Details {
381                    name: flag_name_hyphenated.clone(),
382                    id: format!("local_{}", flag_name_hyphenated),
383                },
384            });
385        }
386
387        if flag_name_lower.contains('_') || flag_name_lower.contains('-') {
388            let flag_name_spaced = flag_name_lower.replace('_', " ").replace('-', " ");
389            result.push(FeatureFlag {
390                enabled,
391                details: Details {
392                    name: flag_name_spaced.clone(),
393                    id: format!("local_{}", flag_name_spaced),
394                },
395            });
396        }
397
398    }
399
400    result
401}
402