prax_query/data_cache/
memory.rs

1//! High-performance in-memory cache using moka.
2//!
3//! This module provides an in-memory cache implementation using the [moka](https://github.com/moka-rs/moka)
4//! crate, which is a fast, concurrent cache inspired by Caffeine (Java).
5//!
6//! # Features
7//!
8//! - **High concurrency**: Lock-free reads, fine-grained locking for writes
9//! - **Automatic eviction**: LRU-based eviction when capacity is reached
10//! - **TTL support**: Per-entry time-to-live
11//! - **Size-based limits**: Limit by entry count or memory usage
12//! - **Async-friendly**: Works seamlessly with async runtimes
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use prax_query::data_cache::memory::{MemoryCache, MemoryCacheConfig};
18//! use std::time::Duration;
19//!
20//! let cache = MemoryCache::builder()
21//!     .max_capacity(10_000)
22//!     .time_to_live(Duration::from_secs(300))
23//!     .time_to_idle(Duration::from_secs(60))
24//!     .build();
25//!
26//! // Use with CacheManager
27//! let manager = CacheManager::new(cache);
28//! ```
29
30use parking_lot::RwLock;
31use std::collections::{HashMap, HashSet};
32use std::sync::atomic::{AtomicUsize, Ordering};
33use std::time::{Duration, Instant};
34
35use super::backend::{BackendStats, CacheBackend, CacheError, CacheResult};
36use super::invalidation::EntityTag;
37use super::key::{CacheKey, KeyPattern};
38
39/// Configuration for the in-memory cache.
40#[derive(Debug, Clone)]
41pub struct MemoryCacheConfig {
42    /// Maximum number of entries.
43    pub max_capacity: u64,
44    /// Default time-to-live for entries.
45    pub time_to_live: Option<Duration>,
46    /// Time-to-idle (evict if not accessed).
47    pub time_to_idle: Option<Duration>,
48    /// Enable entry-level TTL tracking.
49    pub per_entry_ttl: bool,
50    /// Enable tag-based invalidation.
51    pub enable_tags: bool,
52}
53
54impl Default for MemoryCacheConfig {
55    fn default() -> Self {
56        Self {
57            max_capacity: 10_000,
58            time_to_live: Some(Duration::from_secs(300)),
59            time_to_idle: None,
60            per_entry_ttl: true,
61            enable_tags: true,
62        }
63    }
64}
65
66impl MemoryCacheConfig {
67    /// Create a new config with the given capacity.
68    pub fn new(max_capacity: u64) -> Self {
69        Self {
70            max_capacity,
71            ..Default::default()
72        }
73    }
74
75    /// Set the default TTL.
76    pub fn with_ttl(mut self, ttl: Duration) -> Self {
77        self.time_to_live = Some(ttl);
78        self
79    }
80
81    /// Set the time-to-idle.
82    pub fn with_tti(mut self, tti: Duration) -> Self {
83        self.time_to_idle = Some(tti);
84        self
85    }
86
87    /// Disable tags.
88    pub fn without_tags(mut self) -> Self {
89        self.enable_tags = false;
90        self
91    }
92}
93
94/// Builder for MemoryCache.
95#[derive(Default)]
96pub struct MemoryCacheBuilder {
97    config: MemoryCacheConfig,
98}
99
100impl MemoryCacheBuilder {
101    /// Create a new builder.
102    pub fn new() -> Self {
103        Self::default()
104    }
105
106    /// Set max capacity.
107    pub fn max_capacity(mut self, capacity: u64) -> Self {
108        self.config.max_capacity = capacity;
109        self
110    }
111
112    /// Set TTL.
113    pub fn time_to_live(mut self, ttl: Duration) -> Self {
114        self.config.time_to_live = Some(ttl);
115        self
116    }
117
118    /// Set TTI.
119    pub fn time_to_idle(mut self, tti: Duration) -> Self {
120        self.config.time_to_idle = Some(tti);
121        self
122    }
123
124    /// Enable per-entry TTL.
125    pub fn per_entry_ttl(mut self, enabled: bool) -> Self {
126        self.config.per_entry_ttl = enabled;
127        self
128    }
129
130    /// Enable tags.
131    pub fn enable_tags(mut self, enabled: bool) -> Self {
132        self.config.enable_tags = enabled;
133        self
134    }
135
136    /// Build the cache.
137    pub fn build(self) -> MemoryCache {
138        MemoryCache::new(self.config)
139    }
140}
141
142/// A cached entry with metadata.
143#[derive(Clone)]
144struct CacheEntry {
145    /// Serialized value.
146    data: Vec<u8>,
147    /// When the entry was created.
148    created_at: Instant,
149    /// When the entry expires (if TTL set).
150    expires_at: Option<Instant>,
151    /// Last access time.
152    last_accessed: Instant,
153    /// Associated tags.
154    tags: Vec<EntityTag>,
155}
156
157impl CacheEntry {
158    fn new(data: Vec<u8>, ttl: Option<Duration>, tags: Vec<EntityTag>) -> Self {
159        let now = Instant::now();
160        Self {
161            data,
162            created_at: now,
163            expires_at: ttl.map(|d| now + d),
164            last_accessed: now,
165            tags,
166        }
167    }
168
169    fn is_expired(&self) -> bool {
170        self.expires_at.map_or(false, |exp| Instant::now() >= exp)
171    }
172
173    fn touch(&mut self) {
174        self.last_accessed = Instant::now();
175    }
176}
177
178/// High-performance in-memory cache.
179///
180/// Uses a concurrent HashMap with LRU eviction and TTL support.
181pub struct MemoryCache {
182    config: MemoryCacheConfig,
183    entries: RwLock<HashMap<String, CacheEntry>>,
184    tag_index: RwLock<HashMap<String, HashSet<String>>>,
185    entry_count: AtomicUsize,
186}
187
188impl MemoryCache {
189    /// Create a new memory cache with the given config.
190    pub fn new(config: MemoryCacheConfig) -> Self {
191        Self {
192            entries: RwLock::new(HashMap::with_capacity(config.max_capacity as usize)),
193            tag_index: RwLock::new(HashMap::new()),
194            entry_count: AtomicUsize::new(0),
195            config,
196        }
197    }
198
199    /// Create a builder.
200    pub fn builder() -> MemoryCacheBuilder {
201        MemoryCacheBuilder::new()
202    }
203
204    /// Get the config.
205    pub fn config(&self) -> &MemoryCacheConfig {
206        &self.config
207    }
208
209    /// Evict expired entries.
210    pub fn evict_expired(&self) -> usize {
211        let mut entries = self.entries.write();
212        let before = entries.len();
213
214        let expired_keys: Vec<String> = entries
215            .iter()
216            .filter(|(_, e)| e.is_expired())
217            .map(|(k, _)| k.clone())
218            .collect();
219
220        for key in &expired_keys {
221            if let Some(entry) = entries.remove(key) {
222                self.remove_from_tag_index(key, &entry.tags);
223            }
224        }
225
226        let evicted = before - entries.len();
227        self.entry_count.fetch_sub(evicted, Ordering::Relaxed);
228        evicted
229    }
230
231    /// Evict entries to make room (LRU).
232    fn evict_lru(&self, count: usize) {
233        let mut entries = self.entries.write();
234
235        // Find LRU entries
236        let mut by_access: Vec<_> = entries
237            .iter()
238            .map(|(k, e)| (k.clone(), e.last_accessed))
239            .collect();
240        by_access.sort_by_key(|(_, t)| *t);
241
242        for (key, _) in by_access.into_iter().take(count) {
243            if let Some(entry) = entries.remove(&key) {
244                self.remove_from_tag_index(&key, &entry.tags);
245            }
246        }
247
248        self.entry_count
249            .store(entries.len(), Ordering::Relaxed);
250    }
251
252    /// Add entry to tag index.
253    fn add_to_tag_index(&self, key: &str, tags: &[EntityTag]) {
254        if !self.config.enable_tags || tags.is_empty() {
255            return;
256        }
257
258        let mut index = self.tag_index.write();
259        for tag in tags {
260            index
261                .entry(tag.value().to_string())
262                .or_default()
263                .insert(key.to_string());
264        }
265    }
266
267    /// Remove entry from tag index.
268    fn remove_from_tag_index(&self, key: &str, tags: &[EntityTag]) {
269        if !self.config.enable_tags || tags.is_empty() {
270            return;
271        }
272
273        let mut index = self.tag_index.write();
274        for tag in tags {
275            if let Some(keys) = index.get_mut(tag.value()) {
276                keys.remove(key);
277                if keys.is_empty() {
278                    index.remove(tag.value());
279                }
280            }
281        }
282    }
283}
284
285impl CacheBackend for MemoryCache {
286    async fn get<T>(&self, key: &CacheKey) -> CacheResult<Option<T>>
287    where
288        T: serde::de::DeserializeOwned,
289    {
290        let key_str = key.as_str();
291
292        // Try to get with read lock first
293        {
294            let entries = self.entries.read();
295            if let Some(entry) = entries.get(&key_str) {
296                if entry.is_expired() {
297                    // Entry expired, will be cleaned up later
298                    return Ok(None);
299                }
300
301                // Deserialize
302                let value: T = serde_json::from_slice(&entry.data)
303                    .map_err(|e| CacheError::Deserialization(e.to_string()))?;
304
305                return Ok(Some(value));
306            }
307        }
308
309        // Update last_accessed with write lock
310        {
311            let mut entries = self.entries.write();
312            if let Some(entry) = entries.get_mut(&key_str) {
313                entry.touch();
314            }
315        }
316
317        Ok(None)
318    }
319
320    async fn set<T>(
321        &self,
322        key: &CacheKey,
323        value: &T,
324        ttl: Option<Duration>,
325    ) -> CacheResult<()>
326    where
327        T: serde::Serialize + Sync,
328    {
329        let key_str = key.as_str();
330
331        // Serialize
332        let data = serde_json::to_vec(value)
333            .map_err(|e| CacheError::Serialization(e.to_string()))?;
334
335        let effective_ttl = ttl.or(self.config.time_to_live);
336        let entry = CacheEntry::new(data, effective_ttl, Vec::new());
337
338        // Check capacity
339        let current = self.entry_count.load(Ordering::Relaxed);
340        if current >= self.config.max_capacity as usize {
341            // Evict some entries
342            self.evict_expired();
343            let still_over = self.entry_count.load(Ordering::Relaxed);
344            if still_over >= self.config.max_capacity as usize {
345                self.evict_lru((self.config.max_capacity as usize / 10).max(1));
346            }
347        }
348
349        // Insert
350        {
351            let mut entries = self.entries.write();
352            let is_new = !entries.contains_key(&key_str);
353            entries.insert(key_str.clone(), entry);
354            if is_new {
355                self.entry_count.fetch_add(1, Ordering::Relaxed);
356            }
357        }
358
359        Ok(())
360    }
361
362    async fn delete(&self, key: &CacheKey) -> CacheResult<bool> {
363        let key_str = key.as_str();
364
365        let mut entries = self.entries.write();
366        if let Some(entry) = entries.remove(&key_str) {
367            self.remove_from_tag_index(&key_str, &entry.tags);
368            self.entry_count.fetch_sub(1, Ordering::Relaxed);
369            Ok(true)
370        } else {
371            Ok(false)
372        }
373    }
374
375    async fn exists(&self, key: &CacheKey) -> CacheResult<bool> {
376        let key_str = key.as_str();
377
378        let entries = self.entries.read();
379        if let Some(entry) = entries.get(&key_str) {
380            Ok(!entry.is_expired())
381        } else {
382            Ok(false)
383        }
384    }
385
386    async fn invalidate_pattern(&self, pattern: &KeyPattern) -> CacheResult<u64> {
387        let mut entries = self.entries.write();
388        let before = entries.len();
389
390        let matching_keys: Vec<String> = entries
391            .keys()
392            .filter(|k| pattern.matches_str(k))
393            .cloned()
394            .collect();
395
396        for key in &matching_keys {
397            if let Some(entry) = entries.remove(key) {
398                self.remove_from_tag_index(key, &entry.tags);
399            }
400        }
401
402        let removed = before - entries.len();
403        self.entry_count.fetch_sub(removed, Ordering::Relaxed);
404        Ok(removed as u64)
405    }
406
407    async fn invalidate_tags(&self, tags: &[EntityTag]) -> CacheResult<u64> {
408        if !self.config.enable_tags {
409            return Ok(0);
410        }
411
412        let keys_to_remove: HashSet<String> = {
413            let index = self.tag_index.read();
414            tags.iter()
415                .filter_map(|tag| index.get(tag.value()))
416                .flatten()
417                .cloned()
418                .collect()
419        };
420
421        let mut entries = self.entries.write();
422        let mut removed = 0u64;
423
424        for key in keys_to_remove {
425            if let Some(entry) = entries.remove(&key) {
426                self.remove_from_tag_index(&key, &entry.tags);
427                removed += 1;
428            }
429        }
430
431        self.entry_count.fetch_sub(removed as usize, Ordering::Relaxed);
432        Ok(removed)
433    }
434
435    async fn clear(&self) -> CacheResult<()> {
436        let mut entries = self.entries.write();
437        entries.clear();
438        self.entry_count.store(0, Ordering::Relaxed);
439
440        if self.config.enable_tags {
441            let mut index = self.tag_index.write();
442            index.clear();
443        }
444
445        Ok(())
446    }
447
448    async fn len(&self) -> CacheResult<usize> {
449        Ok(self.entry_count.load(Ordering::Relaxed))
450    }
451
452    async fn stats(&self) -> CacheResult<BackendStats> {
453        let entries = self.entries.read();
454        let memory_estimate: usize = entries
455            .values()
456            .map(|e| e.data.len() + 64) // Data + overhead estimate
457            .sum();
458
459        Ok(BackendStats {
460            entries: entries.len(),
461            memory_bytes: Some(memory_estimate),
462            connections: None,
463            info: Some(format!("MemoryCache (max: {})", self.config.max_capacity)),
464        })
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    #[tokio::test]
473    async fn test_memory_cache_basic() {
474        let cache = MemoryCache::new(MemoryCacheConfig::new(100));
475
476        let key = CacheKey::new("test", "key1");
477
478        // Set
479        cache.set(&key, &"hello", None).await.unwrap();
480
481        // Get
482        let value: Option<String> = cache.get(&key).await.unwrap();
483        assert_eq!(value, Some("hello".to_string()));
484
485        // Delete
486        assert!(cache.delete(&key).await.unwrap());
487
488        // Should be gone
489        let value: Option<String> = cache.get(&key).await.unwrap();
490        assert!(value.is_none());
491    }
492
493    #[tokio::test]
494    async fn test_memory_cache_ttl() {
495        let config = MemoryCacheConfig::new(100).with_ttl(Duration::from_millis(50));
496        let cache = MemoryCache::new(config);
497
498        let key = CacheKey::new("test", "ttl");
499        cache.set(&key, &"expires soon", None).await.unwrap();
500
501        // Should exist
502        let value: Option<String> = cache.get(&key).await.unwrap();
503        assert!(value.is_some());
504
505        // Wait for expiration
506        tokio::time::sleep(Duration::from_millis(60)).await;
507
508        // Should be expired
509        let value: Option<String> = cache.get(&key).await.unwrap();
510        assert!(value.is_none());
511    }
512
513    #[tokio::test]
514    async fn test_memory_cache_eviction() {
515        let cache = MemoryCache::new(MemoryCacheConfig::new(5));
516
517        // Fill cache
518        for i in 0..10 {
519            let key = CacheKey::new("test", format!("key{}", i));
520            cache.set(&key, &i, None).await.unwrap();
521        }
522
523        // Should have evicted some
524        let len = cache.len().await.unwrap();
525        assert!(len <= 5);
526    }
527
528    #[tokio::test]
529    async fn test_memory_cache_pattern_invalidation() {
530        let cache = MemoryCache::new(MemoryCacheConfig::new(100));
531
532        // Add some entries
533        for i in 0..5 {
534            let key = CacheKey::new("User", format!("id:{}", i));
535            cache.set(&key, &i, None).await.unwrap();
536        }
537        for i in 0..3 {
538            let key = CacheKey::new("Post", format!("id:{}", i));
539            cache.set(&key, &i, None).await.unwrap();
540        }
541
542        assert_eq!(cache.len().await.unwrap(), 8);
543
544        // Invalidate all User entries
545        let removed = cache
546            .invalidate_pattern(&KeyPattern::entity("User"))
547            .await
548            .unwrap();
549        assert_eq!(removed, 5);
550        assert_eq!(cache.len().await.unwrap(), 3);
551    }
552
553    #[tokio::test]
554    async fn test_memory_cache_builder() {
555        let cache = MemoryCache::builder()
556            .max_capacity(1000)
557            .time_to_live(Duration::from_secs(60))
558            .build();
559
560        assert_eq!(cache.config().max_capacity, 1000);
561        assert_eq!(cache.config().time_to_live, Some(Duration::from_secs(60)));
562    }
563}
564