1use crate::retrieval::engine::QueryResponse;
6use ahash::AHashMap;
7use parking_lot::RwLock;
8use std::collections::VecDeque;
9use std::hash::{Hash, Hasher};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone)]
15pub struct CacheConfig {
16 pub max_entries: usize,
18 pub ttl: Duration,
20 pub cache_filtered: bool,
22}
23
24impl Default for CacheConfig {
25 fn default() -> Self {
26 Self {
27 max_entries: 10_000,
28 ttl: Duration::from_secs(300), cache_filtered: true,
30 }
31 }
32}
33
34impl CacheConfig {
35 #[must_use]
37 pub fn new() -> Self {
38 Self::default()
39 }
40
41 #[must_use]
43 pub const fn with_max_entries(mut self, max: usize) -> Self {
44 self.max_entries = max;
45 self
46 }
47
48 #[must_use]
50 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
51 self.ttl = ttl;
52 self
53 }
54
55 #[must_use]
57 pub const fn with_cache_filtered(mut self, cache: bool) -> Self {
58 self.cache_filtered = cache;
59 self
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Hash)]
65pub struct CacheKey {
66 embedding_hash: u64,
68 k: usize,
70 filter_hash: u64,
72 indexes_hash: u64,
74}
75
76impl CacheKey {
77 #[must_use]
79 pub fn new(
80 embedding: &[f32],
81 k: usize,
82 filter_hash: Option<u64>,
83 indexes: Option<&[String]>,
84 ) -> Self {
85 Self {
86 embedding_hash: Self::hash_embedding(embedding),
87 k,
88 filter_hash: filter_hash.unwrap_or(0),
89 indexes_hash: indexes.map(Self::hash_indexes).unwrap_or(0),
90 }
91 }
92
93 fn hash_embedding(embedding: &[f32]) -> u64 {
95 let mut hasher = xxhash_rust::xxh64::Xxh64::new(0);
96
97 for &value in embedding {
98 hasher.write(&value.to_le_bytes());
99 }
100
101 hasher.finish()
102 }
103
104 fn hash_indexes(indexes: &[String]) -> u64 {
106 let mut hasher = xxhash_rust::xxh64::Xxh64::new(0);
107
108 for name in indexes {
109 hasher.write(name.as_bytes());
110 }
111
112 hasher.finish()
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct CacheEntry {
119 pub response: QueryResponse,
121 pub created_at: Instant,
123 pub access_count: u64,
125}
126
127impl CacheEntry {
128 #[must_use]
130 pub fn is_expired(&self, ttl: Duration) -> bool {
131 self.created_at.elapsed() > ttl
132 }
133}
134
135#[derive(Debug, Clone, Default)]
137pub struct CacheStats {
138 pub hits: u64,
140 pub misses: u64,
142 pub entries: usize,
144 pub evictions: u64,
146 pub expirations: u64,
148}
149
150impl CacheStats {
151 #[must_use]
153 pub fn hit_ratio(&self) -> f64 {
154 let total = self.hits + self.misses;
155 if total == 0 {
156 0.0
157 } else {
158 self.hits as f64 / total as f64
159 }
160 }
161}
162
163pub struct QueryCache {
188 config: CacheConfig,
189 entries: RwLock<AHashMap<CacheKey, CacheEntry>>,
191 order: RwLock<VecDeque<CacheKey>>,
193 hits: AtomicU64,
195 misses: AtomicU64,
196 evictions: AtomicU64,
197 expirations: AtomicU64,
198}
199
200impl std::fmt::Debug for QueryCache {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("QueryCache")
203 .field("config", &self.config)
204 .field("entries", &self.entries.read().len())
205 .finish()
206 }
207}
208
209impl QueryCache {
210 #[must_use]
212 pub fn new(config: CacheConfig) -> Self {
213 Self {
214 config,
215 entries: RwLock::new(AHashMap::new()),
216 order: RwLock::new(VecDeque::new()),
217 hits: AtomicU64::new(0),
218 misses: AtomicU64::new(0),
219 evictions: AtomicU64::new(0),
220 expirations: AtomicU64::new(0),
221 }
222 }
223
224 #[must_use]
226 pub fn default_cache() -> Self {
227 Self::new(CacheConfig::default())
228 }
229
230 pub fn get(&self, key: &CacheKey) -> Option<QueryResponse> {
234 let entries = self.entries.read();
236
237 if let Some(entry) = entries.get(key) {
238 if entry.is_expired(self.config.ttl) {
240 drop(entries);
241 self.remove(key);
242 self.expirations.fetch_add(1, Ordering::Relaxed);
243 self.misses.fetch_add(1, Ordering::Relaxed);
244 return None;
245 }
246
247 self.hits.fetch_add(1, Ordering::Relaxed);
248
249 drop(entries);
251 self.touch(key);
252
253 let entries = self.entries.read();
255 entries.get(key).map(|e| e.response.clone())
256 } else {
257 self.misses.fetch_add(1, Ordering::Relaxed);
258 None
259 }
260 }
261
262 pub fn put(&self, key: CacheKey, response: QueryResponse) {
264 self.maybe_evict();
266
267 let entry = CacheEntry {
268 response,
269 created_at: Instant::now(),
270 access_count: 1,
271 };
272
273 {
274 let mut entries = self.entries.write();
275 let mut order = self.order.write();
276
277 if entries.contains_key(&key) {
279 order.retain(|k| k != &key);
280 }
281
282 entries.insert(key.clone(), entry);
283 order.push_back(key);
284 }
285 }
286
287 pub fn remove(&self, key: &CacheKey) -> Option<CacheEntry> {
289 let mut entries = self.entries.write();
290 let mut order = self.order.write();
291
292 order.retain(|k| k != key);
293 entries.remove(key)
294 }
295
296 fn touch(&self, key: &CacheKey) {
298 let mut order = self.order.write();
299
300 order.retain(|k| k != key);
302 order.push_back(key.clone());
304 }
305
306 fn maybe_evict(&self) {
308 let entries = self.entries.read();
309 let current_size = entries.len();
310 drop(entries);
311
312 if current_size >= self.config.max_entries {
313 let to_evict = self.config.max_entries / 10;
315 self.evict_oldest(to_evict.max(1));
316 }
317 }
318
319 fn evict_oldest(&self, n: usize) {
321 let mut entries = self.entries.write();
322 let mut order = self.order.write();
323
324 for _ in 0..n {
325 if let Some(key) = order.pop_front() {
326 entries.remove(&key);
327 self.evictions.fetch_add(1, Ordering::Relaxed);
328 } else {
329 break;
330 }
331 }
332 }
333
334 pub fn clear(&self) {
336 let mut entries = self.entries.write();
337 let mut order = self.order.write();
338
339 entries.clear();
340 order.clear();
341 }
342
343 pub fn cleanup_expired(&self) {
345 let entries_snapshot: Vec<CacheKey> = {
346 let entries = self.entries.read();
347 entries
348 .iter()
349 .filter(|(_, entry)| entry.is_expired(self.config.ttl))
350 .map(|(key, _)| key.clone())
351 .collect()
352 };
353
354 for key in entries_snapshot {
355 self.remove(&key);
356 self.expirations.fetch_add(1, Ordering::Relaxed);
357 }
358 }
359
360 #[must_use]
362 pub fn stats(&self) -> CacheStats {
363 CacheStats {
364 hits: self.hits.load(Ordering::Relaxed),
365 misses: self.misses.load(Ordering::Relaxed),
366 entries: self.entries.read().len(),
367 evictions: self.evictions.load(Ordering::Relaxed),
368 expirations: self.expirations.load(Ordering::Relaxed),
369 }
370 }
371
372 #[must_use]
374 pub fn len(&self) -> usize {
375 self.entries.read().len()
376 }
377
378 #[must_use]
380 pub fn is_empty(&self) -> bool {
381 self.entries.read().is_empty()
382 }
383}
384
385impl Default for QueryCache {
386 fn default() -> Self {
387 Self::default_cache()
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::retrieval::engine::RetrievedRecord;
395 use crate::stats::OutcomeStats;
396 use crate::types::{MemoryRecord, RecordStatus};
397
398 fn create_test_response(result_count: usize) -> QueryResponse {
399 let results: Vec<RetrievedRecord> = (0..result_count)
400 .map(|i| RetrievedRecord {
401 record: MemoryRecord {
402 id: format!("rec-{i}").into(),
403 embedding: vec![1.0],
404 context: format!("ctx-{i}"),
405 outcome: 0.5,
406 metadata: Default::default(),
407 created_at: 0,
408 status: RecordStatus::Active,
409 stats: OutcomeStats::new(1),
410 },
411 score: 0.9 - (i as f32 * 0.1),
412 rank: i + 1,
413 source_index: "test".into(),
414 })
415 .collect();
416
417 QueryResponse {
418 results,
419 priors: None,
420 latency: Duration::from_millis(10),
421 indexes_searched: 1,
422 candidates_considered: result_count,
423 }
424 }
425
426 #[test]
427 fn test_cache_key() {
428 let key1 = CacheKey::new(&[1.0, 2.0, 3.0], 10, None, None);
429 let key2 = CacheKey::new(&[1.0, 2.0, 3.0], 10, None, None);
430 let key3 = CacheKey::new(&[1.0, 2.0, 4.0], 10, None, None);
431
432 assert_eq!(key1, key2);
433 assert_ne!(key1, key3);
434 }
435
436 #[test]
437 fn test_put_and_get() {
438 let cache = QueryCache::default_cache();
439 let key = CacheKey::new(&[1.0, 2.0], 5, None, None);
440 let response = create_test_response(5);
441
442 cache.put(key.clone(), response);
443
444 let cached = cache.get(&key);
445 assert!(cached.is_some());
446 assert_eq!(cached.unwrap().results.len(), 5);
447 }
448
449 #[test]
450 fn test_cache_miss() {
451 let cache = QueryCache::default_cache();
452 let key = CacheKey::new(&[1.0, 2.0], 5, None, None);
453
454 let cached = cache.get(&key);
455 assert!(cached.is_none());
456
457 let stats = cache.stats();
458 assert_eq!(stats.misses, 1);
459 assert_eq!(stats.hits, 0);
460 }
461
462 #[test]
463 fn test_cache_hit() {
464 let cache = QueryCache::default_cache();
465 let key = CacheKey::new(&[1.0, 2.0], 5, None, None);
466
467 cache.put(key.clone(), create_test_response(5));
468 cache.get(&key);
469
470 let stats = cache.stats();
471 assert_eq!(stats.hits, 1);
472 }
473
474 #[test]
475 fn test_ttl_expiration() {
476 let config = CacheConfig::new().with_ttl(Duration::from_millis(50));
477 let cache = QueryCache::new(config);
478
479 let key = CacheKey::new(&[1.0], 5, None, None);
480 cache.put(key.clone(), create_test_response(5));
481
482 assert!(cache.get(&key).is_some());
484
485 std::thread::sleep(Duration::from_millis(60));
487
488 assert!(cache.get(&key).is_none());
490
491 let stats = cache.stats();
492 assert_eq!(stats.expirations, 1);
493 }
494
495 #[test]
496 fn test_lru_eviction() {
497 let config = CacheConfig::new().with_max_entries(5);
498 let cache = QueryCache::new(config);
499
500 for i in 0..5 {
502 let key = CacheKey::new(&[i as f32], 1, None, None);
503 cache.put(key, create_test_response(1));
504 }
505
506 assert_eq!(cache.len(), 5);
507
508 let key = CacheKey::new(&[100.0], 1, None, None);
510 cache.put(key, create_test_response(1));
511
512 assert!(cache.len() <= 5);
514 }
515
516 #[test]
517 fn test_clear() {
518 let cache = QueryCache::default_cache();
519
520 for i in 0..10 {
521 let key = CacheKey::new(&[i as f32], 1, None, None);
522 cache.put(key, create_test_response(1));
523 }
524
525 assert_eq!(cache.len(), 10);
526
527 cache.clear();
528
529 assert!(cache.is_empty());
530 }
531
532 #[test]
533 fn test_hit_ratio() {
534 let cache = QueryCache::default_cache();
535 let key = CacheKey::new(&[1.0], 5, None, None);
536
537 cache.put(key.clone(), create_test_response(5));
538
539 cache.get(&key);
541 cache.get(&key);
542 cache.get(&key);
543
544 cache.get(&CacheKey::new(&[999.0], 5, None, None));
546
547 let stats = cache.stats();
548 assert!((stats.hit_ratio() - 0.75).abs() < 0.01);
549 }
550
551 #[test]
552 fn test_remove() {
553 let cache = QueryCache::default_cache();
554 let key = CacheKey::new(&[1.0], 5, None, None);
555
556 cache.put(key.clone(), create_test_response(5));
557 assert!(cache.get(&key).is_some());
558
559 cache.remove(&key);
560 assert!(cache.get(&key).is_none());
561 }
562}