1use std::collections::HashMap;
3use std::sync::Mutex;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use log::error;
9
10use crate::flag::FeatureFlag;
11
12#[async_trait]
13pub trait Cache {
14 async fn get(&self, name: &str) -> Result<(bool, bool), Box<dyn std::error::Error + Send + Sync>>;
15 async fn get_all(&self) -> Result<Vec<FeatureFlag>, Box<dyn std::error::Error + Send + Sync>>;
16 async fn refresh(&mut self, flags: &[FeatureFlag], interval_allowed: i32) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
17 async fn should_refresh_cache(&self) -> bool;
18 async fn init(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
19}
20
21pub struct CacheSystem {
22 cache: Box<dyn Cache + Send + Sync>,
23}
24
25impl CacheSystem {
26 pub fn new(cache: Box<dyn Cache + Send + Sync>) -> Self {
27 Self { cache }
28 }
29}
30
31pub struct MemoryCache {
32 flags: Mutex<HashMap<String, FeatureFlag>>,
33 cache_ttl: i64,
34 next_refresh: Mutex<DateTime<Utc>>,
35}
36
37impl MemoryCache {
38 pub fn new() -> Self {
39 let cache = Self {
40 flags: Mutex::new(HashMap::new()),
41 cache_ttl: 60,
42 next_refresh: Mutex::new(Utc::now() - chrono::Duration::seconds(90)), };
44
45 cache
46 }
47}
48
49
50#[async_trait]
51impl Cache for MemoryCache {
52 async fn get(&self, name: &str) -> Result<(bool, bool), Box<dyn std::error::Error + Send + Sync>> {
53 let flags = self.flags.lock().unwrap();
54 if let Some(flag) = flags.get(name) {
55 Ok((flag.enabled, true))
56 } else {
57 Ok((false, false))
58 }
59 }
60
61 async fn get_all(&self) -> Result<Vec<FeatureFlag>, Box<dyn std::error::Error + Send + Sync>> {
62 let flags = self.flags.lock().unwrap();
63 let all_flags: Vec<FeatureFlag> = flags.values().cloned().collect();
64 Ok(all_flags)
65 }
66
67 async fn refresh(&mut self, flags: &[FeatureFlag], interval_allowed: i32) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
68 let mut flag_map = self.flags.lock().unwrap();
69 flag_map.clear();
70
71 for flag in flags {
72 flag_map.insert(flag.details.name.clone(), flag.clone());
73 }
74
75 self.cache_ttl = interval_allowed as i64;
76 let mut next_refresh = self.next_refresh.lock().unwrap();
77 *next_refresh = Utc::now() + chrono::Duration::seconds(self.cache_ttl);
78
79 Ok(())
80 }
81
82 async fn should_refresh_cache(&self) -> bool {
83 let next_refresh = self.next_refresh.lock().unwrap();
84 Utc::now() > *next_refresh
85 }
86
87 async fn init(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
88 self.cache_ttl = 60;
89 let mut next_refresh = self.next_refresh.lock().unwrap();
90 *next_refresh = Utc::now() - chrono::Duration::seconds(90);
91 Ok(())
92 }
93}
94
95#[cfg(feature = "rusqlite")]
96pub struct SqliteCache {
97 conn: Mutex<rusqlite::Connection>,
98 cache_ttl: i64,
99 next_refresh: Mutex<DateTime<Utc>>,
100}
101
102#[cfg(feature = "rusqlite")]
103impl SqliteCache {
104 pub fn new(file_name: &str) -> Self {
105 let conn = match rusqlite::Connection::open(file_name) {
106 Ok(conn) => conn,
107 Err(e) => {
108 error!("Failed to open SQLite database: {}", e);
109 panic!("Failed to open SQLite database");
110 }
111 };
112
113 let cache = Self {
114 conn: Mutex::new(conn),
115 cache_ttl: 60,
116 next_refresh: Mutex::new(Utc::now()),
117 };
118
119 let mut next_refresh = cache.next_refresh.lock().unwrap();
121 *next_refresh = Utc::now() - chrono::Duration::seconds(90);
122
123 cache
124 }
125}
126
127#[cfg(feature = "rusqlite")]
128#[async_trait]
129impl Cache for SqliteCache {
130 async fn get(&self, name: &str) -> Result<(bool, bool), Box<dyn std::error::Error + Send + Sync>> {
131 let conn = self.conn.lock().unwrap();
132
133 let mut stmt = conn.prepare("SELECT enabled FROM flags WHERE name = ?")?;
134 let mut rows = stmt.query(&[name])?;
135
136 if let Some(row) = rows.next()? {
137 let enabled: bool = row.get(0)?;
138 Ok((enabled, true))
139 } else {
140 Ok((false, false))
141 }
142 }
143
144 async fn get_all(&self) -> Result<Vec<FeatureFlag>, Box<dyn std::error::Error + Send + Sync>> {
145 let conn = self.conn.lock().unwrap();
146
147 let mut stmt = conn.prepare("SELECT name, id, enabled FROM flags")?;
148 let rows = stmt.query_map([], |row| {
149 Ok(FeatureFlag {
150 enabled: row.get(2)?,
151 details: crate::flag::Details {
152 name: row.get(0)?,
153 id: row.get(1)?,
154 },
155 })
156 })?;
157
158 let mut flags = Vec::new();
159 for flag in rows {
160 flags.push(flag?);
161 }
162
163 Ok(flags)
164 }
165
166 async fn refresh(&mut self, flags: &[FeatureFlag], interval_allowed: i32) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
167 let conn = self.conn.lock().unwrap();
168
169 let tx = conn.transaction()?;
170 tx.execute("DELETE FROM flags", [])?;
171
172 for flag in flags {
173 tx.execute(
174 "INSERT INTO flags (name, id, enabled) VALUES (?, ?, ?)",
175 &[
176 &flag.details.name,
177 &flag.details.id,
178 &flag.enabled,
179 ],
180 )?;
181 }
182
183 tx.commit()?;
184
185 self.cache_ttl = interval_allowed as i64;
186 let mut next_refresh = self.next_refresh.lock().unwrap();
187 *next_refresh = Utc::now() + chrono::Duration::seconds(self.cache_ttl);
188
189 Ok(())
190 }
191
192 async fn should_refresh_cache(&self) -> bool {
193 let next_refresh = self.next_refresh.lock().unwrap();
194 Utc::now() > *next_refresh
195 }
196
197 async fn init(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
198 let conn = self.conn.lock().unwrap();
199
200 conn.execute(
201 "CREATE TABLE IF NOT EXISTS flags (
202 name TEXT PRIMARY KEY,
203 id TEXT NOT NULL,
204 enabled BOOLEAN NOT NULL
205 )",
206 [],
207 )?;
208
209 self.cache_ttl = 60;
210 let mut next_refresh = self.next_refresh.lock().unwrap();
211 *next_refresh = Utc::now() - chrono::Duration::seconds(90);
212
213 Ok(())
214 }
215}