1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use thiserror::Error;
8use tokio::sync::RwLock;
9pub mod redis;
10pub use crate::redis::RedisCache;
11
12#[derive(Debug, Error)]
13pub enum CacheError {
14 #[error("invalid cache key: {0}")]
15 InvalidKey(String),
16 #[error("invalid TTL: duration must be greater than zero")]
17 InvalidTtl,
18 #[error("serialization error: {0}")]
19 Serde(#[from] serde_json::Error),
20 #[error("redis error: {0}")]
21 Redis(#[from] ::redis::RedisError),
22}
23
24pub type Result<T> = std::result::Result<T, CacheError>;
25
26#[async_trait]
28pub trait Cache: Send + Sync {
29 async fn get<T>(&self, key: &str) -> Result<Option<T>>
30 where
31 T: for<'de> Deserialize<'de> + Send;
32
33 async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
34 where
35 T: Serialize + Send + Sync;
36
37 async fn delete(&self, key: &str) -> Result<()>;
38
39 async fn exists(&self, key: &str) -> Result<bool>;
40
41 async fn flush(&self) -> Result<()>;
42}
43
44#[derive(Clone)]
46struct CacheEntry {
47 data: Vec<u8>,
48 expires_at: Option<Instant>,
49}
50
51impl CacheEntry {
52 fn new(data: Vec<u8>, ttl: Option<Duration>) -> Self {
53 let expires_at = ttl.map(|d| Instant::now() + d);
54 Self { data, expires_at }
55 }
56
57 fn is_expired(&self) -> bool {
58 self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
59 }
60}
61
62pub struct MemoryCache {
64 store: Arc<RwLock<HashMap<String, CacheEntry>>>,
65 default_ttl: Option<Duration>,
66 hits: Arc<AtomicU64>,
67 misses: Arc<AtomicU64>,
68 sets: Arc<AtomicU64>,
69 deletes: Arc<AtomicU64>,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub struct CacheStats {
74 pub hits: u64,
75 pub misses: u64,
76 pub sets: u64,
77 pub deletes: u64,
78}
79
80pub struct NamespacedCache<C> {
82 namespace: String,
83 inner: C,
84}
85
86impl<C> NamespacedCache<C> {
87 pub fn new(namespace: impl Into<String>, inner: C) -> Self {
88 Self {
89 namespace: namespace.into(),
90 inner,
91 }
92 }
93
94 fn key(&self, key: &str) -> String {
95 format!("{}:{}", self.namespace, key)
96 }
97}
98
99impl MemoryCache {
100 pub fn new() -> Self {
101 Self {
102 store: Arc::new(RwLock::new(HashMap::new())),
103 default_ttl: Some(Duration::from_secs(3600)),
104 hits: Arc::new(AtomicU64::new(0)),
105 misses: Arc::new(AtomicU64::new(0)),
106 sets: Arc::new(AtomicU64::new(0)),
107 deletes: Arc::new(AtomicU64::new(0)),
108 }
109 }
110
111 pub fn with_default_ttl(ttl: Duration) -> Self {
112 Self {
113 store: Arc::new(RwLock::new(HashMap::new())),
114 default_ttl: Some(ttl),
115 hits: Arc::new(AtomicU64::new(0)),
116 misses: Arc::new(AtomicU64::new(0)),
117 sets: Arc::new(AtomicU64::new(0)),
118 deletes: Arc::new(AtomicU64::new(0)),
119 }
120 }
121
122 pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, f: F) -> Result<T>
124 where
125 T: Serialize + for<'de> Deserialize<'de> + Send + Sync,
126 F: FnOnce() -> Fut + Send,
127 Fut: std::future::Future<Output = Result<T>> + Send,
128 {
129 validate_cache_key(key)?;
130 validate_ttl(Some(ttl))?;
131
132 if let Some(value) = self.get::<T>(key).await? {
134 return Ok(value);
135 }
136
137 let value = f().await?;
139 self.set(key, &value, Some(ttl)).await?;
140 Ok(value)
141 }
142
143 async fn cleanup(&self) {
145 let mut store = self.store.write().await;
146 store.retain(|_, entry| !entry.is_expired());
147 }
148
149 pub fn stats(&self) -> CacheStats {
151 CacheStats {
152 hits: self.hits.load(Ordering::Relaxed),
153 misses: self.misses.load(Ordering::Relaxed),
154 sets: self.sets.load(Ordering::Relaxed),
155 deletes: self.deletes.load(Ordering::Relaxed),
156 }
157 }
158
159 pub fn reset_stats(&self) {
161 self.hits.store(0, Ordering::Relaxed);
162 self.misses.store(0, Ordering::Relaxed);
163 self.sets.store(0, Ordering::Relaxed);
164 self.deletes.store(0, Ordering::Relaxed);
165 }
166}
167
168impl Default for MemoryCache {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174#[async_trait]
175impl<C: Cache> Cache for NamespacedCache<C> {
176 async fn get<T>(&self, key: &str) -> Result<Option<T>>
177 where
178 T: for<'de> Deserialize<'de> + Send,
179 {
180 self.inner.get(&self.key(key)).await
181 }
182
183 async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
184 where
185 T: Serialize + Send + Sync,
186 {
187 self.inner.set(&self.key(key), value, ttl).await
188 }
189
190 async fn delete(&self, key: &str) -> Result<()> {
191 self.inner.delete(&self.key(key)).await
192 }
193
194 async fn exists(&self, key: &str) -> Result<bool> {
195 self.inner.exists(&self.key(key)).await
196 }
197
198 async fn flush(&self) -> Result<()> {
199 self.inner.flush().await
200 }
201}
202
203#[async_trait]
204impl Cache for MemoryCache {
205 async fn get<T>(&self, key: &str) -> Result<Option<T>>
206 where
207 T: for<'de> Deserialize<'de> + Send,
208 {
209 validate_cache_key(key)?;
210 self.cleanup().await;
211 let store = self.store.read().await;
212
213 if let Some(entry) = store.get(key) {
214 if entry.is_expired() {
215 self.misses.fetch_add(1, Ordering::Relaxed);
216 return Ok(None);
217 }
218
219 let value: T = serde_json::from_slice(&entry.data)?;
220 self.hits.fetch_add(1, Ordering::Relaxed);
221 Ok(Some(value))
222 } else {
223 self.misses.fetch_add(1, Ordering::Relaxed);
224 Ok(None)
225 }
226 }
227
228 async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
229 where
230 T: Serialize + Send + Sync,
231 {
232 validate_cache_key(key)?;
233 validate_ttl(ttl)?;
234 self.cleanup().await;
235 let data = serde_json::to_vec(value)?;
236 let ttl = ttl.or(self.default_ttl);
237 let entry = CacheEntry::new(data, ttl);
238
239 let mut store = self.store.write().await;
240 store.insert(key.to_string(), entry);
241 self.sets.fetch_add(1, Ordering::Relaxed);
242
243 Ok(())
244 }
245
246 async fn delete(&self, key: &str) -> Result<()> {
247 validate_cache_key(key)?;
248 let mut store = self.store.write().await;
249 store.remove(key);
250 self.deletes.fetch_add(1, Ordering::Relaxed);
251 Ok(())
252 }
253
254 async fn exists(&self, key: &str) -> Result<bool> {
255 validate_cache_key(key)?;
256 self.cleanup().await;
257 let store = self.store.read().await;
258 Ok(store.get(key).map(|e| !e.is_expired()).unwrap_or(false))
259 }
260
261 async fn flush(&self) -> Result<()> {
262 let mut store = self.store.write().await;
263 store.clear();
264 Ok(())
265 }
266}
267
268pub(crate) fn validate_cache_key(key: &str) -> Result<()> {
269 if key.trim().is_empty() {
270 return Err(CacheError::InvalidKey(
271 "key cannot be empty".to_string(),
272 ));
273 }
274 if key.chars().any(char::is_control) {
275 return Err(CacheError::InvalidKey(
276 "key cannot contain control characters".to_string(),
277 ));
278 }
279 Ok(())
280}
281
282pub(crate) fn validate_ttl(ttl: Option<Duration>) -> Result<()> {
283 if matches!(ttl, Some(d) if d.is_zero()) {
284 return Err(CacheError::InvalidTtl);
285 }
286 Ok(())
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[tokio::test]
294 async fn test_set_and_get() {
295 let cache = MemoryCache::new();
296
297 cache.set("key1", &"value1", None).await.unwrap();
298 let value: Option<String> = cache.get("key1").await.unwrap();
299
300 assert_eq!(value, Some("value1".to_string()));
301 }
302
303 #[tokio::test]
304 async fn test_expiration() {
305 let cache = MemoryCache::new();
306
307 cache.set("key1", &"value1", Some(Duration::from_millis(100))).await.unwrap();
308 tokio::time::sleep(Duration::from_millis(200)).await;
309
310 let value: Option<String> = cache.get("key1").await.unwrap();
311 assert_eq!(value, None);
312 }
313
314 #[tokio::test]
315 async fn test_remember() {
316 let cache = MemoryCache::new();
317 let mut call_count = 0;
318
319 let value = cache.remember("key1", Duration::from_secs(60), || async {
320 call_count += 1;
321 Ok::<_, CacheError>("computed".to_string())
322 }).await.unwrap();
323
324 assert_eq!(value, "computed");
325 assert_eq!(call_count, 1);
326
327 let value2 = cache.remember("key1", Duration::from_secs(60), || async {
329 call_count += 1;
330 Ok::<_, CacheError>("computed".to_string())
331 }).await.unwrap();
332
333 assert_eq!(value2, "computed");
334 assert_eq!(call_count, 1); }
336
337 #[tokio::test]
338 async fn test_reject_empty_key() {
339 let cache = MemoryCache::new();
340 let result = cache.set("", &"value", None).await;
341 assert!(result.is_err());
342 }
343
344 #[tokio::test]
345 async fn test_reject_zero_ttl() {
346 let cache = MemoryCache::new();
347 let result = cache
348 .set("k", &"value", Some(Duration::from_secs(0)))
349 .await;
350 assert!(result.is_err());
351 }
352
353 #[tokio::test]
354 async fn test_namespaced_cache_prefixes_keys() {
355 let base = MemoryCache::new();
356 let scoped = NamespacedCache::new("users", base);
357
358 scoped.set("1", &"Alice", None).await.expect("set");
359 let value: Option<String> = scoped.get("1").await.expect("get");
360 assert_eq!(value.as_deref(), Some("Alice"));
361 }
362
363 #[tokio::test]
364 async fn test_memory_cache_stats() {
365 let cache = MemoryCache::new();
366 cache.set("k", &"v", None).await.expect("set");
367 let _v: Option<String> = cache.get("k").await.expect("get");
368 let _missing: Option<String> = cache.get("missing").await.expect("get");
369 cache.delete("k").await.expect("delete");
370
371 let stats = cache.stats();
372 assert_eq!(stats.sets, 1);
373 assert_eq!(stats.hits, 1);
374 assert_eq!(stats.misses, 1);
375 assert_eq!(stats.deletes, 1);
376 }
377}