Skip to main content

oxidite_cache/
lib.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use thiserror::Error;
8use tokio::sync::RwLock;
9pub mod redis;
10pub use crate::redis::RedisCache;
11
12#[derive(Debug, Error)]
13pub enum CacheError {
14    #[error("invalid cache key: {0}")]
15    InvalidKey(String),
16    #[error("invalid TTL: duration must be greater than zero")]
17    InvalidTtl,
18    #[error("serialization error: {0}")]
19    Serde(#[from] serde_json::Error),
20    #[error("redis error: {0}")]
21    Redis(#[from] ::redis::RedisError),
22}
23
24pub type Result<T> = std::result::Result<T, CacheError>;
25
26/// Cache trait
27#[async_trait]
28pub trait Cache: Send + Sync {
29    async fn get<T>(&self, key: &str) -> Result<Option<T>>
30    where
31        T: for<'de> Deserialize<'de> + Send;
32
33    async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
34    where
35        T: Serialize + Send + Sync;
36
37    async fn delete(&self, key: &str) -> Result<()>;
38    
39    async fn exists(&self, key: &str) -> Result<bool>;
40    
41    async fn flush(&self) -> Result<()>;
42}
43
44/// Cache entry with expiration
45#[derive(Clone)]
46struct CacheEntry {
47    data: Vec<u8>,
48    expires_at: Option<Instant>,
49}
50
51impl CacheEntry {
52    fn new(data: Vec<u8>, ttl: Option<Duration>) -> Self {
53        let expires_at = ttl.map(|d| Instant::now() + d);
54        Self { data, expires_at }
55    }
56
57    fn is_expired(&self) -> bool {
58        self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
59    }
60}
61
62/// In-memory cache implementation
63pub struct MemoryCache {
64    store: Arc<RwLock<HashMap<String, CacheEntry>>>,
65    default_ttl: Option<Duration>,
66    hits: Arc<AtomicU64>,
67    misses: Arc<AtomicU64>,
68    sets: Arc<AtomicU64>,
69    deletes: Arc<AtomicU64>,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub struct CacheStats {
74    pub hits: u64,
75    pub misses: u64,
76    pub sets: u64,
77    pub deletes: u64,
78}
79
80/// Namespaced cache wrapper to prevent key collisions across domains.
81pub struct NamespacedCache<C> {
82    namespace: String,
83    inner: C,
84}
85
86impl<C> NamespacedCache<C> {
87    pub fn new(namespace: impl Into<String>, inner: C) -> Self {
88        Self {
89            namespace: namespace.into(),
90            inner,
91        }
92    }
93
94    fn key(&self, key: &str) -> String {
95        format!("{}:{}", self.namespace, key)
96    }
97}
98
99impl MemoryCache {
100    pub fn new() -> Self {
101        Self {
102            store: Arc::new(RwLock::new(HashMap::new())),
103            default_ttl: Some(Duration::from_secs(3600)),
104            hits: Arc::new(AtomicU64::new(0)),
105            misses: Arc::new(AtomicU64::new(0)),
106            sets: Arc::new(AtomicU64::new(0)),
107            deletes: Arc::new(AtomicU64::new(0)),
108        }
109    }
110
111    pub fn with_default_ttl(ttl: Duration) -> Self {
112        Self {
113            store: Arc::new(RwLock::new(HashMap::new())),
114            default_ttl: Some(ttl),
115            hits: Arc::new(AtomicU64::new(0)),
116            misses: Arc::new(AtomicU64::new(0)),
117            sets: Arc::new(AtomicU64::new(0)),
118            deletes: Arc::new(AtomicU64::new(0)),
119        }
120    }
121
122    /// Remember a value, executing the closure if not cached
123    pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, f: F) -> Result<T>
124    where
125        T: Serialize + for<'de> Deserialize<'de> + Send + Sync,
126        F: FnOnce() -> Fut + Send,
127        Fut: std::future::Future<Output = Result<T>> + Send,
128    {
129        validate_cache_key(key)?;
130        validate_ttl(Some(ttl))?;
131
132        // Try to get from cache
133        if let Some(value) = self.get::<T>(key).await? {
134            return Ok(value);
135        }
136
137        // Execute closure and cache result
138        let value = f().await?;
139        self.set(key, &value, Some(ttl)).await?;
140        Ok(value)
141    }
142
143    /// Clean expired entries
144    async fn cleanup(&self) {
145        let mut store = self.store.write().await;
146        store.retain(|_, entry| !entry.is_expired());
147    }
148
149    /// Snapshot in-memory cache operation counters.
150    pub fn stats(&self) -> CacheStats {
151        CacheStats {
152            hits: self.hits.load(Ordering::Relaxed),
153            misses: self.misses.load(Ordering::Relaxed),
154            sets: self.sets.load(Ordering::Relaxed),
155            deletes: self.deletes.load(Ordering::Relaxed),
156        }
157    }
158
159    /// Reset cache operation counters.
160    pub fn reset_stats(&self) {
161        self.hits.store(0, Ordering::Relaxed);
162        self.misses.store(0, Ordering::Relaxed);
163        self.sets.store(0, Ordering::Relaxed);
164        self.deletes.store(0, Ordering::Relaxed);
165    }
166}
167
168impl Default for MemoryCache {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174#[async_trait]
175impl<C: Cache> Cache for NamespacedCache<C> {
176    async fn get<T>(&self, key: &str) -> Result<Option<T>>
177    where
178        T: for<'de> Deserialize<'de> + Send,
179    {
180        self.inner.get(&self.key(key)).await
181    }
182
183    async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
184    where
185        T: Serialize + Send + Sync,
186    {
187        self.inner.set(&self.key(key), value, ttl).await
188    }
189
190    async fn delete(&self, key: &str) -> Result<()> {
191        self.inner.delete(&self.key(key)).await
192    }
193
194    async fn exists(&self, key: &str) -> Result<bool> {
195        self.inner.exists(&self.key(key)).await
196    }
197
198    async fn flush(&self) -> Result<()> {
199        self.inner.flush().await
200    }
201}
202
203#[async_trait]
204impl Cache for MemoryCache {
205    async fn get<T>(&self, key: &str) -> Result<Option<T>>
206    where
207        T: for<'de> Deserialize<'de> + Send,
208    {
209        validate_cache_key(key)?;
210        self.cleanup().await;
211        let store = self.store.read().await;
212        
213        if let Some(entry) = store.get(key) {
214            if entry.is_expired() {
215                self.misses.fetch_add(1, Ordering::Relaxed);
216                return Ok(None);
217            }
218            
219            let value: T = serde_json::from_slice(&entry.data)?;
220            self.hits.fetch_add(1, Ordering::Relaxed);
221            Ok(Some(value))
222        } else {
223            self.misses.fetch_add(1, Ordering::Relaxed);
224            Ok(None)
225        }
226    }
227
228    async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
229    where
230        T: Serialize + Send + Sync,
231    {
232        validate_cache_key(key)?;
233        validate_ttl(ttl)?;
234        self.cleanup().await;
235        let data = serde_json::to_vec(value)?;
236        let ttl = ttl.or(self.default_ttl);
237        let entry = CacheEntry::new(data, ttl);
238
239        let mut store = self.store.write().await;
240        store.insert(key.to_string(), entry);
241        self.sets.fetch_add(1, Ordering::Relaxed);
242        
243        Ok(())
244    }
245
246    async fn delete(&self, key: &str) -> Result<()> {
247        validate_cache_key(key)?;
248        let mut store = self.store.write().await;
249        store.remove(key);
250        self.deletes.fetch_add(1, Ordering::Relaxed);
251        Ok(())
252    }
253
254    async fn exists(&self, key: &str) -> Result<bool> {
255        validate_cache_key(key)?;
256        self.cleanup().await;
257        let store = self.store.read().await;
258        Ok(store.get(key).map(|e| !e.is_expired()).unwrap_or(false))
259    }
260
261    async fn flush(&self) -> Result<()> {
262        let mut store = self.store.write().await;
263        store.clear();
264        Ok(())
265    }
266}
267
268pub(crate) fn validate_cache_key(key: &str) -> Result<()> {
269    if key.trim().is_empty() {
270        return Err(CacheError::InvalidKey(
271            "key cannot be empty".to_string(),
272        ));
273    }
274    if key.chars().any(char::is_control) {
275        return Err(CacheError::InvalidKey(
276            "key cannot contain control characters".to_string(),
277        ));
278    }
279    Ok(())
280}
281
282pub(crate) fn validate_ttl(ttl: Option<Duration>) -> Result<()> {
283    if matches!(ttl, Some(d) if d.is_zero()) {
284        return Err(CacheError::InvalidTtl);
285    }
286    Ok(())
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[tokio::test]
294    async fn test_set_and_get() {
295        let cache = MemoryCache::new();
296        
297        cache.set("key1", &"value1", None).await.unwrap();
298        let value: Option<String> = cache.get("key1").await.unwrap();
299        
300        assert_eq!(value, Some("value1".to_string()));
301    }
302
303    #[tokio::test]
304    async fn test_expiration() {
305        let cache = MemoryCache::new();
306        
307        cache.set("key1", &"value1", Some(Duration::from_millis(100))).await.unwrap();
308        tokio::time::sleep(Duration::from_millis(200)).await;
309        
310        let value: Option<String> = cache.get("key1").await.unwrap();
311        assert_eq!(value, None);
312    }
313
314    #[tokio::test]
315    async fn test_remember() {
316        let cache = MemoryCache::new();
317        let mut call_count = 0;
318
319        let value = cache.remember("key1", Duration::from_secs(60), || async {
320            call_count += 1;
321            Ok::<_, CacheError>("computed".to_string())
322        }).await.unwrap();
323
324        assert_eq!(value, "computed");
325        assert_eq!(call_count, 1);
326
327        // Second call should use cache
328        let value2 = cache.remember("key1", Duration::from_secs(60), || async {
329            call_count += 1;
330            Ok::<_, CacheError>("computed".to_string())
331        }).await.unwrap();
332
333        assert_eq!(value2, "computed");
334        assert_eq!(call_count, 1); // Should not increment
335    }
336
337    #[tokio::test]
338    async fn test_reject_empty_key() {
339        let cache = MemoryCache::new();
340        let result = cache.set("", &"value", None).await;
341        assert!(result.is_err());
342    }
343
344    #[tokio::test]
345    async fn test_reject_zero_ttl() {
346        let cache = MemoryCache::new();
347        let result = cache
348            .set("k", &"value", Some(Duration::from_secs(0)))
349            .await;
350        assert!(result.is_err());
351    }
352
353    #[tokio::test]
354    async fn test_namespaced_cache_prefixes_keys() {
355        let base = MemoryCache::new();
356        let scoped = NamespacedCache::new("users", base);
357
358        scoped.set("1", &"Alice", None).await.expect("set");
359        let value: Option<String> = scoped.get("1").await.expect("get");
360        assert_eq!(value.as_deref(), Some("Alice"));
361    }
362
363    #[tokio::test]
364    async fn test_memory_cache_stats() {
365        let cache = MemoryCache::new();
366        cache.set("k", &"v", None).await.expect("set");
367        let _v: Option<String> = cache.get("k").await.expect("get");
368        let _missing: Option<String> = cache.get("missing").await.expect("get");
369        cache.delete("k").await.expect("delete");
370
371        let stats = cache.stats();
372        assert_eq!(stats.sets, 1);
373        assert_eq!(stats.hits, 1);
374        assert_eq!(stats.misses, 1);
375        assert_eq!(stats.deletes, 1);
376    }
377}