Skip to main content

oxirs_vec/hnsw/
query_cache.rs

1//! Query result caching for HNSW index
2//!
3//! This module provides high-performance caching of query results
4//! to dramatically improve performance for repeated or similar queries.
5
6use crate::Vector;
7use blake3::Hasher;
8use lru::LruCache;
9use parking_lot::RwLock;
10use std::num::NonZeroUsize;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14/// Cached query result with expiration
15#[derive(Clone, Debug)]
16struct CachedResult {
17    /// Search results (URI, similarity score)
18    results: Vec<(String, f32)>,
19    /// Time when this result was cached
20    cached_at: Instant,
21    /// Number of times this result has been accessed
22    hit_count: usize,
23}
24
25impl CachedResult {
26    fn new(results: Vec<(String, f32)>) -> Self {
27        Self {
28            results,
29            cached_at: Instant::now(),
30            hit_count: 0,
31        }
32    }
33
34    fn is_expired(&self, ttl: Duration) -> bool {
35        self.cached_at.elapsed() > ttl
36    }
37
38    fn record_hit(&mut self) {
39        self.hit_count += 1;
40    }
41}
42
43/// Query cache configuration
44#[derive(Debug, Clone)]
45pub struct QueryCacheConfig {
46    /// Maximum number of cached queries
47    pub max_entries: usize,
48    /// Time-to-live for cached results
49    pub ttl: Duration,
50    /// Enable similarity-based cache lookup (find similar queries)
51    pub enable_fuzzy_matching: bool,
52    /// Similarity threshold for fuzzy matching (0.0-1.0)
53    pub fuzzy_threshold: f32,
54    /// Enable cache statistics tracking
55    pub enable_stats: bool,
56}
57
58impl Default for QueryCacheConfig {
59    fn default() -> Self {
60        Self {
61            max_entries: 10000,
62            ttl: Duration::from_secs(300), // 5 minutes
63            enable_fuzzy_matching: false,
64            fuzzy_threshold: 0.95,
65            enable_stats: true,
66        }
67    }
68}
69
70/// Query cache statistics
71#[derive(Debug, Clone, Default)]
72pub struct QueryCacheStats {
73    pub total_queries: u64,
74    pub cache_hits: u64,
75    pub cache_misses: u64,
76    pub evictions: u64,
77    pub expirations: u64,
78}
79
80impl QueryCacheStats {
81    pub fn hit_rate(&self) -> f64 {
82        if self.total_queries == 0 {
83            0.0
84        } else {
85            self.cache_hits as f64 / self.total_queries as f64
86        }
87    }
88}
89
90/// High-performance query result cache for HNSW index
91pub struct QueryCache {
92    /// LRU cache for query results
93    cache: Arc<RwLock<LruCache<u64, CachedResult>>>,
94    /// Cache configuration
95    config: QueryCacheConfig,
96    /// Cache statistics
97    stats: Arc<RwLock<QueryCacheStats>>,
98}
99
100impl QueryCache {
101    /// Create a new query cache
102    pub fn new(config: QueryCacheConfig) -> Self {
103        let capacity =
104            NonZeroUsize::new(config.max_entries).expect("cache max_entries must be non-zero");
105        Self {
106            cache: Arc::new(RwLock::new(LruCache::new(capacity))),
107            config,
108            stats: Arc::new(RwLock::new(QueryCacheStats::default())),
109        }
110    }
111
112    /// Generate cache key from query vector and parameters
113    fn generate_key(&self, query: &Vector, k: usize) -> u64 {
114        let mut hasher = Hasher::new();
115
116        // Hash the query vector
117        let query_f32 = query.as_f32();
118        for &val in &query_f32 {
119            hasher.update(&val.to_le_bytes());
120        }
121
122        // Hash the k parameter
123        hasher.update(&k.to_le_bytes());
124
125        // Get first 8 bytes as u64
126        let hash = hasher.finalize();
127        let hash_bytes = hash.as_bytes();
128        u64::from_le_bytes([
129            hash_bytes[0],
130            hash_bytes[1],
131            hash_bytes[2],
132            hash_bytes[3],
133            hash_bytes[4],
134            hash_bytes[5],
135            hash_bytes[6],
136            hash_bytes[7],
137        ])
138    }
139
140    /// Get cached results for a query
141    pub fn get(&self, query: &Vector, k: usize) -> Option<Vec<(String, f32)>> {
142        if self.config.enable_stats {
143            let mut stats = self.stats.write();
144            stats.total_queries += 1;
145        }
146
147        let key = self.generate_key(query, k);
148        let mut cache = self.cache.write();
149
150        if let Some(cached) = cache.get_mut(&key) {
151            // Check expiration
152            if cached.is_expired(self.config.ttl) {
153                cache.pop(&key);
154                if self.config.enable_stats {
155                    let mut stats = self.stats.write();
156                    stats.expirations += 1;
157                    stats.cache_misses += 1;
158                }
159                return None;
160            }
161
162            // Record hit and return results
163            cached.record_hit();
164            if self.config.enable_stats {
165                let mut stats = self.stats.write();
166                stats.cache_hits += 1;
167            }
168            return Some(cached.results.clone());
169        }
170
171        if self.config.enable_stats {
172            let mut stats = self.stats.write();
173            stats.cache_misses += 1;
174        }
175        None
176    }
177
178    /// Cache query results
179    pub fn put(&self, query: &Vector, k: usize, results: Vec<(String, f32)>) {
180        let key = self.generate_key(query, k);
181        let mut cache = self.cache.write();
182
183        let cached_result = CachedResult::new(results);
184
185        // Check if we're evicting an entry
186        if cache.len() >= self.config.max_entries && self.config.enable_stats {
187            let mut stats = self.stats.write();
188            stats.evictions += 1;
189        }
190
191        cache.put(key, cached_result);
192    }
193
194    /// Clear all cached results
195    pub fn clear(&self) {
196        let mut cache = self.cache.write();
197        cache.clear();
198    }
199
200    /// Get cache statistics
201    pub fn get_stats(&self) -> QueryCacheStats {
202        self.stats.read().clone()
203    }
204
205    /// Reset cache statistics
206    pub fn reset_stats(&self) {
207        let mut stats = self.stats.write();
208        *stats = QueryCacheStats::default();
209    }
210
211    /// Get current cache size
212    pub fn len(&self) -> usize {
213        self.cache.read().len()
214    }
215
216    /// Check if cache is empty
217    pub fn is_empty(&self) -> bool {
218        self.cache.read().is_empty()
219    }
220
221    /// Remove expired entries (maintenance operation)
222    pub fn cleanup_expired(&self) -> usize {
223        let mut cache = self.cache.write();
224        let mut expired_keys = Vec::new();
225
226        // Find expired entries
227        for (key, cached) in cache.iter() {
228            if cached.is_expired(self.config.ttl) {
229                expired_keys.push(*key);
230            }
231        }
232
233        // Remove expired entries
234        let count = expired_keys.len();
235        for key in expired_keys {
236            cache.pop(&key);
237        }
238
239        if self.config.enable_stats && count > 0 {
240            let mut stats = self.stats.write();
241            stats.expirations += count as u64;
242        }
243
244        count
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_query_cache_basic() {
254        let config = QueryCacheConfig::default();
255        let cache = QueryCache::new(config);
256
257        let query = Vector::new(vec![1.0, 2.0, 3.0]);
258        let results = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
259
260        // Cache miss on first access
261        assert!(cache.get(&query, 5).is_none());
262
263        // Put results in cache
264        cache.put(&query, 5, results.clone());
265
266        // Cache hit on second access
267        let cached = cache.get(&query, 5).unwrap();
268        assert_eq!(cached.len(), 2);
269        assert_eq!(cached[0].0, "uri1");
270        assert_eq!(cached[0].1, 0.9);
271    }
272
273    #[test]
274    fn test_query_cache_expiration() {
275        let config = QueryCacheConfig {
276            ttl: Duration::from_millis(100),
277            ..Default::default()
278        };
279        let cache = QueryCache::new(config);
280
281        let query = Vector::new(vec![1.0, 2.0, 3.0]);
282        let results = vec![("uri1".to_string(), 0.9)];
283
284        cache.put(&query, 5, results);
285
286        // Should hit immediately
287        assert!(cache.get(&query, 5).is_some());
288
289        // Wait for expiration
290        std::thread::sleep(Duration::from_millis(150));
291
292        // Should miss after expiration
293        assert!(cache.get(&query, 5).is_none());
294    }
295
296    #[test]
297    fn test_query_cache_stats() {
298        let config = QueryCacheConfig::default();
299        let cache = QueryCache::new(config);
300
301        let query = Vector::new(vec![1.0, 2.0, 3.0]);
302        let results = vec![("uri1".to_string(), 0.9)];
303
304        // Miss
305        cache.get(&query, 5);
306
307        // Put and hit
308        cache.put(&query, 5, results);
309        cache.get(&query, 5);
310        cache.get(&query, 5);
311
312        let stats = cache.get_stats();
313        assert_eq!(stats.total_queries, 3);
314        assert_eq!(stats.cache_hits, 2);
315        assert_eq!(stats.cache_misses, 1);
316        assert_eq!(stats.hit_rate(), 2.0 / 3.0);
317    }
318
319    #[test]
320    fn test_query_cache_cleanup() {
321        let config = QueryCacheConfig {
322            ttl: Duration::from_millis(100),
323            ..Default::default()
324        };
325        let cache = QueryCache::new(config);
326
327        // Add multiple entries
328        for i in 0..5 {
329            let query = Vector::new(vec![i as f32, 0.0, 0.0]);
330            let results = vec![(format!("uri{}", i), 0.9)];
331            cache.put(&query, 5, results);
332        }
333
334        assert_eq!(cache.len(), 5);
335
336        // Wait for expiration
337        std::thread::sleep(Duration::from_millis(150));
338
339        // Cleanup expired entries
340        let expired = cache.cleanup_expired();
341        assert_eq!(expired, 5);
342        assert_eq!(cache.len(), 0);
343    }
344
345    #[test]
346    fn test_query_cache_different_k() {
347        let config = QueryCacheConfig::default();
348        let cache = QueryCache::new(config);
349
350        let query = Vector::new(vec![1.0, 2.0, 3.0]);
351        let results_k5 = vec![("uri1".to_string(), 0.9)];
352        let results_k10 = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
353
354        // Cache with k=5
355        cache.put(&query, 5, results_k5);
356
357        // Cache with k=10
358        cache.put(&query, 10, results_k10);
359
360        // Different k values should have different cache entries
361        let cached_k5 = cache.get(&query, 5).unwrap();
362        let cached_k10 = cache.get(&query, 10).unwrap();
363
364        assert_eq!(cached_k5.len(), 1);
365        assert_eq!(cached_k10.len(), 2);
366    }
367}