1use crate::Vector;
7use blake3::Hasher;
8use lru::LruCache;
9use parking_lot::RwLock;
10use std::num::NonZeroUsize;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14#[derive(Clone, Debug)]
16struct CachedResult {
17 results: Vec<(String, f32)>,
19 cached_at: Instant,
21 hit_count: usize,
23}
24
25impl CachedResult {
26 fn new(results: Vec<(String, f32)>) -> Self {
27 Self {
28 results,
29 cached_at: Instant::now(),
30 hit_count: 0,
31 }
32 }
33
34 fn is_expired(&self, ttl: Duration) -> bool {
35 self.cached_at.elapsed() > ttl
36 }
37
38 fn record_hit(&mut self) {
39 self.hit_count += 1;
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct QueryCacheConfig {
46 pub max_entries: usize,
48 pub ttl: Duration,
50 pub enable_fuzzy_matching: bool,
52 pub fuzzy_threshold: f32,
54 pub enable_stats: bool,
56}
57
58impl Default for QueryCacheConfig {
59 fn default() -> Self {
60 Self {
61 max_entries: 10000,
62 ttl: Duration::from_secs(300), enable_fuzzy_matching: false,
64 fuzzy_threshold: 0.95,
65 enable_stats: true,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Default)]
72pub struct QueryCacheStats {
73 pub total_queries: u64,
74 pub cache_hits: u64,
75 pub cache_misses: u64,
76 pub evictions: u64,
77 pub expirations: u64,
78}
79
80impl QueryCacheStats {
81 pub fn hit_rate(&self) -> f64 {
82 if self.total_queries == 0 {
83 0.0
84 } else {
85 self.cache_hits as f64 / self.total_queries as f64
86 }
87 }
88}
89
90pub struct QueryCache {
92 cache: Arc<RwLock<LruCache<u64, CachedResult>>>,
94 config: QueryCacheConfig,
96 stats: Arc<RwLock<QueryCacheStats>>,
98}
99
100impl QueryCache {
101 pub fn new(config: QueryCacheConfig) -> Self {
103 let capacity = NonZeroUsize::new(config.max_entries).unwrap();
104 Self {
105 cache: Arc::new(RwLock::new(LruCache::new(capacity))),
106 config,
107 stats: Arc::new(RwLock::new(QueryCacheStats::default())),
108 }
109 }
110
111 fn generate_key(&self, query: &Vector, k: usize) -> u64 {
113 let mut hasher = Hasher::new();
114
115 let query_f32 = query.as_f32();
117 for &val in &query_f32 {
118 hasher.update(&val.to_le_bytes());
119 }
120
121 hasher.update(&k.to_le_bytes());
123
124 let hash = hasher.finalize();
126 let hash_bytes = hash.as_bytes();
127 u64::from_le_bytes([
128 hash_bytes[0],
129 hash_bytes[1],
130 hash_bytes[2],
131 hash_bytes[3],
132 hash_bytes[4],
133 hash_bytes[5],
134 hash_bytes[6],
135 hash_bytes[7],
136 ])
137 }
138
139 pub fn get(&self, query: &Vector, k: usize) -> Option<Vec<(String, f32)>> {
141 if self.config.enable_stats {
142 let mut stats = self.stats.write();
143 stats.total_queries += 1;
144 }
145
146 let key = self.generate_key(query, k);
147 let mut cache = self.cache.write();
148
149 if let Some(cached) = cache.get_mut(&key) {
150 if cached.is_expired(self.config.ttl) {
152 cache.pop(&key);
153 if self.config.enable_stats {
154 let mut stats = self.stats.write();
155 stats.expirations += 1;
156 stats.cache_misses += 1;
157 }
158 return None;
159 }
160
161 cached.record_hit();
163 if self.config.enable_stats {
164 let mut stats = self.stats.write();
165 stats.cache_hits += 1;
166 }
167 return Some(cached.results.clone());
168 }
169
170 if self.config.enable_stats {
171 let mut stats = self.stats.write();
172 stats.cache_misses += 1;
173 }
174 None
175 }
176
177 pub fn put(&self, query: &Vector, k: usize, results: Vec<(String, f32)>) {
179 let key = self.generate_key(query, k);
180 let mut cache = self.cache.write();
181
182 let cached_result = CachedResult::new(results);
183
184 if cache.len() >= self.config.max_entries && self.config.enable_stats {
186 let mut stats = self.stats.write();
187 stats.evictions += 1;
188 }
189
190 cache.put(key, cached_result);
191 }
192
193 pub fn clear(&self) {
195 let mut cache = self.cache.write();
196 cache.clear();
197 }
198
199 pub fn get_stats(&self) -> QueryCacheStats {
201 self.stats.read().clone()
202 }
203
204 pub fn reset_stats(&self) {
206 let mut stats = self.stats.write();
207 *stats = QueryCacheStats::default();
208 }
209
210 pub fn len(&self) -> usize {
212 self.cache.read().len()
213 }
214
215 pub fn is_empty(&self) -> bool {
217 self.cache.read().is_empty()
218 }
219
220 pub fn cleanup_expired(&self) -> usize {
222 let mut cache = self.cache.write();
223 let mut expired_keys = Vec::new();
224
225 for (key, cached) in cache.iter() {
227 if cached.is_expired(self.config.ttl) {
228 expired_keys.push(*key);
229 }
230 }
231
232 let count = expired_keys.len();
234 for key in expired_keys {
235 cache.pop(&key);
236 }
237
238 if self.config.enable_stats && count > 0 {
239 let mut stats = self.stats.write();
240 stats.expirations += count as u64;
241 }
242
243 count
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_query_cache_basic() {
253 let config = QueryCacheConfig::default();
254 let cache = QueryCache::new(config);
255
256 let query = Vector::new(vec![1.0, 2.0, 3.0]);
257 let results = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
258
259 assert!(cache.get(&query, 5).is_none());
261
262 cache.put(&query, 5, results.clone());
264
265 let cached = cache.get(&query, 5).unwrap();
267 assert_eq!(cached.len(), 2);
268 assert_eq!(cached[0].0, "uri1");
269 assert_eq!(cached[0].1, 0.9);
270 }
271
272 #[test]
273 fn test_query_cache_expiration() {
274 let config = QueryCacheConfig {
275 ttl: Duration::from_millis(100),
276 ..Default::default()
277 };
278 let cache = QueryCache::new(config);
279
280 let query = Vector::new(vec![1.0, 2.0, 3.0]);
281 let results = vec![("uri1".to_string(), 0.9)];
282
283 cache.put(&query, 5, results);
284
285 assert!(cache.get(&query, 5).is_some());
287
288 std::thread::sleep(Duration::from_millis(150));
290
291 assert!(cache.get(&query, 5).is_none());
293 }
294
295 #[test]
296 fn test_query_cache_stats() {
297 let config = QueryCacheConfig::default();
298 let cache = QueryCache::new(config);
299
300 let query = Vector::new(vec![1.0, 2.0, 3.0]);
301 let results = vec![("uri1".to_string(), 0.9)];
302
303 cache.get(&query, 5);
305
306 cache.put(&query, 5, results);
308 cache.get(&query, 5);
309 cache.get(&query, 5);
310
311 let stats = cache.get_stats();
312 assert_eq!(stats.total_queries, 3);
313 assert_eq!(stats.cache_hits, 2);
314 assert_eq!(stats.cache_misses, 1);
315 assert_eq!(stats.hit_rate(), 2.0 / 3.0);
316 }
317
318 #[test]
319 fn test_query_cache_cleanup() {
320 let config = QueryCacheConfig {
321 ttl: Duration::from_millis(100),
322 ..Default::default()
323 };
324 let cache = QueryCache::new(config);
325
326 for i in 0..5 {
328 let query = Vector::new(vec![i as f32, 0.0, 0.0]);
329 let results = vec![(format!("uri{}", i), 0.9)];
330 cache.put(&query, 5, results);
331 }
332
333 assert_eq!(cache.len(), 5);
334
335 std::thread::sleep(Duration::from_millis(150));
337
338 let expired = cache.cleanup_expired();
340 assert_eq!(expired, 5);
341 assert_eq!(cache.len(), 0);
342 }
343
344 #[test]
345 fn test_query_cache_different_k() {
346 let config = QueryCacheConfig::default();
347 let cache = QueryCache::new(config);
348
349 let query = Vector::new(vec![1.0, 2.0, 3.0]);
350 let results_k5 = vec![("uri1".to_string(), 0.9)];
351 let results_k10 = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
352
353 cache.put(&query, 5, results_k5);
355
356 cache.put(&query, 10, results_k10);
358
359 let cached_k5 = cache.get(&query, 5).unwrap();
361 let cached_k10 = cache.get(&query, 10).unwrap();
362
363 assert_eq!(cached_k5.len(), 1);
364 assert_eq!(cached_k10.len(), 2);
365 }
366}