1use std::collections::HashMap;
3use std::env;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use log::{error, warn};
10use reqwest::header::{HeaderMap, HeaderValue};
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13
14pub mod cache;
15pub mod flag;
16mod tests;
17
18use crate::cache::{Cache, CacheSystem, MemoryCache};
19
20const BASE_URL: &str = "https://api.flags.gg";
21const MAX_RETRIES: u32 = 3;
22
23#[derive(Debug, Clone)]
24pub struct Auth {
25 pub project_id: String,
26 pub agent_id: String,
27 pub environment_id: String,
28}
29
30pub struct Flag<'a> {
31 name: String,
32 client: &'a Client,
33}
34
35#[derive(Debug, Error)]
36pub enum FlagError {
37 #[error("HTTP error: {0}")]
38 HttpError(#[from] reqwest::Error),
39
40 #[error("Cache error: {0}")]
41 CacheError(String),
42
43 #[error("Missing authentication: {0}")]
44 AuthError(String),
45
46 #[error("API error: {0}")]
47 ApiError(String),
48}
49
50#[derive(Debug)]
51struct CircuitState {
52 is_open: bool,
53 failure_count: u32,
54 last_failure: Option<DateTime<Utc>>,
55}
56
57#[derive(Debug, Deserialize)]
58struct ApiResponse {
59 #[serde(rename = "intervalAllowed")]
60 interval_allowed: i32,
61 flags: Vec<flag::FeatureFlag>,
62}
63
64pub struct Client {
65 base_url: String,
66 http_client: reqwest::Client,
67 cache: Arc<RwLock<Box<dyn Cache + Send + Sync>>>,
68 max_retries: u32,
69 circuit_state: RwLock<CircuitState>,
70 auth: Option<Auth>,
71}
72
73impl Client {
74 pub fn builder() -> ClientBuilder {
75 ClientBuilder::new()
76 }
77
78 pub fn debug_info(&self) -> String {
79 format!(
80 "Client {{ base_url: {}, max_retries: {}, auth: {:?} }}",
81 self.base_url, self.max_retries, self.auth
82 )
83 }
84
85 pub fn is(&self, name: &str) -> Flag {
86 Flag {
87 name: name.to_string(),
88 client: self,
89 }
90 }
91
92 pub async fn list(&self) -> Result<Vec<flag::FeatureFlag>, FlagError> {
93 let cache = self.cache.read().unwrap();
94 let flags = cache.get_all().await
95 .map_err(|e| FlagError::CacheError(e.to_string()))?;
96 Ok(flags)
97 }
98
99 async fn is_enabled(&self, name: &str) -> bool {
100 let name = name.to_lowercase();
101
102 {
104 let cache = self.cache.read().unwrap();
105 if cache.should_refresh_cache().await {
106 drop(cache); if let Err(e) = self.refetch().await {
108 error!("Failed to refetch flags: {}", e);
109 return false;
110 }
111 }
112 }
113
114 let local_flags = build_local();
116 if let Some(&enabled) = local_flags.get(&name) {
117 return enabled;
118 }
119
120 let cache = self.cache.read().unwrap();
122 match cache.get(&name).await {
123 Ok((enabled, exists)) => {
124 if exists {
125 enabled
126 } else {
127 false
128 }
129 }
130 Err(_) => false,
131 }
132 }
133
134 async fn fetch_flags(&self) -> Result<ApiResponse, FlagError> {
135 let auth = match &self.auth {
136 Some(auth) => auth,
137 None => return Err(FlagError::AuthError("Authentication is required".to_string())),
138 };
139
140 if auth.project_id.is_empty() {
141 return Err(FlagError::AuthError("Project ID is required".to_string()));
142 }
143 if auth.agent_id.is_empty() {
144 return Err(FlagError::AuthError("Agent ID is required".to_string()));
145 }
146 if auth.environment_id.is_empty() {
147 return Err(FlagError::AuthError("Environment ID is required".to_string()));
148 }
149
150 let mut headers = HeaderMap::new();
151 headers.insert("User-Agent", HeaderValue::from_static("Flags-Rust"));
152 headers.insert("Accept", HeaderValue::from_static("application/json"));
153 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
154 headers.insert("X-Project-ID", HeaderValue::from_str(&auth.project_id).unwrap());
155 headers.insert("X-Agent-ID", HeaderValue::from_str(&auth.agent_id).unwrap());
156 headers.insert("X-Environment-ID", HeaderValue::from_str(&auth.environment_id).unwrap());
157
158 let url = format!("{}/flags", self.base_url);
159 let response = self.http_client
160 .get(&url)
161 .headers(headers)
162 .send()
163 .await?;
164
165 if !response.status().is_success() {
166 return Err(FlagError::ApiError(format!(
167 "Unexpected status code: {}",
168 response.status()
169 )));
170 }
171
172 let api_resp = response.json::<ApiResponse>().await?;
173 Ok(api_resp)
174 }
175
176 async fn refetch(&self) -> Result<(), FlagError> {
177 let mut circuit_state = self.circuit_state.write().unwrap();
178
179 if circuit_state.is_open {
180 if let Some(last_failure) = circuit_state.last_failure {
181 let now = Utc::now();
182 if (now - last_failure).num_seconds() < 10 {
183 return Ok(());
184 }
185 }
186 circuit_state.is_open = false;
187 circuit_state.failure_count = 0;
188 }
189 drop(circuit_state);
190
191 let mut api_resp = None;
192 let mut last_error = None;
193
194 for retry in 0..self.max_retries {
195 match self.fetch_flags().await {
196 Ok(resp) => {
197 api_resp = Some(resp);
198 let mut circuit_state = self.circuit_state.write().unwrap();
199 circuit_state.failure_count = 0;
200 break;
201 }
202 Err(e) => {
203 last_error = Some(e);
204 let mut circuit_state = self.circuit_state.write().unwrap();
205 circuit_state.failure_count += 1;
206
207 if circuit_state.failure_count >= self.max_retries {
208 circuit_state.is_open = true;
209 circuit_state.last_failure = Some(Utc::now());
210 return Ok(());
211 }
212 drop(circuit_state);
213
214 tokio::time::sleep(Duration::from_secs((retry + 1) as u64)).await;
215 }
216 }
217 }
218
219 let api_resp = match api_resp {
220 Some(resp) => resp,
221 None => return Err(last_error.unwrap()),
222 };
223
224 let flags: Vec<flag::FeatureFlag> = api_resp.flags
225 .into_iter()
226 .map(|f| flag::FeatureFlag {
227 enabled: f.enabled,
228 details: flag::Details {
229 name: f.details.name.to_lowercase(),
230 id: f.details.id,
231 },
232 })
233 .collect();
234
235 let mut cache = self.cache.write().unwrap();
236 cache.refresh(&flags, api_resp.interval_allowed).await
237 .map_err(|e| FlagError::CacheError(e.to_string()))?;
238
239 Ok(())
240 }
241}
242
243impl<'a> Flag<'a> {
244 pub async fn enabled(&self) -> bool {
245 self.client.is_enabled(&self.name).await
246 }
247}
248
249pub struct ClientBuilder {
250 base_url: String,
251 max_retries: u32,
252 auth: Option<Auth>,
253 use_memory_cache: bool,
254 file_name: Option<String>,
255}
256
257impl ClientBuilder {
258 fn new() -> Self {
259 Self {
260 base_url: BASE_URL.to_string(),
261 max_retries: MAX_RETRIES,
262 auth: None,
263 use_memory_cache: false,
264 file_name: None,
265 }
266 }
267
268 pub fn with_base_url(mut self, base_url: &str) -> Self {
269 self.base_url = base_url.to_string();
270 self
271 }
272
273 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
274 self.max_retries = max_retries;
275 self
276 }
277
278 pub fn with_auth(mut self, auth: Auth) -> Self {
279 self.auth = Some(auth);
280 self
281 }
282
283 pub fn with_file_name(mut self, file_name: &str) -> Self {
284 self.file_name = Some(file_name.to_string());
285 self
286 }
287
288 pub fn with_memory_cache(mut self) -> Self {
289 self.use_memory_cache = true;
290 self
291 }
292
293 pub fn build(self) -> Client {
294 let cache: Box<dyn Cache + Send + Sync> = if self.use_memory_cache {
295 Box::new(MemoryCache::new())
296 } else {
297 #[cfg(feature = "rusqlite")]
298 {
299 if let Some(file_name) = self.file_name {
300 Box::new(cache::SqliteCache::new(&file_name))
301 } else {
302 Box::new(MemoryCache::new())
303 }
304 }
305 #[cfg(not(feature = "rusqlite"))]
306 {
307 Box::new(MemoryCache::new())
308 }
309 };
310
311 Client {
312 base_url: self.base_url,
313 http_client: reqwest::Client::builder()
314 .timeout(Duration::from_secs(10))
315 .build()
316 .unwrap(),
317 cache: Arc::new(RwLock::new(cache)),
318 max_retries: self.max_retries,
319 circuit_state: RwLock::new(CircuitState {
320 is_open: false,
321 failure_count: 0,
322 last_failure: None,
323 }),
324 auth: self.auth,
325 }
326 }
327}
328
329fn build_local() -> HashMap<String, bool> {
330 let mut result = HashMap::new();
331
332 for (key, value) in env::vars() {
333 if !key.starts_with("FLAGS_") {
334 continue;
335 }
336
337 let enabled = value == "true";
338 let key_lower = key.trim_start_matches("FLAGS_").to_lowercase();
339
340 result.insert(key_lower.clone(), enabled);
341 result.insert(key_lower.replace('_', "-"), enabled);
342 result.insert(key_lower.replace('_', " "), enabled);
343 }
344
345 result
346}