1use std::collections::HashMap;
3use std::env;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use log::__private_api::loc;
11use reqwest::header::{HeaderMap, HeaderValue};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15pub mod cache;
16pub mod flag;
17mod tests;
18
19use crate::cache::{Cache, CacheSystem, MemoryCache};
20use crate::flag::{Details, FeatureFlag};
21
22const BASE_URL: &str = "https://api.flags.gg";
23const MAX_RETRIES: u32 = 3;
24
25#[derive(Debug, Clone)]
26pub struct Auth {
27 pub project_id: String,
28 pub agent_id: String,
29 pub environment_id: String,
30}
31
32pub struct Flag<'a> {
33 name: String,
34 client: &'a Client,
35}
36
37#[derive(Debug, Error)]
38pub enum FlagError {
39 #[error("HTTP error: {0}")]
40 HttpError(#[from] reqwest::Error),
41
42 #[error("Cache error: {0}")]
43 CacheError(String),
44
45 #[error("Missing authentication: {0}")]
46 AuthError(String),
47
48 #[error("API error: {0}")]
49 ApiError(String),
50}
51
52#[derive(Debug)]
53struct CircuitState {
54 is_open: bool,
55 failure_count: u32,
56 last_failure: Option<DateTime<Utc>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ApiResponse {
61 #[serde(rename = "intervalAllowed")]
62 interval_allowed: i32,
63 flags: Vec<flag::FeatureFlag>,
64}
65
66pub struct Client {
67 base_url: String,
68 http_client: reqwest::Client,
69 cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
70 max_retries: u32,
71 circuit_state: RwLock<CircuitState>,
72 auth: Option<Auth>,
73}
74
75impl Client {
76 pub fn builder() -> ClientBuilder {
77 ClientBuilder::new()
78 }
79
80 pub fn debug_info(&self) -> String {
81 format!(
82 "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
83 self.base_url, self.max_retries, self.auth
84 )
85 }
86
87 pub fn is(&self, name: &str) -> Flag {
88 Flag {
89 name: name.to_string(),
90 client: self,
91 }
92 }
93
94 pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
95 let cache = self.cache.read().unwrap();
96 let mut flags = cache.get_all().await
97 .map_err(|e| FlagError::CacheError(e.to_string()))?;
98
99 let local_flags = build_local();
100 for (flag, enabled) in local_flags {
101 flags.push(FeatureFlag {
102 enabled,
103 details: Details {
104 name: flag.to_string(),
105 id: format!("local_flag-{}", flag),
106 },
107 })
108 }
109 Ok(flags)
110 }
111
112 async fn is_enabled(&self, name: &str) -> bool {
113 let name = name.to_lowercase();
114
115 {
117 let cache = self.cache.read().unwrap();
118 if cache.should_refresh_cache().await {
119 drop(cache); if let Err(e) = self.refetch().await {
121 error!("Failed to refetch flags: {}", e);
122 return false;
123 }
124 }
125 }
126
127 let local_flags = build_local();
129 if let Some(&enabled) = local_flags.get(&name) {
130 return enabled;
131 }
132
133 let cache = self.cache.read().unwrap();
135 match cache.get(&name).await {
136 Ok((enabled, exists)) => {
137 if exists {
138 enabled
139 } else {
140 false
141 }
142 }
143 Err(_) => false,
144 }
145 }
146
147 async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
148 let auth = match &self.auth {
149 Some(auth) => auth,
150 None => return Err(FlagError::AuthError("Authentication is required".to_string())),
151 };
152
153 if auth.project_id.is_empty() {
154 return Err(FlagError::AuthError("Project ID is required".to_string()));
155 }
156 if auth.agent_id.is_empty() {
157 return Err(FlagError::AuthError("Agent ID is required".to_string()));
158 }
159 if auth.environment_id.is_empty() {
160 return Err(FlagError::AuthError("Environment ID is required".to_string()));
161 }
162
163 let mut headers = HeaderMap::new();
164 headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
165 headers.insert("Accept", HeaderValue::from_static("application/json"));
166 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
167 headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id).unwrap());
168 headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id).unwrap());
169 headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id).unwrap());
170
171 let url = format!("{}/flags", self.base_url);
172 let response = self.http_client
173 .get(&url)
174 .headers(headers)
175 .send()
176 .await?;
177
178 if !response.status().is_success() {
179 return Err(FlagError::ApiError(format!(
180 "Unexpected status code: {}",
181 response.status()
182 )));
183 }
184
185 let api_resp = response.json::<ApiResponse>().await?;
186 Ok(api_resp)
187 }
188
189 async fn refetch(&self) -> Result<(), FlagError> {
190 let mut circuit_state = self.circuit_state.write().unwrap();
191
192 if circuit_state.is_open {
193 if let Some(last_failure) = circuit_state.last_failure {
194 let now = Utc::now();
195 if (now - last_failure).num_seconds() < 10 {
196 return Ok(());
197 }
198 }
199 circuit_state.is_open = false;
200 circuit_state.failure_count = 0;
201 }
202 drop(circuit_state);
203
204 let mut api_resp = None;
205 let mut last_error = None;
206
207 for retry in 0..self.max_retries {
208 match self.fetch_flags().await {
209 Ok(resp) => {
210 api_resp = Some(resp);
211 let mut circuit_state = self.circuit_state.write().unwrap();
212 circuit_state.failure_count = 0;
213 break;
214 }
215 Err(e) => {
216 last_error = Some(e);
217 let mut circuit_state = self.circuit_state.write().unwrap();
218 circuit_state.failure_count += 1;
219
220 if circuit_state.failure_count >= self.max_retries {
221 circuit_state.is_open = true;
222 circuit_state.last_failure = Some(Utc::now());
223 return Ok(());
224 }
225 drop(circuit_state);
226
227 tokio::time::sleep(Duration::from_secs((retry + 1) as u64)).await;
228 }
229 }
230 }
231
232 let api_resp = match api_resp {
233 Some(resp) => resp,
234 None => return Err(last_error.unwrap()),
235 };
236
237 let flags: Vec<flag::FeatureFlag> = api_resp.flags
238 .into_iter()
239 .map(|f| flag::FeatureFlag {
240 enabled: f.enabled,
241 details: flag::Details {
242 name: f.details.name.to_lowercase(),
243 id: f.details.id,
244 },
245 })
246 .collect();
247
248 let mut cache = self.cache.write().unwrap();
249 cache.refresh(&flags, api_resp.interval_allowed).await
250 .map_err(|e| FlagError::CacheError(e.to_string()))?;
251
252 Ok(())
253 }
254}
255
256impl<'a> Flag<'a> {
257 pub async fn enabled(&self) -> bool {
258 self.client.is_enabled(&self.name).await
259 }
260}
261
262pub struct ClientBuilder {
263 base_url: String,
264 max_retries: u32,
265 auth: Option<Auth>,
266 use_memory_cache: bool,
267 file_name: Option<String>,
268}
269
270impl ClientBuilder {
271 fn new() -> Self {
272 Self {
273 base_url: BASE_URL.to_string(),
274 max_retries: MAX_RETRIES,
275 auth: None,
276 use_memory_cache: false,
277 file_name: None,
278 }
279 }
280
281 pub fn with_base_url(mut self, base_url: &str) -> Self {
282 self.base_url = base_url.to_string();
283 self
284 }
285
286 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
287 self.max_retries = max_retries;
288 self
289 }
290
291 pub fn with_auth(mut self, auth: Auth) -> Self {
292 self.auth = Some(auth);
293 self
294 }
295
296 pub fn with_file_name(mut self, file_name: &str) -> Self {
297 self.file_name = Some(file_name.to_string());
298 self
299 }
300
301 pub fn with_memory_cache(mut self) -> Self {
302 self.use_memory_cache = true;
303 self
304 }
305
306 pub fn build(self) -> Client {
307 let cache: Box<dyn Cache + Send + Sync> = if self.use_memory_cache {
308 Box::new(MemoryCache::new())
309 } else {
310 #[cfg(feature = "rusqlite")]
311 {
312 if let Some(file_name) = self.file_name {
313 Box::new(cache::SqliteCache::new(&file_name))
314 } else {
315 Box::new(MemoryCache::new())
316 }
317 }
318 #[cfg(not(feature = "rusqlite"))]
319 {
320 Box::new(MemoryCache::new())
321 }
322 };
323
324 Client {
325 base_url: self.base_url,
326 http_client: reqwest::Client::builder()
327 .timeout(Duration::from_secs(10))
328 .build()
329 .unwrap(),
330 cache: Arc::new(RwLock::new(cache)),
331 max_retries: self.max_retries,
332 circuit_state: RwLock::new(CircuitState {
333 is_open: false,
334 failure_count: 0,
335 last_failure: None,
336 }),
337 auth: self.auth,
338 }
339 }
340}
341
342fn build_local() -> HashMap<String, bool> {
343 let mut result = HashMap::new();
344
345 for (key, value) in env::vars() {
346 if !key.starts_with("FLAGS_") {
347 continue;
348 }
349
350 let enabled = value == "true";
351 let key_lower = key.trim_start_matches("FLAGS_").to_lowercase();
352
353 result.insert(key_lower.clone(), enabled);
354 result.insert(key_lower.replace('_', "-"), enabled);
355 result.insert(key_lower.replace('_', " "), enabled);
356 }
357
358 result
359}