Skip to main content

juncture_checkpoint/
cache.rs

1//! Caching layer for checkpoint storage
2//!
3//! Provides in-memory caching with LRU eviction and TTL support.
4
5use crate::error::CheckpointError;
6use lru::LruCache;
7use std::num::NonZeroUsize;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11
12/// Cache entry with optional expiration
13struct CacheEntry {
14    /// Cached data
15    data: Vec<u8>,
16
17    /// Expiration timestamp (None = no expiration)
18    expires_at: Option<std::time::Instant>,
19}
20
21impl CacheEntry {
22    /// Check if the entry has expired
23    #[must_use]
24    fn is_expired(&self) -> bool {
25        self.expires_at
26            .is_some_and(|expires_at| std::time::Instant::now() >= expires_at)
27    }
28}
29
30/// Base cache trait for checkpoint caching
31///
32/// Defines the interface for caching checkpoint data to reduce storage load.
33#[async_trait::async_trait]
34pub trait BaseCache: Send + Sync + 'static {
35    /// Get cached data
36    ///
37    /// # Errors
38    ///
39    /// Returns [`CheckpointError::Storage`] if retrieval fails.
40    async fn get(&self, namespace: &str, key: &str) -> Result<Option<Vec<u8>>, CheckpointError>;
41
42    /// Set cached data with optional TTL
43    ///
44    /// # Errors
45    ///
46    /// Returns [`CheckpointError::Storage`] if storage fails.
47    async fn set(
48        &self,
49        namespace: &str,
50        key: &str,
51        value: Vec<u8>,
52        ttl: Option<Duration>,
53    ) -> Result<(), CheckpointError>;
54
55    /// Delete cached data
56    ///
57    /// # Errors
58    ///
59    /// Returns [`CheckpointError::Storage`] if deletion fails.
60    async fn delete(&self, namespace: &str, key: &str) -> Result<(), CheckpointError>;
61
62    /// Clear cache (optionally by namespace)
63    ///
64    /// # Errors
65    ///
66    /// Returns [`CheckpointError::Storage`] if clearing fails.
67    async fn clear(&self, namespace: Option<&str>) -> Result<(), CheckpointError>;
68}
69
70/// In-memory LRU cache with TTL support
71///
72/// Thread-safe in-memory cache using LRU eviction policy.
73/// Suitable for single-process deployments and development environments.
74#[derive(Clone, Debug)]
75pub struct MemoryCache {
76    /// LRU cache storage (namespace:key -> entry)
77    entries: Arc<RwLock<LruCache<String, CacheEntry>>>,
78
79    /// Default TTL for new entries
80    default_ttl: Option<Duration>,
81}
82
83impl MemoryCache {
84    /// Create a new in-memory cache
85    ///
86    /// # Panics
87    ///
88    /// Panics if capacity is zero.
89    #[must_use]
90    pub fn new(capacity: usize) -> Self {
91        Self {
92            entries: Arc::new(RwLock::new(LruCache::new(
93                NonZeroUsize::new(capacity).expect("capacity must be non-zero"),
94            ))),
95            default_ttl: None,
96        }
97    }
98
99    /// Create a new cache with default TTL
100    ///
101    /// # Panics
102    ///
103    /// Panics if capacity is zero.
104    #[must_use]
105    pub fn with_ttl(capacity: usize, default_ttl: Duration) -> Self {
106        Self {
107            entries: Arc::new(RwLock::new(LruCache::new(
108                NonZeroUsize::new(capacity).expect("capacity must be non-zero"),
109            ))),
110            default_ttl: Some(default_ttl),
111        }
112    }
113
114    /// Build a cache key from namespace and key
115    #[must_use]
116    fn build_key(namespace: &str, key: &str) -> String {
117        format!("{namespace}:{key}")
118    }
119
120    /// Remove expired entries
121    ///
122    /// This is called automatically during get/set operations,
123    /// but can be invoked manually for cleanup.
124    async fn purge_expired(&self) {
125        let mut cache = self.entries.write().await;
126        let expired_keys: Vec<String> = cache
127            .iter()
128            .filter(|(_, entry)| entry.is_expired())
129            .map(|(key, _)| key.clone())
130            .collect();
131
132        for key in expired_keys {
133            cache.pop(&key);
134        }
135    }
136
137    /// Get cache statistics
138    ///
139    /// Returns (`current_size`, capacity).
140    pub async fn stats(&self) -> (usize, usize) {
141        let cache = self.entries.read().await;
142        (cache.len(), cache.cap().get())
143    }
144}
145
146impl Default for MemoryCache {
147    fn default() -> Self {
148        Self::new(1000)
149    }
150}
151
152#[async_trait::async_trait]
153impl BaseCache for MemoryCache {
154    async fn get(&self, namespace: &str, key: &str) -> Result<Option<Vec<u8>>, CheckpointError> {
155        // Periodic cleanup of expired entries
156        self.purge_expired().await;
157
158        let cache_key = Self::build_key(namespace, key);
159        {
160            let mut cache = self.entries.write().await;
161
162            if let Some(entry) = cache.get_mut(&cache_key) {
163                if entry.is_expired() {
164                    cache.pop(&cache_key);
165                    drop(cache);
166                    return Ok(None);
167                }
168                let result = Ok(Some(entry.data.clone()));
169                drop(cache);
170                return result;
171            }
172        }
173
174        Ok(None)
175    }
176
177    async fn set(
178        &self,
179        namespace: &str,
180        key: &str,
181        value: Vec<u8>,
182        ttl: Option<Duration>,
183    ) -> Result<(), CheckpointError> {
184        let cache_key = Self::build_key(namespace, key);
185        let ttl = ttl.or(self.default_ttl);
186
187        let entry = CacheEntry {
188            data: value,
189            expires_at: ttl.map(|duration| std::time::Instant::now() + duration),
190        };
191
192        self.entries.write().await.put(cache_key, entry);
193
194        Ok(())
195    }
196
197    async fn delete(&self, namespace: &str, key: &str) -> Result<(), CheckpointError> {
198        let cache_key = Self::build_key(namespace, key);
199        self.entries.write().await.pop(&cache_key);
200        Ok(())
201    }
202
203    async fn clear(&self, namespace: Option<&str>) -> Result<(), CheckpointError> {
204        if let Some(ns) = namespace {
205            // Clear all keys in the namespace
206            let prefix = format!("{ns}:");
207            let mut cache = self.entries.write().await;
208            let keys_to_remove: Vec<String> = cache
209                .iter()
210                .filter(|(key, _)| key.starts_with(&prefix))
211                .map(|(key, _)| key.clone())
212                .collect();
213
214            for key in keys_to_remove {
215                cache.pop(&key);
216            }
217        } else {
218            // Clear all entries
219            self.entries.write().await.clear();
220        }
221
222        Ok(())
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[tokio::test]
231    async fn test_memory_cache_set_get() {
232        let cache = MemoryCache::new(10);
233
234        cache
235            .set("ns1", "key1", b"hello".to_vec(), None)
236            .await
237            .unwrap();
238
239        let value = cache.get("ns1", "key1").await.unwrap();
240        assert_eq!(value, Some(b"hello".to_vec()));
241    }
242
243    #[tokio::test]
244    async fn test_memory_cache_miss() {
245        let cache = MemoryCache::new(10);
246
247        let value = cache.get("ns1", "nonexistent").await.unwrap();
248        assert!(value.is_none());
249    }
250
251    #[tokio::test]
252    async fn test_memory_cache_delete() {
253        let cache = MemoryCache::new(10);
254
255        cache
256            .set("ns1", "key1", b"hello".to_vec(), None)
257            .await
258            .unwrap();
259
260        cache.delete("ns1", "key1").await.unwrap();
261
262        let value = cache.get("ns1", "key1").await.unwrap();
263        assert!(value.is_none());
264    }
265
266    #[tokio::test]
267    async fn test_memory_cache_ttl() {
268        let cache = MemoryCache::with_ttl(10, Duration::from_millis(100));
269
270        cache
271            .set("ns1", "key1", b"hello".to_vec(), None)
272            .await
273            .unwrap();
274
275        // Should be present immediately
276        let value = cache.get("ns1", "key1").await.unwrap();
277        assert_eq!(value, Some(b"hello".to_vec()));
278
279        // Wait for expiration
280        tokio::time::sleep(Duration::from_millis(150)).await;
281
282        // Should be expired
283        let value = cache.get("ns1", "key1").await.unwrap();
284        assert!(value.is_none());
285    }
286
287    #[tokio::test]
288    async fn test_memory_cache_clear_namespace() {
289        let cache = MemoryCache::new(10);
290
291        cache
292            .set("ns1", "key1", b"data1".to_vec(), None)
293            .await
294            .unwrap();
295        cache
296            .set("ns2", "key2", b"data2".to_vec(), None)
297            .await
298            .unwrap();
299
300        cache.clear(Some("ns1")).await.unwrap();
301
302        assert!(cache.get("ns1", "key1").await.unwrap().is_none());
303        assert_eq!(
304            cache.get("ns2", "key2").await.unwrap(),
305            Some(b"data2".to_vec())
306        );
307    }
308
309    #[tokio::test]
310    async fn test_memory_cache_clear_all() {
311        let cache = MemoryCache::new(10);
312
313        cache
314            .set("ns1", "key1", b"data1".to_vec(), None)
315            .await
316            .unwrap();
317        cache
318            .set("ns2", "key2", b"data2".to_vec(), None)
319            .await
320            .unwrap();
321
322        cache.clear(None).await.unwrap();
323
324        assert!(cache.get("ns1", "key1").await.unwrap().is_none());
325        assert!(cache.get("ns2", "key2").await.unwrap().is_none());
326    }
327
328    #[tokio::test]
329    async fn test_memory_cache_lru_eviction() {
330        let cache = MemoryCache::new(2);
331
332        cache
333            .set("ns1", "key1", b"data1".to_vec(), None)
334            .await
335            .unwrap();
336        cache
337            .set("ns1", "key2", b"data2".to_vec(), None)
338            .await
339            .unwrap();
340
341        // Access key1 to make it more recently used
342        cache.get("ns1", "key1").await.unwrap();
343
344        // Add key3, should evict key2 (least recently used)
345        cache
346            .set("ns1", "key3", b"data3".to_vec(), None)
347            .await
348            .unwrap();
349
350        assert_eq!(
351            cache.get("ns1", "key1").await.unwrap(),
352            Some(b"data1".to_vec())
353        );
354        assert!(cache.get("ns1", "key2").await.unwrap().is_none());
355        assert_eq!(
356            cache.get("ns1", "key3").await.unwrap(),
357            Some(b"data3".to_vec())
358        );
359    }
360
361    #[tokio::test]
362    async fn test_memory_cache_stats() {
363        let cache = MemoryCache::new(100);
364
365        cache
366            .set("ns1", "key1", b"data1".to_vec(), None)
367            .await
368            .unwrap();
369        cache
370            .set("ns1", "key2", b"data2".to_vec(), None)
371            .await
372            .unwrap();
373
374        let (size, capacity) = cache.stats().await;
375        assert_eq!(size, 2);
376        assert_eq!(capacity, 100);
377    }
378}
379
380// Rust guideline compliant 2026-05-19