1use std::collections::HashMap;
2use std::env;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5use tokio::sync::RwLock;
6use std::time::Duration;
7
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use reqwest::header::{HeaderMap, HeaderValue};
11use serde::Deserialize;
12use thiserror::Error;
13
14pub mod cache;
15pub mod flag;
16mod tests;
17
18#[cfg(feature = "tower-middleware")]
19pub mod middleware;
20
21#[cfg(all(test, feature = "tower-middleware"))]
22mod middleware_tests;
23
24use crate::cache::{Cache, MemoryCache};
25use crate::flag::{Details, FeatureFlag};
26
27const BASE_URL: &str = "https://api.flags.gg";
28const MAX_RETRIES: u32 = 3;
29
30#[derive(Debug, Clone)]
31pub struct Auth {
32 pub project_id: String,
33 pub agent_id: String,
34 pub environment_id: String,
35}
36
37pub struct Flag<'a> {
38 name: String,
39 client: &'a Client,
40}
41
42#[derive(Debug, Error)]
43pub enum FlagError {
44 #[error("HTTP error: {0}")]
45 HttpError(#[from] reqwest::Error),
46
47 #[error("Cache error: {0}")]
48 CacheError(String),
49
50 #[error("Missing authentication: {0}")]
51 AuthError(String),
52
53 #[error("API error: {0}")]
54 ApiError(String),
55
56 #[error("Builder error: {0}")]
57 BuilderError(String),
58}
59
60#[derive(Debug)]
61struct CircuitState {
62 is_open: bool,
63 failure_count: u32,
64 last_failure: Option<DateTime<Utc>>,
65}
66
67#[derive(Debug, Deserialize)]
68struct ApiResponse {
69 #[serde(rename = "intervalAllowed")]
70 interval_allowed: i32,
71 flags: Vec<flag::FeatureFlag>,
72}
73
74pub type ErrorCallback = Arc<dyn Fn(&FlagError) + Send + Sync>;
75
76pub struct Client {
77 base_url: String,
78 http_client: reqwest::Client,
79 cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
80 max_retries: u32,
81 circuit_state: Arc<RwLock<CircuitState>>,
82 auth: Option<Auth>,
83 refresh_in_progress: Arc<AtomicBool>,
84 error_callback: Option<ErrorCallback>,
85}
86
87impl Client {
88 pub fn builder() -> ClientBuilder {
89 ClientBuilder::new()
90 }
91
92 fn handle_error(&self, error: &FlagError) {
93 if let Some(ref callback) = self.error_callback {
94 callback(error);
95 }
96 }
97
98 pub fn debug_info(&self) -> String {
99 format!(
100 "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
101 self.base_url, self.max_retries, self.auth
102 )
103 }
104
105 pub fn is(&self, name: &str) -> Flag<'_> {
106 Flag {
107 name: name.to_string(),
108 client: self,
109 }
110 }
111
112 pub async fn get_multiple(&self, names: &[&str]) -> HashMap<String, bool> {
127 if self.cache.read().await.should_refresh_cache().await {
129 if self.refresh_in_progress.compare_exchange(
130 false,
131 true,
132 Ordering::SeqCst,
133 Ordering::SeqCst
134 ).is_ok() {
135 if let Err(e) = self.refetch().await {
136 error!("Failed to refetch flags for batch operation: {}", e);
137 self.handle_error(&e);
138 }
139 self.refresh_in_progress.store(false, Ordering::SeqCst);
140 }
141 }
142
143 let cache = self.cache.read().await;
145 let mut results = HashMap::with_capacity(names.len());
146
147 for &name in names {
148 let normalized_name = name.to_lowercase();
149 match cache.get(&normalized_name).await {
150 Ok((enabled, exists)) => {
151 results.insert(name.to_string(), exists && enabled);
152 }
153 Err(_) => {
154 results.insert(name.to_string(), false);
155 }
156 }
157 }
158
159 results
160 }
161
162 pub async fn all_enabled(&self, names: &[&str]) -> bool {
174 if names.is_empty() {
175 return true;
176 }
177
178 let flags = self.get_multiple(names).await;
179 names.iter().all(|&name| flags.get(name).copied().unwrap_or(false))
180 }
181
182 pub async fn any_enabled(&self, names: &[&str]) -> bool {
194 if names.is_empty() {
195 return false;
196 }
197
198 let flags = self.get_multiple(names).await;
199 names.iter().any(|&name| flags.get(name).copied().unwrap_or(false))
200 }
201
202 pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
203 if self.cache.read().await.should_refresh_cache().await {
205 if self.refresh_in_progress.compare_exchange(
207 false,
208 true,
209 Ordering::SeqCst,
210 Ordering::SeqCst
211 ).is_ok() {
212 if let Err(e) = self.refetch().await {
214 error!("Failed to refetch flags for list: {}", e);
215 self.handle_error(&e);
216 }
217 self.refresh_in_progress.store(false, Ordering::SeqCst);
219 }
220 }
222
223 let cache = self.cache.read().await;
224 cache.get_all().await
225 .map_err(|e| FlagError::CacheError(e.to_string()))
226 }
227
228 async fn is_enabled(&self, name: &str) -> bool {
229 let name = name.to_lowercase();
230
231 if self.cache.read().await.should_refresh_cache().await {
233 if self.refresh_in_progress.compare_exchange(
235 false,
236 true,
237 Ordering::SeqCst,
238 Ordering::SeqCst
239 ).is_ok() {
240 if let Err(e) = self.refetch().await {
242 error!("Failed to refetch flags: {}", e);
243 self.handle_error(&e);
244 }
245 self.refresh_in_progress.store(false, Ordering::SeqCst);
247 }
248 }
250
251 let cache = self.cache.read().await;
253 match cache.get(&name).await {
254 Ok((enabled, exists)) => {
255 if exists {
256 enabled
257 } else {
258 false
259 }
260 }
261 Err(_) => false, }
263 }
264
265 async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
266 let auth = match &self.auth {
267 Some(auth) => auth,
268 None => return Err(FlagError::AuthError("Authentication is required".to_string())),
269 };
270
271 let mut headers = HeaderMap::new();
272 headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
273 headers.insert("Accept", HeaderValue::from_static("application/json"));
274 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
275 headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id)
276 .map_err(|_| FlagError::AuthError(format!("Invalid project ID: {}", auth.project_id)))?);
277 headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id)
278 .map_err(|_| FlagError::AuthError(format!("Invalid agent ID: {}", auth.agent_id)))?);
279 headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id)
280 .map_err(|_| FlagError::AuthError(format!("Invalid environment ID: {}", auth.environment_id)))?);
281
282 let url = format!("{}/flags", self.base_url);
283 let response = self.http_client
284 .get(&url)
285 .headers(headers)
286 .send()
287 .await?;
288
289 if !response.status().is_success() {
290 return Err(FlagError::ApiError(format!(
291 "Unexpected status code: {}",
292 response.status()
293 )));
294 }
295
296 let api_resp = response.json::<ApiResponse>().await?;
297 Ok(api_resp)
298 }
299
300 async fn refetch(&self) -> Result<(), FlagError> {
301 if self.auth.is_none() {
303 let local_flags = build_local();
304 let mut cache = self.cache.write().await;
305 cache
307 .refresh(&local_flags, 60)
308 .await
309 .map_err(|e| FlagError::CacheError(e.to_string()))?;
310 return Ok(());
311 }
312
313 let mut circuit_state = self.circuit_state.write().await;
314
315 if circuit_state.is_open {
316 if let Some(last_failure) = circuit_state.last_failure {
317 let now = Utc::now();
318 if (now - last_failure).num_seconds() < 10 { warn!("Circuit breaker is open, skipping refetch.");
321 return Ok(());
322 }
323 }
324 warn!("Attempting to close circuit breaker.");
326 circuit_state.is_open = false;
327 circuit_state.failure_count = 0;
328 }
329 drop(circuit_state); let api_resp = {
334 let max = self.max_retries.max(1);
335 let mut attempt: u32 = 1;
336 loop {
337 match self.fetch_flags().await {
338 Ok(resp) => {
339 let mut circuit_state = self.circuit_state.write().await;
340 circuit_state.failure_count = 0; break resp;
342 }
343 Err(e) => {
344 if attempt < max {
345 warn!("Refetch failed (attempt {}/{}), retrying...", attempt, max);
346 self.handle_error(&e);
347 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
348 attempt += 1;
349 continue;
350 }
351 let mut cs = self.circuit_state.write().await;
353 cs.failure_count += 1;
354 cs.last_failure = Some(Utc::now());
355 if cs.failure_count >= self.max_retries.max(1) {
356 }
360 error!("Refetch failed after {} internal retries: {}", max, e);
361 self.handle_error(&e);
362 drop(cs);
363 let local_flags = build_local();
365 let mut cache = self.cache.write().await;
366 cache
367 .refresh(&local_flags, 60)
368 .await
369 .map_err(|e| FlagError::CacheError(e.to_string()))?;
370 return Err(e);
372 }
373 }
374 }
375 };
376
377 let mut api_flags: Vec<flag::FeatureFlag> = api_resp.flags
378 .into_iter()
379 .map(|f| flag::FeatureFlag {
380 enabled: f.enabled,
381 details: flag::Details {
382 name: f.details.name.to_lowercase(),
383 id: f.details.id,
384 },
385 })
386 .collect();
387
388 let local_flags = build_local();
389
390 let mut combined_flags = Vec::new();
392 let mut local_flags_map: HashMap<String, FeatureFlag> = local_flags.into_iter().map(|f| (f.details.name.clone(), f)).collect();
393
394 for api_flag in api_flags.drain(..) {
395 if let Some(local_flag) = local_flags_map.remove(&api_flag.details.name) {
396 combined_flags.push(local_flag);
398 } else {
399 combined_flags.push(api_flag);
401 }
402 }
403
404 combined_flags.extend(local_flags_map.into_values());
406
407
408 let mut cache = self.cache.write().await;
409 cache.refresh(&combined_flags, api_resp.interval_allowed).await
410 .map_err(|e| FlagError::CacheError(e.to_string()))?;
411
412 Ok(())
413 }
414}
415
416impl Clone for Client {
417 fn clone(&self) -> Self {
418 Client {
419 base_url: self.base_url.clone(),
420 http_client: self.http_client.clone(),
421 cache: Arc::clone(&self.cache),
422 max_retries: self.max_retries,
423 circuit_state: Arc::clone(&self.circuit_state),
424 auth: self.auth.clone(),
425 refresh_in_progress: Arc::clone(&self.refresh_in_progress),
426 error_callback: self.error_callback.clone(),
427 }
428 }
429}
430
431impl<'a> Flag<'a> {
432 pub async fn enabled(&self) -> bool {
433 self.client.is_enabled(&self.name).await
434 }
435}
436
437pub struct ClientBuilder {
438 base_url: String,
439 max_retries: u32,
440 auth: Option<Auth>,
441 use_memory_cache: bool,
442 file_name: Option<String>,
443 error_callback: Option<ErrorCallback>,
444}
445
446impl ClientBuilder {
447 fn new() -> Self {
448 Self {
449 base_url: BASE_URL.to_string(),
450 max_retries: MAX_RETRIES,
451 auth: None,
452 use_memory_cache: false,
453 file_name: None,
454 error_callback: None,
455 }
456 }
457
458 pub fn with_error_callback<F>(mut self, callback: F) -> Self
472 where
473 F: Fn(&FlagError) + Send + Sync + 'static,
474 {
475 self.error_callback = Some(Arc::new(callback));
476 self
477 }
478
479 pub fn with_base_url(mut self, base_url: &str) -> Self {
480 self.base_url = base_url.to_string();
481 self
482 }
483
484 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
485 self.max_retries = max_retries;
486 self
487 }
488
489 pub fn with_auth(mut self, auth: Auth) -> Self {
490 self.auth = Some(auth);
491 self
492 }
493
494 pub fn with_file_name(mut self, file_name: &str) -> Self {
495 self.file_name = Some(file_name.to_string());
496 self
497 }
498
499 pub fn with_memory_cache(mut self) -> Self {
500 self.use_memory_cache = true;
501 self
502 }
503
504 pub fn build(self) -> Result<Client, FlagError> {
505 if let Some(ref auth) = self.auth {
507 if auth.project_id.trim().is_empty() {
508 return Err(FlagError::BuilderError("Project ID cannot be empty".to_string()));
509 }
510 if auth.agent_id.trim().is_empty() {
511 return Err(FlagError::BuilderError("Agent ID cannot be empty".to_string()));
512 }
513 if auth.environment_id.trim().is_empty() {
514 return Err(FlagError::BuilderError("Environment ID cannot be empty".to_string()));
515 }
516 }
517
518 if self.base_url.trim().is_empty() {
520 return Err(FlagError::BuilderError("Base URL cannot be empty".to_string()));
521 }
522
523 if self.max_retries > 10 {
525 return Err(FlagError::BuilderError("Max retries cannot exceed 10".to_string()));
526 }
527
528 let cache: Box<dyn Cache + Send + Sync> = Box::new(MemoryCache::new());
529
530 let http_client = reqwest::Client::builder()
531 .timeout(Duration::from_secs(10))
532 .build()
533 .map_err(|e| FlagError::BuilderError(format!("Failed to build HTTP client: {}", e)))?;
534
535 Ok(Client {
536 base_url: self.base_url,
537 http_client,
538 cache: Arc::new(RwLock::new(cache)),
539 max_retries: self.max_retries,
540 circuit_state: Arc::new(RwLock::new(CircuitState {
541 is_open: false,
542 failure_count: 0,
543 last_failure: None,
544 })),
545 auth: self.auth,
546 refresh_in_progress: Arc::new(AtomicBool::new(false)),
547 error_callback: self.error_callback,
548 })
549 }
550}
551
552fn build_local() -> Vec<FeatureFlag> {
553 let mut result = Vec::new();
554
555 for (key, value) in env::vars() {
556 if !key.starts_with("FLAGS_") {
557 continue;
558 }
559
560 let enabled = value == "true";
561 let flag_name_env = key.trim_start_matches("FLAGS_").to_string();
562 let flag_name_lower = flag_name_env.to_lowercase();
563
564 result.push(FeatureFlag {
566 enabled,
567 details: Details {
568 name: flag_name_lower.clone(),
569 id: format!("local_{}", flag_name_lower), },
571 });
572
573 if flag_name_lower.contains('_') {
575 let flag_name_hyphenated = flag_name_lower.replace('_', "-");
576 result.push(FeatureFlag {
577 enabled,
578 details: Details {
579 name: flag_name_hyphenated.clone(),
580 id: format!("local_{}", flag_name_hyphenated),
581 },
582 });
583 }
584
585 if flag_name_lower.contains('_') || flag_name_lower.contains('-') {
586 let flag_name_spaced = flag_name_lower.replace('_', " ").replace('-', " ");
587 result.push(FeatureFlag {
588 enabled,
589 details: Details {
590 name: flag_name_spaced.clone(),
591 id: format!("local_{}", flag_name_spaced),
592 },
593 });
594 }
595
596 }
597
598 result
599}
600