Skip to main content

rullst/
cache.rs

1//! # Rullst Cache System (`rullst::cache`)
2//!
3//! Provides a unified caching API with pluggable drivers.
4//!
5//! ## Drivers
6//! - **In-Memory** (default): `DashMap`-based concurrent store with TTL support. Zero config.
7//! - **Redis** (optional): Requires the `cache-redis` feature flag.
8//!
9//! ## Quick Start
10//! ```rust,ignore
11//! use rullst::cache::Cache;
12//!
13//! let cache = Cache::memory();
14//!
15//! // Store a value with 60-second TTL
16//! cache.put("user:42:name", "Alice", Some(60)).await?;
17//!
18//! // Retrieve it
19//! let name = cache.get("user:42:name").await?; // Some("Alice")
20//!
21//! // Cache-aside pattern: fetch from cache or compute + store
22//! let value = cache.remember("expensive_key", 300, || async {
23//!     Ok("computed_value".to_string())
24//! }).await?;
25//! ```
26
27use async_trait::async_trait;
28use dashmap::DashMap;
29use std::sync::Arc;
30use std::time::Instant;
31
32// ─── Error Types ────────────────────────────────────────────────────────────
33
34/// Errors that can occur during cache operations.
35#[derive(Debug)]
36pub enum CacheError {
37    /// The underlying driver encountered an error.
38    Driver(String),
39    /// Serialization or deserialization failed.
40    Serialization(String),
41}
42
43impl std::fmt::Display for CacheError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            CacheError::Driver(msg) => write!(f, "Cache driver error: {}", msg),
47            CacheError::Serialization(msg) => write!(f, "Cache serialization error: {}", msg),
48        }
49    }
50}
51
52impl std::error::Error for CacheError {}
53
54// ─── Cache Driver Trait ─────────────────────────────────────────────────────
55
56/// Abstraction over cache storage backends.
57///
58/// Implement this trait to add support for new cache backends.
59/// The framework ships with `MemoryDriver` and (optionally) `RedisDriver`.
60#[async_trait]
61pub trait CacheDriver: Send + Sync {
62    /// Retrieve a value by key. Returns `None` if the key doesn't exist or has expired.
63    async fn get(&self, key: &str) -> Result<Option<Arc<String>>, CacheError>;
64    /// Store a value with an optional TTL in seconds.
65    async fn put(&self, key: &str, value: &str, ttl_secs: Option<u64>) -> Result<(), CacheError>;
66    /// Remove a key from the cache.
67    async fn forget(&self, key: &str) -> Result<(), CacheError>;
68    /// Remove all keys from the cache.
69    async fn flush(&self) -> Result<(), CacheError>;
70    /// Check if a key exists and is not expired.
71    async fn has(&self, key: &str) -> Result<bool, CacheError>;
72}
73
74// ─── In-Memory Driver ───────────────────────────────────────────────────────
75
76/// Cache entry holding the value and optional expiration time.
77#[derive(Clone)]
78struct CacheEntry {
79    value: Arc<String>,
80    expires_at: Option<Instant>,
81}
82
83/// In-memory cache driver using `DashMap` for lock-free concurrent access.
84///
85/// Supports TTL-based expiration. Expired entries are lazily cleaned on access.
86/// Perfect for single-instance deployments and development.
87pub struct MemoryDriver {
88    store: DashMap<String, CacheEntry>,
89}
90
91impl MemoryDriver {
92    /// Create a new in-memory cache driver.
93    pub fn new() -> Self {
94        let store: DashMap<String, CacheEntry> = DashMap::new();
95
96        // Spawn active background janitor task to clean up expired cache entries from memory
97        if tokio::runtime::Handle::try_current().is_ok() {
98            let store_clone = store.clone();
99            tokio::spawn(async move {
100                loop {
101                    tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
102                    // Retain only unexpired or eternal entries
103                    store_clone.retain(|_, entry| {
104                        entry.expires_at.map_or(true, |exp| Instant::now() < exp)
105                    });
106                }
107            });
108        }
109
110        Self { store }
111    }
112}
113
114impl Default for MemoryDriver {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120#[async_trait]
121impl CacheDriver for MemoryDriver {
122    async fn get(&self, key: &str) -> Result<Option<Arc<String>>, CacheError> {
123        if let Some(entry) = self.store.get(key) {
124            // Check TTL expiration
125            if let Some(expires_at) = entry.expires_at
126                && Instant::now() > expires_at
127            {
128                // Entry has expired — remove it lazily
129                drop(entry);
130                self.store.remove(key);
131                return Ok(None);
132            }
133            // Cheap pointer clone instead of deep string copy
134            Ok(Some(entry.value.clone()))
135        } else {
136            Ok(None)
137        }
138    }
139
140    async fn put(&self, key: &str, value: &str, ttl_secs: Option<u64>) -> Result<(), CacheError> {
141        let expires_at = ttl_secs.map(|secs| Instant::now() + std::time::Duration::from_secs(secs));
142        self.store.insert(
143            key.to_string(),
144            CacheEntry {
145                value: Arc::new(value.to_string()),
146                expires_at,
147            },
148        );
149        Ok(())
150    }
151
152    async fn forget(&self, key: &str) -> Result<(), CacheError> {
153        self.store.remove(key);
154        Ok(())
155    }
156
157    async fn flush(&self) -> Result<(), CacheError> {
158        self.store.clear();
159        Ok(())
160    }
161
162    async fn has(&self, key: &str) -> Result<bool, CacheError> {
163        Ok(self.get(key).await?.is_some())
164    }
165}
166
167// ─── Redis Driver (behind feature flag) ─────────────────────────────────────
168
169#[cfg(feature = "cache-redis")]
170pub mod redis_driver {
171    //! Redis-backed cache driver. Requires the `cache-redis` feature.
172    use super::*;
173
174    /// Cache driver backed by Redis.
175    ///
176    /// Uses `SET`/`GET` with `EX` for TTL support. Ideal for distributed
177    /// multi-instance deployments where cache must be shared.
178    pub struct RedisDriver {
179        client: redis::Client,
180        prefix: String,
181    }
182
183    impl RedisDriver {
184        /// Create a new Redis cache driver.
185        ///
186        /// All keys are prefixed with `rullst:cache:` to avoid collisions.
187        pub fn new(redis_url: &str) -> Result<Self, CacheError> {
188            let client = redis::Client::open(redis_url)
189                .map_err(|e| CacheError::Driver(format!("Failed to connect to Redis: {}", e)))?;
190            Ok(Self {
191                client,
192                prefix: "rullst:cache:".to_string(),
193            })
194        }
195
196        fn prefixed_key(&self, key: &str) -> String {
197            format!("{}{}", self.prefix, key)
198        }
199    }
200
201    #[async_trait]
202    impl CacheDriver for RedisDriver {
203        async fn get(&self, key: &str) -> Result<Option<Arc<String>>, CacheError> {
204            let mut con = self
205                .client
206                .get_multiplexed_async_connection()
207                .await
208                .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
209            let result: Option<String> = redis::cmd("GET")
210                .arg(self.prefixed_key(key))
211                .query_async(&mut con)
212                .await
213                .map_err(|e| CacheError::Driver(format!("Redis GET failed: {}", e)))?;
214            Ok(result.map(Arc::new))
215        }
216
217        async fn put(
218            &self,
219            key: &str,
220            value: &str,
221            ttl_secs: Option<u64>,
222        ) -> Result<(), CacheError> {
223            let mut con = self
224                .client
225                .get_multiplexed_async_connection()
226                .await
227                .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
228            let pk = self.prefixed_key(key);
229            if let Some(ttl) = ttl_secs {
230                redis::cmd("SETEX")
231                    .arg(&pk)
232                    .arg(ttl as i64)
233                    .arg(value)
234                    .query_async::<()>(&mut con)
235                    .await
236                    .map_err(|e| CacheError::Driver(format!("Redis SETEX failed: {}", e)))?;
237            } else {
238                redis::cmd("SET")
239                    .arg(&pk)
240                    .arg(value)
241                    .query_async::<()>(&mut con)
242                    .await
243                    .map_err(|e| CacheError::Driver(format!("Redis SET failed: {}", e)))?;
244            }
245            Ok(())
246        }
247
248        async fn forget(&self, key: &str) -> Result<(), CacheError> {
249            let mut con = self
250                .client
251                .get_multiplexed_async_connection()
252                .await
253                .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
254            redis::cmd("UNLINK")
255                .arg(self.prefixed_key(key))
256                .query_async::<i64>(&mut con)
257                .await
258                .map_err(|e| CacheError::Driver(format!("Redis UNLINK failed: {}", e)))?;
259            Ok(())
260        }
261
262        async fn flush(&self) -> Result<(), CacheError> {
263            let mut con = self
264                .client
265                .get_multiplexed_async_connection()
266                .await
267                .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
268            let pattern = format!("{}*", self.prefix);
269            let mut cursor: u64 = 0;
270            loop {
271                let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
272                    .arg(cursor)
273                    .arg("MATCH")
274                    .arg(&pattern)
275                    .arg("COUNT")
276                    .arg(100)
277                    .query_async(&mut con)
278                    .await
279                    .map_err(|e| CacheError::Driver(format!("Redis SCAN failed: {}", e)))?;
280
281                if !keys.is_empty() {
282                    redis::cmd("UNLINK")
283                        .arg(&keys)
284                        .query_async::<i64>(&mut con)
285                        .await
286                        .map_err(|e| CacheError::Driver(format!("Redis UNLINK failed: {}", e)))?;
287                }
288
289                cursor = next_cursor;
290                if cursor == 0 {
291                    break;
292                }
293            }
294            Ok(())
295        }
296
297        async fn has(&self, key: &str) -> Result<bool, CacheError> {
298            let mut con = self
299                .client
300                .get_multiplexed_async_connection()
301                .await
302                .map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
303            let exists: bool = redis::cmd("EXISTS")
304                .arg(self.prefixed_key(key))
305                .query_async(&mut con)
306                .await
307                .map_err(|e| CacheError::Driver(format!("Redis EXISTS failed: {}", e)))?;
308            Ok(exists)
309        }
310    }
311}
312
313// ─── Cache Facade ───────────────────────────────────────────────────────────
314
315/// The main cache facade for storing and retrieving cached values.
316///
317/// Provides a driver-agnostic API. Create with `Cache::memory()` or `Cache::redis()`.
318///
319/// # Thread Safety
320/// The `Cache` is `Send + Sync` and can be safely shared across async tasks
321/// and Axum handlers via `Arc` or Axum's `State`.
322pub struct Cache {
323    driver: Arc<Box<dyn CacheDriver>>,
324}
325
326impl Cache {
327    /// Create a cache backed by an in-memory `DashMap`. Zero configuration.
328    ///
329    /// Data is lost on process restart. Perfect for development and single-instance apps.
330    pub fn memory() -> Self {
331        Self {
332            driver: Arc::new(Box::new(MemoryDriver::new())),
333        }
334    }
335
336    /// Create a cache backed by Redis. Requires the `cache-redis` feature.
337    ///
338    /// Data persists across restarts and is shared between instances.
339    #[cfg(feature = "cache-redis")]
340    pub fn redis(redis_url: &str) -> Result<Self, CacheError> {
341        let driver = redis_driver::RedisDriver::new(redis_url)?;
342        Ok(Self {
343            driver: Arc::new(Box::new(driver)),
344        })
345    }
346
347    /// Create a cache from any custom driver implementing `CacheDriver`.
348    pub fn custom(driver: Box<dyn CacheDriver>) -> Self {
349        Self {
350            driver: Arc::new(driver),
351        }
352    }
353
354    /// Retrieve a value by key.
355    pub async fn get(&self, key: &str) -> Result<Option<Arc<String>>, CacheError> {
356        self.driver.get(key).await
357    }
358
359    /// Store a value with an optional TTL in seconds.
360    ///
361    /// Pass `None` for TTL to store indefinitely.
362    pub async fn put(
363        &self,
364        key: &str,
365        value: &str,
366        ttl_secs: Option<u64>,
367    ) -> Result<(), CacheError> {
368        self.driver.put(key, value, ttl_secs).await
369    }
370
371    /// Remove a key from the cache.
372    pub async fn forget(&self, key: &str) -> Result<(), CacheError> {
373        self.driver.forget(key).await
374    }
375
376    /// Remove all keys from the cache.
377    pub async fn flush(&self) -> Result<(), CacheError> {
378        self.driver.flush().await
379    }
380
381    /// Check if a key exists and has not expired.
382    pub async fn has(&self, key: &str) -> Result<bool, CacheError> {
383        self.driver.has(key).await
384    }
385
386    /// Retrieve a cached value, or compute it with the provided closure and cache the result.
387    ///
388    /// This is the **cache-aside** (or "remember") pattern — the most common caching strategy.
389    ///
390    /// # Example
391    /// ```rust,ignore
392    /// let bio = cache.remember("user:42:bio", 300, || async {
393    ///     let user = User::find(42).await.map_err(|e| CacheError::Driver(e.to_string()))?;
394    ///     Ok(user.bio)
395    /// }).await?;
396    /// ```
397    pub async fn remember<F, Fut>(
398        &self,
399        key: &str,
400        ttl_secs: u64,
401        f: F,
402    ) -> Result<Arc<String>, CacheError>
403    where
404        F: FnOnce() -> Fut,
405        Fut: std::future::Future<Output = Result<String, CacheError>>,
406    {
407        // Try the cache first
408        if let Some(cached) = self.get(key).await? {
409            return Ok(cached);
410        }
411        // Cache miss — compute the value
412        let value = f().await?;
413        // Store in cache
414        self.put(key, &value, Some(ttl_secs)).await?;
415        Ok(Arc::new(value))
416    }
417}
418
419// ─── Tests ──────────────────────────────────────────────────────────────────
420
421#[cfg(test)]
422#[allow(clippy::unwrap_used, clippy::expect_used)]
423mod tests {
424    use super::*;
425
426    #[tokio::test]
427    async fn test_memory_cache_put_get() {
428        let cache = Cache::memory();
429        cache.put("key1", "value1", None).await.unwrap();
430        let result = cache.get("key1").await.unwrap();
431        assert_eq!(result, Some(Arc::new("value1".to_string())));
432    }
433
434    #[tokio::test]
435    async fn test_memory_cache_miss() {
436        let cache = Cache::memory();
437        let result = cache.get("nonexistent").await.unwrap();
438        assert!(result.is_none());
439    }
440
441    #[tokio::test]
442    async fn test_memory_cache_forget() {
443        let cache = Cache::memory();
444        cache.put("key1", "value1", None).await.unwrap();
445        cache.forget("key1").await.unwrap();
446        let result = cache.get("key1").await.unwrap();
447        assert!(result.is_none());
448    }
449
450    #[tokio::test]
451    async fn test_memory_cache_flush() {
452        let cache = Cache::memory();
453        cache.put("a", "1", None).await.unwrap();
454        cache.put("b", "2", None).await.unwrap();
455        cache.flush().await.unwrap();
456        assert!(cache.get("a").await.unwrap().is_none());
457        assert!(cache.get("b").await.unwrap().is_none());
458    }
459
460    #[tokio::test]
461    async fn test_memory_cache_has() {
462        let cache = Cache::memory();
463        assert!(!cache.has("key1").await.unwrap());
464        cache.put("key1", "value1", None).await.unwrap();
465        assert!(cache.has("key1").await.unwrap());
466    }
467
468    #[tokio::test]
469    async fn test_memory_cache_remember_miss() {
470        let cache = Cache::memory();
471        let value = cache
472            .remember("computed", 60, || async { Ok("hello".to_string()) })
473            .await
474            .unwrap();
475        assert_eq!(*value, "hello");
476        // Should be cached now
477        let cached = cache.get("computed").await.unwrap();
478        assert_eq!(cached, Some(Arc::new("hello".to_string())));
479    }
480
481    #[tokio::test]
482    async fn test_memory_cache_remember_hit() {
483        let cache = Cache::memory();
484        cache
485            .put("existing", "already_cached", Some(300))
486            .await
487            .unwrap();
488        let value = cache
489            .remember("existing", 60, || async {
490                panic!("This closure should NOT be called on cache hit");
491            })
492            .await
493            .unwrap();
494        assert_eq!(*value, "already_cached");
495    }
496
497    #[tokio::test]
498    async fn test_memory_cache_overwrite() {
499        let cache = Cache::memory();
500        cache.put("key", "v1", None).await.unwrap();
501        cache.put("key", "v2", None).await.unwrap();
502        assert_eq!(
503            cache.get("key").await.unwrap(),
504            Some(Arc::new("v2".to_string()))
505        );
506    }
507
508    struct MockDriver;
509    #[async_trait]
510    impl CacheDriver for MockDriver {
511        async fn get(&self, _key: &str) -> Result<Option<Arc<String>>, CacheError> {
512            Ok(Some(Arc::new("mocked".to_string())))
513        }
514        async fn put(&self, _k: &str, _v: &str, _t: Option<u64>) -> Result<(), CacheError> {
515            Ok(())
516        }
517        async fn forget(&self, _k: &str) -> Result<(), CacheError> {
518            Ok(())
519        }
520        async fn flush(&self) -> Result<(), CacheError> {
521            Ok(())
522        }
523        async fn has(&self, _k: &str) -> Result<bool, CacheError> {
524            Ok(true)
525        }
526    }
527
528    #[tokio::test]
529    async fn test_custom_cache_driver() {
530        let cache = Cache::custom(Box::new(MockDriver));
531        let result = cache.get("anything").await.unwrap();
532        assert_eq!(result, Some(Arc::new("mocked".to_string())));
533    }
534
535    #[cfg(feature = "cache-redis")]
536    #[test]
537    fn test_redis_cache_initialization() {
538        // Just verify that the constructor exists and returns a Result
539        // We use an invalid URL so it fails parsing the connection string
540        let result = Cache::redis("invalid-url-format://host:9999");
541        assert!(result.is_err());
542    }
543}