Skip to main content

ares/rag/
cache.rs

1//! Embedding Cache for RAG Pipeline
2//!
3//! This module provides caching for text embeddings to avoid re-computing
4//! vectors for unchanged content. This is especially valuable for:
5//!
6//! - Large document re-indexing
7//! - Frequently accessed documents
8//! - Multi-collection setups with shared documents
9//!
10//! # Cache Key Strategy
11//!
12//! Cache keys are computed as SHA-256 hashes of `text + model_name` to ensure:
13//! - Unique keys for different content
14//! - Model-specific embeddings (different models produce different vectors)
15//! - Consistent keys across restarts
16//!
17//! # Implementation
18//!
19//! Uses the `lru` crate for O(1) get/put operations with proper LRU eviction.
20//! The cache is thread-safe via `parking_lot::Mutex`.
21//!
22//! # Example
23//!
24//! ```ignore
25//! use ares::rag::cache::{EmbeddingCache, LruEmbeddingCache, CacheConfig};
26//!
27//! // Create a cache with 512MB max size
28//! let cache = LruEmbeddingCache::new(CacheConfig {
29//!     max_size_bytes: 512 * 1024 * 1024,
30//!     ..Default::default()
31//! });
32//!
33//! // Check cache before computing embedding
34//! let key = cache.compute_key("hello world", "bge-small-en-v1.5");
35//! if let Some(embedding) = cache.get(&key).await {
36//!     // Use cached embedding
37//! } else {
38//!     // Compute and cache
39//!     let embedding = embed("hello world").await?;
40//!     cache.set(&key, embedding.clone(), None).await?;
41//! }
42//! ```
43
44use std::num::NonZeroUsize;
45use std::sync::atomic::{AtomicU64, Ordering};
46use std::time::{Duration, Instant};
47
48use lru::LruCache;
49use parking_lot::Mutex;
50use serde::{Deserialize, Serialize};
51use sha2::{Digest, Sha256};
52
53use crate::types::Result;
54
55// ============================================================================
56// Cache Types
57// ============================================================================
58
59/// Statistics for cache performance monitoring
60#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct CacheStats {
62    /// Number of cache hits
63    pub hits: u64,
64    /// Number of cache misses
65    pub misses: u64,
66    /// Current size in bytes (approximate)
67    pub size_bytes: u64,
68    /// Number of entries in cache
69    pub entry_count: usize,
70    /// Number of evictions due to capacity
71    pub evictions: u64,
72}
73
74impl CacheStats {
75    /// Calculate hit rate as a percentage
76    pub fn hit_rate(&self) -> f64 {
77        let total = self.hits + self.misses;
78        if total == 0 {
79            0.0
80        } else {
81            (self.hits as f64 / total as f64) * 100.0
82        }
83    }
84}
85
86/// Configuration for the embedding cache
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct CacheConfig {
89    /// Maximum cache size in bytes (default: 256MB)
90    #[serde(default = "default_max_size_bytes")]
91    pub max_size_bytes: u64,
92
93    /// Default TTL for cache entries (None = no expiry)
94    #[serde(default)]
95    pub default_ttl: Option<Duration>,
96
97    /// Whether the cache is enabled
98    #[serde(default = "default_enabled")]
99    pub enabled: bool,
100}
101
102fn default_max_size_bytes() -> u64 {
103    256 * 1024 * 1024 // 256 MB
104}
105
106fn default_enabled() -> bool {
107    true
108}
109
110impl Default for CacheConfig {
111    fn default() -> Self {
112        Self {
113            max_size_bytes: default_max_size_bytes(),
114            default_ttl: None,
115            enabled: default_enabled(),
116        }
117    }
118}
119
120// ============================================================================
121// Cache Trait
122// ============================================================================
123
124/// Trait for embedding cache implementations
125///
126/// This trait defines the interface for caching embeddings. Implementations
127/// can use different backends (in-memory, Redis, disk, etc.).
128pub trait EmbeddingCache: Send + Sync {
129    /// Get an embedding from the cache
130    fn get(&self, key: &str) -> Option<Vec<f32>>;
131
132    /// Store an embedding in the cache with optional TTL
133    fn set(&self, key: &str, embedding: Vec<f32>, ttl: Option<Duration>) -> Result<()>;
134
135    /// Remove an entry from the cache
136    fn invalidate(&self, key: &str) -> Result<()>;
137
138    /// Clear all entries from the cache
139    fn clear(&self) -> Result<()>;
140
141    /// Get cache statistics
142    fn stats(&self) -> CacheStats;
143
144    /// Compute a cache key for the given text and model
145    fn compute_key(&self, text: &str, model: &str) -> String {
146        let mut hasher = Sha256::new();
147        hasher.update(text.as_bytes());
148        hasher.update(b"|");
149        hasher.update(model.as_bytes());
150        format!("{:x}", hasher.finalize())
151    }
152
153    /// Check if the cache is enabled
154    fn is_enabled(&self) -> bool;
155}
156
157// ============================================================================
158// LRU Cache Entry
159// ============================================================================
160
161/// A cache entry with metadata for expiration
162#[derive(Debug, Clone)]
163struct CacheEntry {
164    /// The cached embedding vector
165    embedding: Vec<f32>,
166    /// Optional expiry time
167    expires_at: Option<Instant>,
168    /// Size in bytes (approximate)
169    size_bytes: usize,
170}
171
172impl CacheEntry {
173    fn new(embedding: Vec<f32>, ttl: Option<Duration>) -> Self {
174        let now = Instant::now();
175        let size_bytes = embedding.len() * std::mem::size_of::<f32>();
176        Self {
177            embedding,
178            expires_at: ttl.map(|d| now + d),
179            size_bytes,
180        }
181    }
182
183    fn is_expired(&self) -> bool {
184        self.expires_at
185            .map(|exp| Instant::now() > exp)
186            .unwrap_or(false)
187    }
188}
189
190// ============================================================================
191// LRU Embedding Cache
192// ============================================================================
193
194/// Default maximum number of entries in the LRU cache
195const DEFAULT_MAX_ENTRIES: usize = 10_000;
196
197/// In-memory LRU cache for embeddings
198///
199/// Uses the `lru` crate for O(1) get/put operations with proper LRU eviction.
200/// Thread-safe via `parking_lot::Mutex`.
201///
202/// # Memory Management
203///
204/// The cache limits entries by count (not bytes) for simplicity and O(1) operations.
205/// The `max_size_bytes` config is used to estimate max entries based on average
206/// embedding size (assuming 384-dimensional embeddings = 1536 bytes each).
207pub struct LruEmbeddingCache {
208    /// The LRU cache storage (key -> CacheEntry)
209    cache: Mutex<LruCache<String, CacheEntry>>,
210    /// Configuration
211    config: CacheConfig,
212    /// Current size in bytes (approximate)
213    current_size: AtomicU64,
214    /// Cache hit counter
215    hits: AtomicU64,
216    /// Cache miss counter
217    misses: AtomicU64,
218    /// Eviction counter
219    evictions: AtomicU64,
220}
221
222impl LruEmbeddingCache {
223    /// Create a new LRU embedding cache with the given configuration
224    pub fn new(config: CacheConfig) -> Self {
225        // Estimate max entries from max_size_bytes
226        // Assume average embedding is 384 dimensions = 1536 bytes
227        let avg_entry_size = 384 * std::mem::size_of::<f32>(); // 1536 bytes
228        let max_entries = (config.max_size_bytes as usize / avg_entry_size).max(100);
229        let capacity = NonZeroUsize::new(max_entries)
230            .unwrap_or(NonZeroUsize::new(DEFAULT_MAX_ENTRIES).unwrap());
231
232        Self {
233            cache: Mutex::new(LruCache::new(capacity)),
234            config,
235            current_size: AtomicU64::new(0),
236            hits: AtomicU64::new(0),
237            misses: AtomicU64::new(0),
238            evictions: AtomicU64::new(0),
239        }
240    }
241
242    /// Create a cache with default configuration
243    pub fn with_defaults() -> Self {
244        Self::new(CacheConfig::default())
245    }
246
247    /// Create a cache with a specific max size in bytes
248    pub fn with_max_size(max_size_bytes: u64) -> Self {
249        Self::new(CacheConfig {
250            max_size_bytes,
251            ..Default::default()
252        })
253    }
254
255    /// Create a cache with a specific max entry count
256    pub fn with_max_entries(max_entries: usize) -> Self {
257        let capacity = NonZeroUsize::new(max_entries)
258            .unwrap_or(NonZeroUsize::new(DEFAULT_MAX_ENTRIES).unwrap());
259        Self {
260            cache: Mutex::new(LruCache::new(capacity)),
261            config: CacheConfig::default(),
262            current_size: AtomicU64::new(0),
263            hits: AtomicU64::new(0),
264            misses: AtomicU64::new(0),
265            evictions: AtomicU64::new(0),
266        }
267    }
268
269    /// Remove expired entries from the cache
270    pub fn cleanup_expired(&self) {
271        let mut cache = self.cache.lock();
272        let mut expired_keys = Vec::new();
273
274        // Collect expired keys (can't remove while iterating)
275        for (key, entry) in cache.iter() {
276            if entry.is_expired() {
277                expired_keys.push(key.clone());
278            }
279        }
280
281        // Remove expired entries
282        for key in expired_keys {
283            if let Some(entry) = cache.pop(&key) {
284                self.current_size
285                    .fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
286            }
287        }
288    }
289
290    /// Get the current cache size in bytes
291    pub fn size_bytes(&self) -> u64 {
292        self.current_size.load(Ordering::Relaxed)
293    }
294
295    /// Get the number of entries in the cache
296    pub fn len(&self) -> usize {
297        self.cache.lock().len()
298    }
299
300    /// Check if the cache is empty
301    pub fn is_empty(&self) -> bool {
302        self.cache.lock().is_empty()
303    }
304}
305
306impl EmbeddingCache for LruEmbeddingCache {
307    fn get(&self, key: &str) -> Option<Vec<f32>> {
308        if !self.config.enabled {
309            return None;
310        }
311
312        let mut cache = self.cache.lock();
313
314        // get() in lru crate automatically promotes to most recently used
315        if let Some(entry) = cache.get(key) {
316            if entry.is_expired() {
317                // Remove expired entry
318                let entry = cache.pop(key).unwrap();
319                self.current_size
320                    .fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
321                self.misses.fetch_add(1, Ordering::Relaxed);
322                return None;
323            }
324            self.hits.fetch_add(1, Ordering::Relaxed);
325            Some(entry.embedding.clone())
326        } else {
327            self.misses.fetch_add(1, Ordering::Relaxed);
328            None
329        }
330    }
331
332    fn set(&self, key: &str, embedding: Vec<f32>, ttl: Option<Duration>) -> Result<()> {
333        if !self.config.enabled {
334            return Ok(());
335        }
336
337        let entry = CacheEntry::new(embedding, ttl.or(self.config.default_ttl));
338        let entry_size = entry.size_bytes;
339
340        let mut cache = self.cache.lock();
341
342        // Remove old entry if exists (to update size tracking)
343        if let Some(old_entry) = cache.pop(key) {
344            self.current_size
345                .fetch_sub(old_entry.size_bytes as u64, Ordering::Relaxed);
346        }
347
348        // Check if cache is at capacity before push
349        let was_at_capacity = cache.len() == cache.cap().get();
350
351        // Push new entry (LRU eviction happens automatically if at capacity)
352        if let Some((_, evicted)) = cache.push(key.to_string(), entry) {
353            // An entry was evicted
354            self.current_size
355                .fetch_sub(evicted.size_bytes as u64, Ordering::Relaxed);
356            self.evictions.fetch_add(1, Ordering::Relaxed);
357        } else if was_at_capacity {
358            // We were at capacity but push didn't return evicted (shouldn't happen)
359            // but handle it just in case
360            self.evictions.fetch_add(1, Ordering::Relaxed);
361        }
362
363        // Update size
364        self.current_size
365            .fetch_add(entry_size as u64, Ordering::Relaxed);
366
367        Ok(())
368    }
369
370    fn invalidate(&self, key: &str) -> Result<()> {
371        let mut cache = self.cache.lock();
372        if let Some(entry) = cache.pop(key) {
373            self.current_size
374                .fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
375        }
376        Ok(())
377    }
378
379    fn clear(&self) -> Result<()> {
380        let mut cache = self.cache.lock();
381        cache.clear();
382        self.current_size.store(0, Ordering::Relaxed);
383        Ok(())
384    }
385
386    fn stats(&self) -> CacheStats {
387        CacheStats {
388            hits: self.hits.load(Ordering::Relaxed),
389            misses: self.misses.load(Ordering::Relaxed),
390            size_bytes: self.current_size.load(Ordering::Relaxed),
391            entry_count: self.cache.lock().len(),
392            evictions: self.evictions.load(Ordering::Relaxed),
393        }
394    }
395
396    fn is_enabled(&self) -> bool {
397        self.config.enabled
398    }
399}
400
401// ============================================================================
402// No-Op Cache
403// ============================================================================
404
405/// A no-op cache that doesn't store anything
406///
407/// Useful for disabling caching without changing the code structure.
408#[derive(Debug, Default)]
409pub struct NoOpCache;
410
411impl NoOpCache {
412    /// Create a new no-op cache
413    pub fn new() -> Self {
414        Self
415    }
416}
417
418impl EmbeddingCache for NoOpCache {
419    fn get(&self, _key: &str) -> Option<Vec<f32>> {
420        None
421    }
422
423    fn set(&self, _key: &str, _embedding: Vec<f32>, _ttl: Option<Duration>) -> Result<()> {
424        Ok(())
425    }
426
427    fn invalidate(&self, _key: &str) -> Result<()> {
428        Ok(())
429    }
430
431    fn clear(&self) -> Result<()> {
432        Ok(())
433    }
434
435    fn stats(&self) -> CacheStats {
436        CacheStats::default()
437    }
438
439    fn is_enabled(&self) -> bool {
440        false
441    }
442}
443
444// ============================================================================
445// Tests
446// ============================================================================
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_cache_key_computation() {
454        let cache = LruEmbeddingCache::with_defaults();
455
456        let key1 = cache.compute_key("hello world", "bge-small-en-v1.5");
457        let key2 = cache.compute_key("hello world", "bge-small-en-v1.5");
458        let key3 = cache.compute_key("hello world", "bge-base-en-v1.5");
459        let key4 = cache.compute_key("different text", "bge-small-en-v1.5");
460
461        // Same input should produce same key
462        assert_eq!(key1, key2);
463        // Different model should produce different key
464        assert_ne!(key1, key3);
465        // Different text should produce different key
466        assert_ne!(key1, key4);
467    }
468
469    #[test]
470    fn test_cache_set_and_get() {
471        let cache = LruEmbeddingCache::with_defaults();
472        let key = "test_key";
473        let embedding = vec![1.0, 2.0, 3.0, 4.0];
474
475        // Initially empty
476        assert!(cache.get(key).is_none());
477        assert_eq!(cache.stats().misses, 1);
478
479        // Set and get
480        cache.set(key, embedding.clone(), None).unwrap();
481        let retrieved = cache.get(key);
482
483        assert!(retrieved.is_some());
484        assert_eq!(retrieved.unwrap(), embedding);
485        assert_eq!(cache.stats().hits, 1);
486    }
487
488    #[test]
489    fn test_cache_invalidate() {
490        let cache = LruEmbeddingCache::with_defaults();
491        let key = "test_key";
492        let embedding = vec![1.0, 2.0, 3.0];
493
494        cache.set(key, embedding, None).unwrap();
495        assert!(cache.get(key).is_some());
496
497        cache.invalidate(key).unwrap();
498        assert!(cache.get(key).is_none());
499    }
500
501    #[test]
502    fn test_cache_clear() {
503        let cache = LruEmbeddingCache::with_defaults();
504
505        cache.set("key1", vec![1.0, 2.0], None).unwrap();
506        cache.set("key2", vec![3.0, 4.0], None).unwrap();
507
508        assert_eq!(cache.len(), 2);
509        assert!(cache.size_bytes() > 0);
510
511        cache.clear().unwrap();
512
513        assert_eq!(cache.len(), 0);
514        assert_eq!(cache.size_bytes(), 0);
515    }
516
517    #[test]
518    fn test_cache_lru_eviction() {
519        // Create a small cache with exactly 2 entries to test LRU eviction
520        let cache = LruEmbeddingCache::with_max_entries(2);
521
522        let embedding1 = vec![1.0, 2.0, 3.0, 4.0];
523        let embedding2 = vec![5.0, 6.0, 7.0, 8.0];
524        let embedding3 = vec![9.0, 10.0, 11.0, 12.0];
525
526        cache.set("key1", embedding1.clone(), None).unwrap();
527        cache.set("key2", embedding2.clone(), None).unwrap();
528
529        // Both should be present
530        assert!(cache.get("key1").is_some());
531        assert!(cache.get("key2").is_some());
532
533        // Adding a third should evict the LRU (key1, since key2 was accessed more recently)
534        cache.set("key3", embedding3.clone(), None).unwrap();
535
536        // key1 should be evicted
537        assert!(cache.get("key1").is_none());
538        // key2 and key3 should exist
539        assert!(cache.get("key2").is_some());
540        assert!(cache.get("key3").is_some());
541
542        assert!(cache.stats().evictions > 0);
543    }
544
545    #[test]
546    fn test_cache_ttl_expiry() {
547        let cache = LruEmbeddingCache::with_defaults();
548        let key = "test_key";
549        let embedding = vec![1.0, 2.0, 3.0];
550
551        // Set with 0 duration TTL (immediate expiry)
552        cache
553            .set(key, embedding, Some(Duration::from_nanos(1)))
554            .unwrap();
555
556        // Sleep briefly to ensure expiry
557        std::thread::sleep(Duration::from_millis(1));
558
559        // Should be expired
560        assert!(cache.get(key).is_none());
561    }
562
563    #[test]
564    fn test_cache_stats() {
565        let cache = LruEmbeddingCache::with_defaults();
566
567        // Generate some activity
568        cache.set("key1", vec![1.0, 2.0], None).unwrap();
569        let _ = cache.get("key1"); // hit
570        let _ = cache.get("key2"); // miss
571        let _ = cache.get("key3"); // miss
572
573        let stats = cache.stats();
574        assert_eq!(stats.hits, 1);
575        assert_eq!(stats.misses, 2);
576        assert_eq!(stats.entry_count, 1);
577        assert!(stats.size_bytes > 0);
578    }
579
580    #[test]
581    fn test_cache_hit_rate() {
582        let stats = CacheStats {
583            hits: 75,
584            misses: 25,
585            size_bytes: 0,
586            entry_count: 0,
587            evictions: 0,
588        };
589
590        assert!((stats.hit_rate() - 75.0).abs() < 0.001);
591    }
592
593    #[test]
594    fn test_noop_cache() {
595        let cache = NoOpCache::new();
596
597        // Set should succeed but not store
598        cache.set("key", vec![1.0, 2.0], None).unwrap();
599
600        // Get should always return None
601        assert!(cache.get("key").is_none());
602
603        // Stats should be empty
604        let stats = cache.stats();
605        assert_eq!(stats.hits, 0);
606        assert_eq!(stats.misses, 0);
607        assert!(!cache.is_enabled());
608    }
609
610    #[test]
611    fn test_cache_disabled() {
612        let cache = LruEmbeddingCache::new(CacheConfig {
613            enabled: false,
614            ..Default::default()
615        });
616
617        // Set should succeed but not store
618        cache.set("key", vec![1.0, 2.0], None).unwrap();
619
620        // Get should return None when disabled
621        assert!(cache.get("key").is_none());
622        assert!(!cache.is_enabled());
623    }
624
625    #[test]
626    fn test_cache_update_existing() {
627        let cache = LruEmbeddingCache::with_defaults();
628        let key = "test_key";
629
630        cache.set(key, vec![1.0, 2.0], None).unwrap();
631        let size1 = cache.size_bytes();
632
633        // Update with different embedding
634        cache.set(key, vec![3.0, 4.0, 5.0, 6.0], None).unwrap();
635        let size2 = cache.size_bytes();
636
637        // Size should have changed (old removed, new added)
638        assert!(size2 > size1);
639        assert_eq!(cache.len(), 1);
640
641        // Should get the new value
642        let retrieved = cache.get(key).unwrap();
643        assert_eq!(retrieved, vec![3.0, 4.0, 5.0, 6.0]);
644    }
645}