oxirs_vec/hnsw/
query_cache.rs

1//! Query result caching for HNSW index
2//!
3//! This module provides high-performance caching of query results
4//! to dramatically improve performance for repeated or similar queries.
5
6use crate::Vector;
7use blake3::Hasher;
8use lru::LruCache;
9use parking_lot::RwLock;
10use std::num::NonZeroUsize;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14/// Cached query result with expiration
15#[derive(Clone, Debug)]
16struct CachedResult {
17    /// Search results (URI, similarity score)
18    results: Vec<(String, f32)>,
19    /// Time when this result was cached
20    cached_at: Instant,
21    /// Number of times this result has been accessed
22    hit_count: usize,
23}
24
25impl CachedResult {
26    fn new(results: Vec<(String, f32)>) -> Self {
27        Self {
28            results,
29            cached_at: Instant::now(),
30            hit_count: 0,
31        }
32    }
33
34    fn is_expired(&self, ttl: Duration) -> bool {
35        self.cached_at.elapsed() > ttl
36    }
37
38    fn record_hit(&mut self) {
39        self.hit_count += 1;
40    }
41}
42
43/// Query cache configuration
44#[derive(Debug, Clone)]
45pub struct QueryCacheConfig {
46    /// Maximum number of cached queries
47    pub max_entries: usize,
48    /// Time-to-live for cached results
49    pub ttl: Duration,
50    /// Enable similarity-based cache lookup (find similar queries)
51    pub enable_fuzzy_matching: bool,
52    /// Similarity threshold for fuzzy matching (0.0-1.0)
53    pub fuzzy_threshold: f32,
54    /// Enable cache statistics tracking
55    pub enable_stats: bool,
56}
57
58impl Default for QueryCacheConfig {
59    fn default() -> Self {
60        Self {
61            max_entries: 10000,
62            ttl: Duration::from_secs(300), // 5 minutes
63            enable_fuzzy_matching: false,
64            fuzzy_threshold: 0.95,
65            enable_stats: true,
66        }
67    }
68}
69
70/// Query cache statistics
71#[derive(Debug, Clone, Default)]
72pub struct QueryCacheStats {
73    pub total_queries: u64,
74    pub cache_hits: u64,
75    pub cache_misses: u64,
76    pub evictions: u64,
77    pub expirations: u64,
78}
79
80impl QueryCacheStats {
81    pub fn hit_rate(&self) -> f64 {
82        if self.total_queries == 0 {
83            0.0
84        } else {
85            self.cache_hits as f64 / self.total_queries as f64
86        }
87    }
88}
89
90/// High-performance query result cache for HNSW index
91pub struct QueryCache {
92    /// LRU cache for query results
93    cache: Arc<RwLock<LruCache<u64, CachedResult>>>,
94    /// Cache configuration
95    config: QueryCacheConfig,
96    /// Cache statistics
97    stats: Arc<RwLock<QueryCacheStats>>,
98}
99
100impl QueryCache {
101    /// Create a new query cache
102    pub fn new(config: QueryCacheConfig) -> Self {
103        let capacity = NonZeroUsize::new(config.max_entries).unwrap();
104        Self {
105            cache: Arc::new(RwLock::new(LruCache::new(capacity))),
106            config,
107            stats: Arc::new(RwLock::new(QueryCacheStats::default())),
108        }
109    }
110
111    /// Generate cache key from query vector and parameters
112    fn generate_key(&self, query: &Vector, k: usize) -> u64 {
113        let mut hasher = Hasher::new();
114
115        // Hash the query vector
116        let query_f32 = query.as_f32();
117        for &val in &query_f32 {
118            hasher.update(&val.to_le_bytes());
119        }
120
121        // Hash the k parameter
122        hasher.update(&k.to_le_bytes());
123
124        // Get first 8 bytes as u64
125        let hash = hasher.finalize();
126        let hash_bytes = hash.as_bytes();
127        u64::from_le_bytes([
128            hash_bytes[0],
129            hash_bytes[1],
130            hash_bytes[2],
131            hash_bytes[3],
132            hash_bytes[4],
133            hash_bytes[5],
134            hash_bytes[6],
135            hash_bytes[7],
136        ])
137    }
138
139    /// Get cached results for a query
140    pub fn get(&self, query: &Vector, k: usize) -> Option<Vec<(String, f32)>> {
141        if self.config.enable_stats {
142            let mut stats = self.stats.write();
143            stats.total_queries += 1;
144        }
145
146        let key = self.generate_key(query, k);
147        let mut cache = self.cache.write();
148
149        if let Some(cached) = cache.get_mut(&key) {
150            // Check expiration
151            if cached.is_expired(self.config.ttl) {
152                cache.pop(&key);
153                if self.config.enable_stats {
154                    let mut stats = self.stats.write();
155                    stats.expirations += 1;
156                    stats.cache_misses += 1;
157                }
158                return None;
159            }
160
161            // Record hit and return results
162            cached.record_hit();
163            if self.config.enable_stats {
164                let mut stats = self.stats.write();
165                stats.cache_hits += 1;
166            }
167            return Some(cached.results.clone());
168        }
169
170        if self.config.enable_stats {
171            let mut stats = self.stats.write();
172            stats.cache_misses += 1;
173        }
174        None
175    }
176
177    /// Cache query results
178    pub fn put(&self, query: &Vector, k: usize, results: Vec<(String, f32)>) {
179        let key = self.generate_key(query, k);
180        let mut cache = self.cache.write();
181
182        let cached_result = CachedResult::new(results);
183
184        // Check if we're evicting an entry
185        if cache.len() >= self.config.max_entries && self.config.enable_stats {
186            let mut stats = self.stats.write();
187            stats.evictions += 1;
188        }
189
190        cache.put(key, cached_result);
191    }
192
193    /// Clear all cached results
194    pub fn clear(&self) {
195        let mut cache = self.cache.write();
196        cache.clear();
197    }
198
199    /// Get cache statistics
200    pub fn get_stats(&self) -> QueryCacheStats {
201        self.stats.read().clone()
202    }
203
204    /// Reset cache statistics
205    pub fn reset_stats(&self) {
206        let mut stats = self.stats.write();
207        *stats = QueryCacheStats::default();
208    }
209
210    /// Get current cache size
211    pub fn len(&self) -> usize {
212        self.cache.read().len()
213    }
214
215    /// Check if cache is empty
216    pub fn is_empty(&self) -> bool {
217        self.cache.read().is_empty()
218    }
219
220    /// Remove expired entries (maintenance operation)
221    pub fn cleanup_expired(&self) -> usize {
222        let mut cache = self.cache.write();
223        let mut expired_keys = Vec::new();
224
225        // Find expired entries
226        for (key, cached) in cache.iter() {
227            if cached.is_expired(self.config.ttl) {
228                expired_keys.push(*key);
229            }
230        }
231
232        // Remove expired entries
233        let count = expired_keys.len();
234        for key in expired_keys {
235            cache.pop(&key);
236        }
237
238        if self.config.enable_stats && count > 0 {
239            let mut stats = self.stats.write();
240            stats.expirations += count as u64;
241        }
242
243        count
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_query_cache_basic() {
253        let config = QueryCacheConfig::default();
254        let cache = QueryCache::new(config);
255
256        let query = Vector::new(vec![1.0, 2.0, 3.0]);
257        let results = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
258
259        // Cache miss on first access
260        assert!(cache.get(&query, 5).is_none());
261
262        // Put results in cache
263        cache.put(&query, 5, results.clone());
264
265        // Cache hit on second access
266        let cached = cache.get(&query, 5).unwrap();
267        assert_eq!(cached.len(), 2);
268        assert_eq!(cached[0].0, "uri1");
269        assert_eq!(cached[0].1, 0.9);
270    }
271
272    #[test]
273    fn test_query_cache_expiration() {
274        let config = QueryCacheConfig {
275            ttl: Duration::from_millis(100),
276            ..Default::default()
277        };
278        let cache = QueryCache::new(config);
279
280        let query = Vector::new(vec![1.0, 2.0, 3.0]);
281        let results = vec![("uri1".to_string(), 0.9)];
282
283        cache.put(&query, 5, results);
284
285        // Should hit immediately
286        assert!(cache.get(&query, 5).is_some());
287
288        // Wait for expiration
289        std::thread::sleep(Duration::from_millis(150));
290
291        // Should miss after expiration
292        assert!(cache.get(&query, 5).is_none());
293    }
294
295    #[test]
296    fn test_query_cache_stats() {
297        let config = QueryCacheConfig::default();
298        let cache = QueryCache::new(config);
299
300        let query = Vector::new(vec![1.0, 2.0, 3.0]);
301        let results = vec![("uri1".to_string(), 0.9)];
302
303        // Miss
304        cache.get(&query, 5);
305
306        // Put and hit
307        cache.put(&query, 5, results);
308        cache.get(&query, 5);
309        cache.get(&query, 5);
310
311        let stats = cache.get_stats();
312        assert_eq!(stats.total_queries, 3);
313        assert_eq!(stats.cache_hits, 2);
314        assert_eq!(stats.cache_misses, 1);
315        assert_eq!(stats.hit_rate(), 2.0 / 3.0);
316    }
317
318    #[test]
319    fn test_query_cache_cleanup() {
320        let config = QueryCacheConfig {
321            ttl: Duration::from_millis(100),
322            ..Default::default()
323        };
324        let cache = QueryCache::new(config);
325
326        // Add multiple entries
327        for i in 0..5 {
328            let query = Vector::new(vec![i as f32, 0.0, 0.0]);
329            let results = vec![(format!("uri{}", i), 0.9)];
330            cache.put(&query, 5, results);
331        }
332
333        assert_eq!(cache.len(), 5);
334
335        // Wait for expiration
336        std::thread::sleep(Duration::from_millis(150));
337
338        // Cleanup expired entries
339        let expired = cache.cleanup_expired();
340        assert_eq!(expired, 5);
341        assert_eq!(cache.len(), 0);
342    }
343
344    #[test]
345    fn test_query_cache_different_k() {
346        let config = QueryCacheConfig::default();
347        let cache = QueryCache::new(config);
348
349        let query = Vector::new(vec![1.0, 2.0, 3.0]);
350        let results_k5 = vec![("uri1".to_string(), 0.9)];
351        let results_k10 = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
352
353        // Cache with k=5
354        cache.put(&query, 5, results_k5);
355
356        // Cache with k=10
357        cache.put(&query, 10, results_k10);
358
359        // Different k values should have different cache entries
360        let cached_k5 = cache.get(&query, 5).unwrap();
361        let cached_k10 = cache.get(&query, 10).unwrap();
362
363        assert_eq!(cached_k5.len(), 1);
364        assert_eq!(cached_k10.len(), 2);
365    }
366}