1use std::collections::HashMap;
3use std::env;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use log::__private_api::loc;
11use reqwest::header::{HeaderMap, HeaderValue};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15pub mod cache;
16pub mod flag;
17mod tests;
18
19use crate::cache::{Cache, CacheSystem, MemoryCache};
20use crate::flag::{Details, FeatureFlag};
21
22const BASE_URL: &str = "https://api.flags.gg";
23const MAX_RETRIES: u32 = 3;
24
25#[derive(Debug, Clone)]
26pub struct Auth {
27 pub project_id: String,
28 pub agent_id: String,
29 pub environment_id: String,
30}
31
32pub struct Flag<'a> {
33 name: String,
34 client: &'a Client,
35}
36
37#[derive(Debug, Error)]
38pub enum FlagError {
39 #[error("HTTP error: {0}")]
40 HttpError(#[from] reqwest::Error),
41
42 #[error("Cache error: {0}")]
43 CacheError(String),
44
45 #[error("Missing authentication: {0}")]
46 AuthError(String),
47
48 #[error("API error: {0}")]
49 ApiError(String),
50}
51
52#[derive(Debug)]
53struct CircuitState {
54 is_open: bool,
55 failure_count: u32,
56 last_failure: Option<DateTime<Utc>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ApiResponse {
61 #[serde(rename = "intervalAllowed")]
62 interval_allowed: i32,
63 flags: Vec<flag::FeatureFlag>,
64}
65
66pub struct Client {
67 base_url: String,
68 http_client: reqwest::Client,
69 cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
70 max_retries: u32,
71 circuit_state: RwLock<CircuitState>,
72 auth: Option<Auth>,
73}
74
75impl Client {
76 pub fn builder() -> ClientBuilder {
77 ClientBuilder::new()
78 }
79
80 pub fn debug_info(&self) -> String {
81 format!(
82 "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
83 self.base_url, self.max_retries, self.auth
84 )
85 }
86
87 pub fn is(&self, name: &str) -> Flag {
88 Flag {
89 name: name.to_string(),
90 client: self,
91 }
92 }
93
94 pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
95 {
97 let cache = self.cache.read().unwrap();
98 if cache.should_refresh_cache().await {
99 drop(cache); if let Err(e) = self.refetch().await {
101 error!("Failed to refetch flags for list: {}", e);
102 }
104 }
105 }
106
107 let cache = self.cache.read().unwrap();
108 cache.get_all().await
109 .map_err(|e| FlagError::CacheError(e.to_string()))
110 }
111
112 async fn is_enabled(&self, name: &str) -> bool {
113 let name = name.to_lowercase();
114
115 {
117 let cache = self.cache.read().unwrap();
118 if cache.should_refresh_cache().await {
119 drop(cache); if let Err(e) = self.refetch().await {
121 error!("Failed to refetch flags: {}", e);
122 }
124 }
125 }
126
127 let cache = self.cache.read().unwrap();
129 match cache.get(&name).await {
130 Ok((enabled, exists)) => {
131 if exists {
132 enabled
133 } else {
134 false
135 }
136 }
137 Err(_) => false, }
139 }
140
141 async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
142 let auth = match &self.auth {
143 Some(auth) => auth,
144 None => return Err(FlagError::AuthError("Authentication is required".to_string())),
145 };
146
147 if auth.project_id.is_empty() {
148 return Err(FlagError::AuthError("Project ID is required".to_string()));
149 }
150 if auth.agent_id.is_empty() {
151 return Err(FlagError::AuthError("Agent ID is required".to_string()));
152 }
153 if auth.environment_id.is_empty() {
154 return Err(FlagError::AuthError("Environment ID is required".to_string()));
155 }
156
157 let mut headers = HeaderMap::new();
158 headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
159 headers.insert("Accept", HeaderValue::from_static("application/json"));
160 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
161 headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id).unwrap());
162 headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id).unwrap());
163 headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id).unwrap());
164
165 let url = format!("{}/flags", self.base_url);
166 let response = self.http_client
167 .get(&url)
168 .headers(headers)
169 .send()
170 .await?;
171
172 if !response.status().is_success() {
173 return Err(FlagError::ApiError(format!(
174 "Unexpected status code: {}",
175 response.status()
176 )));
177 }
178
179 let api_resp = response.json::<ApiResponse>().await?;
180 Ok(api_resp)
181 }
182
183 async fn refetch(&self) -> Result<(), FlagError> {
184 let mut circuit_state = self.circuit_state.write().unwrap();
185
186 if circuit_state.is_open {
187 if let Some(last_failure) = circuit_state.last_failure {
188 let now = Utc::now();
189 if (now - last_failure).num_seconds() < 10 { warn!("Circuit breaker is open, skipping refetch.");
192 return Ok(());
193 }
194 }
195 warn!("Attempting to close circuit breaker.");
197 circuit_state.is_open = false;
198 circuit_state.failure_count = 0;
199 }
200 drop(circuit_state); let api_resp = match self.fetch_flags().await {
203 Ok(resp) => {
204 let mut circuit_state = self.circuit_state.write().unwrap();
205 circuit_state.failure_count = 0; resp
207 }
208 Err(e) => {
209 let mut circuit_state = self.circuit_state.write().unwrap();
210 circuit_state.failure_count += 1;
211 circuit_state.last_failure = Some(Utc::now());
212 if circuit_state.failure_count >= self.max_retries {
213 circuit_state.is_open = true;
214 error!("Refetch failed after {} retries, opening circuit breaker: {}", self.max_retries, e);
215 } else {
216 warn!("Refetch failed (attempt {}/{}), retrying: {}", circuit_state.failure_count, self.max_retries, e);
217 }
218 drop(circuit_state); let local_flags = build_local(); let mut cache = self.cache.write().unwrap();
222 cache.refresh(&local_flags, 60).await .map_err(|e| FlagError::CacheError(e.to_string()))?;
225 return Err(e); }
227 };
228
229 let mut api_flags: Vec<flag::FeatureFlag> = api_resp.flags
230 .into_iter()
231 .map(|f| flag::FeatureFlag {
232 enabled: f.enabled,
233 details: flag::Details {
234 name: f.details.name.to_lowercase(),
235 id: f.details.id,
236 },
237 })
238 .collect();
239
240 let local_flags = build_local();
241
242 let mut combined_flags = Vec::new();
244 let mut local_flags_map: HashMap<String, FeatureFlag> = local_flags.into_iter().map(|f| (f.details.name.clone(), f)).collect();
245
246 for api_flag in api_flags.drain(..) {
247 if let Some(local_flag) = local_flags_map.remove(&api_flag.details.name) {
248 combined_flags.push(local_flag);
250 } else {
251 combined_flags.push(api_flag);
253 }
254 }
255
256 combined_flags.extend(local_flags_map.into_values());
258
259
260 let mut cache = self.cache.write().unwrap();
261 cache.refresh(&combined_flags, api_resp.interval_allowed).await
262 .map_err(|e| FlagError::CacheError(e.to_string()))?;
263
264 Ok(())
265 }
266}
267
268impl<'a> Flag<'a> {
269 pub async fn enabled(&self) -> bool {
270 self.client.is_enabled(&self.name).await
271 }
272}
273
274pub struct ClientBuilder {
275 base_url: String,
276 max_retries: u32,
277 auth: Option<Auth>,
278 use_memory_cache: bool,
279 file_name: Option<String>,
280}
281
282impl ClientBuilder {
283 fn new() -> Self {
284 Self {
285 base_url: BASE_URL.to_string(),
286 max_retries: MAX_RETRIES,
287 auth: None,
288 use_memory_cache: false,
289 file_name: None,
290 }
291 }
292
293 pub fn with_base_url(mut self, base_url: &str) -> Self {
294 self.base_url = base_url.to_string();
295 self
296 }
297
298 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
299 self.max_retries = max_retries;
300 self
301 }
302
303 pub fn with_auth(mut self, auth: Auth) -> Self {
304 self.auth = Some(auth);
305 self
306 }
307
308 pub fn with_file_name(mut self, file_name: &str) -> Self {
309 self.file_name = Some(file_name.to_string());
310 self
311 }
312
313 pub fn with_memory_cache(mut self) -> Self {
314 self.use_memory_cache = true;
315 self
316 }
317
318 pub fn build(self) -> Client {
319 let cache: Box<dyn Cache + Send + Sync> = if self.use_memory_cache {
320 Box::new(MemoryCache::new())
321 } else {
322 #[cfg(feature = "rusqlite")]
323 {
324 if let Some(file_name) = self.file_name {
325 Box::new(cache::SqliteCache::new(&file_name))
326 } else {
327 Box::new(MemoryCache::new())
328 }
329 }
330 #[cfg(not(feature = "rusqlite"))]
331 {
332 Box::new(MemoryCache::new())
333 }
334 };
335
336 Client {
337 base_url: self.base_url,
338 http_client: reqwest::Client::builder()
339 .timeout(Duration::from_secs(10))
340 .build()
341 .unwrap(),
342 cache: Arc::new(RwLock::new(cache)),
343 max_retries: self.max_retries,
344 circuit_state: RwLock::new(CircuitState {
345 is_open: false,
346 failure_count: 0,
347 last_failure: None,
348 }),
349 auth: self.auth,
350 }
351 }
352}
353
354fn build_local() -> Vec<FeatureFlag> {
355 let mut result = Vec::new();
356
357 for (key, value) in env::vars() {
358 if !key.starts_with("FLAGS_") {
359 continue;
360 }
361
362 let enabled = value == "true";
363 let flag_name_env = key.trim_start_matches("FLAGS_").to_string();
364 let flag_name_lower = flag_name_env.to_lowercase();
365
366 result.push(FeatureFlag {
368 enabled,
369 details: Details {
370 name: flag_name_lower.clone(),
371 id: format!("local_{}", flag_name_lower), },
373 });
374
375 if flag_name_lower.contains('_') {
377 let flag_name_hyphenated = flag_name_lower.replace('_', "-");
378 result.push(FeatureFlag {
379 enabled,
380 details: Details {
381 name: flag_name_hyphenated.clone(),
382 id: format!("local_{}", flag_name_hyphenated),
383 },
384 });
385 }
386
387 if flag_name_lower.contains('_') || flag_name_lower.contains('-') {
388 let flag_name_spaced = flag_name_lower.replace('_', " ").replace('-', " ");
389 result.push(FeatureFlag {
390 enabled,
391 details: Details {
392 name: flag_name_spaced.clone(),
393 id: format!("local_{}", flag_name_spaced),
394 },
395 });
396 }
397
398 }
399
400 result
401}
402