1use lru::LruCache;
29use std::{
30 hash::Hash,
31 num::NonZeroUsize,
32 sync::{Arc, Mutex},
33 time::{Duration, Instant},
34};
35
36use crate::{GraphRAGError, GraphRAGResult};
37
38#[derive(Debug, Clone)]
40pub struct QueryCacheConfig {
41 pub capacity: NonZeroUsize,
46
47 pub default_ttl: Duration,
49
50 pub min_ttl: Duration,
52
53 pub max_ttl: Duration,
55}
56
57impl Default for QueryCacheConfig {
58 fn default() -> Self {
59 Self {
60 capacity: NonZeroUsize::new(1024).expect("1024 is non-zero"),
61 default_ttl: Duration::from_secs(3600), min_ttl: Duration::from_secs(300), max_ttl: Duration::from_secs(86_400), }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct CacheEntry<V> {
71 pub value: V,
73 pub inserted_at: Instant,
75 pub ttl: Duration,
77 pub hit_count: u64,
79}
80
81impl<V: Clone> CacheEntry<V> {
82 #[inline]
84 pub fn is_fresh(&self) -> bool {
85 self.inserted_at.elapsed() < self.ttl
86 }
87
88 #[inline]
90 pub fn remaining_ttl(&self) -> Duration {
91 let elapsed = self.inserted_at.elapsed();
92 self.ttl.saturating_sub(elapsed)
93 }
94}
95
96#[derive(Debug, Clone, Default)]
98pub struct CacheStats {
99 pub hits: u64,
101 pub misses: u64,
103 pub stale_evictions: u64,
105 pub lru_evictions: u64,
107 pub live_entries: usize,
109 pub capacity: usize,
111}
112
113impl CacheStats {
114 #[inline]
116 pub fn hit_rate(&self) -> f64 {
117 let total = self.hits + self.misses;
118 if total == 0 {
119 0.0
120 } else {
121 self.hits as f64 / total as f64
122 }
123 }
124}
125
126struct CacheInner<K, V> {
129 lru: LruCache<K, CacheEntry<V>>,
130 stats: CacheStats,
131 config: QueryCacheConfig,
132}
133
134impl<K: Hash + Eq + Clone, V: Clone> CacheInner<K, V> {
135 fn new(config: QueryCacheConfig) -> Self {
136 let capacity = config.capacity;
137 Self {
138 lru: LruCache::new(capacity),
139 stats: CacheStats {
140 capacity: capacity.get(),
141 ..Default::default()
142 },
143 config,
144 }
145 }
146
147 fn clamp_ttl(&self, ttl: Duration) -> Duration {
149 ttl.max(self.config.min_ttl).min(self.config.max_ttl)
150 }
151
152 fn get(&mut self, key: &K) -> Option<V> {
156 let is_stale = match self.lru.peek(key) {
158 Some(entry) => !entry.is_fresh(),
159 None => {
160 self.stats.misses += 1;
161 return None;
162 }
163 };
164
165 if is_stale {
166 self.lru.pop(key);
167 self.stats.stale_evictions += 1;
168 self.stats.misses += 1;
169 self.stats.live_entries = self.lru.len();
170 return None;
171 }
172
173 if let Some(entry) = self.lru.get_mut(key) {
175 entry.hit_count += 1;
176 let value = entry.value.clone();
177 self.stats.hits += 1;
178 Some(value)
179 } else {
180 self.stats.misses += 1;
181 None
182 }
183 }
184
185 fn put_with_ttl(&mut self, key: K, value: V, ttl: Duration) {
187 let ttl = self.clamp_ttl(ttl);
188
189 if self.lru.len() == self.lru.cap().get() {
191 let oldest_stale = self
194 .lru
195 .peek_lru()
196 .map(|(_, e)| !e.is_fresh())
197 .unwrap_or(false);
198 if oldest_stale {
199 self.stats.stale_evictions += 1;
200 } else {
201 self.stats.lru_evictions += 1;
202 }
203 }
204
205 let entry = CacheEntry {
206 value,
207 inserted_at: Instant::now(),
208 ttl,
209 hit_count: 0,
210 };
211 self.lru.put(key, entry);
212 self.stats.live_entries = self.lru.len();
213 }
214
215 fn put(&mut self, key: K, value: V) {
217 let ttl = self.config.default_ttl;
218 self.put_with_ttl(key, value, ttl);
219 }
220
221 fn remove(&mut self, key: &K) -> Option<V> {
223 let entry = self.lru.pop(key)?;
224 self.stats.live_entries = self.lru.len();
225 if entry.is_fresh() {
226 Some(entry.value)
227 } else {
228 self.stats.stale_evictions += 1;
229 None
230 }
231 }
232
233 fn evict_expired(&mut self) -> usize {
235 let stale_keys: Vec<K> = self
236 .lru
237 .iter()
238 .filter(|(_, entry)| !entry.is_fresh())
239 .map(|(k, _)| k.clone())
240 .collect();
241
242 let count = stale_keys.len();
243 for key in stale_keys {
244 self.lru.pop(&key);
245 }
246
247 self.stats.stale_evictions += count as u64;
248 self.stats.live_entries = self.lru.len();
249 count
250 }
251
252 fn peek_entry(&self, key: &K) -> Option<&CacheEntry<V>> {
254 self.lru.peek(key)
255 }
256
257 fn stats(&self) -> CacheStats {
259 self.stats.clone()
260 }
261
262 fn clear(&mut self) {
264 self.lru.clear();
265 self.stats.live_entries = 0;
266 }
267}
268
269#[derive(Clone)]
296pub struct QueryCache<K, V> {
297 inner: Arc<Mutex<CacheInner<K, V>>>,
298}
299
300impl<K, V> std::fmt::Debug for QueryCache<K, V> {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("QueryCache").finish_non_exhaustive()
303 }
304}
305
306impl<K, V> QueryCache<K, V>
307where
308 K: Hash + Eq + Clone + Send + 'static,
309 V: Clone + Send + 'static,
310{
311 pub fn new(config: QueryCacheConfig) -> Self {
313 Self {
314 inner: Arc::new(Mutex::new(CacheInner::new(config))),
315 }
316 }
317
318 pub fn with_defaults() -> Self {
320 Self::new(QueryCacheConfig::default())
321 }
322
323 pub fn get(&self, key: &K) -> GraphRAGResult<Option<V>> {
327 let mut guard = self
328 .inner
329 .lock()
330 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
331 Ok(guard.get(key))
332 }
333
334 pub fn put(&self, key: K, value: V) -> GraphRAGResult<()> {
336 let mut guard = self
337 .inner
338 .lock()
339 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
340 guard.put(key, value);
341 Ok(())
342 }
343
344 pub fn put_with_ttl(&self, key: K, value: V, ttl: Duration) -> GraphRAGResult<()> {
348 let mut guard = self
349 .inner
350 .lock()
351 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
352 guard.put_with_ttl(key, value, ttl);
353 Ok(())
354 }
355
356 pub fn remove(&self, key: &K) -> GraphRAGResult<Option<V>> {
358 let mut guard = self
359 .inner
360 .lock()
361 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
362 Ok(guard.remove(key))
363 }
364
365 pub fn evict_expired(&self) -> GraphRAGResult<usize> {
370 let mut guard = self
371 .inner
372 .lock()
373 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
374 Ok(guard.evict_expired())
375 }
376
377 pub fn stats(&self) -> GraphRAGResult<CacheStats> {
379 let guard = self
380 .inner
381 .lock()
382 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
383 Ok(guard.stats())
384 }
385
386 pub fn peek_entry<F, R>(&self, key: &K, f: F) -> GraphRAGResult<Option<R>>
390 where
391 F: FnOnce(&CacheEntry<V>) -> R,
392 {
393 let guard = self
394 .inner
395 .lock()
396 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
397 match guard.peek_entry(key) {
398 Some(entry) if entry.is_fresh() => Ok(Some(f(entry))),
399 _ => Ok(None),
400 }
401 }
402
403 pub fn clear(&self) -> GraphRAGResult<()> {
405 let mut guard = self
406 .inner
407 .lock()
408 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
409 guard.clear();
410 Ok(())
411 }
412
413 pub fn len(&self) -> GraphRAGResult<usize> {
418 let guard = self
419 .inner
420 .lock()
421 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
422 Ok(guard.lru.len())
423 }
424
425 pub fn is_empty(&self) -> GraphRAGResult<bool> {
427 Ok(self.len()? == 0)
428 }
429
430 pub fn capacity(&self) -> GraphRAGResult<usize> {
432 let guard = self
433 .inner
434 .lock()
435 .map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
436 Ok(guard.lru.cap().get())
437 }
438}
439
440#[cfg(test)]
443mod tests {
444 use super::*;
445 use std::thread;
446 use std::time::Duration;
447
448 fn small_cache(cap: usize, ttl_secs: u64) -> QueryCache<String, String> {
449 QueryCache::new(QueryCacheConfig {
450 capacity: NonZeroUsize::new(cap).expect("cap is non-zero"),
451 default_ttl: Duration::from_secs(ttl_secs),
452 min_ttl: Duration::from_millis(1),
453 max_ttl: Duration::from_secs(86_400),
454 })
455 }
456
457 #[test]
458 fn test_basic_put_get() {
459 let cache = small_cache(10, 3600);
460 cache.put("key1".to_string(), "value1".to_string()).unwrap();
461 let result = cache.get(&"key1".to_string()).unwrap();
462 assert_eq!(result, Some("value1".to_string()));
463 }
464
465 #[test]
466 fn test_miss_on_absent_key() {
467 let cache: QueryCache<String, String> = small_cache(10, 3600);
468 let result = cache.get(&"absent".to_string()).unwrap();
469 assert_eq!(result, None);
470 }
471
472 #[test]
473 fn test_overwrite_key() {
474 let cache = small_cache(10, 3600);
475 cache.put("k".to_string(), "v1".to_string()).unwrap();
476 cache.put("k".to_string(), "v2".to_string()).unwrap();
477 let result = cache.get(&"k".to_string()).unwrap();
478 assert_eq!(result, Some("v2".to_string()));
479 }
480
481 #[test]
482 fn test_ttl_expiry() {
483 let cache = QueryCache::new(QueryCacheConfig {
484 capacity: NonZeroUsize::new(10).unwrap(),
485 default_ttl: Duration::from_millis(50),
486 min_ttl: Duration::from_millis(1),
487 max_ttl: Duration::from_secs(3600),
488 });
489 cache.put("k".to_string(), "v".to_string()).unwrap();
490 assert_eq!(cache.get(&"k".to_string()).unwrap(), Some("v".to_string()));
492 thread::sleep(Duration::from_millis(100));
494 assert_eq!(cache.get(&"k".to_string()).unwrap(), None);
496 }
497
498 #[test]
499 fn test_lru_eviction() {
500 let cache = small_cache(3, 3600);
501 cache.put("a".to_string(), "1".to_string()).unwrap();
502 cache.put("b".to_string(), "2".to_string()).unwrap();
503 cache.put("c".to_string(), "3".to_string()).unwrap();
504
505 let _ = cache.get(&"a".to_string()).unwrap();
507
508 cache.put("d".to_string(), "4".to_string()).unwrap();
510
511 assert_eq!(
512 cache.get(&"b".to_string()).unwrap(),
513 None,
514 "b should be evicted"
515 );
516 assert!(
517 cache.get(&"a".to_string()).unwrap().is_some(),
518 "a should survive"
519 );
520 assert!(
521 cache.get(&"d".to_string()).unwrap().is_some(),
522 "d should be present"
523 );
524 }
525
526 #[test]
527 fn test_remove() {
528 let cache = small_cache(10, 3600);
529 cache.put("k".to_string(), "v".to_string()).unwrap();
530 let removed = cache.remove(&"k".to_string()).unwrap();
531 assert_eq!(removed, Some("v".to_string()));
532 assert_eq!(cache.get(&"k".to_string()).unwrap(), None);
533 }
534
535 #[test]
536 fn test_evict_expired_batch() {
537 let cache = QueryCache::new(QueryCacheConfig {
538 capacity: NonZeroUsize::new(20).unwrap(),
539 default_ttl: Duration::from_millis(50),
540 min_ttl: Duration::from_millis(1),
541 max_ttl: Duration::from_secs(3600),
542 });
543 for i in 0..5u32 {
544 cache.put(format!("k{}", i), format!("v{}", i)).unwrap();
545 }
546 thread::sleep(Duration::from_millis(100));
547 let evicted = cache.evict_expired().unwrap();
548 assert_eq!(evicted, 5);
549 assert_eq!(cache.len().unwrap(), 0);
550 }
551
552 #[test]
553 fn test_stats_hit_rate() {
554 let cache = small_cache(10, 3600);
555 cache.put("x".to_string(), "1".to_string()).unwrap();
556 let _ = cache.get(&"x".to_string()).unwrap(); let _ = cache.get(&"y".to_string()).unwrap(); let stats = cache.stats().unwrap();
560 assert_eq!(stats.hits, 1);
561 assert_eq!(stats.misses, 1);
562 assert!((stats.hit_rate() - 0.5).abs() < 1e-9);
563 }
564
565 #[test]
566 fn test_put_with_explicit_ttl() {
567 let cache = QueryCache::new(QueryCacheConfig {
568 capacity: NonZeroUsize::new(10).unwrap(),
569 default_ttl: Duration::from_secs(3600),
570 min_ttl: Duration::from_millis(1),
571 max_ttl: Duration::from_secs(86_400),
572 });
573 cache
575 .put_with_ttl("k".to_string(), "v".to_string(), Duration::from_millis(50))
576 .unwrap();
577 assert!(cache.get(&"k".to_string()).unwrap().is_some());
578 thread::sleep(Duration::from_millis(100));
579 assert_eq!(cache.get(&"k".to_string()).unwrap(), None);
580 }
581
582 #[test]
583 fn test_clear() {
584 let cache = small_cache(10, 3600);
585 cache.put("a".to_string(), "1".to_string()).unwrap();
586 cache.put("b".to_string(), "2".to_string()).unwrap();
587 cache.clear().unwrap();
588 assert_eq!(cache.len().unwrap(), 0);
589 }
590
591 #[test]
592 fn test_thread_safe_concurrent_access() {
593 let cache: QueryCache<String, usize> = QueryCache::new(QueryCacheConfig {
594 capacity: NonZeroUsize::new(256).unwrap(),
595 default_ttl: Duration::from_secs(60),
596 min_ttl: Duration::from_millis(1),
597 max_ttl: Duration::from_secs(3600),
598 });
599
600 let handles: Vec<_> = (0..8_usize)
601 .map(|t| {
602 let c = cache.clone();
603 thread::spawn(move || {
604 for i in 0..32_usize {
605 let key = format!("t{}k{}", t, i);
606 c.put(key.clone(), t * 100 + i).expect("put failed");
607 let _ = c.get(&key).expect("get failed");
608 }
609 })
610 })
611 .collect();
612
613 for h in handles {
614 h.join().expect("thread panicked");
615 }
616
617 let stats = cache.stats().unwrap();
618 assert!(stats.hits >= 256, "expected hits ≥256, got {}", stats.hits);
620 }
621
622 #[test]
623 fn test_peek_entry_metadata() {
624 let cache = small_cache(10, 3600);
625 cache.put("k".to_string(), "v".to_string()).unwrap();
626 let hit_count = cache.peek_entry(&"k".to_string(), |e| e.hit_count).unwrap();
627 assert_eq!(hit_count, Some(0)); let _ = cache.get(&"k".to_string()).unwrap();
629 let hit_count2 = cache.peek_entry(&"k".to_string(), |e| e.hit_count).unwrap();
630 assert_eq!(hit_count2, Some(1));
631 }
632
633 #[test]
634 fn test_ttl_clamping() {
635 let cache = QueryCache::new(QueryCacheConfig {
636 capacity: NonZeroUsize::new(10).unwrap(),
637 default_ttl: Duration::from_secs(60),
638 min_ttl: Duration::from_secs(10),
639 max_ttl: Duration::from_secs(120),
640 });
641 cache
643 .put_with_ttl("k".to_string(), "v".to_string(), Duration::from_millis(1))
644 .unwrap();
645 let result = cache.peek_entry(&"k".to_string(), |e| e.ttl).unwrap();
647 assert_eq!(result, Some(Duration::from_secs(10)));
648 }
649}