1use lru::LruCache;
29use std::{
30 hash::Hash,
31 num::NonZeroUsize,
32 sync::{Arc, Mutex},
33 time::{Duration, Instant},
34};
35
36use crate::{GraphRAGError, GraphRAGResult};
37
38#[derive(Debug, Clone)]
40pub struct QueryCacheConfig {
41 pub capacity: NonZeroUsize,
46
47 pub default_ttl: Duration,
49
50 pub min_ttl: Duration,
52
53 pub max_ttl: Duration,
55}
56
57impl Default for QueryCacheConfig {
58 fn default() -> Self {
59 Self {
60 capacity: NonZeroUsize::new(1024).expect("1024 is non-zero"),
61 default_ttl: Duration::from_secs(3600), min_ttl: Duration::from_secs(300), max_ttl: Duration::from_secs(86_400), }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct CacheEntry<V> {
71 pub value: V,
73 pub inserted_at: Instant,
75 pub ttl: Duration,
77 pub hit_count: u64,
79}
80
81impl<V: Clone> CacheEntry<V> {
82 #[inline]
84 pub fn is_fresh(&self) -> bool {
85 self.inserted_at.elapsed() < self.ttl
86 }
87
88 #[inline]
90 pub fn remaining_ttl(&self) -> Duration {
91 let elapsed = self.inserted_at.elapsed();
92 self.ttl.saturating_sub(elapsed)
93 }
94}
95
96#[derive(Debug, Clone, Default)]
98pub struct CacheStats {
99 pub hits: u64,
101 pub misses: u64,
103 pub stale_evictions: u64,
105 pub lru_evictions: u64,
107 pub live_entries: usize,
109 pub capacity: usize,
111}
112
113impl CacheStats {
114 #[inline]
116 pub fn hit_rate(&self) -> f64 {
117 let total = self.hits + self.misses;
118 if total == 0 {
119 0.0
120 } else {
121 self.hits as f64 / total as f64
122 }
123 }
124}
125
126struct CacheInner<K, V> {
129 lru: LruCache<K, CacheEntry<V>>,
130 stats: CacheStats,
131 config: QueryCacheConfig,
132}
133
134impl<K: Hash + Eq + Clone, V: Clone> CacheInner<K, V> {
135 fn new(config: QueryCacheConfig) -> Self {
136 let capacity = config.capacity;
137 Self {
138 lru: LruCache::new(capacity),
139 stats: CacheStats {
140 capacity: capacity.get(),
141 ..Default::default()
142 },
143 config,
144 }
145 }
146
147 fn clamp_ttl(&self, ttl: Duration) -> Duration {
149 ttl.max(self.config.min_ttl).min(self.config.max_ttl)
150 }
151
152 fn get(&mut self, key: &K) -> Option<V> {
156 let is_stale = match self.lru.peek(key) {
158 Some(entry) => !entry.is_fresh(),
159 None => {
160 self.stats.misses += 1;
161 return None;
162 }
163 };
164
165 if is_stale {
166 self.lru.pop(key);
167 self.stats.stale_evictions += 1;
168 self.stats.misses += 1;
169 self.stats.live_entries = self.lru.len();
170 return None;
171 }
172
173 if let Some(entry) = self.lru.get_mut(key) {
175 entry.hit_count += 1;
176 let value = entry.value.clone();
177 self.stats.hits += 1;
178 Some(value)
179 } else {
180 self.stats.misses += 1;
181 None
182 }
183 }
184
185 fn put_with_ttl(&mut self, key: K, value: V, ttl: Duration) {
187 let ttl = self.clamp_ttl(ttl);
188
189 if self.lru.len() == self.lru.cap().get() {
191 let oldest_stale = self
194 .lru
195 .peek_lru()
196 .map(|(_, e)| !e.is_fresh())
197 .unwrap_or(false);
198 if oldest_stale {
199 self.stats.stale_evictions += 1;
200 } else {
201 self.stats.lru_evictions += 1;
202 }
203 }
204
205 let entry = CacheEntry {
206 value,
207 inserted_at: Instant::now(),
208 ttl,
209 hit_count: 0,
210 };
211 self.lru.put(key, entry);
212 self.stats.live_entries = self.lru.len();
213 }
214
215 fn put(&mut self, key: K, value: V) {
217 let ttl = self.config.default_ttl;
218 self.put_with_ttl(key, value, ttl);
219 }
220
221 fn remove(&mut self, key: &K) -> Option<V> {
223 let entry = self.lru.pop(key)?;
224 self.stats.live_entries = self.lru.len();
225 if entry.is_fresh() {
226 Some(entry.value)
227 } else {
228 self.stats.stale_evictions += 1;
229 None
230 }
231 }
232
233 fn evict_expired(&mut self) -> usize {
235 let stale_keys: Vec<K> = self
236 .lru
237 .iter()
238 .filter(|(_, entry)| !entry.is_fresh())
239 .map(|(k, _)| k.clone())
240 .collect();
241
242 let count = stale_keys.len();
243 for key in stale_keys {
244 self.lru.pop(&key);
245 }
246
247 self.stats.stale_evictions += count as u64;
248 self.stats.live_entries = self.lru.len();
249 count
250 }
251
252 fn peek_entry(&self, key: &K) -> Option<&CacheEntry<V>> {
254 self.lru.peek(key)
255 }
256
257 fn stats(&self) -> CacheStats {
259 self.stats.clone()
260 }
261
262 fn clear(&mut self) {
264 self.lru.clear();
265 self.stats.live_entries = 0;
266 }
267}
268
269#[derive(Clone)]
296pub struct QueryCache<K, V> {
297 inner: Arc<Mutex<CacheInner<K, V>>>,
298}
299
300impl<K, V> std::fmt::Debug for QueryCache<K, V> {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("QueryCache").finish_non_exhaustive()
303 }
304}
305
306impl<K, V> QueryCache<K, V>
307where
308 K: Hash + Eq + Clone + Send + 'static,
309 V: Clone + Send + 'static,
310{
311 pub fn new(config: QueryCacheConfig) -> Self {
313 Self {
314 inner: Arc::new(Mutex::new(CacheInner::new(config))),
315 }
316 }
317
318 pub fn with_defaults() -> Self {
320 Self::new(QueryCacheConfig::default())
321 }
322
323 pub fn get(&self, key: &K) -> GraphRAGResult<Option<V>> {
327 let mut guard = self
328 .inner
329 .lock()
330 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
331 Ok(guard.get(key))
332 }
333
334 pub fn put(&self, key: K, value: V) -> GraphRAGResult<()> {
336 let mut guard = self
337 .inner
338 .lock()
339 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
340 guard.put(key, value);
341 Ok(())
342 }
343
344 pub fn put_with_ttl(&self, key: K, value: V, ttl: Duration) -> GraphRAGResult<()> {
348 let mut guard = self
349 .inner
350 .lock()
351 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
352 guard.put_with_ttl(key, value, ttl);
353 Ok(())
354 }
355
356 pub fn remove(&self, key: &K) -> GraphRAGResult<Option<V>> {
358 let mut guard = self
359 .inner
360 .lock()
361 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
362 Ok(guard.remove(key))
363 }
364
365 pub fn evict_expired(&self) -> GraphRAGResult<usize> {
370 let mut guard = self
371 .inner
372 .lock()
373 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
374 Ok(guard.evict_expired())
375 }
376
377 pub fn stats(&self) -> GraphRAGResult<CacheStats> {
379 let guard = self
380 .inner
381 .lock()
382 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
383 Ok(guard.stats())
384 }
385
386 pub fn peek_entry<F, R>(&self, key: &K, f: F) -> GraphRAGResult<Option<R>>
390 where
391 F: FnOnce(&CacheEntry<V>) -> R,
392 {
393 let guard = self
394 .inner
395 .lock()
396 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
397 match guard.peek_entry(key) {
398 Some(entry) if entry.is_fresh() => Ok(Some(f(entry))),
399 _ => Ok(None),
400 }
401 }
402
403 pub fn clear(&self) -> GraphRAGResult<()> {
405 let mut guard = self
406 .inner
407 .lock()
408 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
409 guard.clear();
410 Ok(())
411 }
412
413 pub fn len(&self) -> GraphRAGResult<usize> {
418 let guard = self
419 .inner
420 .lock()
421 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
422 Ok(guard.lru.len())
423 }
424
425 pub fn is_empty(&self) -> GraphRAGResult<bool> {
427 Ok(self.len()? == 0)
428 }
429
430 pub fn capacity(&self) -> GraphRAGResult<usize> {
432 let guard = self
433 .inner
434 .lock()
435 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
436 Ok(guard.lru.cap().get())
437 }
438}
439
440#[cfg(test)]
443mod tests {
444 use super::*;
445 use std::thread;
446 use std::time::Duration;
447
448 fn small_cache(cap: usize, ttl_secs: u64) -> QueryCache<String, String> {
449 QueryCache::new(QueryCacheConfig {
450 capacity: NonZeroUsize::new(cap).expect("cap is non-zero"),
451 default_ttl: Duration::from_secs(ttl_secs),
452 min_ttl: Duration::from_millis(1),
453 max_ttl: Duration::from_secs(86_400),
454 })
455 }
456
457 #[test]
458 fn test_basic_put_get() {
459 let cache = small_cache(10, 3600);
460 cache
461 .put("key1".to_string(), "value1".to_string())
462 .expect("should succeed");
463 let result = cache.get(&"key1".to_string()).expect("should succeed");
464 assert_eq!(result, Some("value1".to_string()));
465 }
466
467 #[test]
468 fn test_miss_on_absent_key() {
469 let cache: QueryCache<String, String> = small_cache(10, 3600);
470 let result = cache.get(&"absent".to_string()).expect("should succeed");
471 assert_eq!(result, None);
472 }
473
474 #[test]
475 fn test_overwrite_key() {
476 let cache = small_cache(10, 3600);
477 cache
478 .put("k".to_string(), "v1".to_string())
479 .expect("should succeed");
480 cache
481 .put("k".to_string(), "v2".to_string())
482 .expect("should succeed");
483 let result = cache.get(&"k".to_string()).expect("should succeed");
484 assert_eq!(result, Some("v2".to_string()));
485 }
486
487 #[test]
488 fn test_ttl_expiry() {
489 let cache = QueryCache::new(QueryCacheConfig {
490 capacity: NonZeroUsize::new(10).expect("should succeed"),
491 default_ttl: Duration::from_millis(50),
492 min_ttl: Duration::from_millis(1),
493 max_ttl: Duration::from_secs(3600),
494 });
495 cache
496 .put("k".to_string(), "v".to_string())
497 .expect("should succeed");
498 assert_eq!(
500 cache.get(&"k".to_string()).expect("should succeed"),
501 Some("v".to_string())
502 );
503 thread::sleep(Duration::from_millis(100));
505 assert_eq!(cache.get(&"k".to_string()).expect("should succeed"), None);
507 }
508
509 #[test]
510 fn test_lru_eviction() {
511 let cache = small_cache(3, 3600);
512 cache
513 .put("a".to_string(), "1".to_string())
514 .expect("should succeed");
515 cache
516 .put("b".to_string(), "2".to_string())
517 .expect("should succeed");
518 cache
519 .put("c".to_string(), "3".to_string())
520 .expect("should succeed");
521
522 let _ = cache.get(&"a".to_string()).expect("should succeed");
524
525 cache
527 .put("d".to_string(), "4".to_string())
528 .expect("should succeed");
529
530 assert_eq!(
531 cache.get(&"b".to_string()).expect("should succeed"),
532 None,
533 "b should be evicted"
534 );
535 assert!(
536 cache
537 .get(&"a".to_string())
538 .expect("should succeed")
539 .is_some(),
540 "a should survive"
541 );
542 assert!(
543 cache
544 .get(&"d".to_string())
545 .expect("should succeed")
546 .is_some(),
547 "d should be present"
548 );
549 }
550
551 #[test]
552 fn test_remove() {
553 let cache = small_cache(10, 3600);
554 cache
555 .put("k".to_string(), "v".to_string())
556 .expect("should succeed");
557 let removed = cache.remove(&"k".to_string()).expect("should succeed");
558 assert_eq!(removed, Some("v".to_string()));
559 assert_eq!(cache.get(&"k".to_string()).expect("should succeed"), None);
560 }
561
562 #[test]
563 fn test_evict_expired_batch() {
564 let cache = QueryCache::new(QueryCacheConfig {
565 capacity: NonZeroUsize::new(20).expect("should succeed"),
566 default_ttl: Duration::from_millis(50),
567 min_ttl: Duration::from_millis(1),
568 max_ttl: Duration::from_secs(3600),
569 });
570 for i in 0..5u32 {
571 cache
572 .put(format!("k{}", i), format!("v{}", i))
573 .expect("should succeed");
574 }
575 thread::sleep(Duration::from_millis(100));
576 let evicted = cache.evict_expired().expect("should succeed");
577 assert_eq!(evicted, 5);
578 assert_eq!(cache.len().expect("should succeed"), 0);
579 }
580
581 #[test]
582 fn test_stats_hit_rate() {
583 let cache = small_cache(10, 3600);
584 cache
585 .put("x".to_string(), "1".to_string())
586 .expect("should succeed");
587 let _ = cache.get(&"x".to_string()).expect("should succeed"); let _ = cache.get(&"y".to_string()).expect("should succeed"); let stats = cache.stats().expect("should succeed");
591 assert_eq!(stats.hits, 1);
592 assert_eq!(stats.misses, 1);
593 assert!((stats.hit_rate() - 0.5).abs() < 1e-9);
594 }
595
596 #[test]
597 fn test_put_with_explicit_ttl() {
598 let cache = QueryCache::new(QueryCacheConfig {
599 capacity: NonZeroUsize::new(10).expect("should succeed"),
600 default_ttl: Duration::from_secs(3600),
601 min_ttl: Duration::from_millis(1),
602 max_ttl: Duration::from_secs(86_400),
603 });
604 cache
606 .put_with_ttl("k".to_string(), "v".to_string(), Duration::from_millis(50))
607 .expect("should succeed");
608 assert!(cache
609 .get(&"k".to_string())
610 .expect("should succeed")
611 .is_some());
612 thread::sleep(Duration::from_millis(100));
613 assert_eq!(cache.get(&"k".to_string()).expect("should succeed"), None);
614 }
615
616 #[test]
617 fn test_clear() {
618 let cache = small_cache(10, 3600);
619 cache
620 .put("a".to_string(), "1".to_string())
621 .expect("should succeed");
622 cache
623 .put("b".to_string(), "2".to_string())
624 .expect("should succeed");
625 cache.clear().expect("should succeed");
626 assert_eq!(cache.len().expect("should succeed"), 0);
627 }
628
629 #[test]
630 fn test_thread_safe_concurrent_access() {
631 let cache: QueryCache<String, usize> = QueryCache::new(QueryCacheConfig {
632 capacity: NonZeroUsize::new(256).expect("should succeed"),
633 default_ttl: Duration::from_secs(60),
634 min_ttl: Duration::from_millis(1),
635 max_ttl: Duration::from_secs(3600),
636 });
637
638 let handles: Vec<_> = (0..8_usize)
639 .map(|t| {
640 let c = cache.clone();
641 thread::spawn(move || {
642 for i in 0..32_usize {
643 let key = format!("t{}k{}", t, i);
644 c.put(key.clone(), t * 100 + i).expect("put failed");
645 let _ = c.get(&key).expect("get failed");
646 }
647 })
648 })
649 .collect();
650
651 for h in handles {
652 h.join().expect("thread panicked");
653 }
654
655 let stats = cache.stats().expect("should succeed");
656 assert!(stats.hits >= 256, "expected hits ≥256, got {}", stats.hits);
658 }
659
660 #[test]
661 fn test_peek_entry_metadata() {
662 let cache = small_cache(10, 3600);
663 cache
664 .put("k".to_string(), "v".to_string())
665 .expect("should succeed");
666 let hit_count = cache
667 .peek_entry(&"k".to_string(), |e| e.hit_count)
668 .expect("should succeed");
669 assert_eq!(hit_count, Some(0)); let _ = cache.get(&"k".to_string()).expect("should succeed");
671 let hit_count2 = cache
672 .peek_entry(&"k".to_string(), |e| e.hit_count)
673 .expect("should succeed");
674 assert_eq!(hit_count2, Some(1));
675 }
676
677 #[test]
678 fn test_ttl_clamping() {
679 let cache = QueryCache::new(QueryCacheConfig {
680 capacity: NonZeroUsize::new(10).expect("should succeed"),
681 default_ttl: Duration::from_secs(60),
682 min_ttl: Duration::from_secs(10),
683 max_ttl: Duration::from_secs(120),
684 });
685 cache
687 .put_with_ttl("k".to_string(), "v".to_string(), Duration::from_millis(1))
688 .expect("should succeed");
689 let result = cache
691 .peek_entry(&"k".to_string(), |e| e.ttl)
692 .expect("should succeed");
693 assert_eq!(result, Some(Duration::from_secs(10)));
694 }
695}