1use async_trait::async_trait;
28use dashmap::DashMap;
29use std::sync::Arc;
30use std::time::Instant;
31
32#[derive(Debug)]
36pub enum CacheError {
37 Driver(String),
39 Serialization(String),
41}
42
43impl std::fmt::Display for CacheError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 CacheError::Driver(msg) => write!(f, "Cache driver error: {}", msg),
47 CacheError::Serialization(msg) => write!(f, "Cache serialization error: {}", msg),
48 }
49 }
50}
51
52impl std::error::Error for CacheError {}
53
54#[async_trait]
61pub trait CacheDriver: Send + Sync {
62 async fn get(&self, key: &str) -> Result<Option<Arc<String>>, CacheError>;
64 async fn put(&self, key: &str, value: &str, ttl_secs: Option<u64>) -> Result<(), CacheError>;
66 async fn forget(&self, key: &str) -> Result<(), CacheError>;
68 async fn flush(&self) -> Result<(), CacheError>;
70 async fn has(&self, key: &str) -> Result<bool, CacheError>;
72}
73
74#[derive(Clone)]
78struct CacheEntry {
79 value: Arc<String>,
80 expires_at: Option<Instant>,
81}
82
83pub struct MemoryDriver {
88 store: DashMap<String, CacheEntry>,
89}
90
91impl MemoryDriver {
92 pub fn new() -> Self {
94 let store: DashMap<String, CacheEntry> = DashMap::new();
95
96 if tokio::runtime::Handle::try_current().is_ok() {
98 let store_clone = store.clone();
99 tokio::spawn(async move {
100 loop {
101 tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
102 store_clone.retain(|_, entry| {
104 entry.expires_at.map_or(true, |exp| Instant::now() < exp)
105 });
106 }
107 });
108 }
109
110 Self { store }
111 }
112}
113
114impl Default for MemoryDriver {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120#[async_trait]
121impl CacheDriver for MemoryDriver {
122 async fn get(&self, key: &str) -> Result<Option<Arc<String>>, CacheError> {
123 if let Some(entry) = self.store.get(key) {
124 if let Some(expires_at) = entry.expires_at
126 && Instant::now() > expires_at
127 {
128 drop(entry);
130 self.store.remove(key);
131 return Ok(None);
132 }
133 Ok(Some(entry.value.clone()))
135 } else {
136 Ok(None)
137 }
138 }
139
140 async fn put(&self, key: &str, value: &str, ttl_secs: Option<u64>) -> Result<(), CacheError> {
141 let expires_at = ttl_secs.map(|secs| Instant::now() + std::time::Duration::from_secs(secs));
142 self.store.insert(
143 key.to_string(),
144 CacheEntry {
145 value: Arc::new(value.to_string()),
146 expires_at,
147 },
148 );
149 Ok(())
150 }
151
152 async fn forget(&self, key: &str) -> Result<(), CacheError> {
153 self.store.remove(key);
154 Ok(())
155 }
156
157 async fn flush(&self) -> Result<(), CacheError> {
158 self.store.clear();
159 Ok(())
160 }
161
162 async fn has(&self, key: &str) -> Result<bool, CacheError> {
163 Ok(self.get(key).await?.is_some())
164 }
165}
166
167#[cfg(feature = "cache-redis")]
170pub mod redis_driver {
171 use super::*;
173
174 pub struct RedisDriver {
179 client: redis::Client,
180 prefix: String,
181 }
182
183 impl RedisDriver {
184 pub fn new(redis_url: &str) -> Result<Self, CacheError> {
188 let client = redis::Client::open(redis_url)
189 .map_err(|e| CacheError::Driver(format!("Failed to connect to Redis: {}", e)))?;
190 Ok(Self {
191 client,
192 prefix: "rullst:cache:".to_string(),
193 })
194 }
195
196 fn prefixed_key(&self, key: &str) -> String {
197 format!("{}{}", self.prefix, key)
198 }
199 }
200
201 #[async_trait]
202 impl CacheDriver for RedisDriver {
203 async fn get(&self, key: &str) -> Result<Option<Arc<String>>, CacheError> {
204 let mut con = self
205 .client
206 .get_multiplexed_async_connection()
207 .await
208 .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
209 let result: Option<String> = redis::cmd("GET")
210 .arg(self.prefixed_key(key))
211 .query_async(&mut con)
212 .await
213 .map_err(|e| CacheError::Driver(format!("Redis GET failed: {}", e)))?;
214 Ok(result.map(Arc::new))
215 }
216
217 async fn put(
218 &self,
219 key: &str,
220 value: &str,
221 ttl_secs: Option<u64>,
222 ) -> Result<(), CacheError> {
223 let mut con = self
224 .client
225 .get_multiplexed_async_connection()
226 .await
227 .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
228 let pk = self.prefixed_key(key);
229 if let Some(ttl) = ttl_secs {
230 redis::cmd("SETEX")
231 .arg(&pk)
232 .arg(ttl as i64)
233 .arg(value)
234 .query_async::<()>(&mut con)
235 .await
236 .map_err(|e| CacheError::Driver(format!("Redis SETEX failed: {}", e)))?;
237 } else {
238 redis::cmd("SET")
239 .arg(&pk)
240 .arg(value)
241 .query_async::<()>(&mut con)
242 .await
243 .map_err(|e| CacheError::Driver(format!("Redis SET failed: {}", e)))?;
244 }
245 Ok(())
246 }
247
248 async fn forget(&self, key: &str) -> Result<(), CacheError> {
249 let mut con = self
250 .client
251 .get_multiplexed_async_connection()
252 .await
253 .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
254 redis::cmd("UNLINK")
255 .arg(self.prefixed_key(key))
256 .query_async::<i64>(&mut con)
257 .await
258 .map_err(|e| CacheError::Driver(format!("Redis UNLINK failed: {}", e)))?;
259 Ok(())
260 }
261
262 async fn flush(&self) -> Result<(), CacheError> {
263 let mut con = self
264 .client
265 .get_multiplexed_async_connection()
266 .await
267 .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
268 let pattern = format!("{}*", self.prefix);
269 let mut cursor: u64 = 0;
270 loop {
271 let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
272 .arg(cursor)
273 .arg("MATCH")
274 .arg(&pattern)
275 .arg("COUNT")
276 .arg(100)
277 .query_async(&mut con)
278 .await
279 .map_err(|e| CacheError::Driver(format!("Redis SCAN failed: {}", e)))?;
280
281 if !keys.is_empty() {
282 redis::cmd("UNLINK")
283 .arg(&keys)
284 .query_async::<i64>(&mut con)
285 .await
286 .map_err(|e| CacheError::Driver(format!("Redis UNLINK failed: {}", e)))?;
287 }
288
289 cursor = next_cursor;
290 if cursor == 0 {
291 break;
292 }
293 }
294 Ok(())
295 }
296
297 async fn has(&self, key: &str) -> Result<bool, CacheError> {
298 let mut con = self
299 .client
300 .get_multiplexed_async_connection()
301 .await
302 .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
303 let exists: bool = redis::cmd("EXISTS")
304 .arg(self.prefixed_key(key))
305 .query_async(&mut con)
306 .await
307 .map_err(|e| CacheError::Driver(format!("Redis EXISTS failed: {}", e)))?;
308 Ok(exists)
309 }
310 }
311}
312
313pub struct Cache {
323 driver: Arc<Box<dyn CacheDriver>>,
324}
325
326impl Cache {
327 pub fn memory() -> Self {
331 Self {
332 driver: Arc::new(Box::new(MemoryDriver::new())),
333 }
334 }
335
336 #[cfg(feature = "cache-redis")]
340 pub fn redis(redis_url: &str) -> Result<Self, CacheError> {
341 let driver = redis_driver::RedisDriver::new(redis_url)?;
342 Ok(Self {
343 driver: Arc::new(Box::new(driver)),
344 })
345 }
346
347 pub fn custom(driver: Box<dyn CacheDriver>) -> Self {
349 Self {
350 driver: Arc::new(driver),
351 }
352 }
353
354 pub async fn get(&self, key: &str) -> Result<Option<Arc<String>>, CacheError> {
356 self.driver.get(key).await
357 }
358
359 pub async fn put(
363 &self,
364 key: &str,
365 value: &str,
366 ttl_secs: Option<u64>,
367 ) -> Result<(), CacheError> {
368 self.driver.put(key, value, ttl_secs).await
369 }
370
371 pub async fn forget(&self, key: &str) -> Result<(), CacheError> {
373 self.driver.forget(key).await
374 }
375
376 pub async fn flush(&self) -> Result<(), CacheError> {
378 self.driver.flush().await
379 }
380
381 pub async fn has(&self, key: &str) -> Result<bool, CacheError> {
383 self.driver.has(key).await
384 }
385
386 pub async fn remember<F, Fut>(
398 &self,
399 key: &str,
400 ttl_secs: u64,
401 f: F,
402 ) -> Result<Arc<String>, CacheError>
403 where
404 F: FnOnce() -> Fut,
405 Fut: std::future::Future<Output = Result<String, CacheError>>,
406 {
407 if let Some(cached) = self.get(key).await? {
409 return Ok(cached);
410 }
411 let value = f().await?;
413 self.put(key, &value, Some(ttl_secs)).await?;
415 Ok(Arc::new(value))
416 }
417}
418
419#[cfg(test)]
422#[allow(clippy::unwrap_used, clippy::expect_used)]
423mod tests {
424 use super::*;
425
426 #[tokio::test]
427 async fn test_memory_cache_put_get() {
428 let cache = Cache::memory();
429 cache.put("key1", "value1", None).await.unwrap();
430 let result = cache.get("key1").await.unwrap();
431 assert_eq!(result, Some(Arc::new("value1".to_string())));
432 }
433
434 #[tokio::test]
435 async fn test_memory_cache_miss() {
436 let cache = Cache::memory();
437 let result = cache.get("nonexistent").await.unwrap();
438 assert!(result.is_none());
439 }
440
441 #[tokio::test]
442 async fn test_memory_cache_forget() {
443 let cache = Cache::memory();
444 cache.put("key1", "value1", None).await.unwrap();
445 cache.forget("key1").await.unwrap();
446 let result = cache.get("key1").await.unwrap();
447 assert!(result.is_none());
448 }
449
450 #[tokio::test]
451 async fn test_memory_cache_flush() {
452 let cache = Cache::memory();
453 cache.put("a", "1", None).await.unwrap();
454 cache.put("b", "2", None).await.unwrap();
455 cache.flush().await.unwrap();
456 assert!(cache.get("a").await.unwrap().is_none());
457 assert!(cache.get("b").await.unwrap().is_none());
458 }
459
460 #[tokio::test]
461 async fn test_memory_cache_has() {
462 let cache = Cache::memory();
463 assert!(!cache.has("key1").await.unwrap());
464 cache.put("key1", "value1", None).await.unwrap();
465 assert!(cache.has("key1").await.unwrap());
466 }
467
468 #[tokio::test]
469 async fn test_memory_cache_remember_miss() {
470 let cache = Cache::memory();
471 let value = cache
472 .remember("computed", 60, || async { Ok("hello".to_string()) })
473 .await
474 .unwrap();
475 assert_eq!(*value, "hello");
476 let cached = cache.get("computed").await.unwrap();
478 assert_eq!(cached, Some(Arc::new("hello".to_string())));
479 }
480
481 #[tokio::test]
482 async fn test_memory_cache_remember_hit() {
483 let cache = Cache::memory();
484 cache
485 .put("existing", "already_cached", Some(300))
486 .await
487 .unwrap();
488 let value = cache
489 .remember("existing", 60, || async {
490 panic!("This closure should NOT be called on cache hit");
491 })
492 .await
493 .unwrap();
494 assert_eq!(*value, "already_cached");
495 }
496
497 #[tokio::test]
498 async fn test_memory_cache_overwrite() {
499 let cache = Cache::memory();
500 cache.put("key", "v1", None).await.unwrap();
501 cache.put("key", "v2", None).await.unwrap();
502 assert_eq!(
503 cache.get("key").await.unwrap(),
504 Some(Arc::new("v2".to_string()))
505 );
506 }
507
508 struct MockDriver;
509 #[async_trait]
510 impl CacheDriver for MockDriver {
511 async fn get(&self, _key: &str) -> Result<Option<Arc<String>>, CacheError> {
512 Ok(Some(Arc::new("mocked".to_string())))
513 }
514 async fn put(&self, _k: &str, _v: &str, _t: Option<u64>) -> Result<(), CacheError> {
515 Ok(())
516 }
517 async fn forget(&self, _k: &str) -> Result<(), CacheError> {
518 Ok(())
519 }
520 async fn flush(&self) -> Result<(), CacheError> {
521 Ok(())
522 }
523 async fn has(&self, _k: &str) -> Result<bool, CacheError> {
524 Ok(true)
525 }
526 }
527
528 #[tokio::test]
529 async fn test_custom_cache_driver() {
530 let cache = Cache::custom(Box::new(MockDriver));
531 let result = cache.get("anything").await.unwrap();
532 assert_eq!(result, Some(Arc::new("mocked".to_string())));
533 }
534
535 #[cfg(feature = "cache-redis")]
536 #[test]
537 fn test_redis_cache_initialization() {
538 let result = Cache::redis("invalid-url-format://host:9999");
541 assert!(result.is_err());
542 }
543}