Skip to main content

oxirs_vec/hnsw/
query_cache.rs

1//! Query result caching for HNSW index
2//!
3//! This module provides high-performance caching of query results
4//! to dramatically improve performance for repeated or similar queries.
5
6use crate::Vector;
7use blake3::Hasher;
8use lru::LruCache;
9use parking_lot::RwLock;
10use std::num::NonZeroUsize;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14/// Cached query result with expiration
15#[derive(Clone, Debug)]
16struct CachedResult {
17    /// Search results (URI, similarity score)
18    results: Vec<(String, f32)>,
19    /// Time when this result was cached
20    cached_at: Instant,
21    /// Number of times this result has been accessed
22    hit_count: usize,
23}
24
25impl CachedResult {
26    fn new(results: Vec<(String, f32)>) -> Self {
27        Self {
28            results,
29            cached_at: Instant::now(),
30            hit_count: 0,
31        }
32    }
33
34    fn is_expired(&self, ttl: Duration) -> bool {
35        self.cached_at.elapsed() > ttl
36    }
37
38    fn record_hit(&mut self) {
39        self.hit_count += 1;
40    }
41}
42
43/// Query cache configuration
44#[derive(Debug, Clone)]
45pub struct QueryCacheConfig {
46    /// Maximum number of cached queries
47    pub max_entries: usize,
48    /// Time-to-live for cached results
49    pub ttl: Duration,
50    /// Enable similarity-based cache lookup (find similar queries)
51    pub enable_fuzzy_matching: bool,
52    /// Similarity threshold for fuzzy matching (0.0-1.0)
53    pub fuzzy_threshold: f32,
54    /// Enable cache statistics tracking
55    pub enable_stats: bool,
56}
57
58impl Default for QueryCacheConfig {
59    fn default() -> Self {
60        Self {
61            max_entries: 10000,
62            ttl: Duration::from_secs(300), // 5 minutes
63            enable_fuzzy_matching: false,
64            fuzzy_threshold: 0.95,
65            enable_stats: true,
66        }
67    }
68}
69
70/// Query cache statistics
71#[derive(Debug, Clone, Default)]
72pub struct QueryCacheStats {
73    pub total_queries: u64,
74    pub cache_hits: u64,
75    pub cache_misses: u64,
76    pub evictions: u64,
77    pub expirations: u64,
78}
79
80impl QueryCacheStats {
81    pub fn hit_rate(&self) -> f64 {
82        if self.total_queries == 0 {
83            0.0
84        } else {
85            self.cache_hits as f64 / self.total_queries as f64
86        }
87    }
88}
89
90/// High-performance query result cache for HNSW index
91pub struct QueryCache {
92    /// LRU cache for query results
93    cache: Arc<RwLock<LruCache<u64, CachedResult>>>,
94    /// Cache configuration
95    config: QueryCacheConfig,
96    /// Cache statistics
97    stats: Arc<RwLock<QueryCacheStats>>,
98}
99
100impl QueryCache {
101    /// Create a new query cache
102    pub fn new(config: QueryCacheConfig) -> Self {
103        let capacity =
104            NonZeroUsize::new(config.max_entries).expect("cache max_entries must be non-zero");
105        Self {
106            cache: Arc::new(RwLock::new(LruCache::new(capacity))),
107            config,
108            stats: Arc::new(RwLock::new(QueryCacheStats::default())),
109        }
110    }
111
112    /// Generate cache key from query vector and parameters
113    fn generate_key(&self, query: &Vector, k: usize) -> u64 {
114        let mut hasher = Hasher::new();
115
116        // Hash the query vector
117        let query_f32 = query.as_f32();
118        for &val in &query_f32 {
119            hasher.update(&val.to_le_bytes());
120        }
121
122        // Hash the k parameter
123        hasher.update(&k.to_le_bytes());
124
125        // Get first 8 bytes as u64
126        let hash = hasher.finalize();
127        let hash_bytes = hash.as_bytes();
128        u64::from_le_bytes([
129            hash_bytes[0],
130            hash_bytes[1],
131            hash_bytes[2],
132            hash_bytes[3],
133            hash_bytes[4],
134            hash_bytes[5],
135            hash_bytes[6],
136            hash_bytes[7],
137        ])
138    }
139
140    /// Get cached results for a query
141    pub fn get(&self, query: &Vector, k: usize) -> Option<Vec<(String, f32)>> {
142        if self.config.enable_stats {
143            let mut stats = self.stats.write();
144            stats.total_queries += 1;
145        }
146
147        let key = self.generate_key(query, k);
148        let mut cache = self.cache.write();
149
150        if let Some(cached) = cache.get_mut(&key) {
151            // Check expiration
152            if cached.is_expired(self.config.ttl) {
153                cache.pop(&key);
154                if self.config.enable_stats {
155                    let mut stats = self.stats.write();
156                    stats.expirations += 1;
157                    stats.cache_misses += 1;
158                }
159                return None;
160            }
161
162            // Record hit and return results
163            cached.record_hit();
164            if self.config.enable_stats {
165                let mut stats = self.stats.write();
166                stats.cache_hits += 1;
167            }
168            return Some(cached.results.clone());
169        }
170
171        if self.config.enable_stats {
172            let mut stats = self.stats.write();
173            stats.cache_misses += 1;
174        }
175        None
176    }
177
178    /// Cache query results
179    pub fn put(&self, query: &Vector, k: usize, results: Vec<(String, f32)>) {
180        let key = self.generate_key(query, k);
181        let mut cache = self.cache.write();
182
183        let cached_result = CachedResult::new(results);
184
185        // Check if we're evicting an entry
186        if cache.len() >= self.config.max_entries && self.config.enable_stats {
187            let mut stats = self.stats.write();
188            stats.evictions += 1;
189        }
190
191        cache.put(key, cached_result);
192    }
193
194    /// Clear all cached results
195    pub fn clear(&self) {
196        let mut cache = self.cache.write();
197        cache.clear();
198    }
199
200    /// Get cache statistics
201    pub fn get_stats(&self) -> QueryCacheStats {
202        self.stats.read().clone()
203    }
204
205    /// Reset cache statistics
206    pub fn reset_stats(&self) {
207        let mut stats = self.stats.write();
208        *stats = QueryCacheStats::default();
209    }
210
211    /// Get current cache size
212    pub fn len(&self) -> usize {
213        self.cache.read().len()
214    }
215
216    /// Check if cache is empty
217    pub fn is_empty(&self) -> bool {
218        self.cache.read().is_empty()
219    }
220
221    /// Remove expired entries (maintenance operation)
222    pub fn cleanup_expired(&self) -> usize {
223        let mut cache = self.cache.write();
224        let mut expired_keys = Vec::new();
225
226        // Find expired entries
227        for (key, cached) in cache.iter() {
228            if cached.is_expired(self.config.ttl) {
229                expired_keys.push(*key);
230            }
231        }
232
233        // Remove expired entries
234        let count = expired_keys.len();
235        for key in expired_keys {
236            cache.pop(&key);
237        }
238
239        if self.config.enable_stats && count > 0 {
240            let mut stats = self.stats.write();
241            stats.expirations += count as u64;
242        }
243
244        count
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
251    use super::*;
252
253    #[test]
254    fn test_query_cache_basic() -> Result<()> {
255        let config = QueryCacheConfig::default();
256        let cache = QueryCache::new(config);
257
258        let query = Vector::new(vec![1.0, 2.0, 3.0]);
259        let results = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
260
261        // Cache miss on first access
262        assert!(cache.get(&query, 5).is_none());
263
264        // Put results in cache
265        cache.put(&query, 5, results.clone());
266
267        // Cache hit on second access
268        let cached = cache.get(&query, 5).expect("cache should have results");
269        assert_eq!(cached.len(), 2);
270        assert_eq!(cached[0].0, "uri1");
271        assert_eq!(cached[0].1, 0.9);
272        Ok(())
273    }
274
275    #[test]
276    fn test_query_cache_expiration() {
277        let config = QueryCacheConfig {
278            ttl: Duration::from_millis(100),
279            ..Default::default()
280        };
281        let cache = QueryCache::new(config);
282
283        let query = Vector::new(vec![1.0, 2.0, 3.0]);
284        let results = vec![("uri1".to_string(), 0.9)];
285
286        cache.put(&query, 5, results);
287
288        // Should hit immediately
289        assert!(cache.get(&query, 5).is_some());
290
291        // Wait for expiration
292        std::thread::sleep(Duration::from_millis(150));
293
294        // Should miss after expiration
295        assert!(cache.get(&query, 5).is_none());
296    }
297
298    #[test]
299    fn test_query_cache_stats() {
300        let config = QueryCacheConfig::default();
301        let cache = QueryCache::new(config);
302
303        let query = Vector::new(vec![1.0, 2.0, 3.0]);
304        let results = vec![("uri1".to_string(), 0.9)];
305
306        // Miss
307        cache.get(&query, 5);
308
309        // Put and hit
310        cache.put(&query, 5, results);
311        cache.get(&query, 5);
312        cache.get(&query, 5);
313
314        let stats = cache.get_stats();
315        assert_eq!(stats.total_queries, 3);
316        assert_eq!(stats.cache_hits, 2);
317        assert_eq!(stats.cache_misses, 1);
318        assert_eq!(stats.hit_rate(), 2.0 / 3.0);
319    }
320
321    #[test]
322    fn test_query_cache_cleanup() {
323        let config = QueryCacheConfig {
324            ttl: Duration::from_millis(100),
325            ..Default::default()
326        };
327        let cache = QueryCache::new(config);
328
329        // Add multiple entries
330        for i in 0..5 {
331            let query = Vector::new(vec![i as f32, 0.0, 0.0]);
332            let results = vec![(format!("uri{}", i), 0.9)];
333            cache.put(&query, 5, results);
334        }
335
336        assert_eq!(cache.len(), 5);
337
338        // Wait for expiration
339        std::thread::sleep(Duration::from_millis(150));
340
341        // Cleanup expired entries
342        let expired = cache.cleanup_expired();
343        assert_eq!(expired, 5);
344        assert_eq!(cache.len(), 0);
345    }
346
347    #[test]
348    fn test_query_cache_different_k() -> Result<()> {
349        let config = QueryCacheConfig::default();
350        let cache = QueryCache::new(config);
351
352        let query = Vector::new(vec![1.0, 2.0, 3.0]);
353        let results_k5 = vec![("uri1".to_string(), 0.9)];
354        let results_k10 = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
355
356        // Cache with k=5
357        cache.put(&query, 5, results_k5);
358
359        // Cache with k=10
360        cache.put(&query, 10, results_k10);
361
362        // Different k values should have different cache entries
363        let cached_k5 = cache.get(&query, 5).expect("cache k5 should have results");
364        let cached_k10 = cache
365            .get(&query, 10)
366            .expect("cache k10 should have results");
367
368        assert_eq!(cached_k5.len(), 1);
369        assert_eq!(cached_k10.len(), 2);
370        Ok(())
371    }
372}