flags_rs/
lib.rs

1// src/lib.rs
2use std::collections::HashMap;
3use std::env;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use log::__private_api::loc;
11use reqwest::header::{HeaderMap, HeaderValue};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15pub mod cache;
16pub mod flag;
17mod tests;
18
19use crate::cache::{Cache, CacheSystem, MemoryCache};
20use crate::flag::{Details, FeatureFlag};
21
22const BASE_URL: &str = "https://api.flags.gg";
23const MAX_RETRIES: u32 = 3;
24
25#[derive(Debug, Clone)]
26pub struct Auth {
27    pub project_id: String,
28    pub agent_id: String,
29    pub environment_id: String,
30}
31
32pub struct Flag<'a> {
33    name: String,
34    client: &'a Client,
35}
36
37#[derive(Debug, Error)]
38pub enum FlagError {
39    #[error("HTTP error: {0}")]
40    HttpError(#[from] reqwest::Error),
41
42    #[error("Cache error: {0}")]
43    CacheError(String),
44
45    #[error("Missing authentication: {0}")]
46    AuthError(String),
47
48    #[error("API error: {0}")]
49    ApiError(String),
50}
51
52#[derive(Debug)]
53struct CircuitState {
54    is_open: bool,
55    failure_count: u32,
56    last_failure: Option<DateTime<Utc>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ApiResponse {
61    #[serde(rename = "intervalAllowed")]
62    interval_allowed: i32,
63    flags: Vec<flag::FeatureFlag>,
64}
65
66pub struct Client {
67    base_url: String,
68    http_client: reqwest::Client,
69    cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
70    max_retries: u32,
71    circuit_state: RwLock<CircuitState>,
72    auth: Option<Auth>,
73}
74
75impl Client {
76    pub fn builder() -> ClientBuilder {
77        ClientBuilder::new()
78    }
79
80    pub fn debug_info(&self) -> String {
81        format!(
82            "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
83            self.base_url, self.max_retries, self.auth
84        )
85    }
86
87    pub fn is(&self, name: &str) -> Flag {
88        Flag {
89            name: name.to_string(),
90            client: self,
91        }
92    }
93
94    pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
95        let cache = self.cache.read().unwrap();
96        let mut flags = cache.get_all().await
97            .map_err(|e| FlagError::CacheError(e.to_string()))?;
98        
99        let local_flags = build_local();
100        for (flag, enabled) in local_flags {
101            flags.push(FeatureFlag {
102                enabled,
103                details: Details {
104                    name: flag.to_string(),
105                    id: format!("local_flag-{}", flag),
106                },
107            })
108        }
109        Ok(flags)
110    }
111
112    async fn is_enabled(&self, name: &str) -> bool {
113        let name = name.to_lowercase();
114
115        // Check if cache needs refresh
116        {
117            let cache = self.cache.read().unwrap();
118            if cache.should_refresh_cache().await {
119                drop(cache); // Release the read lock before acquiring write lock
120                if let Err(e) = self.refetch().await {
121                    error!("Failed to refetch flags: {}", e);
122                    return false;
123                }
124            }
125        }
126
127        // Check local environment variables first
128        let local_flags = build_local();
129        if let Some(&enabled) = local_flags.get(&name) {
130            return enabled;
131        }
132
133        // Check cache
134        let cache = self.cache.read().unwrap();
135        match cache.get(&name).await {
136            Ok((enabled, exists)) => {
137                if exists {
138                    enabled
139                } else {
140                    false
141                }
142            }
143            Err(_) => false,
144        }
145    }
146
147    async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
148        let auth = match &self.auth {
149            Some(auth) => auth,
150            None => return Err(FlagError::AuthError("Authentication is required".to_string())),
151        };
152
153        if auth.project_id.is_empty() {
154            return Err(FlagError::AuthError("Project ID is required".to_string()));
155        }
156        if auth.agent_id.is_empty() {
157            return Err(FlagError::AuthError("Agent ID is required".to_string()));
158        }
159        if auth.environment_id.is_empty() {
160            return Err(FlagError::AuthError("Environment ID is required".to_string()));
161        }
162
163        let mut headers = HeaderMap::new();
164        headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
165        headers.insert("Accept", HeaderValue::from_static("application/json"));
166        headers.insert("Content-Type", HeaderValue::from_static("application/json"));
167        headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id).unwrap());
168        headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id).unwrap());
169        headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id).unwrap());
170
171        let url = format!("{}/flags", self.base_url);
172        let response = self.http_client
173            .get(&url)
174            .headers(headers)
175            .send()
176            .await?;
177
178        if !response.status().is_success() {
179            return Err(FlagError::ApiError(format!(
180                "Unexpected status code: {}",
181                response.status()
182            )));
183        }
184
185        let api_resp = response.json::<ApiResponse>().await?;
186        Ok(api_resp)
187    }
188
189    async fn refetch(&self) -> Result<(), FlagError> {
190        let mut circuit_state = self.circuit_state.write().unwrap();
191
192        if circuit_state.is_open {
193            if let Some(last_failure) = circuit_state.last_failure {
194                let now = Utc::now();
195                if (now - last_failure).num_seconds() < 10 {
196                    return Ok(());
197                }
198            }
199            circuit_state.is_open = false;
200            circuit_state.failure_count = 0;
201        }
202        drop(circuit_state);
203
204        let mut api_resp = None;
205        let mut last_error = None;
206
207        for retry in 0..self.max_retries {
208            match self.fetch_flags().await {
209                Ok(resp) => {
210                    api_resp = Some(resp);
211                    let mut circuit_state = self.circuit_state.write().unwrap();
212                    circuit_state.failure_count = 0;
213                    break;
214                }
215                Err(e) => {
216                    last_error = Some(e);
217                    let mut circuit_state = self.circuit_state.write().unwrap();
218                    circuit_state.failure_count += 1;
219
220                    if circuit_state.failure_count >= self.max_retries {
221                        circuit_state.is_open = true;
222                        circuit_state.last_failure = Some(Utc::now());
223                        return Ok(());
224                    }
225                    drop(circuit_state);
226
227                    tokio::time::sleep(Duration::from_secs((retry + 1) as u64)).await;
228                }
229            }
230        }
231
232        let api_resp = match api_resp {
233            Some(resp) => resp,
234            None => return Err(last_error.unwrap()),
235        };
236
237        let flags: Vec<flag::FeatureFlag> = api_resp.flags
238            .into_iter()
239            .map(|f| flag::FeatureFlag {
240                enabled: f.enabled,
241                details: flag::Details {
242                    name: f.details.name.to_lowercase(),
243                    id: f.details.id,
244                },
245            })
246            .collect();
247
248        let mut cache = self.cache.write().unwrap();
249        cache.refresh(&flags, api_resp.interval_allowed).await
250            .map_err(|e| FlagError::CacheError(e.to_string()))?;
251
252        Ok(())
253    }
254}
255
256impl<'a> Flag<'a> {
257    pub async fn enabled(&self) -> bool {
258        self.client.is_enabled(&self.name).await
259    }
260}
261
262pub struct ClientBuilder {
263    base_url: String,
264    max_retries: u32,
265    auth: Option<Auth>,
266    use_memory_cache: bool,
267    file_name: Option<String>,
268}
269
270impl ClientBuilder {
271    fn new() -> Self {
272        Self {
273            base_url: BASE_URL.to_string(),
274            max_retries: MAX_RETRIES,
275            auth: None,
276            use_memory_cache: false,
277            file_name: None,
278        }
279    }
280
281    pub fn with_base_url(mut self, base_url: &str) -> Self {
282        self.base_url = base_url.to_string();
283        self
284    }
285
286    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
287        self.max_retries = max_retries;
288        self
289    }
290
291    pub fn with_auth(mut self, auth: Auth) -> Self {
292        self.auth = Some(auth);
293        self
294    }
295
296    pub fn with_file_name(mut self, file_name: &str) -> Self {
297        self.file_name = Some(file_name.to_string());
298        self
299    }
300
301    pub fn with_memory_cache(mut self) -> Self {
302        self.use_memory_cache = true;
303        self
304    }
305
306    pub fn build(self) -> Client {
307        let cache: Box<dyn Cache + Send + Sync> = if self.use_memory_cache {
308            Box::new(MemoryCache::new())
309        } else {
310            #[cfg(feature = "rusqlite")]
311            {
312                if let Some(file_name) = self.file_name {
313                    Box::new(cache::SqliteCache::new(&file_name))
314                } else {
315                    Box::new(MemoryCache::new())
316                }
317            }
318            #[cfg(not(feature = "rusqlite"))]
319            {
320                Box::new(MemoryCache::new())
321            }
322        };
323
324        Client {
325            base_url: self.base_url,
326            http_client: reqwest::Client::builder()
327                .timeout(Duration::from_secs(10))
328                .build()
329                .unwrap(),
330            cache: Arc::new(RwLock::new(cache)),
331            max_retries: self.max_retries,
332            circuit_state: RwLock::new(CircuitState {
333                is_open: false,
334                failure_count: 0,
335                last_failure: None,
336            }),
337            auth: self.auth,
338        }
339    }
340}
341
342fn build_local() -> HashMap<String, bool> {
343    let mut result = HashMap::new();
344
345    for (key, value) in env::vars() {
346        if !key.starts_with("FLAGS_") {
347            continue;
348        }
349
350        let enabled = value == "true";
351        let key_lower = key.trim_start_matches("FLAGS_").to_lowercase();
352
353        result.insert(key_lower.clone(), enabled);
354        result.insert(key_lower.replace('_', "-"), enabled);
355        result.insert(key_lower.replace('_', " "), enabled);
356    }
357
358    result
359}