1use crate::Vector;
7use blake3::Hasher;
8use lru::LruCache;
9use parking_lot::RwLock;
10use std::num::NonZeroUsize;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14#[derive(Clone, Debug)]
16struct CachedResult {
17 results: Vec<(String, f32)>,
19 cached_at: Instant,
21 hit_count: usize,
23}
24
25impl CachedResult {
26 fn new(results: Vec<(String, f32)>) -> Self {
27 Self {
28 results,
29 cached_at: Instant::now(),
30 hit_count: 0,
31 }
32 }
33
34 fn is_expired(&self, ttl: Duration) -> bool {
35 self.cached_at.elapsed() > ttl
36 }
37
38 fn record_hit(&mut self) {
39 self.hit_count += 1;
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct QueryCacheConfig {
46 pub max_entries: usize,
48 pub ttl: Duration,
50 pub enable_fuzzy_matching: bool,
52 pub fuzzy_threshold: f32,
54 pub enable_stats: bool,
56}
57
58impl Default for QueryCacheConfig {
59 fn default() -> Self {
60 Self {
61 max_entries: 10000,
62 ttl: Duration::from_secs(300), enable_fuzzy_matching: false,
64 fuzzy_threshold: 0.95,
65 enable_stats: true,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Default)]
72pub struct QueryCacheStats {
73 pub total_queries: u64,
74 pub cache_hits: u64,
75 pub cache_misses: u64,
76 pub evictions: u64,
77 pub expirations: u64,
78}
79
80impl QueryCacheStats {
81 pub fn hit_rate(&self) -> f64 {
82 if self.total_queries == 0 {
83 0.0
84 } else {
85 self.cache_hits as f64 / self.total_queries as f64
86 }
87 }
88}
89
90pub struct QueryCache {
92 cache: Arc<RwLock<LruCache<u64, CachedResult>>>,
94 config: QueryCacheConfig,
96 stats: Arc<RwLock<QueryCacheStats>>,
98}
99
100impl QueryCache {
101 pub fn new(config: QueryCacheConfig) -> Self {
103 let capacity =
104 NonZeroUsize::new(config.max_entries).expect("cache max_entries must be non-zero");
105 Self {
106 cache: Arc::new(RwLock::new(LruCache::new(capacity))),
107 config,
108 stats: Arc::new(RwLock::new(QueryCacheStats::default())),
109 }
110 }
111
112 fn generate_key(&self, query: &Vector, k: usize) -> u64 {
114 let mut hasher = Hasher::new();
115
116 let query_f32 = query.as_f32();
118 for &val in &query_f32 {
119 hasher.update(&val.to_le_bytes());
120 }
121
122 hasher.update(&k.to_le_bytes());
124
125 let hash = hasher.finalize();
127 let hash_bytes = hash.as_bytes();
128 u64::from_le_bytes([
129 hash_bytes[0],
130 hash_bytes[1],
131 hash_bytes[2],
132 hash_bytes[3],
133 hash_bytes[4],
134 hash_bytes[5],
135 hash_bytes[6],
136 hash_bytes[7],
137 ])
138 }
139
140 pub fn get(&self, query: &Vector, k: usize) -> Option<Vec<(String, f32)>> {
142 if self.config.enable_stats {
143 let mut stats = self.stats.write();
144 stats.total_queries += 1;
145 }
146
147 let key = self.generate_key(query, k);
148 let mut cache = self.cache.write();
149
150 if let Some(cached) = cache.get_mut(&key) {
151 if cached.is_expired(self.config.ttl) {
153 cache.pop(&key);
154 if self.config.enable_stats {
155 let mut stats = self.stats.write();
156 stats.expirations += 1;
157 stats.cache_misses += 1;
158 }
159 return None;
160 }
161
162 cached.record_hit();
164 if self.config.enable_stats {
165 let mut stats = self.stats.write();
166 stats.cache_hits += 1;
167 }
168 return Some(cached.results.clone());
169 }
170
171 if self.config.enable_stats {
172 let mut stats = self.stats.write();
173 stats.cache_misses += 1;
174 }
175 None
176 }
177
178 pub fn put(&self, query: &Vector, k: usize, results: Vec<(String, f32)>) {
180 let key = self.generate_key(query, k);
181 let mut cache = self.cache.write();
182
183 let cached_result = CachedResult::new(results);
184
185 if cache.len() >= self.config.max_entries && self.config.enable_stats {
187 let mut stats = self.stats.write();
188 stats.evictions += 1;
189 }
190
191 cache.put(key, cached_result);
192 }
193
194 pub fn clear(&self) {
196 let mut cache = self.cache.write();
197 cache.clear();
198 }
199
200 pub fn get_stats(&self) -> QueryCacheStats {
202 self.stats.read().clone()
203 }
204
205 pub fn reset_stats(&self) {
207 let mut stats = self.stats.write();
208 *stats = QueryCacheStats::default();
209 }
210
211 pub fn len(&self) -> usize {
213 self.cache.read().len()
214 }
215
216 pub fn is_empty(&self) -> bool {
218 self.cache.read().is_empty()
219 }
220
221 pub fn cleanup_expired(&self) -> usize {
223 let mut cache = self.cache.write();
224 let mut expired_keys = Vec::new();
225
226 for (key, cached) in cache.iter() {
228 if cached.is_expired(self.config.ttl) {
229 expired_keys.push(*key);
230 }
231 }
232
233 let count = expired_keys.len();
235 for key in expired_keys {
236 cache.pop(&key);
237 }
238
239 if self.config.enable_stats && count > 0 {
240 let mut stats = self.stats.write();
241 stats.expirations += count as u64;
242 }
243
244 count
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_query_cache_basic() {
254 let config = QueryCacheConfig::default();
255 let cache = QueryCache::new(config);
256
257 let query = Vector::new(vec![1.0, 2.0, 3.0]);
258 let results = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
259
260 assert!(cache.get(&query, 5).is_none());
262
263 cache.put(&query, 5, results.clone());
265
266 let cached = cache.get(&query, 5).unwrap();
268 assert_eq!(cached.len(), 2);
269 assert_eq!(cached[0].0, "uri1");
270 assert_eq!(cached[0].1, 0.9);
271 }
272
273 #[test]
274 fn test_query_cache_expiration() {
275 let config = QueryCacheConfig {
276 ttl: Duration::from_millis(100),
277 ..Default::default()
278 };
279 let cache = QueryCache::new(config);
280
281 let query = Vector::new(vec![1.0, 2.0, 3.0]);
282 let results = vec![("uri1".to_string(), 0.9)];
283
284 cache.put(&query, 5, results);
285
286 assert!(cache.get(&query, 5).is_some());
288
289 std::thread::sleep(Duration::from_millis(150));
291
292 assert!(cache.get(&query, 5).is_none());
294 }
295
296 #[test]
297 fn test_query_cache_stats() {
298 let config = QueryCacheConfig::default();
299 let cache = QueryCache::new(config);
300
301 let query = Vector::new(vec![1.0, 2.0, 3.0]);
302 let results = vec![("uri1".to_string(), 0.9)];
303
304 cache.get(&query, 5);
306
307 cache.put(&query, 5, results);
309 cache.get(&query, 5);
310 cache.get(&query, 5);
311
312 let stats = cache.get_stats();
313 assert_eq!(stats.total_queries, 3);
314 assert_eq!(stats.cache_hits, 2);
315 assert_eq!(stats.cache_misses, 1);
316 assert_eq!(stats.hit_rate(), 2.0 / 3.0);
317 }
318
319 #[test]
320 fn test_query_cache_cleanup() {
321 let config = QueryCacheConfig {
322 ttl: Duration::from_millis(100),
323 ..Default::default()
324 };
325 let cache = QueryCache::new(config);
326
327 for i in 0..5 {
329 let query = Vector::new(vec![i as f32, 0.0, 0.0]);
330 let results = vec![(format!("uri{}", i), 0.9)];
331 cache.put(&query, 5, results);
332 }
333
334 assert_eq!(cache.len(), 5);
335
336 std::thread::sleep(Duration::from_millis(150));
338
339 let expired = cache.cleanup_expired();
341 assert_eq!(expired, 5);
342 assert_eq!(cache.len(), 0);
343 }
344
345 #[test]
346 fn test_query_cache_different_k() {
347 let config = QueryCacheConfig::default();
348 let cache = QueryCache::new(config);
349
350 let query = Vector::new(vec![1.0, 2.0, 3.0]);
351 let results_k5 = vec![("uri1".to_string(), 0.9)];
352 let results_k10 = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
353
354 cache.put(&query, 5, results_k5);
356
357 cache.put(&query, 10, results_k10);
359
360 let cached_k5 = cache.get(&query, 5).unwrap();
362 let cached_k10 = cache.get(&query, 10).unwrap();
363
364 assert_eq!(cached_k5.len(), 1);
365 assert_eq!(cached_k10.len(), 2);
366 }
367}