1use crate::cache::CacheStore;
4use crate::error::Error;
5use async_trait::async_trait;
6use dashmap::DashMap;
7use moka::future::Cache as MokaCache;
8use moka::policy::Expiry;
9use std::collections::HashSet;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13#[derive(Clone)]
15struct CacheValue {
16 data: Vec<u8>,
17 ttl: Duration,
18}
19
20struct PerEntryExpiry;
22
23impl Expiry<String, CacheValue> for PerEntryExpiry {
24 fn expire_after_create(
25 &self,
26 _key: &String,
27 value: &CacheValue,
28 _created_at: Instant,
29 ) -> Option<Duration> {
30 Some(value.ttl)
31 }
32
33 fn expire_after_update(
34 &self,
35 _key: &String,
36 value: &CacheValue,
37 _updated_at: Instant,
38 _duration_until_expiry: Option<Duration>,
39 ) -> Option<Duration> {
40 Some(value.ttl)
41 }
42
43 fn expire_after_read(
44 &self,
45 _key: &String,
46 _value: &CacheValue,
47 _read_at: Instant,
48 duration_until_expiry: Option<Duration>,
49 _last_modified_at: Instant,
50 ) -> Option<Duration> {
51 duration_until_expiry
52 }
53}
54
55pub struct MemoryStore {
57 cache: MokaCache<String, CacheValue>,
58 tags: Arc<DashMap<String, HashSet<String>>>,
59 counters: MokaCache<String, i64>,
60}
61
62impl Default for MemoryStore {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl MemoryStore {
69 pub fn new() -> Self {
71 Self::with_capacity(10_000)
72 }
73
74 pub fn with_capacity(capacity: u64) -> Self {
76 let tags: Arc<DashMap<String, HashSet<String>>> = Arc::new(DashMap::new());
77 let tags_clone = tags.clone();
78
79 let cache = MokaCache::builder()
80 .max_capacity(capacity)
81 .expire_after(PerEntryExpiry)
82 .eviction_listener(move |key: Arc<String>, _value, _cause| {
83 tags_clone.retain(|_tag, members| {
84 members.remove(key.as_str());
85 !members.is_empty()
86 });
87 })
88 .build();
89
90 let counters = MokaCache::builder().max_capacity(capacity).build();
91
92 Self {
93 cache,
94 tags,
95 counters,
96 }
97 }
98}
99
100#[async_trait]
101impl CacheStore for MemoryStore {
102 async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
103 Ok(self.cache.get(key).await.map(|cv| cv.data))
104 }
105
106 async fn put_raw(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<(), Error> {
107 let cv = CacheValue { data: value, ttl };
108 self.cache.insert(key.to_string(), cv).await;
109 Ok(())
110 }
111
112 async fn has(&self, key: &str) -> Result<bool, Error> {
113 Ok(self.cache.contains_key(key))
114 }
115
116 async fn forget(&self, key: &str) -> Result<bool, Error> {
117 let existed = self.cache.contains_key(key);
118 self.cache.remove(key).await;
119 self.counters.remove(key).await;
120 Ok(existed)
121 }
122
123 async fn flush(&self) -> Result<(), Error> {
124 self.cache.invalidate_all();
125 self.tags.clear();
126 self.counters.invalidate_all();
127 Ok(())
128 }
129
130 async fn increment(&self, key: &str, value: i64) -> Result<i64, Error> {
131 let current = self.counters.get(key).await.unwrap_or(0);
132 let new_val = current + value;
133 self.counters.insert(key.to_string(), new_val).await;
134 Ok(new_val)
135 }
136
137 async fn decrement(&self, key: &str, value: i64) -> Result<i64, Error> {
138 let current = self.counters.get(key).await.unwrap_or(0);
139 let new_val = current - value;
140 self.counters.insert(key.to_string(), new_val).await;
141 Ok(new_val)
142 }
143
144 async fn tag_add(&self, tag: &str, key: &str) -> Result<(), Error> {
145 self.tags
146 .entry(tag.to_string())
147 .or_default()
148 .insert(key.to_string());
149 Ok(())
150 }
151
152 async fn tag_members(&self, tag: &str) -> Result<Vec<String>, Error> {
153 let Some(mut entry) = self.tags.get_mut(tag) else {
154 return Ok(Vec::new());
155 };
156 entry.retain(|k| self.cache.contains_key(k));
158 let members: Vec<String> = entry.iter().cloned().collect();
159 let is_empty = entry.is_empty();
160 drop(entry);
161 if is_empty {
162 self.tags.remove(tag);
163 }
164 Ok(members)
165 }
166
167 async fn tag_flush(&self, tag: &str) -> Result<(), Error> {
168 if let Some((_, keys)) = self.tags.remove(tag) {
169 for key in keys {
170 self.cache.remove(&key).await;
171 }
172 }
173 Ok(())
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[tokio::test]
182 async fn test_memory_store_put_get() {
183 let store = MemoryStore::new();
184
185 store
186 .put_raw("key", b"value".to_vec(), Duration::from_secs(60))
187 .await
188 .unwrap();
189
190 let value = store.get_raw("key").await.unwrap();
191 assert_eq!(value, Some(b"value".to_vec()));
192 }
193
194 #[tokio::test]
195 async fn test_memory_store_has() {
196 let store = MemoryStore::new();
197
198 assert!(!store.has("missing").await.unwrap());
199
200 store
201 .put_raw("exists", b"value".to_vec(), Duration::from_secs(60))
202 .await
203 .unwrap();
204
205 assert!(store.has("exists").await.unwrap());
206 }
207
208 #[tokio::test]
209 async fn test_memory_store_forget() {
210 let store = MemoryStore::new();
211
212 store
213 .put_raw("key", b"value".to_vec(), Duration::from_secs(60))
214 .await
215 .unwrap();
216
217 let removed = store.forget("key").await.unwrap();
218 assert!(removed);
219 assert!(!store.has("key").await.unwrap());
220 }
221
222 #[tokio::test]
223 async fn test_memory_store_increment_decrement() {
224 let store = MemoryStore::new();
225
226 let val = store.increment("counter", 5).await.unwrap();
227 assert_eq!(val, 5);
228
229 let val = store.increment("counter", 3).await.unwrap();
230 assert_eq!(val, 8);
231
232 let val = store.decrement("counter", 2).await.unwrap();
233 assert_eq!(val, 6);
234 }
235
236 #[tokio::test]
237 async fn test_memory_store_tags() {
238 let store = MemoryStore::new();
239
240 store
241 .put_raw("user:1", b"alice".to_vec(), Duration::from_secs(60))
242 .await
243 .unwrap();
244 store
245 .put_raw("user:2", b"bob".to_vec(), Duration::from_secs(60))
246 .await
247 .unwrap();
248
249 store.tag_add("users", "user:1").await.unwrap();
250 store.tag_add("users", "user:2").await.unwrap();
251
252 let members = store.tag_members("users").await.unwrap();
253 assert_eq!(members.len(), 2);
254
255 store.tag_flush("users").await.unwrap();
256
257 assert!(!store.has("user:1").await.unwrap());
258 assert!(!store.has("user:2").await.unwrap());
259 }
260
261 #[tokio::test]
262 async fn test_memory_store_flush() {
263 let store = MemoryStore::new();
264
265 store
266 .put_raw("key1", b"value1".to_vec(), Duration::from_secs(60))
267 .await
268 .unwrap();
269 store
270 .put_raw("key2", b"value2".to_vec(), Duration::from_secs(60))
271 .await
272 .unwrap();
273
274 store.flush().await.unwrap();
275
276 assert!(!store.has("key1").await.unwrap());
277 assert!(!store.has("key2").await.unwrap());
278 }
279
280 #[tokio::test]
281 async fn test_per_entry_ttl_respected() {
282 let store = MemoryStore::new();
283
284 store
285 .put_raw("short", b"data".to_vec(), Duration::from_millis(100))
286 .await
287 .unwrap();
288
289 assert!(store.has("short").await.unwrap());
291
292 tokio::time::sleep(Duration::from_millis(200)).await;
293 store.cache.run_pending_tasks().await;
294
295 assert!(store.get_raw("short").await.unwrap().is_none());
297 }
298
299 #[tokio::test]
300 async fn test_tag_deduplication() {
301 let store = MemoryStore::new();
302
303 store
304 .put_raw("item", b"val".to_vec(), Duration::from_secs(60))
305 .await
306 .unwrap();
307
308 store.tag_add("dup-tag", "item").await.unwrap();
310 store.tag_add("dup-tag", "item").await.unwrap();
311
312 let members = store.tag_members("dup-tag").await.unwrap();
313 assert_eq!(members.len(), 1, "duplicate tag entries must be prevented");
314 }
315
316 #[tokio::test]
317 async fn test_eviction_cleans_tags() {
318 let store = MemoryStore::new();
319
320 for i in 0..5u64 {
322 let key = format!("ephemeral{i}");
323 store
324 .put_raw(&key, b"v".to_vec(), Duration::from_millis(100))
325 .await
326 .unwrap();
327 store.tag_add("temp", &key).await.unwrap();
328 }
329
330 assert_eq!(store.tag_members("temp").await.unwrap().len(), 5);
331
332 tokio::time::sleep(Duration::from_millis(200)).await;
334 store.cache.run_pending_tasks().await;
335
336 let members = store.tag_members("temp").await.unwrap();
338 assert!(
339 members.is_empty(),
340 "stale tag references should be cleaned on read, got {} members",
341 members.len()
342 );
343
344 assert_eq!(store.tags.len(), 0, "empty tag sets should be removed");
346 }
347
348 #[tokio::test]
349 async fn test_eviction_listener_on_explicit_remove() {
350 let store = MemoryStore::new();
351
352 store
353 .put_raw("tagged-key", b"data".to_vec(), Duration::from_secs(60))
354 .await
355 .unwrap();
356 store.tag_add("group", "tagged-key").await.unwrap();
357
358 store.cache.remove("tagged-key").await;
360 store.cache.run_pending_tasks().await;
361
362 let raw_members: Vec<String> = store
364 .tags
365 .get("group")
366 .map(|s| s.iter().cloned().collect())
367 .unwrap_or_default();
368 assert!(
369 raw_members.is_empty(),
370 "eviction listener should remove key from tags on explicit removal"
371 );
372 }
373
374 #[tokio::test]
375 async fn test_counters_bounded() {
376 let store = MemoryStore::with_capacity(50);
377
378 for i in 0..200u64 {
380 store.increment(&format!("c{i}"), 1).await.unwrap();
381 }
382 store.counters.run_pending_tasks().await;
383
384 let count = store.counters.entry_count();
385 assert!(
386 count <= 60,
387 "counter count should be bounded near capacity, got {count}"
388 );
389 }
390}