1use std::collections::HashMap;
2use std::env;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use log::__private_api::loc;
11use reqwest::header::{HeaderMap, HeaderValue};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15pub mod cache;
16pub mod flag;
17mod tests;
18
19use crate::cache::{Cache, CacheSystem, MemoryCache};
20use crate::flag::{Details, FeatureFlag};
21
22const BASE_URL: &str = "https://api.flags.gg";
23const MAX_RETRIES: u32 = 3;
24
25#[derive(Debug, Clone)]
26pub struct Auth {
27 pub project_id: String,
28 pub agent_id: String,
29 pub environment_id: String,
30}
31
32pub struct Flag<'a> {
33 name: String,
34 client: &'a Client,
35}
36
37#[derive(Debug, Error)]
38pub enum FlagError {
39 #[error("HTTP error: {0}")]
40 HttpError(#[from] reqwest::Error),
41
42 #[error("Cache error: {0}")]
43 CacheError(String),
44
45 #[error("Missing authentication: {0}")]
46 AuthError(String),
47
48 #[error("API error: {0}")]
49 ApiError(String),
50}
51
52#[derive(Debug)]
53struct CircuitState {
54 is_open: bool,
55 failure_count: u32,
56 last_failure: Option<DateTime<Utc>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ApiResponse {
61 #[serde(rename = "intervalAllowed")]
62 interval_allowed: i32,
63 flags: Vec<flag::FeatureFlag>,
64}
65
66pub struct Client {
67 base_url: String,
68 http_client: reqwest::Client,
69 cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
70 max_retries: u32,
71 circuit_state: RwLock<CircuitState>,
72 auth: Option<Auth>,
73}
74
75impl Client {
76 pub fn builder() -> ClientBuilder {
77 ClientBuilder::new()
78 }
79
80 pub fn debug_info(&self) -> String {
81 format!(
82 "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
83 self.base_url, self.max_retries, self.auth
84 )
85 }
86
87 pub fn is(&self, name: &str) -> Flag {
88 Flag {
89 name: name.to_string(),
90 client: self,
91 }
92 }
93
94 pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
95 {
97 let cache = self.cache.read().await;
98 if cache.should_refresh_cache().await {
99 drop(cache); if let Err(e) = self.refetch().await {
101 error!("Failed to refetch flags for list: {}", e);
102 }
104 }
105 }
106
107 let cache = self.cache.read().await;
108 cache.get_all().await
109 .map_err(|e| FlagError::CacheError(e.to_string()))
110 }
111
112 async fn is_enabled(&self, name: &str) -> bool {
113 let name = name.to_lowercase();
114
115 {
117 let cache = self.cache.read().await;
118 if cache.should_refresh_cache().await {
119 drop(cache); if let Err(e) = self.refetch().await {
121 error!("Failed to refetch flags: {}", e);
122 }
124 }
125 }
126
127 let cache = self.cache.read().await;
129 match cache.get(&name).await {
130 Ok((enabled, exists)) => {
131 if exists {
132 enabled
133 } else {
134 false
135 }
136 }
137 Err(_) => false, }
139 }
140
141 async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
142 let auth = match &self.auth {
143 Some(auth) => auth,
144 None => return Err(FlagError::AuthError("Authentication is required".to_string())),
145 };
146
147 if auth.project_id.is_empty() {
148 return Err(FlagError::AuthError("Project ID is required".to_string()));
149 }
150 if auth.agent_id.is_empty() {
151 return Err(FlagError::AuthError("Agent ID is required".to_string()));
152 }
153 if auth.environment_id.is_empty() {
154 return Err(FlagError::AuthError("Environment ID is required".to_string()));
155 }
156
157 let mut headers = HeaderMap::new();
158 headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
159 headers.insert("Accept", HeaderValue::from_static("application/json"));
160 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
161 headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id).unwrap());
162 headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id).unwrap());
163 headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id).unwrap());
164
165 let url = format!("{}/flags", self.base_url);
166 let response = self.http_client
167 .get(&url)
168 .headers(headers)
169 .send()
170 .await?;
171
172 if !response.status().is_success() {
173 return Err(FlagError::ApiError(format!(
174 "Unexpected status code: {}",
175 response.status()
176 )));
177 }
178
179 let api_resp = response.json::<ApiResponse>().await?;
180 Ok(api_resp)
181 }
182
183 async fn refetch(&self) -> Result<(), FlagError> {
184 let mut circuit_state = self.circuit_state.write().await;
185
186 if circuit_state.is_open {
187 if let Some(last_failure) = circuit_state.last_failure {
188 let now = Utc::now();
189 if (now - last_failure).num_seconds() < 10 { warn!("Circuit breaker is open, skipping refetch.");
192 return Ok(());
193 }
194 }
195 warn!("Attempting to close circuit breaker.");
197 circuit_state.is_open = false;
198 circuit_state.failure_count = 0;
199 }
200 drop(circuit_state); let api_resp = match self.fetch_flags().await {
203 Ok(resp) => {
204 let mut circuit_state = self.circuit_state.write().await;
205 circuit_state.failure_count = 0; resp
207 }
208 Err(e) => {
209 let mut circuit_state = self.circuit_state.write().await;
210 circuit_state.failure_count += 1;
211 circuit_state.last_failure = Some(Utc::now());
212 if circuit_state.failure_count >= self.max_retries {
213 circuit_state.is_open = true;
214 error!("Refetch failed after {} retries, opening circuit breaker: {}", self.max_retries, e);
215 } else {
216 warn!("Refetch failed (attempt {}/{}), retrying: {}", circuit_state.failure_count, self.max_retries, e);
217 }
218 drop(circuit_state); let local_flags = build_local(); let mut cache = self.cache.write().await;
222 cache.refresh(&local_flags, 60).await .map_err(|e| FlagError::CacheError(e.to_string()))?;
225 return Err(e); }
227 };
228
229 let mut api_flags: Vec<flag::FeatureFlag> = api_resp.flags
230 .into_iter()
231 .map(|f| flag::FeatureFlag {
232 enabled: f.enabled,
233 details: flag::Details {
234 name: f.details.name.to_lowercase(),
235 id: f.details.id,
236 },
237 })
238 .collect();
239
240 let local_flags = build_local();
241
242 let mut combined_flags = Vec::new();
244 let mut local_flags_map: HashMap<String, FeatureFlag> = local_flags.into_iter().map(|f| (f.details.name.clone(), f)).collect();
245
246 for api_flag in api_flags.drain(..) {
247 if let Some(local_flag) = local_flags_map.remove(&api_flag.details.name) {
248 combined_flags.push(local_flag);
250 } else {
251 combined_flags.push(api_flag);
253 }
254 }
255
256 combined_flags.extend(local_flags_map.into_values());
258
259
260 let mut cache = self.cache.write().await;
261 cache.refresh(&combined_flags, api_resp.interval_allowed).await
262 .map_err(|e| FlagError::CacheError(e.to_string()))?;
263
264 Ok(())
265 }
266}
267
268impl Clone for Client {
269 fn clone(&self) -> Self {
270 Client {
271 base_url: self.base_url.clone(),
272 http_client: self.http_client.clone(),
273 cache: Arc::clone(&self.cache),
274 max_retries: self.max_retries,
275 circuit_state: RwLock::new(CircuitState{
276 is_open: self.circuit_state.blocking_read().is_open,
277 failure_count: self.circuit_state.blocking_read().failure_count,
278 last_failure: self.circuit_state.blocking_read().last_failure,
279 }),
280 auth: self.auth.clone(),
281 }
282 }
283}
284
285impl<'a> Flag<'a> {
286 pub async fn enabled(&self) -> bool {
287 self.client.is_enabled(&self.name).await
288 }
289}
290
291pub struct ClientBuilder {
292 base_url: String,
293 max_retries: u32,
294 auth: Option<Auth>,
295 use_memory_cache: bool,
296 file_name: Option<String>,
297}
298
299impl ClientBuilder {
300 fn new() -> Self {
301 Self {
302 base_url: BASE_URL.to_string(),
303 max_retries: MAX_RETRIES,
304 auth: None,
305 use_memory_cache: false,
306 file_name: None,
307 }
308 }
309
310 pub fn with_base_url(mut self, base_url: &str) -> Self {
311 self.base_url = base_url.to_string();
312 self
313 }
314
315 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
316 self.max_retries = max_retries;
317 self
318 }
319
320 pub fn with_auth(mut self, auth: Auth) -> Self {
321 self.auth = Some(auth);
322 self
323 }
324
325 pub fn with_file_name(mut self, file_name: &str) -> Self {
326 self.file_name = Some(file_name.to_string());
327 self
328 }
329
330 pub fn with_memory_cache(mut self) -> Self {
331 self.use_memory_cache = true;
332 self
333 }
334
335 pub fn build(self) -> Client {
336 let cache: Box<dyn Cache + Send + Sync> = if self.use_memory_cache {
337 Box::new(MemoryCache::new())
338 } else {
339 #[cfg(feature = "rusqlite")]
340 {
341 if let Some(file_name) = self.file_name {
342 Box::new(cache::SqliteCache::new(&file_name))
343 } else {
344 Box::new(MemoryCache::new())
345 }
346 }
347 #[cfg(not(feature = "rusqlite"))]
348 {
349 Box::new(MemoryCache::new())
350 }
351 };
352
353 Client {
354 base_url: self.base_url,
355 http_client: reqwest::Client::builder()
356 .timeout(Duration::from_secs(10))
357 .build()
358 .unwrap(),
359 cache: Arc::new(RwLock::new(cache)),
360 max_retries: self.max_retries,
361 circuit_state: RwLock::new(CircuitState {
362 is_open: false,
363 failure_count: 0,
364 last_failure: None,
365 }),
366 auth: self.auth,
367 }
368 }
369}
370
371fn build_local() -> Vec<FeatureFlag> {
372 let mut result = Vec::new();
373
374 for (key, value) in env::vars() {
375 if !key.starts_with("FLAGS_") {
376 continue;
377 }
378
379 let enabled = value == "true";
380 let flag_name_env = key.trim_start_matches("FLAGS_").to_string();
381 let flag_name_lower = flag_name_env.to_lowercase();
382
383 result.push(FeatureFlag {
385 enabled,
386 details: Details {
387 name: flag_name_lower.clone(),
388 id: format!("local_{}", flag_name_lower), },
390 });
391
392 if flag_name_lower.contains('_') {
394 let flag_name_hyphenated = flag_name_lower.replace('_', "-");
395 result.push(FeatureFlag {
396 enabled,
397 details: Details {
398 name: flag_name_hyphenated.clone(),
399 id: format!("local_{}", flag_name_hyphenated),
400 },
401 });
402 }
403
404 if flag_name_lower.contains('_') || flag_name_lower.contains('-') {
405 let flag_name_spaced = flag_name_lower.replace('_', " ").replace('-', " ");
406 result.push(FeatureFlag {
407 enabled,
408 details: Details {
409 name: flag_name_spaced.clone(),
410 id: format!("local_{}", flag_name_spaced),
411 },
412 });
413 }
414
415 }
416
417 result
418}
419