oxify_vector/
cache.rs

1//! Query result caching for improved search performance.
2//!
3//! This module provides caching mechanisms for vector search results to avoid
4//! recomputing identical or similar queries. Particularly useful for production
5//! RAG systems with repeated queries.
6//!
7//! ## Features
8//!
9//! - LRU (Least Recently Used) eviction
10//! - TTL (Time-To-Live) expiration
11//! - Approximate query matching with configurable tolerance
12//! - Cache statistics and monitoring
13//! - Thread-safe concurrent access
14//!
15//! ## Example
16//!
17//! ```rust
18//! use oxify_vector::cache::{QueryCache, CacheConfig};
19//! use oxify_vector::{SearchResult, DistanceMetric};
20//!
21//! # fn example() -> anyhow::Result<()> {
22//! let config = CacheConfig::default();
23//! let mut cache = QueryCache::new(config);
24//!
25//! let query = vec![1.0, 2.0, 3.0];
26//! let results = vec![
27//!     SearchResult {
28//!         entity_id: "doc1".to_string(),
29//!         score: 0.95,
30//!         distance: 0.05,
31//!         rank: 1,
32//!     },
33//! ];
34//!
35//! // Cache the results
36//! cache.put(&query, DistanceMetric::Cosine, 10, results.clone());
37//!
38//! // Retrieve from cache
39//! if let Some(cached) = cache.get(&query, DistanceMetric::Cosine, 10) {
40//!     println!("Cache hit! Found {} results", cached.len());
41//! }
42//! # Ok(())
43//! # }
44//! ```
45
46use crate::types::{DistanceMetric, SearchResult};
47use serde::{Deserialize, Serialize};
48use std::collections::{HashMap, VecDeque};
49use std::hash::{Hash, Hasher};
50use std::sync::{Arc, RwLock};
51use std::time::{Duration, Instant};
52
53/// Configuration for query result caching.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct CacheConfig {
56    /// Maximum number of cached queries.
57    pub max_entries: usize,
58    /// Time-to-live for cached results.
59    pub ttl: Duration,
60    /// Tolerance for approximate query matching (0.0 = exact match).
61    pub similarity_threshold: f32,
62    /// Whether to enable approximate matching.
63    pub enable_approximate_matching: bool,
64}
65
66impl Default for CacheConfig {
67    fn default() -> Self {
68        Self {
69            max_entries: 1000,
70            ttl: Duration::from_secs(300), // 5 minutes
71            similarity_threshold: 0.99,    // 99% similar
72            enable_approximate_matching: false,
73        }
74    }
75}
76
77impl CacheConfig {
78    /// Create a config optimized for high hit rate (more entries, longer TTL).
79    pub fn high_hit_rate() -> Self {
80        Self {
81            max_entries: 10_000,
82            ttl: Duration::from_secs(3600), // 1 hour
83            similarity_threshold: 0.95,
84            enable_approximate_matching: true,
85        }
86    }
87
88    /// Create a config optimized for low memory (fewer entries, shorter TTL).
89    pub fn low_memory() -> Self {
90        Self {
91            max_entries: 100,
92            ttl: Duration::from_secs(60), // 1 minute
93            similarity_threshold: 0.99,
94            enable_approximate_matching: false,
95        }
96    }
97
98    /// Create a config with exact matching only.
99    pub fn exact_match_only() -> Self {
100        Self {
101            max_entries: 1000,
102            ttl: Duration::from_secs(300),
103            similarity_threshold: 1.0,
104            enable_approximate_matching: false,
105        }
106    }
107}
108
109/// Query cache key combining query vector, metric, and k.
110#[derive(Debug, Clone, PartialEq)]
111struct CacheKey {
112    query_hash: u64,
113    metric: DistanceMetric,
114    k: usize,
115}
116
117impl Hash for CacheKey {
118    fn hash<H: Hasher>(&self, state: &mut H) {
119        self.query_hash.hash(state);
120        // Hash the metric by discriminant
121        std::mem::discriminant(&self.metric).hash(state);
122        self.k.hash(state);
123    }
124}
125
126impl Eq for CacheKey {}
127
128/// Cached query entry with results and metadata.
129#[derive(Debug, Clone)]
130struct CacheEntry {
131    results: Vec<SearchResult>,
132    inserted_at: Instant,
133    last_accessed: Instant,
134    access_count: u64,
135    query: Vec<f32>, // Store for approximate matching
136}
137
138impl CacheEntry {
139    fn new(query: Vec<f32>, results: Vec<SearchResult>) -> Self {
140        let now = Instant::now();
141        Self {
142            results,
143            inserted_at: now,
144            last_accessed: now,
145            access_count: 0,
146            query,
147        }
148    }
149
150    fn is_expired(&self, ttl: Duration) -> bool {
151        self.inserted_at.elapsed() > ttl
152    }
153
154    fn touch(&mut self) {
155        self.last_accessed = Instant::now();
156        self.access_count += 1;
157    }
158}
159
160/// Thread-safe query result cache with LRU eviction.
161pub struct QueryCache {
162    config: CacheConfig,
163    cache: Arc<RwLock<HashMap<CacheKey, CacheEntry>>>,
164    access_order: Arc<RwLock<VecDeque<CacheKey>>>,
165    stats: Arc<RwLock<CacheStats>>,
166}
167
168impl QueryCache {
169    /// Create a new query cache with the given configuration.
170    pub fn new(config: CacheConfig) -> Self {
171        Self {
172            config,
173            cache: Arc::new(RwLock::new(HashMap::new())),
174            access_order: Arc::new(RwLock::new(VecDeque::new())),
175            stats: Arc::new(RwLock::new(CacheStats::default())),
176        }
177    }
178
179    /// Get cached results for a query.
180    ///
181    /// Returns `None` if the query is not cached or has expired.
182    pub fn get(
183        &self,
184        query: &[f32],
185        metric: DistanceMetric,
186        k: usize,
187    ) -> Option<Vec<SearchResult>> {
188        let key = self.make_key(query, metric, k);
189
190        // Try exact match first
191        if let Some(entry) = self.get_exact(&key) {
192            return Some(entry);
193        }
194
195        // Try approximate match if enabled
196        if self.config.enable_approximate_matching {
197            if let Some(entry) = self.get_approximate(query, metric, k) {
198                return Some(entry);
199            }
200        }
201
202        // Update stats for miss
203        if let Ok(mut stats) = self.stats.write() {
204            stats.misses += 1;
205        }
206
207        None
208    }
209
210    /// Get exact cache match.
211    fn get_exact(&self, key: &CacheKey) -> Option<Vec<SearchResult>> {
212        let mut cache = self.cache.write().ok()?;
213        let mut access_order = self.access_order.write().ok()?;
214
215        if let Some(entry) = cache.get_mut(key) {
216            // Check if expired
217            if entry.is_expired(self.config.ttl) {
218                cache.remove(key);
219                access_order.retain(|k| k != key);
220                if let Ok(mut stats) = self.stats.write() {
221                    stats.expirations += 1;
222                }
223                return None;
224            }
225
226            // Update access info
227            entry.touch();
228
229            // Move to front of access order (LRU)
230            access_order.retain(|k| k != key);
231            access_order.push_back(key.clone());
232
233            // Update stats
234            if let Ok(mut stats) = self.stats.write() {
235                stats.hits += 1;
236            }
237
238            return Some(entry.results.clone());
239        }
240
241        None
242    }
243
244    /// Get approximate cache match based on similarity threshold.
245    fn get_approximate(
246        &self,
247        query: &[f32],
248        metric: DistanceMetric,
249        k: usize,
250    ) -> Option<Vec<SearchResult>> {
251        let best_key = {
252            let cache = self.cache.read().ok()?;
253
254            // Find most similar cached query
255            let mut best_match: Option<(CacheKey, f32)> = None;
256
257            for (cache_key, entry) in cache.iter() {
258                // Only consider same metric and k
259                if cache_key.metric != metric || cache_key.k != k {
260                    continue;
261                }
262
263                // Skip expired entries
264                if entry.is_expired(self.config.ttl) {
265                    continue;
266                }
267
268                // Compute similarity
269                let similarity = cosine_similarity(&entry.query, query);
270
271                if similarity >= self.config.similarity_threshold {
272                    if let Some((_, best_sim)) = &best_match {
273                        if similarity > *best_sim {
274                            best_match = Some((cache_key.clone(), similarity));
275                        }
276                    } else {
277                        best_match = Some((cache_key.clone(), similarity));
278                    }
279                }
280            }
281
282            best_match.map(|(key, _)| key)
283        }; // cache read lock is released here
284
285        if let Some(key) = best_key {
286            return self.get_exact(&key);
287        }
288
289        None
290    }
291
292    /// Store query results in the cache.
293    pub fn put(
294        &mut self,
295        query: &[f32],
296        metric: DistanceMetric,
297        k: usize,
298        results: Vec<SearchResult>,
299    ) {
300        let key = self.make_key(query, metric, k);
301        let entry = CacheEntry::new(query.to_vec(), results);
302
303        let mut cache = match self.cache.write() {
304            Ok(c) => c,
305            Err(_) => return,
306        };
307
308        let mut access_order = match self.access_order.write() {
309            Ok(a) => a,
310            Err(_) => return,
311        };
312
313        // Evict if at capacity
314        if cache.len() >= self.config.max_entries && !cache.contains_key(&key) {
315            if let Some(oldest_key) = access_order.pop_front() {
316                cache.remove(&oldest_key);
317                if let Ok(mut stats) = self.stats.write() {
318                    stats.evictions += 1;
319                }
320            }
321        }
322
323        cache.insert(key.clone(), entry);
324        access_order.push_back(key);
325
326        if let Ok(mut stats) = self.stats.write() {
327            stats.inserts += 1;
328        }
329    }
330
331    /// Clear all cached entries.
332    pub fn clear(&mut self) {
333        if let Ok(mut cache) = self.cache.write() {
334            cache.clear();
335        }
336        if let Ok(mut access_order) = self.access_order.write() {
337            access_order.clear();
338        }
339        if let Ok(mut stats) = self.stats.write() {
340            *stats = CacheStats::default();
341        }
342    }
343
344    /// Remove expired entries from the cache.
345    pub fn evict_expired(&mut self) -> usize {
346        let mut cache = match self.cache.write() {
347            Ok(c) => c,
348            Err(_) => return 0,
349        };
350
351        let mut access_order = match self.access_order.write() {
352            Ok(a) => a,
353            Err(_) => return 0,
354        };
355
356        let mut expired_keys = Vec::new();
357
358        for (key, entry) in cache.iter() {
359            if entry.is_expired(self.config.ttl) {
360                expired_keys.push(key.clone());
361            }
362        }
363
364        let count = expired_keys.len();
365
366        for key in expired_keys {
367            cache.remove(&key);
368            access_order.retain(|k| k != &key);
369        }
370
371        if let Ok(mut stats) = self.stats.write() {
372            stats.expirations += count as u64;
373        }
374
375        count
376    }
377
378    /// Get cache statistics.
379    pub fn stats(&self) -> CacheStats {
380        self.stats.read().unwrap().clone()
381    }
382
383    /// Get current cache size.
384    pub fn len(&self) -> usize {
385        self.cache.read().unwrap().len()
386    }
387
388    /// Check if cache is empty.
389    pub fn is_empty(&self) -> bool {
390        self.len() == 0
391    }
392
393    /// Create a cache key from query parameters.
394    fn make_key(&self, query: &[f32], metric: DistanceMetric, k: usize) -> CacheKey {
395        CacheKey {
396            query_hash: hash_f32_slice(query),
397            metric,
398            k,
399        }
400    }
401}
402
403/// Statistics for cache performance monitoring.
404#[derive(Debug, Clone, Default, Serialize, Deserialize)]
405pub struct CacheStats {
406    /// Number of cache hits.
407    pub hits: u64,
408    /// Number of cache misses.
409    pub misses: u64,
410    /// Number of cache inserts.
411    pub inserts: u64,
412    /// Number of cache evictions (LRU).
413    pub evictions: u64,
414    /// Number of cache expirations (TTL).
415    pub expirations: u64,
416}
417
418impl CacheStats {
419    /// Calculate hit rate as a percentage.
420    pub fn hit_rate(&self) -> f64 {
421        let total = self.hits + self.misses;
422        if total == 0 {
423            0.0
424        } else {
425            (self.hits as f64 / total as f64) * 100.0
426        }
427    }
428
429    /// Calculate miss rate as a percentage.
430    pub fn miss_rate(&self) -> f64 {
431        100.0 - self.hit_rate()
432    }
433}
434
435/// Hash a float slice for cache key generation.
436fn hash_f32_slice(slice: &[f32]) -> u64 {
437    use std::collections::hash_map::DefaultHasher;
438
439    let mut hasher = DefaultHasher::new();
440
441    // Hash length first
442    slice.len().hash(&mut hasher);
443
444    // Hash each float as bits
445    for &val in slice {
446        val.to_bits().hash(&mut hasher);
447    }
448
449    hasher.finish()
450}
451
452/// Compute cosine similarity between two vectors.
453fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
454    if a.len() != b.len() {
455        return 0.0;
456    }
457
458    let mut dot = 0.0;
459    let mut norm_a = 0.0;
460    let mut norm_b = 0.0;
461
462    for i in 0..a.len() {
463        dot += a[i] * b[i];
464        norm_a += a[i] * a[i];
465        norm_b += b[i] * b[i];
466    }
467
468    if norm_a == 0.0 || norm_b == 0.0 {
469        return 0.0;
470    }
471
472    dot / (norm_a.sqrt() * norm_b.sqrt())
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    #[test]
480    fn test_cache_config_default() {
481        let config = CacheConfig::default();
482        assert_eq!(config.max_entries, 1000);
483        assert_eq!(config.ttl, Duration::from_secs(300));
484        assert!(!config.enable_approximate_matching);
485    }
486
487    #[test]
488    fn test_cache_config_presets() {
489        let high_hit = CacheConfig::high_hit_rate();
490        assert_eq!(high_hit.max_entries, 10_000);
491        assert!(high_hit.enable_approximate_matching);
492
493        let low_mem = CacheConfig::low_memory();
494        assert_eq!(low_mem.max_entries, 100);
495        assert_eq!(low_mem.ttl, Duration::from_secs(60));
496
497        let exact = CacheConfig::exact_match_only();
498        assert_eq!(exact.similarity_threshold, 1.0);
499        assert!(!exact.enable_approximate_matching);
500    }
501
502    #[test]
503    fn test_query_cache_basic() {
504        let config = CacheConfig::default();
505        let mut cache = QueryCache::new(config);
506
507        let query = vec![1.0, 2.0, 3.0];
508        let results = vec![SearchResult {
509            entity_id: "doc1".to_string(),
510            score: 0.95,
511            distance: 0.05,
512            rank: 1,
513        }];
514
515        // Initially empty
516        assert!(cache.is_empty());
517
518        // Put and get
519        cache.put(&query, DistanceMetric::Cosine, 10, results.clone());
520        assert_eq!(cache.len(), 1);
521
522        let cached = cache.get(&query, DistanceMetric::Cosine, 10);
523        assert!(cached.is_some());
524        assert_eq!(cached.unwrap().len(), 1);
525    }
526
527    #[test]
528    fn test_query_cache_miss() {
529        let config = CacheConfig::default();
530        let cache = QueryCache::new(config);
531
532        let query = vec![1.0, 2.0, 3.0];
533        let cached = cache.get(&query, DistanceMetric::Cosine, 10);
534        assert!(cached.is_none());
535
536        let stats = cache.stats();
537        assert_eq!(stats.misses, 1);
538        assert_eq!(stats.hits, 0);
539    }
540
541    #[test]
542    fn test_query_cache_different_k() {
543        let config = CacheConfig::default();
544        let mut cache = QueryCache::new(config);
545
546        let query = vec![1.0, 2.0, 3.0];
547        let results = vec![SearchResult {
548            entity_id: "doc1".to_string(),
549            score: 0.95,
550            distance: 0.05,
551            rank: 1,
552        }];
553
554        cache.put(&query, DistanceMetric::Cosine, 10, results.clone());
555
556        // Same query but different k should miss
557        let cached = cache.get(&query, DistanceMetric::Cosine, 20);
558        assert!(cached.is_none());
559    }
560
561    #[test]
562    fn test_query_cache_different_metric() {
563        let config = CacheConfig::default();
564        let mut cache = QueryCache::new(config);
565
566        let query = vec![1.0, 2.0, 3.0];
567        let results = vec![SearchResult {
568            entity_id: "doc1".to_string(),
569            score: 0.95,
570            distance: 0.05,
571            rank: 1,
572        }];
573
574        cache.put(&query, DistanceMetric::Cosine, 10, results.clone());
575
576        // Same query but different metric should miss
577        let cached = cache.get(&query, DistanceMetric::Euclidean, 10);
578        assert!(cached.is_none());
579    }
580
581    #[test]
582    fn test_query_cache_lru_eviction() {
583        let config = CacheConfig {
584            max_entries: 2,
585            ..Default::default()
586        };
587        let mut cache = QueryCache::new(config);
588
589        let results = vec![SearchResult {
590            entity_id: "doc1".to_string(),
591            score: 0.95,
592            distance: 0.05,
593            rank: 1,
594        }];
595
596        // Add 3 queries (should evict oldest)
597        cache.put(&[1.0], DistanceMetric::Cosine, 10, results.clone());
598        cache.put(&[2.0], DistanceMetric::Cosine, 10, results.clone());
599        cache.put(&[3.0], DistanceMetric::Cosine, 10, results.clone());
600
601        assert_eq!(cache.len(), 2);
602
603        // First query should be evicted
604        let cached = cache.get(&[1.0], DistanceMetric::Cosine, 10);
605        assert!(cached.is_none());
606
607        // Last two should be present
608        assert!(cache.get(&[2.0], DistanceMetric::Cosine, 10).is_some());
609        assert!(cache.get(&[3.0], DistanceMetric::Cosine, 10).is_some());
610    }
611
612    #[test]
613    fn test_query_cache_clear() {
614        let config = CacheConfig::default();
615        let mut cache = QueryCache::new(config);
616
617        let query = vec![1.0, 2.0, 3.0];
618        let results = vec![SearchResult {
619            entity_id: "doc1".to_string(),
620            score: 0.95,
621            distance: 0.05,
622            rank: 1,
623        }];
624
625        cache.put(&query, DistanceMetric::Cosine, 10, results);
626        assert_eq!(cache.len(), 1);
627
628        cache.clear();
629        assert!(cache.is_empty());
630    }
631
632    #[test]
633    fn test_query_cache_stats() {
634        let config = CacheConfig::default();
635        let mut cache = QueryCache::new(config);
636
637        let query = vec![1.0, 2.0, 3.0];
638        let results = vec![SearchResult {
639            entity_id: "doc1".to_string(),
640            score: 0.95,
641            distance: 0.05,
642            rank: 1,
643        }];
644
645        cache.put(&query, DistanceMetric::Cosine, 10, results);
646        let stats = cache.stats();
647        assert_eq!(stats.inserts, 1);
648
649        // Hit
650        cache.get(&query, DistanceMetric::Cosine, 10);
651        let stats = cache.stats();
652        assert_eq!(stats.hits, 1);
653
654        // Miss
655        cache.get(&[9.0], DistanceMetric::Cosine, 10);
656        let stats = cache.stats();
657        assert_eq!(stats.misses, 1);
658
659        assert_eq!(stats.hit_rate(), 50.0);
660        assert_eq!(stats.miss_rate(), 50.0);
661    }
662
663    #[test]
664    fn test_hash_f32_slice() {
665        let a = vec![1.0, 2.0, 3.0];
666        let b = vec![1.0, 2.0, 3.0];
667        let c = vec![1.0, 2.0, 3.1];
668
669        assert_eq!(hash_f32_slice(&a), hash_f32_slice(&b));
670        assert_ne!(hash_f32_slice(&a), hash_f32_slice(&c));
671    }
672
673    #[test]
674    fn test_cosine_similarity() {
675        let a = vec![1.0, 0.0, 0.0];
676        let b = vec![1.0, 0.0, 0.0];
677        let c = vec![0.0, 1.0, 0.0];
678
679        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.01);
680        assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.01);
681    }
682
683    #[test]
684    fn test_cache_stats_hit_rate() {
685        let stats = CacheStats {
686            hits: 75,
687            misses: 25,
688            inserts: 100,
689            evictions: 0,
690            expirations: 0,
691        };
692
693        assert_eq!(stats.hit_rate(), 75.0);
694        assert_eq!(stats.miss_rate(), 25.0);
695    }
696
697    #[test]
698    fn test_cache_entry_expiration() {
699        let query = vec![1.0, 2.0, 3.0];
700        let results = vec![SearchResult {
701            entity_id: "doc1".to_string(),
702            score: 0.95,
703            distance: 0.05,
704            rank: 1,
705        }];
706
707        let entry = CacheEntry::new(query, results);
708
709        // Should not be expired immediately
710        assert!(!entry.is_expired(Duration::from_secs(1)));
711
712        // Should be expired after a very short TTL
713        std::thread::sleep(Duration::from_millis(10));
714        assert!(entry.is_expired(Duration::from_millis(1)));
715    }
716
717    #[test]
718    fn test_approximate_matching_disabled() {
719        let config = CacheConfig::exact_match_only();
720        let mut cache = QueryCache::new(config);
721
722        let query1 = vec![1.0, 0.0, 0.0];
723        let query2 = vec![0.99, 0.01, 0.0]; // Very similar but not exact
724
725        let results = vec![SearchResult {
726            entity_id: "doc1".to_string(),
727            score: 0.95,
728            distance: 0.05,
729            rank: 1,
730        }];
731
732        cache.put(&query1, DistanceMetric::Cosine, 10, results);
733
734        // Should not match due to disabled approximate matching
735        let cached = cache.get(&query2, DistanceMetric::Cosine, 10);
736        assert!(cached.is_none());
737    }
738
739    #[test]
740    fn test_approximate_matching_enabled() {
741        let config = CacheConfig {
742            enable_approximate_matching: true,
743            similarity_threshold: 0.95,
744            ..Default::default()
745        };
746        let mut cache = QueryCache::new(config);
747
748        let query1 = vec![1.0, 0.0, 0.0];
749        let query2 = vec![0.99, 0.14, 0.0]; // >95% similar
750
751        let results = vec![SearchResult {
752            entity_id: "doc1".to_string(),
753            score: 0.95,
754            distance: 0.05,
755            rank: 1,
756        }];
757
758        cache.put(&query1, DistanceMetric::Cosine, 10, results);
759
760        // Should match due to approximate matching
761        let cached = cache.get(&query2, DistanceMetric::Cosine, 10);
762        assert!(cached.is_some());
763    }
764}