1use async_trait::async_trait;
7use moka::sync::Cache;
8use moka::Expiry;
9use std::time::{Duration, Instant};
10
11use super::store::CacheStore;
12use crate::error::FrameworkError;
13
14#[derive(Clone)]
16struct CacheValue {
17 value: String,
18 ttl: Option<Duration>,
19}
20
21struct CacheTtlExpiry;
23
24impl Expiry<String, CacheValue> for CacheTtlExpiry {
25 fn expire_after_create(
26 &self,
27 _key: &String,
28 value: &CacheValue,
29 _created_at: Instant,
30 ) -> Option<Duration> {
31 value.ttl
32 }
33
34 fn expire_after_read(
35 &self,
36 _key: &String,
37 _value: &CacheValue,
38 _read_at: Instant,
39 duration_until_expiry: Option<Duration>,
40 _last_modified_at: Instant,
41 ) -> Option<Duration> {
42 duration_until_expiry
43 }
44
45 fn expire_after_update(
46 &self,
47 _key: &String,
48 value: &CacheValue,
49 _updated_at: Instant,
50 _duration_until_expiry: Option<Duration>,
51 ) -> Option<Duration> {
52 value.ttl
53 }
54}
55
56pub struct InMemoryCache {
69 cache: Cache<String, CacheValue>,
70 prefix: String,
71}
72
73impl InMemoryCache {
74 pub fn new() -> Self {
76 Self::with_capacity(10_000)
77 }
78
79 pub fn with_capacity(capacity: u64) -> Self {
81 Self {
82 cache: Cache::builder()
83 .max_capacity(capacity)
84 .expire_after(CacheTtlExpiry)
85 .build(),
86 prefix: "ferro_cache:".to_string(),
87 }
88 }
89
90 pub fn with_prefix(prefix: impl Into<String>) -> Self {
92 Self {
93 cache: Cache::builder()
94 .max_capacity(10_000)
95 .expire_after(CacheTtlExpiry)
96 .build(),
97 prefix: prefix.into(),
98 }
99 }
100
101 fn prefixed_key(&self, key: &str) -> String {
102 format!("{}{}", self.prefix, key)
103 }
104}
105
106impl Default for InMemoryCache {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112#[async_trait]
113impl CacheStore for InMemoryCache {
114 async fn get_raw(&self, key: &str) -> Result<Option<String>, FrameworkError> {
115 let key = self.prefixed_key(key);
116 Ok(self.cache.get(&key).map(|cv| cv.value))
117 }
118
119 async fn put_raw(
120 &self,
121 key: &str,
122 value: &str,
123 ttl: Option<Duration>,
124 ) -> Result<(), FrameworkError> {
125 let key = self.prefixed_key(key);
126 self.cache.insert(
127 key,
128 CacheValue {
129 value: value.to_string(),
130 ttl,
131 },
132 );
133 Ok(())
134 }
135
136 async fn has(&self, key: &str) -> Result<bool, FrameworkError> {
137 let key = self.prefixed_key(key);
138 Ok(self.cache.contains_key(&key))
139 }
140
141 async fn forget(&self, key: &str) -> Result<bool, FrameworkError> {
142 let key = self.prefixed_key(key);
143 let existed = self.cache.contains_key(&key);
144 self.cache.remove(&key);
145 Ok(existed)
146 }
147
148 async fn flush(&self) -> Result<(), FrameworkError> {
149 self.cache.invalidate_all();
150 Ok(())
151 }
152
153 async fn increment(&self, key: &str, amount: i64) -> Result<i64, FrameworkError> {
154 let key = self.prefixed_key(key);
155
156 let current: i64 = self
157 .cache
158 .get(&key)
159 .and_then(|cv| cv.value.parse().ok())
160 .unwrap_or(0);
161
162 let new_value = current + amount;
163
164 self.cache.insert(
165 key,
166 CacheValue {
167 value: new_value.to_string(),
168 ttl: None,
169 },
170 );
171
172 Ok(new_value)
173 }
174
175 async fn decrement(&self, key: &str, amount: i64) -> Result<i64, FrameworkError> {
176 self.increment(key, -amount).await
177 }
178
179 async fn expire(&self, key: &str, ttl: Duration) -> Result<bool, FrameworkError> {
180 let key = self.prefixed_key(key);
181
182 match self.cache.get(&key) {
183 Some(cv) => {
184 self.cache.insert(
185 key,
186 CacheValue {
187 value: cv.value,
188 ttl: Some(ttl),
189 },
190 );
191 Ok(true)
192 }
193 None => Ok(false),
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[tokio::test]
203 async fn test_expire_sets_ttl() {
204 let cache = InMemoryCache::new();
205
206 cache.increment("counter", 1).await.unwrap();
208
209 let result = cache
211 .expire("counter", Duration::from_secs(1))
212 .await
213 .unwrap();
214 assert!(result, "expire should return true for existing key");
215
216 let val = cache.get_raw("counter").await.unwrap();
218 assert_eq!(val, Some("1".to_string()));
219
220 tokio::time::sleep(Duration::from_millis(1100)).await;
222
223 let val = cache.get_raw("counter").await.unwrap();
225 assert!(val.is_none(), "key should be expired after TTL");
226
227 let new_val = cache.increment("counter", 1).await.unwrap();
229 assert_eq!(new_val, 1, "increment on expired key should return 1");
230 }
231
232 #[tokio::test]
233 async fn test_expire_missing_key() {
234 let cache = InMemoryCache::new();
235
236 let result = cache
237 .expire("nonexistent", Duration::from_secs(10))
238 .await
239 .unwrap();
240 assert!(!result, "expire on missing key should return false");
241 }
242
243 #[tokio::test]
244 async fn test_increment_then_expire_preserves_value() {
245 let cache = InMemoryCache::new();
246
247 for _ in 0..5 {
249 cache.increment("counter", 1).await.unwrap();
250 }
251
252 let result = cache
254 .expire("counter", Duration::from_secs(10))
255 .await
256 .unwrap();
257 assert!(result);
258
259 let val = cache.increment("counter", 1).await.unwrap();
261 assert_eq!(val, 6, "expire should not reset the value");
262 }
263
264 #[tokio::test]
265 async fn test_put_get_forget_flush() {
266 let cache = InMemoryCache::new();
267
268 cache.put_raw("key1", "value1", None).await.unwrap();
270 assert_eq!(
271 cache.get_raw("key1").await.unwrap(),
272 Some("value1".to_string())
273 );
274 assert!(cache.has("key1").await.unwrap());
275
276 assert!(cache.get_raw("missing").await.unwrap().is_none());
278 assert!(!cache.has("missing").await.unwrap());
279
280 assert!(cache.forget("key1").await.unwrap());
282 assert!(cache.get_raw("key1").await.unwrap().is_none());
283
284 assert!(!cache.forget("key1").await.unwrap());
286
287 cache.put_raw("a", "1", None).await.unwrap();
289 cache.put_raw("b", "2", None).await.unwrap();
290 cache.flush().await.unwrap();
291 assert!(!cache.has("a").await.unwrap());
292 assert!(!cache.has("b").await.unwrap());
293 }
294
295 #[tokio::test]
296 async fn test_capacity_eviction() {
297 let cache = InMemoryCache::with_capacity(100);
298
299 for i in 0..200u64 {
300 cache
301 .put_raw(&format!("key{i}"), &format!("val{i}"), None)
302 .await
303 .unwrap();
304 }
305
306 cache.cache.run_pending_tasks();
308
309 let count = cache.cache.entry_count();
310 assert!(
311 count <= 110,
312 "cache should be bounded near capacity, got {count}"
313 );
314 }
315
316 #[tokio::test]
317 async fn test_expired_entries_not_returned() {
318 let cache = InMemoryCache::new();
319
320 cache
321 .put_raw("short-lived", "data", Some(Duration::from_millis(100)))
322 .await
323 .unwrap();
324
325 assert!(cache.has("short-lived").await.unwrap());
327 assert_eq!(
328 cache.get_raw("short-lived").await.unwrap(),
329 Some("data".to_string())
330 );
331
332 tokio::time::sleep(Duration::from_millis(200)).await;
333
334 assert!(
336 cache.get_raw("short-lived").await.unwrap().is_none(),
337 "expired entry should not be returned"
338 );
339 assert!(
340 !cache.has("short-lived").await.unwrap(),
341 "has() should return false for expired entry"
342 );
343 }
344}