Skip to main content

prax_query/data_cache/
memory.rs

1//! High-performance in-memory cache using moka.
2//!
3//! This module provides an in-memory cache implementation using the [moka](https://github.com/moka-rs/moka)
4//! crate, which is a fast, concurrent cache inspired by Caffeine (Java).
5//!
6//! # Features
7//!
8//! - **High concurrency**: Lock-free reads, fine-grained locking for writes
9//! - **Automatic eviction**: LRU-based eviction when capacity is reached
10//! - **TTL support**: Per-entry time-to-live
11//! - **Size-based limits**: Limit by entry count or memory usage
12//! - **Async-friendly**: Works seamlessly with async runtimes
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use prax_query::data_cache::memory::{MemoryCache, MemoryCacheConfig};
18//! use std::time::Duration;
19//!
20//! let cache = MemoryCache::builder()
21//!     .max_capacity(10_000)
22//!     .time_to_live(Duration::from_secs(300))
23//!     .time_to_idle(Duration::from_secs(60))
24//!     .build();
25//!
26//! // Use with CacheManager
27//! let manager = CacheManager::new(cache);
28//! ```
29
30use parking_lot::RwLock;
31use std::collections::{HashMap, HashSet};
32use std::sync::atomic::{AtomicUsize, Ordering};
33use std::time::{Duration, Instant};
34
35use super::backend::{BackendStats, CacheBackend, CacheError, CacheResult};
36use super::invalidation::EntityTag;
37use super::key::{CacheKey, KeyPattern};
38
39/// Configuration for the in-memory cache.
40#[derive(Debug, Clone)]
41pub struct MemoryCacheConfig {
42    /// Maximum number of entries.
43    pub max_capacity: u64,
44    /// Default time-to-live for entries.
45    pub time_to_live: Option<Duration>,
46    /// Time-to-idle (evict if not accessed).
47    pub time_to_idle: Option<Duration>,
48    /// Enable entry-level TTL tracking.
49    pub per_entry_ttl: bool,
50    /// Enable tag-based invalidation.
51    pub enable_tags: bool,
52}
53
54impl Default for MemoryCacheConfig {
55    fn default() -> Self {
56        Self {
57            max_capacity: 10_000,
58            time_to_live: Some(Duration::from_secs(300)),
59            time_to_idle: None,
60            per_entry_ttl: true,
61            enable_tags: true,
62        }
63    }
64}
65
66impl MemoryCacheConfig {
67    /// Create a new config with the given capacity.
68    pub fn new(max_capacity: u64) -> Self {
69        Self {
70            max_capacity,
71            ..Default::default()
72        }
73    }
74
75    /// Set the default TTL.
76    pub fn with_ttl(mut self, ttl: Duration) -> Self {
77        self.time_to_live = Some(ttl);
78        self
79    }
80
81    /// Set the time-to-idle.
82    pub fn with_tti(mut self, tti: Duration) -> Self {
83        self.time_to_idle = Some(tti);
84        self
85    }
86
87    /// Disable tags.
88    pub fn without_tags(mut self) -> Self {
89        self.enable_tags = false;
90        self
91    }
92}
93
94/// Builder for MemoryCache.
95#[derive(Default)]
96pub struct MemoryCacheBuilder {
97    config: MemoryCacheConfig,
98}
99
100impl MemoryCacheBuilder {
101    /// Create a new builder.
102    pub fn new() -> Self {
103        Self::default()
104    }
105
106    /// Set max capacity.
107    pub fn max_capacity(mut self, capacity: u64) -> Self {
108        self.config.max_capacity = capacity;
109        self
110    }
111
112    /// Set TTL.
113    pub fn time_to_live(mut self, ttl: Duration) -> Self {
114        self.config.time_to_live = Some(ttl);
115        self
116    }
117
118    /// Set TTI.
119    pub fn time_to_idle(mut self, tti: Duration) -> Self {
120        self.config.time_to_idle = Some(tti);
121        self
122    }
123
124    /// Enable per-entry TTL.
125    pub fn per_entry_ttl(mut self, enabled: bool) -> Self {
126        self.config.per_entry_ttl = enabled;
127        self
128    }
129
130    /// Enable tags.
131    pub fn enable_tags(mut self, enabled: bool) -> Self {
132        self.config.enable_tags = enabled;
133        self
134    }
135
136    /// Build the cache.
137    pub fn build(self) -> MemoryCache {
138        MemoryCache::new(self.config)
139    }
140}
141
142/// A cached entry with metadata.
143#[derive(Clone)]
144struct CacheEntry {
145    /// Serialized value.
146    data: Vec<u8>,
147    /// When the entry was created.
148    created_at: Instant,
149    /// When the entry expires (if TTL set).
150    expires_at: Option<Instant>,
151    /// Last access time.
152    last_accessed: Instant,
153    /// Associated tags.
154    tags: Vec<EntityTag>,
155}
156
157impl CacheEntry {
158    fn new(data: Vec<u8>, ttl: Option<Duration>, tags: Vec<EntityTag>) -> Self {
159        let now = Instant::now();
160        Self {
161            data,
162            created_at: now,
163            expires_at: ttl.map(|d| now + d),
164            last_accessed: now,
165            tags,
166        }
167    }
168
169    fn is_expired(&self) -> bool {
170        self.expires_at.map_or(false, |exp| Instant::now() >= exp)
171    }
172
173    fn touch(&mut self) {
174        self.last_accessed = Instant::now();
175    }
176}
177
178/// High-performance in-memory cache.
179///
180/// Uses a concurrent HashMap with LRU eviction and TTL support.
181pub struct MemoryCache {
182    config: MemoryCacheConfig,
183    entries: RwLock<HashMap<String, CacheEntry>>,
184    tag_index: RwLock<HashMap<String, HashSet<String>>>,
185    entry_count: AtomicUsize,
186}
187
188impl MemoryCache {
189    /// Create a new memory cache with the given config.
190    pub fn new(config: MemoryCacheConfig) -> Self {
191        Self {
192            entries: RwLock::new(HashMap::with_capacity(config.max_capacity as usize)),
193            tag_index: RwLock::new(HashMap::new()),
194            entry_count: AtomicUsize::new(0),
195            config,
196        }
197    }
198
199    /// Create a builder.
200    pub fn builder() -> MemoryCacheBuilder {
201        MemoryCacheBuilder::new()
202    }
203
204    /// Get the config.
205    pub fn config(&self) -> &MemoryCacheConfig {
206        &self.config
207    }
208
209    /// Evict expired entries.
210    pub fn evict_expired(&self) -> usize {
211        let mut entries = self.entries.write();
212        let before = entries.len();
213
214        let expired_keys: Vec<String> = entries
215            .iter()
216            .filter(|(_, e)| e.is_expired())
217            .map(|(k, _)| k.clone())
218            .collect();
219
220        for key in &expired_keys {
221            if let Some(entry) = entries.remove(key) {
222                self.remove_from_tag_index(key, &entry.tags);
223            }
224        }
225
226        let evicted = before - entries.len();
227        self.entry_count.fetch_sub(evicted, Ordering::Relaxed);
228        evicted
229    }
230
231    /// Evict entries to make room (LRU).
232    fn evict_lru(&self, count: usize) {
233        let mut entries = self.entries.write();
234
235        // Find LRU entries
236        let mut by_access: Vec<_> = entries
237            .iter()
238            .map(|(k, e)| (k.clone(), e.last_accessed))
239            .collect();
240        by_access.sort_by_key(|(_, t)| *t);
241
242        for (key, _) in by_access.into_iter().take(count) {
243            if let Some(entry) = entries.remove(&key) {
244                self.remove_from_tag_index(&key, &entry.tags);
245            }
246        }
247
248        self.entry_count.store(entries.len(), Ordering::Relaxed);
249    }
250
251    /// Add entry to tag index.
252    fn add_to_tag_index(&self, key: &str, tags: &[EntityTag]) {
253        if !self.config.enable_tags || tags.is_empty() {
254            return;
255        }
256
257        let mut index = self.tag_index.write();
258        for tag in tags {
259            index
260                .entry(tag.value().to_string())
261                .or_default()
262                .insert(key.to_string());
263        }
264    }
265
266    /// Remove entry from tag index.
267    fn remove_from_tag_index(&self, key: &str, tags: &[EntityTag]) {
268        if !self.config.enable_tags || tags.is_empty() {
269            return;
270        }
271
272        let mut index = self.tag_index.write();
273        for tag in tags {
274            if let Some(keys) = index.get_mut(tag.value()) {
275                keys.remove(key);
276                if keys.is_empty() {
277                    index.remove(tag.value());
278                }
279            }
280        }
281    }
282}
283
284impl CacheBackend for MemoryCache {
285    async fn get<T>(&self, key: &CacheKey) -> CacheResult<Option<T>>
286    where
287        T: serde::de::DeserializeOwned,
288    {
289        let key_str = key.as_str();
290
291        // Try to get with read lock first
292        {
293            let entries = self.entries.read();
294            if let Some(entry) = entries.get(&key_str) {
295                if entry.is_expired() {
296                    // Entry expired, will be cleaned up later
297                    return Ok(None);
298                }
299
300                // Deserialize
301                let value: T = serde_json::from_slice(&entry.data)
302                    .map_err(|e| CacheError::Deserialization(e.to_string()))?;
303
304                return Ok(Some(value));
305            }
306        }
307
308        // Update last_accessed with write lock
309        {
310            let mut entries = self.entries.write();
311            if let Some(entry) = entries.get_mut(&key_str) {
312                entry.touch();
313            }
314        }
315
316        Ok(None)
317    }
318
319    async fn set<T>(&self, key: &CacheKey, value: &T, ttl: Option<Duration>) -> CacheResult<()>
320    where
321        T: serde::Serialize + Sync,
322    {
323        let key_str = key.as_str();
324
325        // Serialize
326        let data =
327            serde_json::to_vec(value).map_err(|e| CacheError::Serialization(e.to_string()))?;
328
329        let effective_ttl = ttl.or(self.config.time_to_live);
330        let entry = CacheEntry::new(data, effective_ttl, Vec::new());
331
332        // Check capacity
333        let current = self.entry_count.load(Ordering::Relaxed);
334        if current >= self.config.max_capacity as usize {
335            // Evict some entries
336            self.evict_expired();
337            let still_over = self.entry_count.load(Ordering::Relaxed);
338            if still_over >= self.config.max_capacity as usize {
339                self.evict_lru((self.config.max_capacity as usize / 10).max(1));
340            }
341        }
342
343        // Insert
344        {
345            let mut entries = self.entries.write();
346            let is_new = !entries.contains_key(&key_str);
347            entries.insert(key_str.clone(), entry);
348            if is_new {
349                self.entry_count.fetch_add(1, Ordering::Relaxed);
350            }
351        }
352
353        Ok(())
354    }
355
356    async fn delete(&self, key: &CacheKey) -> CacheResult<bool> {
357        let key_str = key.as_str();
358
359        let mut entries = self.entries.write();
360        if let Some(entry) = entries.remove(&key_str) {
361            self.remove_from_tag_index(&key_str, &entry.tags);
362            self.entry_count.fetch_sub(1, Ordering::Relaxed);
363            Ok(true)
364        } else {
365            Ok(false)
366        }
367    }
368
369    async fn exists(&self, key: &CacheKey) -> CacheResult<bool> {
370        let key_str = key.as_str();
371
372        let entries = self.entries.read();
373        if let Some(entry) = entries.get(&key_str) {
374            Ok(!entry.is_expired())
375        } else {
376            Ok(false)
377        }
378    }
379
380    async fn invalidate_pattern(&self, pattern: &KeyPattern) -> CacheResult<u64> {
381        let mut entries = self.entries.write();
382        let before = entries.len();
383
384        let matching_keys: Vec<String> = entries
385            .keys()
386            .filter(|k| pattern.matches_str(k))
387            .cloned()
388            .collect();
389
390        for key in &matching_keys {
391            if let Some(entry) = entries.remove(key) {
392                self.remove_from_tag_index(key, &entry.tags);
393            }
394        }
395
396        let removed = before - entries.len();
397        self.entry_count.fetch_sub(removed, Ordering::Relaxed);
398        Ok(removed as u64)
399    }
400
401    async fn invalidate_tags(&self, tags: &[EntityTag]) -> CacheResult<u64> {
402        if !self.config.enable_tags {
403            return Ok(0);
404        }
405
406        let keys_to_remove: HashSet<String> = {
407            let index = self.tag_index.read();
408            tags.iter()
409                .filter_map(|tag| index.get(tag.value()))
410                .flatten()
411                .cloned()
412                .collect()
413        };
414
415        let mut entries = self.entries.write();
416        let mut removed = 0u64;
417
418        for key in keys_to_remove {
419            if let Some(entry) = entries.remove(&key) {
420                self.remove_from_tag_index(&key, &entry.tags);
421                removed += 1;
422            }
423        }
424
425        self.entry_count
426            .fetch_sub(removed as usize, Ordering::Relaxed);
427        Ok(removed)
428    }
429
430    async fn clear(&self) -> CacheResult<()> {
431        let mut entries = self.entries.write();
432        entries.clear();
433        self.entry_count.store(0, Ordering::Relaxed);
434
435        if self.config.enable_tags {
436            let mut index = self.tag_index.write();
437            index.clear();
438        }
439
440        Ok(())
441    }
442
443    async fn len(&self) -> CacheResult<usize> {
444        Ok(self.entry_count.load(Ordering::Relaxed))
445    }
446
447    async fn stats(&self) -> CacheResult<BackendStats> {
448        let entries = self.entries.read();
449        let memory_estimate: usize = entries
450            .values()
451            .map(|e| e.data.len() + 64) // Data + overhead estimate
452            .sum();
453
454        Ok(BackendStats {
455            entries: entries.len(),
456            memory_bytes: Some(memory_estimate),
457            connections: None,
458            info: Some(format!("MemoryCache (max: {})", self.config.max_capacity)),
459        })
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[tokio::test]
468    async fn test_memory_cache_basic() {
469        let cache = MemoryCache::new(MemoryCacheConfig::new(100));
470
471        let key = CacheKey::new("test", "key1");
472
473        // Set
474        cache.set(&key, &"hello", None).await.unwrap();
475
476        // Get
477        let value: Option<String> = cache.get(&key).await.unwrap();
478        assert_eq!(value, Some("hello".to_string()));
479
480        // Delete
481        assert!(cache.delete(&key).await.unwrap());
482
483        // Should be gone
484        let value: Option<String> = cache.get(&key).await.unwrap();
485        assert!(value.is_none());
486    }
487
488    #[tokio::test]
489    async fn test_memory_cache_ttl() {
490        let config = MemoryCacheConfig::new(100).with_ttl(Duration::from_millis(50));
491        let cache = MemoryCache::new(config);
492
493        let key = CacheKey::new("test", "ttl");
494        cache.set(&key, &"expires soon", None).await.unwrap();
495
496        // Should exist
497        let value: Option<String> = cache.get(&key).await.unwrap();
498        assert!(value.is_some());
499
500        // Wait for expiration
501        tokio::time::sleep(Duration::from_millis(60)).await;
502
503        // Should be expired
504        let value: Option<String> = cache.get(&key).await.unwrap();
505        assert!(value.is_none());
506    }
507
508    #[tokio::test]
509    async fn test_memory_cache_eviction() {
510        let cache = MemoryCache::new(MemoryCacheConfig::new(5));
511
512        // Fill cache
513        for i in 0..10 {
514            let key = CacheKey::new("test", format!("key{}", i));
515            cache.set(&key, &i, None).await.unwrap();
516        }
517
518        // Should have evicted some
519        let len = cache.len().await.unwrap();
520        assert!(len <= 5);
521    }
522
523    #[tokio::test]
524    async fn test_memory_cache_pattern_invalidation() {
525        let cache = MemoryCache::new(MemoryCacheConfig::new(100));
526
527        // Add some entries
528        for i in 0..5 {
529            let key = CacheKey::new("User", format!("id:{}", i));
530            cache.set(&key, &i, None).await.unwrap();
531        }
532        for i in 0..3 {
533            let key = CacheKey::new("Post", format!("id:{}", i));
534            cache.set(&key, &i, None).await.unwrap();
535        }
536
537        assert_eq!(cache.len().await.unwrap(), 8);
538
539        // Invalidate all User entries
540        let removed = cache
541            .invalidate_pattern(&KeyPattern::entity("User"))
542            .await
543            .unwrap();
544        assert_eq!(removed, 5);
545        assert_eq!(cache.len().await.unwrap(), 3);
546    }
547
548    #[tokio::test]
549    async fn test_memory_cache_builder() {
550        let cache = MemoryCache::builder()
551            .max_capacity(1000)
552            .time_to_live(Duration::from_secs(60))
553            .build();
554
555        assert_eq!(cache.config().max_capacity, 1000);
556        assert_eq!(cache.config().time_to_live, Some(Duration::from_secs(60)));
557    }
558}