1use std::{
2 collections::BTreeMap,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use async_trait::async_trait;
8use tokio::sync::RwLock;
9
10use crate::cache::{CacheKey, CacheResult, CacheStore};
11
12#[derive(Debug, Clone)]
13struct Entry {
14 value: Vec<u8>,
15 expires_at: Option<Instant>,
16}
17
18#[derive(Debug, Clone, Default)]
20pub struct MemoryCacheStore {
21 entries: Arc<RwLock<BTreeMap<String, Entry>>>,
22}
23
24impl MemoryCacheStore {
25 pub fn new() -> Self {
27 Self::default()
28 }
29}
30
31#[async_trait]
32impl CacheStore for MemoryCacheStore {
33 async fn get_raw(&self, key: &CacheKey) -> CacheResult<Option<Vec<u8>>> {
34 let rendered = key.render();
35 let mut entries = self.entries.write().await;
36 let Some(entry) = entries.get(&rendered) else {
37 return Ok(None);
38 };
39 if entry
40 .expires_at
41 .is_some_and(|deadline| deadline <= Instant::now())
42 {
43 entries.remove(&rendered);
44 return Ok(None);
45 }
46 Ok(Some(entry.value.clone()))
47 }
48
49 async fn set_raw(
50 &self,
51 key: &CacheKey,
52 value: Vec<u8>,
53 ttl: Option<Duration>,
54 ) -> CacheResult<()> {
55 let expires_at = ttl.map(|ttl| Instant::now() + ttl);
56 self.entries
57 .write()
58 .await
59 .insert(key.render(), Entry { value, expires_at });
60 Ok(())
61 }
62
63 async fn delete(&self, key: &CacheKey) -> CacheResult<()> {
64 self.entries.write().await.remove(&key.render());
65 Ok(())
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use std::time::Duration;
72
73 use crate::cache::{CacheKey, CacheStore, MemoryCacheStore};
74
75 #[tokio::test]
76 async fn memory_cache_stores_json_with_ttl() {
77 let store = MemoryCacheStore::new();
78 let key = CacheKey::new("test", ["user", "1"]);
79 store
80 .set_json(
81 &key,
82 &serde_json::json!({"name":"Ada"}),
83 Some(Duration::from_secs(1)),
84 )
85 .await
86 .expect("set");
87 let value: serde_json::Value = store.get_json(&key).await.expect("get").expect("value");
88 assert_eq!(value["name"], "Ada");
89 store.delete(&key).await.expect("delete");
90 assert!(store.get_raw(&key).await.expect("get").is_none());
91 }
92}