1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::RwLock;
9use std::time::{Duration, Instant};
10
11use super::store::CacheStore;
12use crate::error::FrameworkError;
13
14#[derive(Clone)]
16struct CacheEntry {
17 value: String,
18 expires_at: Option<Instant>,
19}
20
21impl CacheEntry {
22 fn is_expired(&self) -> bool {
23 self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
24 }
25}
26
27pub struct InMemoryCache {
40 store: RwLock<HashMap<String, CacheEntry>>,
41 prefix: String,
42}
43
44impl InMemoryCache {
45 pub fn new() -> Self {
47 Self {
48 store: RwLock::new(HashMap::new()),
49 prefix: "ferro_cache:".to_string(),
50 }
51 }
52
53 pub fn with_prefix(prefix: impl Into<String>) -> Self {
55 Self {
56 store: RwLock::new(HashMap::new()),
57 prefix: prefix.into(),
58 }
59 }
60
61 fn prefixed_key(&self, key: &str) -> String {
62 format!("{}{}", self.prefix, key)
63 }
64}
65
66impl Default for InMemoryCache {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72#[async_trait]
73impl CacheStore for InMemoryCache {
74 async fn get_raw(&self, key: &str) -> Result<Option<String>, FrameworkError> {
75 let key = self.prefixed_key(key);
76
77 let store = self
78 .store
79 .read()
80 .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
81
82 match store.get(&key) {
83 Some(entry) if !entry.is_expired() => Ok(Some(entry.value.clone())),
84 _ => Ok(None),
85 }
86 }
87
88 async fn put_raw(
89 &self,
90 key: &str,
91 value: &str,
92 ttl: Option<Duration>,
93 ) -> Result<(), FrameworkError> {
94 let key = self.prefixed_key(key);
95
96 let entry = CacheEntry {
97 value: value.to_string(),
98 expires_at: ttl.map(|d| Instant::now() + d),
99 };
100
101 let mut store = self
102 .store
103 .write()
104 .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
105
106 store.insert(key, entry);
107 Ok(())
108 }
109
110 async fn has(&self, key: &str) -> Result<bool, FrameworkError> {
111 let key = self.prefixed_key(key);
112
113 let store = self
114 .store
115 .read()
116 .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
117
118 Ok(store.get(&key).map(|e| !e.is_expired()).unwrap_or(false))
119 }
120
121 async fn forget(&self, key: &str) -> Result<bool, FrameworkError> {
122 let key = self.prefixed_key(key);
123
124 let mut store = self
125 .store
126 .write()
127 .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
128
129 Ok(store.remove(&key).is_some())
130 }
131
132 async fn flush(&self) -> Result<(), FrameworkError> {
133 let mut store = self
134 .store
135 .write()
136 .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
137
138 store.clear();
139 Ok(())
140 }
141
142 async fn increment(&self, key: &str, amount: i64) -> Result<i64, FrameworkError> {
143 let key = self.prefixed_key(key);
144
145 let mut store = self
146 .store
147 .write()
148 .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
149
150 let current: i64 = store
151 .get(&key)
152 .filter(|e| !e.is_expired())
153 .and_then(|e| e.value.parse().ok())
154 .unwrap_or(0);
155
156 let new_value = current + amount;
157
158 store.insert(
159 key,
160 CacheEntry {
161 value: new_value.to_string(),
162 expires_at: None,
163 },
164 );
165
166 Ok(new_value)
167 }
168
169 async fn decrement(&self, key: &str, amount: i64) -> Result<i64, FrameworkError> {
170 self.increment(key, -amount).await
171 }
172
173 async fn expire(&self, key: &str, ttl: Duration) -> Result<bool, FrameworkError> {
174 let key = self.prefixed_key(key);
175
176 let mut store = self
177 .store
178 .write()
179 .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
180
181 if let Some(entry) = store.get_mut(&key) {
182 entry.expires_at = Some(Instant::now() + ttl);
183 Ok(true)
184 } else {
185 Ok(false)
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[tokio::test]
195 async fn test_expire_sets_ttl() {
196 let cache = InMemoryCache::new();
197
198 cache.increment("counter", 1).await.unwrap();
200
201 let result = cache
203 .expire("counter", Duration::from_secs(1))
204 .await
205 .unwrap();
206 assert!(result, "expire should return true for existing key");
207
208 let val = cache.get_raw("counter").await.unwrap();
210 assert_eq!(val, Some("1".to_string()));
211
212 tokio::time::sleep(Duration::from_millis(1100)).await;
214
215 let val = cache.get_raw("counter").await.unwrap();
217 assert!(val.is_none(), "key should be expired after TTL");
218
219 let new_val = cache.increment("counter", 1).await.unwrap();
221 assert_eq!(new_val, 1, "increment on expired key should return 1");
222 }
223
224 #[tokio::test]
225 async fn test_expire_missing_key() {
226 let cache = InMemoryCache::new();
227
228 let result = cache
229 .expire("nonexistent", Duration::from_secs(10))
230 .await
231 .unwrap();
232 assert!(!result, "expire on missing key should return false");
233 }
234
235 #[tokio::test]
236 async fn test_increment_then_expire_preserves_value() {
237 let cache = InMemoryCache::new();
238
239 for _ in 0..5 {
241 cache.increment("counter", 1).await.unwrap();
242 }
243
244 let result = cache
246 .expire("counter", Duration::from_secs(10))
247 .await
248 .unwrap();
249 assert!(result);
250
251 let val = cache.increment("counter", 1).await.unwrap();
253 assert_eq!(val, 6, "expire should not reset the value");
254 }
255}