rag_plusplus_core/cache/
query_cache.rs

1//! Query Cache Implementation
2//!
3//! LRU cache with TTL expiration for query results.
4
5use crate::retrieval::engine::QueryResponse;
6use ahash::AHashMap;
7use parking_lot::RwLock;
8use std::collections::VecDeque;
9use std::hash::{Hash, Hasher};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12
13/// Cache configuration.
14#[derive(Debug, Clone)]
15pub struct CacheConfig {
16    /// Maximum number of entries
17    pub max_entries: usize,
18    /// Time-to-live for entries
19    pub ttl: Duration,
20    /// Whether to cache queries with filters
21    pub cache_filtered: bool,
22}
23
24impl Default for CacheConfig {
25    fn default() -> Self {
26        Self {
27            max_entries: 10_000,
28            ttl: Duration::from_secs(300), // 5 minutes
29            cache_filtered: true,
30        }
31    }
32}
33
34impl CacheConfig {
35    /// Create new config with defaults.
36    #[must_use]
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Set max entries.
42    #[must_use]
43    pub const fn with_max_entries(mut self, max: usize) -> Self {
44        self.max_entries = max;
45        self
46    }
47
48    /// Set TTL.
49    #[must_use]
50    pub const fn with_ttl(mut self, ttl: Duration) -> Self {
51        self.ttl = ttl;
52        self
53    }
54
55    /// Set whether to cache filtered queries.
56    #[must_use]
57    pub const fn with_cache_filtered(mut self, cache: bool) -> Self {
58        self.cache_filtered = cache;
59        self
60    }
61}
62
63/// Cache key for queries.
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
65pub struct CacheKey {
66    /// Hash of embedding vector
67    embedding_hash: u64,
68    /// Number of results
69    k: usize,
70    /// Hash of filter (0 if no filter)
71    filter_hash: u64,
72    /// Hash of index names (0 if all indexes)
73    indexes_hash: u64,
74}
75
76impl CacheKey {
77    /// Create a cache key from query parameters.
78    #[must_use]
79    pub fn new(
80        embedding: &[f32],
81        k: usize,
82        filter_hash: Option<u64>,
83        indexes: Option<&[String]>,
84    ) -> Self {
85        Self {
86            embedding_hash: Self::hash_embedding(embedding),
87            k,
88            filter_hash: filter_hash.unwrap_or(0),
89            indexes_hash: indexes.map(Self::hash_indexes).unwrap_or(0),
90        }
91    }
92
93    /// Hash an embedding vector.
94    fn hash_embedding(embedding: &[f32]) -> u64 {
95        let mut hasher = xxhash_rust::xxh64::Xxh64::new(0);
96
97        for &value in embedding {
98            hasher.write(&value.to_le_bytes());
99        }
100
101        hasher.finish()
102    }
103
104    /// Hash index names.
105    fn hash_indexes(indexes: &[String]) -> u64 {
106        let mut hasher = xxhash_rust::xxh64::Xxh64::new(0);
107
108        for name in indexes {
109            hasher.write(name.as_bytes());
110        }
111
112        hasher.finish()
113    }
114}
115
116/// Cached entry with metadata.
117#[derive(Debug, Clone)]
118pub struct CacheEntry {
119    /// Cached response
120    pub response: QueryResponse,
121    /// When the entry was created
122    pub created_at: Instant,
123    /// Number of times this entry was accessed
124    pub access_count: u64,
125}
126
127impl CacheEntry {
128    /// Check if entry is expired.
129    #[must_use]
130    pub fn is_expired(&self, ttl: Duration) -> bool {
131        self.created_at.elapsed() > ttl
132    }
133}
134
135/// Cache statistics.
136#[derive(Debug, Clone, Default)]
137pub struct CacheStats {
138    /// Number of cache hits
139    pub hits: u64,
140    /// Number of cache misses
141    pub misses: u64,
142    /// Number of entries currently in cache
143    pub entries: usize,
144    /// Number of evictions
145    pub evictions: u64,
146    /// Number of expired entries removed
147    pub expirations: u64,
148}
149
150impl CacheStats {
151    /// Calculate hit ratio.
152    #[must_use]
153    pub fn hit_ratio(&self) -> f64 {
154        let total = self.hits + self.misses;
155        if total == 0 {
156            0.0
157        } else {
158            self.hits as f64 / total as f64
159        }
160    }
161}
162
163/// LRU query cache with TTL expiration.
164///
165/// Thread-safe cache for query results.
166///
167/// # Example
168///
169/// ```ignore
170/// use rag_plusplus_core::cache::{QueryCache, CacheConfig, CacheKey};
171///
172/// let cache = QueryCache::new(CacheConfig::default());
173///
174/// let key = CacheKey::new(&embedding, 10, None, None);
175///
176/// // Try cache first
177/// if let Some(response) = cache.get(&key) {
178///     return response;
179/// }
180///
181/// // Execute query
182/// let response = engine.query(request)?;
183///
184/// // Cache result
185/// cache.put(key, response.clone());
186/// ```
187pub struct QueryCache {
188    config: CacheConfig,
189    /// Cache entries
190    entries: RwLock<AHashMap<CacheKey, CacheEntry>>,
191    /// LRU order (front = oldest)
192    order: RwLock<VecDeque<CacheKey>>,
193    /// Statistics
194    hits: AtomicU64,
195    misses: AtomicU64,
196    evictions: AtomicU64,
197    expirations: AtomicU64,
198}
199
200impl std::fmt::Debug for QueryCache {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        f.debug_struct("QueryCache")
203            .field("config", &self.config)
204            .field("entries", &self.entries.read().len())
205            .finish()
206    }
207}
208
209impl QueryCache {
210    /// Create a new cache.
211    #[must_use]
212    pub fn new(config: CacheConfig) -> Self {
213        Self {
214            config,
215            entries: RwLock::new(AHashMap::new()),
216            order: RwLock::new(VecDeque::new()),
217            hits: AtomicU64::new(0),
218            misses: AtomicU64::new(0),
219            evictions: AtomicU64::new(0),
220            expirations: AtomicU64::new(0),
221        }
222    }
223
224    /// Create with default config.
225    #[must_use]
226    pub fn default_cache() -> Self {
227        Self::new(CacheConfig::default())
228    }
229
230    /// Get a cached response.
231    ///
232    /// Returns `Some(response)` if found and not expired, `None` otherwise.
233    pub fn get(&self, key: &CacheKey) -> Option<QueryResponse> {
234        // Check for entry
235        let entries = self.entries.read();
236
237        if let Some(entry) = entries.get(key) {
238            // Check expiration
239            if entry.is_expired(self.config.ttl) {
240                drop(entries);
241                self.remove(key);
242                self.expirations.fetch_add(1, Ordering::Relaxed);
243                self.misses.fetch_add(1, Ordering::Relaxed);
244                return None;
245            }
246
247            self.hits.fetch_add(1, Ordering::Relaxed);
248
249            // Move to back of LRU (update access)
250            drop(entries);
251            self.touch(key);
252
253            // Re-read after touch
254            let entries = self.entries.read();
255            entries.get(key).map(|e| e.response.clone())
256        } else {
257            self.misses.fetch_add(1, Ordering::Relaxed);
258            None
259        }
260    }
261
262    /// Put a response in the cache.
263    pub fn put(&self, key: CacheKey, response: QueryResponse) {
264        // Evict if necessary
265        self.maybe_evict();
266
267        let entry = CacheEntry {
268            response,
269            created_at: Instant::now(),
270            access_count: 1,
271        };
272
273        {
274            let mut entries = self.entries.write();
275            let mut order = self.order.write();
276
277            // Remove old entry if exists
278            if entries.contains_key(&key) {
279                order.retain(|k| k != &key);
280            }
281
282            entries.insert(key.clone(), entry);
283            order.push_back(key);
284        }
285    }
286
287    /// Remove an entry.
288    pub fn remove(&self, key: &CacheKey) -> Option<CacheEntry> {
289        let mut entries = self.entries.write();
290        let mut order = self.order.write();
291
292        order.retain(|k| k != key);
293        entries.remove(key)
294    }
295
296    /// Touch an entry (move to back of LRU).
297    fn touch(&self, key: &CacheKey) {
298        let mut order = self.order.write();
299
300        // Remove from current position
301        order.retain(|k| k != key);
302        // Add to back
303        order.push_back(key.clone());
304    }
305
306    /// Evict oldest entries if over capacity.
307    fn maybe_evict(&self) {
308        let entries = self.entries.read();
309        let current_size = entries.len();
310        drop(entries);
311
312        if current_size >= self.config.max_entries {
313            // Evict 10% of entries
314            let to_evict = self.config.max_entries / 10;
315            self.evict_oldest(to_evict.max(1));
316        }
317    }
318
319    /// Evict the n oldest entries.
320    fn evict_oldest(&self, n: usize) {
321        let mut entries = self.entries.write();
322        let mut order = self.order.write();
323
324        for _ in 0..n {
325            if let Some(key) = order.pop_front() {
326                entries.remove(&key);
327                self.evictions.fetch_add(1, Ordering::Relaxed);
328            } else {
329                break;
330            }
331        }
332    }
333
334    /// Clear all entries.
335    pub fn clear(&self) {
336        let mut entries = self.entries.write();
337        let mut order = self.order.write();
338
339        entries.clear();
340        order.clear();
341    }
342
343    /// Remove expired entries.
344    pub fn cleanup_expired(&self) {
345        let entries_snapshot: Vec<CacheKey> = {
346            let entries = self.entries.read();
347            entries
348                .iter()
349                .filter(|(_, entry)| entry.is_expired(self.config.ttl))
350                .map(|(key, _)| key.clone())
351                .collect()
352        };
353
354        for key in entries_snapshot {
355            self.remove(&key);
356            self.expirations.fetch_add(1, Ordering::Relaxed);
357        }
358    }
359
360    /// Get cache statistics.
361    #[must_use]
362    pub fn stats(&self) -> CacheStats {
363        CacheStats {
364            hits: self.hits.load(Ordering::Relaxed),
365            misses: self.misses.load(Ordering::Relaxed),
366            entries: self.entries.read().len(),
367            evictions: self.evictions.load(Ordering::Relaxed),
368            expirations: self.expirations.load(Ordering::Relaxed),
369        }
370    }
371
372    /// Get current size.
373    #[must_use]
374    pub fn len(&self) -> usize {
375        self.entries.read().len()
376    }
377
378    /// Check if empty.
379    #[must_use]
380    pub fn is_empty(&self) -> bool {
381        self.entries.read().is_empty()
382    }
383}
384
385impl Default for QueryCache {
386    fn default() -> Self {
387        Self::default_cache()
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::retrieval::engine::RetrievedRecord;
395    use crate::stats::OutcomeStats;
396    use crate::types::{MemoryRecord, RecordStatus};
397
398    fn create_test_response(result_count: usize) -> QueryResponse {
399        let results: Vec<RetrievedRecord> = (0..result_count)
400            .map(|i| RetrievedRecord {
401                record: MemoryRecord {
402                    id: format!("rec-{i}").into(),
403                    embedding: vec![1.0],
404                    context: format!("ctx-{i}"),
405                    outcome: 0.5,
406                    metadata: Default::default(),
407                    created_at: 0,
408                    status: RecordStatus::Active,
409                    stats: OutcomeStats::new(1),
410                },
411                score: 0.9 - (i as f32 * 0.1),
412                rank: i + 1,
413                source_index: "test".into(),
414            })
415            .collect();
416
417        QueryResponse {
418            results,
419            priors: None,
420            latency: Duration::from_millis(10),
421            indexes_searched: 1,
422            candidates_considered: result_count,
423        }
424    }
425
426    #[test]
427    fn test_cache_key() {
428        let key1 = CacheKey::new(&[1.0, 2.0, 3.0], 10, None, None);
429        let key2 = CacheKey::new(&[1.0, 2.0, 3.0], 10, None, None);
430        let key3 = CacheKey::new(&[1.0, 2.0, 4.0], 10, None, None);
431
432        assert_eq!(key1, key2);
433        assert_ne!(key1, key3);
434    }
435
436    #[test]
437    fn test_put_and_get() {
438        let cache = QueryCache::default_cache();
439        let key = CacheKey::new(&[1.0, 2.0], 5, None, None);
440        let response = create_test_response(5);
441
442        cache.put(key.clone(), response);
443
444        let cached = cache.get(&key);
445        assert!(cached.is_some());
446        assert_eq!(cached.unwrap().results.len(), 5);
447    }
448
449    #[test]
450    fn test_cache_miss() {
451        let cache = QueryCache::default_cache();
452        let key = CacheKey::new(&[1.0, 2.0], 5, None, None);
453
454        let cached = cache.get(&key);
455        assert!(cached.is_none());
456
457        let stats = cache.stats();
458        assert_eq!(stats.misses, 1);
459        assert_eq!(stats.hits, 0);
460    }
461
462    #[test]
463    fn test_cache_hit() {
464        let cache = QueryCache::default_cache();
465        let key = CacheKey::new(&[1.0, 2.0], 5, None, None);
466
467        cache.put(key.clone(), create_test_response(5));
468        cache.get(&key);
469
470        let stats = cache.stats();
471        assert_eq!(stats.hits, 1);
472    }
473
474    #[test]
475    fn test_ttl_expiration() {
476        let config = CacheConfig::new().with_ttl(Duration::from_millis(50));
477        let cache = QueryCache::new(config);
478
479        let key = CacheKey::new(&[1.0], 5, None, None);
480        cache.put(key.clone(), create_test_response(5));
481
482        // Should hit
483        assert!(cache.get(&key).is_some());
484
485        // Wait for expiration
486        std::thread::sleep(Duration::from_millis(60));
487
488        // Should miss (expired)
489        assert!(cache.get(&key).is_none());
490
491        let stats = cache.stats();
492        assert_eq!(stats.expirations, 1);
493    }
494
495    #[test]
496    fn test_lru_eviction() {
497        let config = CacheConfig::new().with_max_entries(5);
498        let cache = QueryCache::new(config);
499
500        // Fill cache
501        for i in 0..5 {
502            let key = CacheKey::new(&[i as f32], 1, None, None);
503            cache.put(key, create_test_response(1));
504        }
505
506        assert_eq!(cache.len(), 5);
507
508        // Add one more (should trigger eviction)
509        let key = CacheKey::new(&[100.0], 1, None, None);
510        cache.put(key, create_test_response(1));
511
512        // Cache should not exceed max
513        assert!(cache.len() <= 5);
514    }
515
516    #[test]
517    fn test_clear() {
518        let cache = QueryCache::default_cache();
519
520        for i in 0..10 {
521            let key = CacheKey::new(&[i as f32], 1, None, None);
522            cache.put(key, create_test_response(1));
523        }
524
525        assert_eq!(cache.len(), 10);
526
527        cache.clear();
528
529        assert!(cache.is_empty());
530    }
531
532    #[test]
533    fn test_hit_ratio() {
534        let cache = QueryCache::default_cache();
535        let key = CacheKey::new(&[1.0], 5, None, None);
536
537        cache.put(key.clone(), create_test_response(5));
538
539        // 3 hits
540        cache.get(&key);
541        cache.get(&key);
542        cache.get(&key);
543
544        // 1 miss
545        cache.get(&CacheKey::new(&[999.0], 5, None, None));
546
547        let stats = cache.stats();
548        assert!((stats.hit_ratio() - 0.75).abs() < 0.01);
549    }
550
551    #[test]
552    fn test_remove() {
553        let cache = QueryCache::default_cache();
554        let key = CacheKey::new(&[1.0], 5, None, None);
555
556        cache.put(key.clone(), create_test_response(5));
557        assert!(cache.get(&key).is_some());
558
559        cache.remove(&key);
560        assert!(cache.get(&key).is_none());
561    }
562}