ferro_cache/stores/
memory.rs

1//! In-memory cache store using moka.
2
3use crate::cache::CacheStore;
4use crate::error::Error;
5use async_trait::async_trait;
6use dashmap::DashMap;
7use moka::future::Cache as MokaCache;
8use std::sync::Arc;
9use std::time::Duration;
10
11/// In-memory cache store.
12pub struct MemoryStore {
13    cache: MokaCache<String, Vec<u8>>,
14    tags: Arc<DashMap<String, Vec<String>>>,
15    counters: Arc<DashMap<String, i64>>,
16}
17
18impl Default for MemoryStore {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl MemoryStore {
25    /// Create a new memory store.
26    pub fn new() -> Self {
27        Self {
28            cache: MokaCache::builder().max_capacity(10_000).build(),
29            tags: Arc::new(DashMap::new()),
30            counters: Arc::new(DashMap::new()),
31        }
32    }
33
34    /// Create with custom capacity.
35    pub fn with_capacity(capacity: u64) -> Self {
36        Self {
37            cache: MokaCache::builder().max_capacity(capacity).build(),
38            tags: Arc::new(DashMap::new()),
39            counters: Arc::new(DashMap::new()),
40        }
41    }
42}
43
44#[async_trait]
45impl CacheStore for MemoryStore {
46    async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
47        Ok(self.cache.get(key).await)
48    }
49
50    async fn put_raw(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<(), Error> {
51        self.cache.insert(key.to_string(), value).await;
52
53        // Moka handles TTL through expiration, but we need to set it per-entry
54        // For simplicity, we'll use the builder-level TTL
55        // In production, you'd want per-entry TTL using a wrapper or different approach
56        let _ = ttl; // TTL is handled at cache level for moka
57
58        Ok(())
59    }
60
61    async fn has(&self, key: &str) -> Result<bool, Error> {
62        Ok(self.cache.contains_key(key))
63    }
64
65    async fn forget(&self, key: &str) -> Result<bool, Error> {
66        let existed = self.cache.contains_key(key);
67        self.cache.remove(key).await;
68        self.counters.remove(key);
69        Ok(existed)
70    }
71
72    async fn flush(&self) -> Result<(), Error> {
73        self.cache.invalidate_all();
74        self.tags.clear();
75        self.counters.clear();
76        Ok(())
77    }
78
79    async fn increment(&self, key: &str, value: i64) -> Result<i64, Error> {
80        let mut entry = self.counters.entry(key.to_string()).or_insert(0);
81        *entry += value;
82        Ok(*entry)
83    }
84
85    async fn decrement(&self, key: &str, value: i64) -> Result<i64, Error> {
86        let mut entry = self.counters.entry(key.to_string()).or_insert(0);
87        *entry -= value;
88        Ok(*entry)
89    }
90
91    async fn tag_add(&self, tag: &str, key: &str) -> Result<(), Error> {
92        self.tags
93            .entry(tag.to_string())
94            .or_default()
95            .push(key.to_string());
96        Ok(())
97    }
98
99    async fn tag_members(&self, tag: &str) -> Result<Vec<String>, Error> {
100        Ok(self.tags.get(tag).map(|v| v.clone()).unwrap_or_default())
101    }
102
103    async fn tag_flush(&self, tag: &str) -> Result<(), Error> {
104        if let Some((_, keys)) = self.tags.remove(tag) {
105            for key in keys {
106                self.cache.remove(&key).await;
107            }
108        }
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[tokio::test]
118    async fn test_memory_store_put_get() {
119        let store = MemoryStore::new();
120
121        store
122            .put_raw("key", b"value".to_vec(), Duration::from_secs(60))
123            .await
124            .unwrap();
125
126        let value = store.get_raw("key").await.unwrap();
127        assert_eq!(value, Some(b"value".to_vec()));
128    }
129
130    #[tokio::test]
131    async fn test_memory_store_has() {
132        let store = MemoryStore::new();
133
134        assert!(!store.has("missing").await.unwrap());
135
136        store
137            .put_raw("exists", b"value".to_vec(), Duration::from_secs(60))
138            .await
139            .unwrap();
140
141        assert!(store.has("exists").await.unwrap());
142    }
143
144    #[tokio::test]
145    async fn test_memory_store_forget() {
146        let store = MemoryStore::new();
147
148        store
149            .put_raw("key", b"value".to_vec(), Duration::from_secs(60))
150            .await
151            .unwrap();
152
153        let removed = store.forget("key").await.unwrap();
154        assert!(removed);
155        assert!(!store.has("key").await.unwrap());
156    }
157
158    #[tokio::test]
159    async fn test_memory_store_increment_decrement() {
160        let store = MemoryStore::new();
161
162        let val = store.increment("counter", 5).await.unwrap();
163        assert_eq!(val, 5);
164
165        let val = store.increment("counter", 3).await.unwrap();
166        assert_eq!(val, 8);
167
168        let val = store.decrement("counter", 2).await.unwrap();
169        assert_eq!(val, 6);
170    }
171
172    #[tokio::test]
173    async fn test_memory_store_tags() {
174        let store = MemoryStore::new();
175
176        store
177            .put_raw("user:1", b"alice".to_vec(), Duration::from_secs(60))
178            .await
179            .unwrap();
180        store
181            .put_raw("user:2", b"bob".to_vec(), Duration::from_secs(60))
182            .await
183            .unwrap();
184
185        store.tag_add("users", "user:1").await.unwrap();
186        store.tag_add("users", "user:2").await.unwrap();
187
188        let members = store.tag_members("users").await.unwrap();
189        assert_eq!(members.len(), 2);
190
191        store.tag_flush("users").await.unwrap();
192
193        assert!(!store.has("user:1").await.unwrap());
194        assert!(!store.has("user:2").await.unwrap());
195    }
196
197    #[tokio::test]
198    async fn test_memory_store_flush() {
199        let store = MemoryStore::new();
200
201        store
202            .put_raw("key1", b"value1".to_vec(), Duration::from_secs(60))
203            .await
204            .unwrap();
205        store
206            .put_raw("key2", b"value2".to_vec(), Duration::from_secs(60))
207            .await
208            .unwrap();
209
210        store.flush().await.unwrap();
211
212        assert!(!store.has("key1").await.unwrap());
213        assert!(!store.has("key2").await.unwrap());
214    }
215}