1use crate::types::{DistanceMetric, SearchResult};
47use serde::{Deserialize, Serialize};
48use std::collections::{HashMap, VecDeque};
49use std::hash::{Hash, Hasher};
50use std::sync::{Arc, RwLock};
51use std::time::{Duration, Instant};
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct CacheConfig {
56 pub max_entries: usize,
58 pub ttl: Duration,
60 pub similarity_threshold: f32,
62 pub enable_approximate_matching: bool,
64}
65
66impl Default for CacheConfig {
67 fn default() -> Self {
68 Self {
69 max_entries: 1000,
70 ttl: Duration::from_secs(300), similarity_threshold: 0.99, enable_approximate_matching: false,
73 }
74 }
75}
76
77impl CacheConfig {
78 pub fn high_hit_rate() -> Self {
80 Self {
81 max_entries: 10_000,
82 ttl: Duration::from_secs(3600), similarity_threshold: 0.95,
84 enable_approximate_matching: true,
85 }
86 }
87
88 pub fn low_memory() -> Self {
90 Self {
91 max_entries: 100,
92 ttl: Duration::from_secs(60), similarity_threshold: 0.99,
94 enable_approximate_matching: false,
95 }
96 }
97
98 pub fn exact_match_only() -> Self {
100 Self {
101 max_entries: 1000,
102 ttl: Duration::from_secs(300),
103 similarity_threshold: 1.0,
104 enable_approximate_matching: false,
105 }
106 }
107}
108
109#[derive(Debug, Clone, PartialEq)]
111struct CacheKey {
112 query_hash: u64,
113 metric: DistanceMetric,
114 k: usize,
115}
116
117impl Hash for CacheKey {
118 fn hash<H: Hasher>(&self, state: &mut H) {
119 self.query_hash.hash(state);
120 std::mem::discriminant(&self.metric).hash(state);
122 self.k.hash(state);
123 }
124}
125
126impl Eq for CacheKey {}
127
128#[derive(Debug, Clone)]
130struct CacheEntry {
131 results: Vec<SearchResult>,
132 inserted_at: Instant,
133 last_accessed: Instant,
134 access_count: u64,
135 query: Vec<f32>, }
137
138impl CacheEntry {
139 fn new(query: Vec<f32>, results: Vec<SearchResult>) -> Self {
140 let now = Instant::now();
141 Self {
142 results,
143 inserted_at: now,
144 last_accessed: now,
145 access_count: 0,
146 query,
147 }
148 }
149
150 fn is_expired(&self, ttl: Duration) -> bool {
151 self.inserted_at.elapsed() > ttl
152 }
153
154 fn touch(&mut self) {
155 self.last_accessed = Instant::now();
156 self.access_count += 1;
157 }
158}
159
160pub struct QueryCache {
162 config: CacheConfig,
163 cache: Arc<RwLock<HashMap<CacheKey, CacheEntry>>>,
164 access_order: Arc<RwLock<VecDeque<CacheKey>>>,
165 stats: Arc<RwLock<CacheStats>>,
166}
167
168impl QueryCache {
169 pub fn new(config: CacheConfig) -> Self {
171 Self {
172 config,
173 cache: Arc::new(RwLock::new(HashMap::new())),
174 access_order: Arc::new(RwLock::new(VecDeque::new())),
175 stats: Arc::new(RwLock::new(CacheStats::default())),
176 }
177 }
178
179 pub fn get(
183 &self,
184 query: &[f32],
185 metric: DistanceMetric,
186 k: usize,
187 ) -> Option<Vec<SearchResult>> {
188 let key = self.make_key(query, metric, k);
189
190 if let Some(entry) = self.get_exact(&key) {
192 return Some(entry);
193 }
194
195 if self.config.enable_approximate_matching {
197 if let Some(entry) = self.get_approximate(query, metric, k) {
198 return Some(entry);
199 }
200 }
201
202 if let Ok(mut stats) = self.stats.write() {
204 stats.misses += 1;
205 }
206
207 None
208 }
209
210 fn get_exact(&self, key: &CacheKey) -> Option<Vec<SearchResult>> {
212 let mut cache = self.cache.write().ok()?;
213 let mut access_order = self.access_order.write().ok()?;
214
215 if let Some(entry) = cache.get_mut(key) {
216 if entry.is_expired(self.config.ttl) {
218 cache.remove(key);
219 access_order.retain(|k| k != key);
220 if let Ok(mut stats) = self.stats.write() {
221 stats.expirations += 1;
222 }
223 return None;
224 }
225
226 entry.touch();
228
229 access_order.retain(|k| k != key);
231 access_order.push_back(key.clone());
232
233 if let Ok(mut stats) = self.stats.write() {
235 stats.hits += 1;
236 }
237
238 return Some(entry.results.clone());
239 }
240
241 None
242 }
243
244 fn get_approximate(
246 &self,
247 query: &[f32],
248 metric: DistanceMetric,
249 k: usize,
250 ) -> Option<Vec<SearchResult>> {
251 let best_key = {
252 let cache = self.cache.read().ok()?;
253
254 let mut best_match: Option<(CacheKey, f32)> = None;
256
257 for (cache_key, entry) in cache.iter() {
258 if cache_key.metric != metric || cache_key.k != k {
260 continue;
261 }
262
263 if entry.is_expired(self.config.ttl) {
265 continue;
266 }
267
268 let similarity = cosine_similarity(&entry.query, query);
270
271 if similarity >= self.config.similarity_threshold {
272 if let Some((_, best_sim)) = &best_match {
273 if similarity > *best_sim {
274 best_match = Some((cache_key.clone(), similarity));
275 }
276 } else {
277 best_match = Some((cache_key.clone(), similarity));
278 }
279 }
280 }
281
282 best_match.map(|(key, _)| key)
283 }; if let Some(key) = best_key {
286 return self.get_exact(&key);
287 }
288
289 None
290 }
291
292 pub fn put(
294 &mut self,
295 query: &[f32],
296 metric: DistanceMetric,
297 k: usize,
298 results: Vec<SearchResult>,
299 ) {
300 let key = self.make_key(query, metric, k);
301 let entry = CacheEntry::new(query.to_vec(), results);
302
303 let mut cache = match self.cache.write() {
304 Ok(c) => c,
305 Err(_) => return,
306 };
307
308 let mut access_order = match self.access_order.write() {
309 Ok(a) => a,
310 Err(_) => return,
311 };
312
313 if cache.len() >= self.config.max_entries && !cache.contains_key(&key) {
315 if let Some(oldest_key) = access_order.pop_front() {
316 cache.remove(&oldest_key);
317 if let Ok(mut stats) = self.stats.write() {
318 stats.evictions += 1;
319 }
320 }
321 }
322
323 cache.insert(key.clone(), entry);
324 access_order.push_back(key);
325
326 if let Ok(mut stats) = self.stats.write() {
327 stats.inserts += 1;
328 }
329 }
330
331 pub fn clear(&mut self) {
333 if let Ok(mut cache) = self.cache.write() {
334 cache.clear();
335 }
336 if let Ok(mut access_order) = self.access_order.write() {
337 access_order.clear();
338 }
339 if let Ok(mut stats) = self.stats.write() {
340 *stats = CacheStats::default();
341 }
342 }
343
344 pub fn evict_expired(&mut self) -> usize {
346 let mut cache = match self.cache.write() {
347 Ok(c) => c,
348 Err(_) => return 0,
349 };
350
351 let mut access_order = match self.access_order.write() {
352 Ok(a) => a,
353 Err(_) => return 0,
354 };
355
356 let mut expired_keys = Vec::new();
357
358 for (key, entry) in cache.iter() {
359 if entry.is_expired(self.config.ttl) {
360 expired_keys.push(key.clone());
361 }
362 }
363
364 let count = expired_keys.len();
365
366 for key in expired_keys {
367 cache.remove(&key);
368 access_order.retain(|k| k != &key);
369 }
370
371 if let Ok(mut stats) = self.stats.write() {
372 stats.expirations += count as u64;
373 }
374
375 count
376 }
377
378 pub fn stats(&self) -> CacheStats {
380 self.stats.read().unwrap().clone()
381 }
382
383 pub fn len(&self) -> usize {
385 self.cache.read().unwrap().len()
386 }
387
388 pub fn is_empty(&self) -> bool {
390 self.len() == 0
391 }
392
393 fn make_key(&self, query: &[f32], metric: DistanceMetric, k: usize) -> CacheKey {
395 CacheKey {
396 query_hash: hash_f32_slice(query),
397 metric,
398 k,
399 }
400 }
401}
402
403#[derive(Debug, Clone, Default, Serialize, Deserialize)]
405pub struct CacheStats {
406 pub hits: u64,
408 pub misses: u64,
410 pub inserts: u64,
412 pub evictions: u64,
414 pub expirations: u64,
416}
417
418impl CacheStats {
419 pub fn hit_rate(&self) -> f64 {
421 let total = self.hits + self.misses;
422 if total == 0 {
423 0.0
424 } else {
425 (self.hits as f64 / total as f64) * 100.0
426 }
427 }
428
429 pub fn miss_rate(&self) -> f64 {
431 100.0 - self.hit_rate()
432 }
433}
434
435fn hash_f32_slice(slice: &[f32]) -> u64 {
437 use std::collections::hash_map::DefaultHasher;
438
439 let mut hasher = DefaultHasher::new();
440
441 slice.len().hash(&mut hasher);
443
444 for &val in slice {
446 val.to_bits().hash(&mut hasher);
447 }
448
449 hasher.finish()
450}
451
452fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
454 if a.len() != b.len() {
455 return 0.0;
456 }
457
458 let mut dot = 0.0;
459 let mut norm_a = 0.0;
460 let mut norm_b = 0.0;
461
462 for i in 0..a.len() {
463 dot += a[i] * b[i];
464 norm_a += a[i] * a[i];
465 norm_b += b[i] * b[i];
466 }
467
468 if norm_a == 0.0 || norm_b == 0.0 {
469 return 0.0;
470 }
471
472 dot / (norm_a.sqrt() * norm_b.sqrt())
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn test_cache_config_default() {
481 let config = CacheConfig::default();
482 assert_eq!(config.max_entries, 1000);
483 assert_eq!(config.ttl, Duration::from_secs(300));
484 assert!(!config.enable_approximate_matching);
485 }
486
487 #[test]
488 fn test_cache_config_presets() {
489 let high_hit = CacheConfig::high_hit_rate();
490 assert_eq!(high_hit.max_entries, 10_000);
491 assert!(high_hit.enable_approximate_matching);
492
493 let low_mem = CacheConfig::low_memory();
494 assert_eq!(low_mem.max_entries, 100);
495 assert_eq!(low_mem.ttl, Duration::from_secs(60));
496
497 let exact = CacheConfig::exact_match_only();
498 assert_eq!(exact.similarity_threshold, 1.0);
499 assert!(!exact.enable_approximate_matching);
500 }
501
502 #[test]
503 fn test_query_cache_basic() {
504 let config = CacheConfig::default();
505 let mut cache = QueryCache::new(config);
506
507 let query = vec![1.0, 2.0, 3.0];
508 let results = vec![SearchResult {
509 entity_id: "doc1".to_string(),
510 score: 0.95,
511 distance: 0.05,
512 rank: 1,
513 }];
514
515 assert!(cache.is_empty());
517
518 cache.put(&query, DistanceMetric::Cosine, 10, results.clone());
520 assert_eq!(cache.len(), 1);
521
522 let cached = cache.get(&query, DistanceMetric::Cosine, 10);
523 assert!(cached.is_some());
524 assert_eq!(cached.unwrap().len(), 1);
525 }
526
527 #[test]
528 fn test_query_cache_miss() {
529 let config = CacheConfig::default();
530 let cache = QueryCache::new(config);
531
532 let query = vec![1.0, 2.0, 3.0];
533 let cached = cache.get(&query, DistanceMetric::Cosine, 10);
534 assert!(cached.is_none());
535
536 let stats = cache.stats();
537 assert_eq!(stats.misses, 1);
538 assert_eq!(stats.hits, 0);
539 }
540
541 #[test]
542 fn test_query_cache_different_k() {
543 let config = CacheConfig::default();
544 let mut cache = QueryCache::new(config);
545
546 let query = vec![1.0, 2.0, 3.0];
547 let results = vec![SearchResult {
548 entity_id: "doc1".to_string(),
549 score: 0.95,
550 distance: 0.05,
551 rank: 1,
552 }];
553
554 cache.put(&query, DistanceMetric::Cosine, 10, results.clone());
555
556 let cached = cache.get(&query, DistanceMetric::Cosine, 20);
558 assert!(cached.is_none());
559 }
560
561 #[test]
562 fn test_query_cache_different_metric() {
563 let config = CacheConfig::default();
564 let mut cache = QueryCache::new(config);
565
566 let query = vec![1.0, 2.0, 3.0];
567 let results = vec![SearchResult {
568 entity_id: "doc1".to_string(),
569 score: 0.95,
570 distance: 0.05,
571 rank: 1,
572 }];
573
574 cache.put(&query, DistanceMetric::Cosine, 10, results.clone());
575
576 let cached = cache.get(&query, DistanceMetric::Euclidean, 10);
578 assert!(cached.is_none());
579 }
580
581 #[test]
582 fn test_query_cache_lru_eviction() {
583 let config = CacheConfig {
584 max_entries: 2,
585 ..Default::default()
586 };
587 let mut cache = QueryCache::new(config);
588
589 let results = vec![SearchResult {
590 entity_id: "doc1".to_string(),
591 score: 0.95,
592 distance: 0.05,
593 rank: 1,
594 }];
595
596 cache.put(&[1.0], DistanceMetric::Cosine, 10, results.clone());
598 cache.put(&[2.0], DistanceMetric::Cosine, 10, results.clone());
599 cache.put(&[3.0], DistanceMetric::Cosine, 10, results.clone());
600
601 assert_eq!(cache.len(), 2);
602
603 let cached = cache.get(&[1.0], DistanceMetric::Cosine, 10);
605 assert!(cached.is_none());
606
607 assert!(cache.get(&[2.0], DistanceMetric::Cosine, 10).is_some());
609 assert!(cache.get(&[3.0], DistanceMetric::Cosine, 10).is_some());
610 }
611
612 #[test]
613 fn test_query_cache_clear() {
614 let config = CacheConfig::default();
615 let mut cache = QueryCache::new(config);
616
617 let query = vec![1.0, 2.0, 3.0];
618 let results = vec![SearchResult {
619 entity_id: "doc1".to_string(),
620 score: 0.95,
621 distance: 0.05,
622 rank: 1,
623 }];
624
625 cache.put(&query, DistanceMetric::Cosine, 10, results);
626 assert_eq!(cache.len(), 1);
627
628 cache.clear();
629 assert!(cache.is_empty());
630 }
631
632 #[test]
633 fn test_query_cache_stats() {
634 let config = CacheConfig::default();
635 let mut cache = QueryCache::new(config);
636
637 let query = vec![1.0, 2.0, 3.0];
638 let results = vec![SearchResult {
639 entity_id: "doc1".to_string(),
640 score: 0.95,
641 distance: 0.05,
642 rank: 1,
643 }];
644
645 cache.put(&query, DistanceMetric::Cosine, 10, results);
646 let stats = cache.stats();
647 assert_eq!(stats.inserts, 1);
648
649 cache.get(&query, DistanceMetric::Cosine, 10);
651 let stats = cache.stats();
652 assert_eq!(stats.hits, 1);
653
654 cache.get(&[9.0], DistanceMetric::Cosine, 10);
656 let stats = cache.stats();
657 assert_eq!(stats.misses, 1);
658
659 assert_eq!(stats.hit_rate(), 50.0);
660 assert_eq!(stats.miss_rate(), 50.0);
661 }
662
663 #[test]
664 fn test_hash_f32_slice() {
665 let a = vec![1.0, 2.0, 3.0];
666 let b = vec![1.0, 2.0, 3.0];
667 let c = vec![1.0, 2.0, 3.1];
668
669 assert_eq!(hash_f32_slice(&a), hash_f32_slice(&b));
670 assert_ne!(hash_f32_slice(&a), hash_f32_slice(&c));
671 }
672
673 #[test]
674 fn test_cosine_similarity() {
675 let a = vec![1.0, 0.0, 0.0];
676 let b = vec![1.0, 0.0, 0.0];
677 let c = vec![0.0, 1.0, 0.0];
678
679 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.01);
680 assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.01);
681 }
682
683 #[test]
684 fn test_cache_stats_hit_rate() {
685 let stats = CacheStats {
686 hits: 75,
687 misses: 25,
688 inserts: 100,
689 evictions: 0,
690 expirations: 0,
691 };
692
693 assert_eq!(stats.hit_rate(), 75.0);
694 assert_eq!(stats.miss_rate(), 25.0);
695 }
696
697 #[test]
698 fn test_cache_entry_expiration() {
699 let query = vec![1.0, 2.0, 3.0];
700 let results = vec![SearchResult {
701 entity_id: "doc1".to_string(),
702 score: 0.95,
703 distance: 0.05,
704 rank: 1,
705 }];
706
707 let entry = CacheEntry::new(query, results);
708
709 assert!(!entry.is_expired(Duration::from_secs(1)));
711
712 std::thread::sleep(Duration::from_millis(10));
714 assert!(entry.is_expired(Duration::from_millis(1)));
715 }
716
717 #[test]
718 fn test_approximate_matching_disabled() {
719 let config = CacheConfig::exact_match_only();
720 let mut cache = QueryCache::new(config);
721
722 let query1 = vec![1.0, 0.0, 0.0];
723 let query2 = vec![0.99, 0.01, 0.0]; let results = vec![SearchResult {
726 entity_id: "doc1".to_string(),
727 score: 0.95,
728 distance: 0.05,
729 rank: 1,
730 }];
731
732 cache.put(&query1, DistanceMetric::Cosine, 10, results);
733
734 let cached = cache.get(&query2, DistanceMetric::Cosine, 10);
736 assert!(cached.is_none());
737 }
738
739 #[test]
740 fn test_approximate_matching_enabled() {
741 let config = CacheConfig {
742 enable_approximate_matching: true,
743 similarity_threshold: 0.95,
744 ..Default::default()
745 };
746 let mut cache = QueryCache::new(config);
747
748 let query1 = vec![1.0, 0.0, 0.0];
749 let query2 = vec![0.99, 0.14, 0.0]; let results = vec![SearchResult {
752 entity_id: "doc1".to_string(),
753 score: 0.95,
754 distance: 0.05,
755 rank: 1,
756 }];
757
758 cache.put(&query1, DistanceMetric::Cosine, 10, results);
759
760 let cached = cache.get(&query2, DistanceMetric::Cosine, 10);
762 assert!(cached.is_some());
763 }
764}