1use crate::Vector;
7use blake3::Hasher;
8use lru::LruCache;
9use parking_lot::RwLock;
10use std::num::NonZeroUsize;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14#[derive(Clone, Debug)]
16struct CachedResult {
17 results: Vec<(String, f32)>,
19 cached_at: Instant,
21 hit_count: usize,
23}
24
25impl CachedResult {
26 fn new(results: Vec<(String, f32)>) -> Self {
27 Self {
28 results,
29 cached_at: Instant::now(),
30 hit_count: 0,
31 }
32 }
33
34 fn is_expired(&self, ttl: Duration) -> bool {
35 self.cached_at.elapsed() > ttl
36 }
37
38 fn record_hit(&mut self) {
39 self.hit_count += 1;
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct QueryCacheConfig {
46 pub max_entries: usize,
48 pub ttl: Duration,
50 pub enable_fuzzy_matching: bool,
52 pub fuzzy_threshold: f32,
54 pub enable_stats: bool,
56}
57
58impl Default for QueryCacheConfig {
59 fn default() -> Self {
60 Self {
61 max_entries: 10000,
62 ttl: Duration::from_secs(300), enable_fuzzy_matching: false,
64 fuzzy_threshold: 0.95,
65 enable_stats: true,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Default)]
72pub struct QueryCacheStats {
73 pub total_queries: u64,
74 pub cache_hits: u64,
75 pub cache_misses: u64,
76 pub evictions: u64,
77 pub expirations: u64,
78}
79
80impl QueryCacheStats {
81 pub fn hit_rate(&self) -> f64 {
82 if self.total_queries == 0 {
83 0.0
84 } else {
85 self.cache_hits as f64 / self.total_queries as f64
86 }
87 }
88}
89
90pub struct QueryCache {
92 cache: Arc<RwLock<LruCache<u64, CachedResult>>>,
94 config: QueryCacheConfig,
96 stats: Arc<RwLock<QueryCacheStats>>,
98}
99
100impl QueryCache {
101 pub fn new(config: QueryCacheConfig) -> Self {
103 let capacity =
104 NonZeroUsize::new(config.max_entries).expect("cache max_entries must be non-zero");
105 Self {
106 cache: Arc::new(RwLock::new(LruCache::new(capacity))),
107 config,
108 stats: Arc::new(RwLock::new(QueryCacheStats::default())),
109 }
110 }
111
112 fn generate_key(&self, query: &Vector, k: usize) -> u64 {
114 let mut hasher = Hasher::new();
115
116 let query_f32 = query.as_f32();
118 for &val in &query_f32 {
119 hasher.update(&val.to_le_bytes());
120 }
121
122 hasher.update(&k.to_le_bytes());
124
125 let hash = hasher.finalize();
127 let hash_bytes = hash.as_bytes();
128 u64::from_le_bytes([
129 hash_bytes[0],
130 hash_bytes[1],
131 hash_bytes[2],
132 hash_bytes[3],
133 hash_bytes[4],
134 hash_bytes[5],
135 hash_bytes[6],
136 hash_bytes[7],
137 ])
138 }
139
140 pub fn get(&self, query: &Vector, k: usize) -> Option<Vec<(String, f32)>> {
142 if self.config.enable_stats {
143 let mut stats = self.stats.write();
144 stats.total_queries += 1;
145 }
146
147 let key = self.generate_key(query, k);
148 let mut cache = self.cache.write();
149
150 if let Some(cached) = cache.get_mut(&key) {
151 if cached.is_expired(self.config.ttl) {
153 cache.pop(&key);
154 if self.config.enable_stats {
155 let mut stats = self.stats.write();
156 stats.expirations += 1;
157 stats.cache_misses += 1;
158 }
159 return None;
160 }
161
162 cached.record_hit();
164 if self.config.enable_stats {
165 let mut stats = self.stats.write();
166 stats.cache_hits += 1;
167 }
168 return Some(cached.results.clone());
169 }
170
171 if self.config.enable_stats {
172 let mut stats = self.stats.write();
173 stats.cache_misses += 1;
174 }
175 None
176 }
177
178 pub fn put(&self, query: &Vector, k: usize, results: Vec<(String, f32)>) {
180 let key = self.generate_key(query, k);
181 let mut cache = self.cache.write();
182
183 let cached_result = CachedResult::new(results);
184
185 if cache.len() >= self.config.max_entries && self.config.enable_stats {
187 let mut stats = self.stats.write();
188 stats.evictions += 1;
189 }
190
191 cache.put(key, cached_result);
192 }
193
194 pub fn clear(&self) {
196 let mut cache = self.cache.write();
197 cache.clear();
198 }
199
200 pub fn get_stats(&self) -> QueryCacheStats {
202 self.stats.read().clone()
203 }
204
205 pub fn reset_stats(&self) {
207 let mut stats = self.stats.write();
208 *stats = QueryCacheStats::default();
209 }
210
211 pub fn len(&self) -> usize {
213 self.cache.read().len()
214 }
215
216 pub fn is_empty(&self) -> bool {
218 self.cache.read().is_empty()
219 }
220
221 pub fn cleanup_expired(&self) -> usize {
223 let mut cache = self.cache.write();
224 let mut expired_keys = Vec::new();
225
226 for (key, cached) in cache.iter() {
228 if cached.is_expired(self.config.ttl) {
229 expired_keys.push(*key);
230 }
231 }
232
233 let count = expired_keys.len();
235 for key in expired_keys {
236 cache.pop(&key);
237 }
238
239 if self.config.enable_stats && count > 0 {
240 let mut stats = self.stats.write();
241 stats.expirations += count as u64;
242 }
243
244 count
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
251 use super::*;
252
253 #[test]
254 fn test_query_cache_basic() -> Result<()> {
255 let config = QueryCacheConfig::default();
256 let cache = QueryCache::new(config);
257
258 let query = Vector::new(vec![1.0, 2.0, 3.0]);
259 let results = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
260
261 assert!(cache.get(&query, 5).is_none());
263
264 cache.put(&query, 5, results.clone());
266
267 let cached = cache.get(&query, 5).expect("cache should have results");
269 assert_eq!(cached.len(), 2);
270 assert_eq!(cached[0].0, "uri1");
271 assert_eq!(cached[0].1, 0.9);
272 Ok(())
273 }
274
275 #[test]
276 fn test_query_cache_expiration() {
277 let config = QueryCacheConfig {
278 ttl: Duration::from_millis(100),
279 ..Default::default()
280 };
281 let cache = QueryCache::new(config);
282
283 let query = Vector::new(vec![1.0, 2.0, 3.0]);
284 let results = vec![("uri1".to_string(), 0.9)];
285
286 cache.put(&query, 5, results);
287
288 assert!(cache.get(&query, 5).is_some());
290
291 std::thread::sleep(Duration::from_millis(150));
293
294 assert!(cache.get(&query, 5).is_none());
296 }
297
298 #[test]
299 fn test_query_cache_stats() {
300 let config = QueryCacheConfig::default();
301 let cache = QueryCache::new(config);
302
303 let query = Vector::new(vec![1.0, 2.0, 3.0]);
304 let results = vec![("uri1".to_string(), 0.9)];
305
306 cache.get(&query, 5);
308
309 cache.put(&query, 5, results);
311 cache.get(&query, 5);
312 cache.get(&query, 5);
313
314 let stats = cache.get_stats();
315 assert_eq!(stats.total_queries, 3);
316 assert_eq!(stats.cache_hits, 2);
317 assert_eq!(stats.cache_misses, 1);
318 assert_eq!(stats.hit_rate(), 2.0 / 3.0);
319 }
320
321 #[test]
322 fn test_query_cache_cleanup() {
323 let config = QueryCacheConfig {
324 ttl: Duration::from_millis(100),
325 ..Default::default()
326 };
327 let cache = QueryCache::new(config);
328
329 for i in 0..5 {
331 let query = Vector::new(vec![i as f32, 0.0, 0.0]);
332 let results = vec![(format!("uri{}", i), 0.9)];
333 cache.put(&query, 5, results);
334 }
335
336 assert_eq!(cache.len(), 5);
337
338 std::thread::sleep(Duration::from_millis(150));
340
341 let expired = cache.cleanup_expired();
343 assert_eq!(expired, 5);
344 assert_eq!(cache.len(), 0);
345 }
346
347 #[test]
348 fn test_query_cache_different_k() -> Result<()> {
349 let config = QueryCacheConfig::default();
350 let cache = QueryCache::new(config);
351
352 let query = Vector::new(vec![1.0, 2.0, 3.0]);
353 let results_k5 = vec![("uri1".to_string(), 0.9)];
354 let results_k10 = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
355
356 cache.put(&query, 5, results_k5);
358
359 cache.put(&query, 10, results_k10);
361
362 let cached_k5 = cache.get(&query, 5).expect("cache k5 should have results");
364 let cached_k10 = cache
365 .get(&query, 10)
366 .expect("cache k10 should have results");
367
368 assert_eq!(cached_k5.len(), 1);
369 assert_eq!(cached_k10.len(), 2);
370 Ok(())
371 }
372}