Skip to main content

ferro_cache/stores/
memory.rs

1//! In-memory cache store using moka.
2
3use crate::cache::CacheStore;
4use crate::error::Error;
5use async_trait::async_trait;
6use dashmap::DashMap;
7use moka::future::Cache as MokaCache;
8use moka::policy::Expiry;
9use std::collections::HashSet;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13/// Wrapper that stores data alongside its per-entry TTL.
14#[derive(Clone)]
15struct CacheValue {
16    data: Vec<u8>,
17    ttl: Duration,
18}
19
20/// Per-entry expiry policy: each entry expires after its own TTL.
21struct PerEntryExpiry;
22
23impl Expiry<String, CacheValue> for PerEntryExpiry {
24    fn expire_after_create(
25        &self,
26        _key: &String,
27        value: &CacheValue,
28        _created_at: Instant,
29    ) -> Option<Duration> {
30        Some(value.ttl)
31    }
32
33    fn expire_after_update(
34        &self,
35        _key: &String,
36        value: &CacheValue,
37        _updated_at: Instant,
38        _duration_until_expiry: Option<Duration>,
39    ) -> Option<Duration> {
40        Some(value.ttl)
41    }
42
43    fn expire_after_read(
44        &self,
45        _key: &String,
46        _value: &CacheValue,
47        _read_at: Instant,
48        duration_until_expiry: Option<Duration>,
49        _last_modified_at: Instant,
50    ) -> Option<Duration> {
51        duration_until_expiry
52    }
53}
54
55/// In-memory cache store.
56pub struct MemoryStore {
57    cache: MokaCache<String, CacheValue>,
58    tags: Arc<DashMap<String, HashSet<String>>>,
59    counters: MokaCache<String, i64>,
60}
61
62impl Default for MemoryStore {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl MemoryStore {
69    /// Create a new memory store.
70    pub fn new() -> Self {
71        Self::with_capacity(10_000)
72    }
73
74    /// Create with custom capacity.
75    pub fn with_capacity(capacity: u64) -> Self {
76        let tags: Arc<DashMap<String, HashSet<String>>> = Arc::new(DashMap::new());
77        let tags_clone = tags.clone();
78
79        let cache = MokaCache::builder()
80            .max_capacity(capacity)
81            .expire_after(PerEntryExpiry)
82            .eviction_listener(move |key: Arc<String>, _value, _cause| {
83                tags_clone.retain(|_tag, members| {
84                    members.remove(key.as_str());
85                    !members.is_empty()
86                });
87            })
88            .build();
89
90        let counters = MokaCache::builder().max_capacity(capacity).build();
91
92        Self {
93            cache,
94            tags,
95            counters,
96        }
97    }
98}
99
100#[async_trait]
101impl CacheStore for MemoryStore {
102    async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
103        Ok(self.cache.get(key).await.map(|cv| cv.data))
104    }
105
106    async fn put_raw(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<(), Error> {
107        let cv = CacheValue { data: value, ttl };
108        self.cache.insert(key.to_string(), cv).await;
109        Ok(())
110    }
111
112    async fn has(&self, key: &str) -> Result<bool, Error> {
113        Ok(self.cache.contains_key(key))
114    }
115
116    async fn forget(&self, key: &str) -> Result<bool, Error> {
117        let existed = self.cache.contains_key(key);
118        self.cache.remove(key).await;
119        self.counters.remove(key).await;
120        Ok(existed)
121    }
122
123    async fn flush(&self) -> Result<(), Error> {
124        self.cache.invalidate_all();
125        self.tags.clear();
126        self.counters.invalidate_all();
127        Ok(())
128    }
129
130    async fn increment(&self, key: &str, value: i64) -> Result<i64, Error> {
131        let current = self.counters.get(key).await.unwrap_or(0);
132        let new_val = current + value;
133        self.counters.insert(key.to_string(), new_val).await;
134        Ok(new_val)
135    }
136
137    async fn decrement(&self, key: &str, value: i64) -> Result<i64, Error> {
138        let current = self.counters.get(key).await.unwrap_or(0);
139        let new_val = current - value;
140        self.counters.insert(key.to_string(), new_val).await;
141        Ok(new_val)
142    }
143
144    async fn tag_add(&self, tag: &str, key: &str) -> Result<(), Error> {
145        self.tags
146            .entry(tag.to_string())
147            .or_default()
148            .insert(key.to_string());
149        Ok(())
150    }
151
152    async fn tag_members(&self, tag: &str) -> Result<Vec<String>, Error> {
153        let Some(mut entry) = self.tags.get_mut(tag) else {
154            return Ok(Vec::new());
155        };
156        // Lazy cleanup: remove keys no longer present in the cache.
157        entry.retain(|k| self.cache.contains_key(k));
158        let members: Vec<String> = entry.iter().cloned().collect();
159        let is_empty = entry.is_empty();
160        drop(entry);
161        if is_empty {
162            self.tags.remove(tag);
163        }
164        Ok(members)
165    }
166
167    async fn tag_flush(&self, tag: &str) -> Result<(), Error> {
168        if let Some((_, keys)) = self.tags.remove(tag) {
169            for key in keys {
170                self.cache.remove(&key).await;
171            }
172        }
173        Ok(())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[tokio::test]
182    async fn test_memory_store_put_get() {
183        let store = MemoryStore::new();
184
185        store
186            .put_raw("key", b"value".to_vec(), Duration::from_secs(60))
187            .await
188            .unwrap();
189
190        let value = store.get_raw("key").await.unwrap();
191        assert_eq!(value, Some(b"value".to_vec()));
192    }
193
194    #[tokio::test]
195    async fn test_memory_store_has() {
196        let store = MemoryStore::new();
197
198        assert!(!store.has("missing").await.unwrap());
199
200        store
201            .put_raw("exists", b"value".to_vec(), Duration::from_secs(60))
202            .await
203            .unwrap();
204
205        assert!(store.has("exists").await.unwrap());
206    }
207
208    #[tokio::test]
209    async fn test_memory_store_forget() {
210        let store = MemoryStore::new();
211
212        store
213            .put_raw("key", b"value".to_vec(), Duration::from_secs(60))
214            .await
215            .unwrap();
216
217        let removed = store.forget("key").await.unwrap();
218        assert!(removed);
219        assert!(!store.has("key").await.unwrap());
220    }
221
222    #[tokio::test]
223    async fn test_memory_store_increment_decrement() {
224        let store = MemoryStore::new();
225
226        let val = store.increment("counter", 5).await.unwrap();
227        assert_eq!(val, 5);
228
229        let val = store.increment("counter", 3).await.unwrap();
230        assert_eq!(val, 8);
231
232        let val = store.decrement("counter", 2).await.unwrap();
233        assert_eq!(val, 6);
234    }
235
236    #[tokio::test]
237    async fn test_memory_store_tags() {
238        let store = MemoryStore::new();
239
240        store
241            .put_raw("user:1", b"alice".to_vec(), Duration::from_secs(60))
242            .await
243            .unwrap();
244        store
245            .put_raw("user:2", b"bob".to_vec(), Duration::from_secs(60))
246            .await
247            .unwrap();
248
249        store.tag_add("users", "user:1").await.unwrap();
250        store.tag_add("users", "user:2").await.unwrap();
251
252        let members = store.tag_members("users").await.unwrap();
253        assert_eq!(members.len(), 2);
254
255        store.tag_flush("users").await.unwrap();
256
257        assert!(!store.has("user:1").await.unwrap());
258        assert!(!store.has("user:2").await.unwrap());
259    }
260
261    #[tokio::test]
262    async fn test_memory_store_flush() {
263        let store = MemoryStore::new();
264
265        store
266            .put_raw("key1", b"value1".to_vec(), Duration::from_secs(60))
267            .await
268            .unwrap();
269        store
270            .put_raw("key2", b"value2".to_vec(), Duration::from_secs(60))
271            .await
272            .unwrap();
273
274        store.flush().await.unwrap();
275
276        assert!(!store.has("key1").await.unwrap());
277        assert!(!store.has("key2").await.unwrap());
278    }
279
280    #[tokio::test]
281    async fn test_per_entry_ttl_respected() {
282        let store = MemoryStore::new();
283
284        store
285            .put_raw("short", b"data".to_vec(), Duration::from_millis(100))
286            .await
287            .unwrap();
288
289        // Entry exists immediately.
290        assert!(store.has("short").await.unwrap());
291
292        tokio::time::sleep(Duration::from_millis(200)).await;
293        store.cache.run_pending_tasks().await;
294
295        // Entry expired after its TTL.
296        assert!(store.get_raw("short").await.unwrap().is_none());
297    }
298
299    #[tokio::test]
300    async fn test_tag_deduplication() {
301        let store = MemoryStore::new();
302
303        store
304            .put_raw("item", b"val".to_vec(), Duration::from_secs(60))
305            .await
306            .unwrap();
307
308        // Add same key to same tag twice.
309        store.tag_add("dup-tag", "item").await.unwrap();
310        store.tag_add("dup-tag", "item").await.unwrap();
311
312        let members = store.tag_members("dup-tag").await.unwrap();
313        assert_eq!(members.len(), 1, "duplicate tag entries must be prevented");
314    }
315
316    #[tokio::test]
317    async fn test_eviction_cleans_tags() {
318        let store = MemoryStore::new();
319
320        // Insert entries with short TTL and tag them.
321        for i in 0..5u64 {
322            let key = format!("ephemeral{i}");
323            store
324                .put_raw(&key, b"v".to_vec(), Duration::from_millis(100))
325                .await
326                .unwrap();
327            store.tag_add("temp", &key).await.unwrap();
328        }
329
330        assert_eq!(store.tag_members("temp").await.unwrap().len(), 5);
331
332        // Wait for TTL expiry.
333        tokio::time::sleep(Duration::from_millis(200)).await;
334        store.cache.run_pending_tasks().await;
335
336        // tag_members performs lazy cleanup of stale references.
337        let members = store.tag_members("temp").await.unwrap();
338        assert!(
339            members.is_empty(),
340            "stale tag references should be cleaned on read, got {} members",
341            members.len()
342        );
343
344        // Empty tag set should be pruned.
345        assert_eq!(store.tags.len(), 0, "empty tag sets should be removed");
346    }
347
348    #[tokio::test]
349    async fn test_eviction_listener_on_explicit_remove() {
350        let store = MemoryStore::new();
351
352        store
353            .put_raw("tagged-key", b"data".to_vec(), Duration::from_secs(60))
354            .await
355            .unwrap();
356        store.tag_add("group", "tagged-key").await.unwrap();
357
358        // Explicit removal triggers the eviction listener.
359        store.cache.remove("tagged-key").await;
360        store.cache.run_pending_tasks().await;
361
362        // Listener should have cleaned the key from tag sets.
363        let raw_members: Vec<String> = store
364            .tags
365            .get("group")
366            .map(|s| s.iter().cloned().collect())
367            .unwrap_or_default();
368        assert!(
369            raw_members.is_empty(),
370            "eviction listener should remove key from tags on explicit removal"
371        );
372    }
373
374    #[tokio::test]
375    async fn test_counters_bounded() {
376        let store = MemoryStore::with_capacity(50);
377
378        // Insert many unique counter keys.
379        for i in 0..200u64 {
380            store.increment(&format!("c{i}"), 1).await.unwrap();
381        }
382        store.counters.run_pending_tasks().await;
383
384        let count = store.counters.entry_count();
385        assert!(
386            count <= 60,
387            "counter count should be bounded near capacity, got {count}"
388        );
389    }
390}