oxidite_cache/
lib.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7pub mod redis;
8pub use crate::redis::RedisCache;
9
10pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
11
12/// Cache trait
13#[async_trait]
14pub trait Cache: Send + Sync {
15    async fn get<T>(&self, key: &str) -> Result<Option<T>>
16    where
17        T: for<'de> Deserialize<'de> + Send;
18
19    async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
20    where
21        T: Serialize + Send + Sync;
22
23    async fn delete(&self, key: &str) -> Result<()>;
24    
25    async fn exists(&self, key: &str) -> Result<bool>;
26    
27    async fn flush(&self) -> Result<()>;
28}
29
30/// Cache entry with expiration
31#[derive(Clone)]
32struct CacheEntry {
33    data: Vec<u8>,
34    expires_at: Option<Instant>,
35}
36
37impl CacheEntry {
38    fn new(data: Vec<u8>, ttl: Option<Duration>) -> Self {
39        let expires_at = ttl.map(|d| Instant::now() + d);
40        Self { data, expires_at }
41    }
42
43    fn is_expired(&self) -> bool {
44        self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
45    }
46}
47
48/// In-memory cache implementation
49pub struct MemoryCache {
50    store: Arc<RwLock<HashMap<String, CacheEntry>>>,
51    default_ttl: Option<Duration>,
52}
53
54impl MemoryCache {
55    pub fn new() -> Self {
56        Self {
57            store: Arc::new(RwLock::new(HashMap::new())),
58            default_ttl: Some(Duration::from_secs(3600)),
59        }
60    }
61
62    pub fn with_default_ttl(ttl: Duration) -> Self {
63        Self {
64            store: Arc::new(RwLock::new(HashMap::new())),
65            default_ttl: Some(ttl),
66        }
67    }
68
69    /// Remember a value, executing the closure if not cached
70    pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, f: F) -> Result<T>
71    where
72        T: Serialize + for<'de> Deserialize<'de> + Send + Sync,
73        F: FnOnce() -> Fut + Send,
74        Fut: std::future::Future<Output = Result<T>> + Send,
75    {
76        // Try to get from cache
77        if let Some(value) = self.get::<T>(key).await? {
78            return Ok(value);
79        }
80
81        // Execute closure and cache result
82        let value = f().await?;
83        self.set(key, &value, Some(ttl)).await?;
84        Ok(value)
85    }
86
87    /// Clean expired entries
88    async fn cleanup(&self) {
89        let mut store = self.store.write().await;
90        store.retain(|_, entry| !entry.is_expired());
91    }
92}
93
94impl Default for MemoryCache {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100#[async_trait]
101impl Cache for MemoryCache {
102    async fn get<T>(&self, key: &str) -> Result<Option<T>>
103    where
104        T: for<'de> Deserialize<'de> + Send,
105    {
106        let store = self.store.read().await;
107        
108        if let Some(entry) = store.get(key) {
109            if entry.is_expired() {
110                return Ok(None);
111            }
112            
113            let value: T = serde_json::from_slice(&entry.data)?;
114            Ok(Some(value))
115        } else {
116            Ok(None)
117        }
118    }
119
120    async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
121    where
122        T: Serialize + Send + Sync,
123    {
124        let data = serde_json::to_vec(value)?;
125        let ttl = ttl.or(self.default_ttl);
126        let entry = CacheEntry::new(data, ttl);
127
128        let mut store = self.store.write().await;
129        store.insert(key.to_string(), entry);
130        
131        Ok(())
132    }
133
134    async fn delete(&self, key: &str) -> Result<()> {
135        let mut store = self.store.write().await;
136        store.remove(key);
137        Ok(())
138    }
139
140    async fn exists(&self, key: &str) -> Result<bool> {
141        let store = self.store.read().await;
142        Ok(store.get(key).map(|e| !e.is_expired()).unwrap_or(false))
143    }
144
145    async fn flush(&self) -> Result<()> {
146        let mut store = self.store.write().await;
147        store.clear();
148        Ok(())
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[tokio::test]
157    async fn test_set_and_get() {
158        let cache = MemoryCache::new();
159        
160        cache.set("key1", &"value1", None).await.unwrap();
161        let value: Option<String> = cache.get("key1").await.unwrap();
162        
163        assert_eq!(value, Some("value1".to_string()));
164    }
165
166    #[tokio::test]
167    async fn test_expiration() {
168        let cache = MemoryCache::new();
169        
170        cache.set("key1", &"value1", Some(Duration::from_millis(100))).await.unwrap();
171        tokio::time::sleep(Duration::from_millis(200)).await;
172        
173        let value: Option<String> = cache.get("key1").await.unwrap();
174        assert_eq!(value, None);
175    }
176
177    #[tokio::test]
178    async fn test_remember() {
179        let cache = MemoryCache::new();
180        let mut call_count = 0;
181
182        let value = cache.remember("key1", Duration::from_secs(60), || async {
183            call_count += 1;
184            Ok::<_, Box<dyn std::error::Error + Send + Sync>>("computed".to_string())
185        }).await.unwrap();
186
187        assert_eq!(value, "computed");
188        assert_eq!(call_count, 1);
189
190        // Second call should use cache
191        let value2 = cache.remember("key1", Duration::from_secs(60), || async {
192            call_count += 1;
193            Ok::<_, Box<dyn std::error::Error + Send + Sync>>("computed".to_string())
194        }).await.unwrap();
195
196        assert_eq!(value2, "computed");
197        assert_eq!(call_count, 1); // Should not increment
198    }
199}