flags_rs/
lib.rs

1// src/lib.rs
2use std::collections::HashMap;
3use std::env;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use reqwest::header::{HeaderMap, HeaderValue};
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13
14pub mod cache;
15pub mod flag;
16mod tests;
17
18use crate::cache::{Cache, CacheSystem, MemoryCache};
19
20const BASE_URL: &str = "https://api.flags.gg";
21const MAX_RETRIES: u32 = 3;
22
23#[derive(Debug, Clone)]
24pub struct Auth {
25    pub project_id: String,
26    pub agent_id: String,
27    pub environment_id: String,
28}
29
30pub struct Flag<'a> {
31    name: String,
32    client: &'a Client,
33}
34
35#[derive(Debug, Error)]
36pub enum FlagError {
37    #[error("HTTP error: {0}")]
38    HttpError(#[from] reqwest::Error),
39
40    #[error("Cache error: {0}")]
41    CacheError(String),
42
43    #[error("Missing authentication: {0}")]
44    AuthError(String),
45
46    #[error("API error: {0}")]
47    ApiError(String),
48}
49
50#[derive(Debug)]
51struct CircuitState {
52    is_open: bool,
53    failure_count: u32,
54    last_failure: Option<DateTime<Utc>>,
55}
56
57#[derive(Debug, Deserialize)]
58struct ApiResponse {
59    #[serde(rename = "intervalAllowed")]
60    interval_allowed: i32,
61    flags: Vec<flag::FeatureFlag>,
62}
63
64pub struct Client {
65    base_url: String,
66    http_client: reqwest::Client,
67    cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
68    max_retries: u32,
69    circuit_state: RwLock<CircuitState>,
70    auth: Option<Auth>,
71}
72
73impl Client {
74    pub fn builder() -> ClientBuilder {
75        ClientBuilder::new()
76    }
77
78    pub fn debug_info(&self) -> String {
79        format!(
80            "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
81            self.base_url, self.max_retries, self.auth
82        )
83    }
84
85    pub fn is(&self, name: &str) -> Flag {
86        Flag {
87            name: name.to_string(),
88            client: self,
89        }
90    }
91
92    pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
93        let cache = self.cache.read().unwrap();
94        let flags = cache.get_all().await
95            .map_err(|e| FlagError::CacheError(e.to_string()))?;
96        Ok(flags)
97    }
98
99    async fn is_enabled(&self, name: &str) -> bool {
100        let name = name.to_lowercase();
101
102        // Check if cache needs refresh
103        {
104            let cache = self.cache.read().unwrap();
105            if cache.should_refresh_cache().await {
106                drop(cache); // Release the read lock before acquiring write lock
107                if let Err(e) = self.refetch().await {
108                    error!("Failed to refetch flags: {}", e);
109                    return false;
110                }
111            }
112        }
113
114        // Check local environment variables first
115        let local_flags = build_local();
116        if let Some(&enabled) = local_flags.get(&name) {
117            return enabled;
118        }
119
120        // Check cache
121        let cache = self.cache.read().unwrap();
122        match cache.get(&name).await {
123            Ok((enabled, exists)) => {
124                if exists {
125                    enabled
126                } else {
127                    false
128                }
129            }
130            Err(_) => false,
131        }
132    }
133
134    async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
135        let auth = match &self.auth {
136            Some(auth) => auth,
137            None => return Err(FlagError::AuthError("Authentication is required".to_string())),
138        };
139
140        if auth.project_id.is_empty() {
141            return Err(FlagError::AuthError("Project ID is required".to_string()));
142        }
143        if auth.agent_id.is_empty() {
144            return Err(FlagError::AuthError("Agent ID is required".to_string()));
145        }
146        if auth.environment_id.is_empty() {
147            return Err(FlagError::AuthError("Environment ID is required".to_string()));
148        }
149
150        let mut headers = HeaderMap::new();
151        headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
152        headers.insert("Accept", HeaderValue::from_static("application/json"));
153        headers.insert("Content-Type", HeaderValue::from_static("application/json"));
154        headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id).unwrap());
155        headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id).unwrap());
156        headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id).unwrap());
157
158        let url = format!("{}/flags", self.base_url);
159        let response = self.http_client
160            .get(&url)
161            .headers(headers)
162            .send()
163            .await?;
164
165        if !response.status().is_success() {
166            return Err(FlagError::ApiError(format!(
167                "Unexpected status code: {}",
168                response.status()
169            )));
170        }
171
172        let api_resp = response.json::<ApiResponse>().await?;
173        Ok(api_resp)
174    }
175
176    async fn refetch(&self) -> Result<(), FlagError> {
177        let mut circuit_state = self.circuit_state.write().unwrap();
178
179        if circuit_state.is_open {
180            if let Some(last_failure) = circuit_state.last_failure {
181                let now = Utc::now();
182                if (now - last_failure).num_seconds() < 10 {
183                    return Ok(());
184                }
185            }
186            circuit_state.is_open = false;
187            circuit_state.failure_count = 0;
188        }
189        drop(circuit_state);
190
191        let mut api_resp = None;
192        let mut last_error = None;
193
194        for retry in 0..self.max_retries {
195            match self.fetch_flags().await {
196                Ok(resp) => {
197                    api_resp = Some(resp);
198                    let mut circuit_state = self.circuit_state.write().unwrap();
199                    circuit_state.failure_count = 0;
200                    break;
201                }
202                Err(e) => {
203                    last_error = Some(e);
204                    let mut circuit_state = self.circuit_state.write().unwrap();
205                    circuit_state.failure_count += 1;
206
207                    if circuit_state.failure_count >= self.max_retries {
208                        circuit_state.is_open = true;
209                        circuit_state.last_failure = Some(Utc::now());
210                        return Ok(());
211                    }
212                    drop(circuit_state);
213
214                    tokio::time::sleep(Duration::from_secs((retry + 1) as u64)).await;
215                }
216            }
217        }
218
219        let api_resp = match api_resp {
220            Some(resp) => resp,
221            None => return Err(last_error.unwrap()),
222        };
223
224        let flags: Vec<flag::FeatureFlag> = api_resp.flags
225            .into_iter()
226            .map(|f| flag::FeatureFlag {
227                enabled: f.enabled,
228                details: flag::Details {
229                    name: f.details.name.to_lowercase(),
230                    id: f.details.id,
231                },
232            })
233            .collect();
234
235        let mut cache = self.cache.write().unwrap();
236        cache.refresh(&flags, api_resp.interval_allowed).await
237            .map_err(|e| FlagError::CacheError(e.to_string()))?;
238
239        Ok(())
240    }
241}
242
243impl<'a> Flag<'a> {
244    pub async fn enabled(&self) -> bool {
245        self.client.is_enabled(&self.name).await
246    }
247}
248
249pub struct ClientBuilder {
250    base_url: String,
251    max_retries: u32,
252    auth: Option<Auth>,
253    use_memory_cache: bool,
254    file_name: Option<String>,
255}
256
257impl ClientBuilder {
258    fn new() -> Self {
259        Self {
260            base_url: BASE_URL.to_string(),
261            max_retries: MAX_RETRIES,
262            auth: None,
263            use_memory_cache: false,
264            file_name: None,
265        }
266    }
267
268    pub fn with_base_url(mut self, base_url: &str) -> Self {
269        self.base_url = base_url.to_string();
270        self
271    }
272
273    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
274        self.max_retries = max_retries;
275        self
276    }
277
278    pub fn with_auth(mut self, auth: Auth) -> Self {
279        self.auth = Some(auth);
280        self
281    }
282
283    pub fn with_file_name(mut self, file_name: &str) -> Self {
284        self.file_name = Some(file_name.to_string());
285        self
286    }
287
288    pub fn with_memory_cache(mut self) -> Self {
289        self.use_memory_cache = true;
290        self
291    }
292
293    pub fn build(self) -> Client {
294        let cache: Box<dyn Cache + Send + Sync> = if self.use_memory_cache {
295            Box::new(MemoryCache::new())
296        } else {
297            #[cfg(feature = "rusqlite")]
298            {
299                if let Some(file_name) = self.file_name {
300                    Box::new(cache::SqliteCache::new(&file_name))
301                } else {
302                    Box::new(MemoryCache::new())
303                }
304            }
305            #[cfg(not(feature = "rusqlite"))]
306            {
307                Box::new(MemoryCache::new())
308            }
309        };
310
311        Client {
312            base_url: self.base_url,
313            http_client: reqwest::Client::builder()
314                .timeout(Duration::from_secs(10))
315                .build()
316                .unwrap(),
317            cache: Arc::new(RwLock::new(cache)),
318            max_retries: self.max_retries,
319            circuit_state: RwLock::new(CircuitState {
320                is_open: false,
321                failure_count: 0,
322                last_failure: None,
323            }),
324            auth: self.auth,
325        }
326    }
327}
328
329fn build_local() -> HashMap<String, bool> {
330    let mut result = HashMap::new();
331
332    for (key, value) in env::vars() {
333        if !key.starts_with("FLAGS_") {
334            continue;
335        }
336
337        let enabled = value == "true";
338        let key_lower = key.trim_start_matches("FLAGS_").to_lowercase();
339
340        result.insert(key_lower.clone(), enabled);
341        result.insert(key_lower.replace('_', "-"), enabled);
342        result.insert(key_lower.replace('_', " "), enabled);
343    }
344
345    result
346}