use lru::LruCache;
use std::{
hash::Hash,
num::NonZeroUsize,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use crate::{GraphRAGError, GraphRAGResult};
#[derive(Debug, Clone)]
pub struct QueryCacheConfig {
pub capacity: NonZeroUsize,
pub default_ttl: Duration,
pub min_ttl: Duration,
pub max_ttl: Duration,
}
impl Default for QueryCacheConfig {
fn default() -> Self {
Self {
capacity: NonZeroUsize::new(1024).expect("1024 is non-zero"),
default_ttl: Duration::from_secs(3600), min_ttl: Duration::from_secs(300), max_ttl: Duration::from_secs(86_400), }
}
}
#[derive(Debug, Clone)]
pub struct CacheEntry<V> {
pub value: V,
pub inserted_at: Instant,
pub ttl: Duration,
pub hit_count: u64,
}
impl<V: Clone> CacheEntry<V> {
#[inline]
pub fn is_fresh(&self) -> bool {
self.inserted_at.elapsed() < self.ttl
}
#[inline]
pub fn remaining_ttl(&self) -> Duration {
let elapsed = self.inserted_at.elapsed();
self.ttl.saturating_sub(elapsed)
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub stale_evictions: u64,
pub lru_evictions: u64,
pub live_entries: usize,
pub capacity: usize,
}
impl CacheStats {
#[inline]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
struct CacheInner<K, V> {
lru: LruCache<K, CacheEntry<V>>,
stats: CacheStats,
config: QueryCacheConfig,
}
impl<K: Hash + Eq + Clone, V: Clone> CacheInner<K, V> {
fn new(config: QueryCacheConfig) -> Self {
let capacity = config.capacity;
Self {
lru: LruCache::new(capacity),
stats: CacheStats {
capacity: capacity.get(),
..Default::default()
},
config,
}
}
fn clamp_ttl(&self, ttl: Duration) -> Duration {
ttl.max(self.config.min_ttl).min(self.config.max_ttl)
}
fn get(&mut self, key: &K) -> Option<V> {
let is_stale = match self.lru.peek(key) {
Some(entry) => !entry.is_fresh(),
None => {
self.stats.misses += 1;
return None;
}
};
if is_stale {
self.lru.pop(key);
self.stats.stale_evictions += 1;
self.stats.misses += 1;
self.stats.live_entries = self.lru.len();
return None;
}
if let Some(entry) = self.lru.get_mut(key) {
entry.hit_count += 1;
let value = entry.value.clone();
self.stats.hits += 1;
Some(value)
} else {
self.stats.misses += 1;
None
}
}
fn put_with_ttl(&mut self, key: K, value: V, ttl: Duration) {
let ttl = self.clamp_ttl(ttl);
if self.lru.len() == self.lru.cap().get() {
let oldest_stale = self
.lru
.peek_lru()
.map(|(_, e)| !e.is_fresh())
.unwrap_or(false);
if oldest_stale {
self.stats.stale_evictions += 1;
} else {
self.stats.lru_evictions += 1;
}
}
let entry = CacheEntry {
value,
inserted_at: Instant::now(),
ttl,
hit_count: 0,
};
self.lru.put(key, entry);
self.stats.live_entries = self.lru.len();
}
fn put(&mut self, key: K, value: V) {
let ttl = self.config.default_ttl;
self.put_with_ttl(key, value, ttl);
}
fn remove(&mut self, key: &K) -> Option<V> {
let entry = self.lru.pop(key)?;
self.stats.live_entries = self.lru.len();
if entry.is_fresh() {
Some(entry.value)
} else {
self.stats.stale_evictions += 1;
None
}
}
fn evict_expired(&mut self) -> usize {
let stale_keys: Vec<K> = self
.lru
.iter()
.filter(|(_, entry)| !entry.is_fresh())
.map(|(k, _)| k.clone())
.collect();
let count = stale_keys.len();
for key in stale_keys {
self.lru.pop(&key);
}
self.stats.stale_evictions += count as u64;
self.stats.live_entries = self.lru.len();
count
}
fn peek_entry(&self, key: &K) -> Option<&CacheEntry<V>> {
self.lru.peek(key)
}
fn stats(&self) -> CacheStats {
self.stats.clone()
}
fn clear(&mut self) {
self.lru.clear();
self.stats.live_entries = 0;
}
}
#[derive(Clone)]
pub struct QueryCache<K, V> {
inner: Arc<Mutex<CacheInner<K, V>>>,
}
impl<K, V> std::fmt::Debug for QueryCache<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueryCache").finish_non_exhaustive()
}
}
impl<K, V> QueryCache<K, V>
where
K: Hash + Eq + Clone + Send + 'static,
V: Clone + Send + 'static,
{
pub fn new(config: QueryCacheConfig) -> Self {
Self {
inner: Arc::new(Mutex::new(CacheInner::new(config))),
}
}
pub fn with_defaults() -> Self {
Self::new(QueryCacheConfig::default())
}
pub fn get(&self, key: &K) -> GraphRAGResult<Option<V>> {
let mut guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
Ok(guard.get(key))
}
pub fn put(&self, key: K, value: V) -> GraphRAGResult<()> {
let mut guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
guard.put(key, value);
Ok(())
}
pub fn put_with_ttl(&self, key: K, value: V, ttl: Duration) -> GraphRAGResult<()> {
let mut guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
guard.put_with_ttl(key, value, ttl);
Ok(())
}
pub fn remove(&self, key: &K) -> GraphRAGResult<Option<V>> {
let mut guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
Ok(guard.remove(key))
}
pub fn evict_expired(&self) -> GraphRAGResult<usize> {
let mut guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
Ok(guard.evict_expired())
}
pub fn stats(&self) -> GraphRAGResult<CacheStats> {
let guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
Ok(guard.stats())
}
pub fn peek_entry<F, R>(&self, key: &K, f: F) -> GraphRAGResult<Option<R>>
where
F: FnOnce(&CacheEntry<V>) -> R,
{
let guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
match guard.peek_entry(key) {
Some(entry) if entry.is_fresh() => Ok(Some(f(entry))),
_ => Ok(None),
}
}
pub fn clear(&self) -> GraphRAGResult<()> {
let mut guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
guard.clear();
Ok(())
}
pub fn len(&self) -> GraphRAGResult<usize> {
let guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
Ok(guard.lru.len())
}
pub fn is_empty(&self) -> GraphRAGResult<bool> {
Ok(self.len()? == 0)
}
pub fn capacity(&self) -> GraphRAGResult<usize> {
let guard = self
.inner
.lock()
.map_err(|_| GraphRAGError::InternalError("cache mutex poisoned".to_string()))?;
Ok(guard.lru.cap().get())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
fn small_cache(cap: usize, ttl_secs: u64) -> QueryCache<String, String> {
QueryCache::new(QueryCacheConfig {
capacity: NonZeroUsize::new(cap).expect("cap is non-zero"),
default_ttl: Duration::from_secs(ttl_secs),
min_ttl: Duration::from_millis(1),
max_ttl: Duration::from_secs(86_400),
})
}
#[test]
fn test_basic_put_get() {
let cache = small_cache(10, 3600);
cache
.put("key1".to_string(), "value1".to_string())
.expect("should succeed");
let result = cache.get(&"key1".to_string()).expect("should succeed");
assert_eq!(result, Some("value1".to_string()));
}
#[test]
fn test_miss_on_absent_key() {
let cache: QueryCache<String, String> = small_cache(10, 3600);
let result = cache.get(&"absent".to_string()).expect("should succeed");
assert_eq!(result, None);
}
#[test]
fn test_overwrite_key() {
let cache = small_cache(10, 3600);
cache
.put("k".to_string(), "v1".to_string())
.expect("should succeed");
cache
.put("k".to_string(), "v2".to_string())
.expect("should succeed");
let result = cache.get(&"k".to_string()).expect("should succeed");
assert_eq!(result, Some("v2".to_string()));
}
#[test]
fn test_ttl_expiry() {
let cache = QueryCache::new(QueryCacheConfig {
capacity: NonZeroUsize::new(10).expect("should succeed"),
default_ttl: Duration::from_millis(50),
min_ttl: Duration::from_millis(1),
max_ttl: Duration::from_secs(3600),
});
cache
.put("k".to_string(), "v".to_string())
.expect("should succeed");
assert_eq!(
cache.get(&"k".to_string()).expect("should succeed"),
Some("v".to_string())
);
thread::sleep(Duration::from_millis(100));
assert_eq!(cache.get(&"k".to_string()).expect("should succeed"), None);
}
#[test]
fn test_lru_eviction() {
let cache = small_cache(3, 3600);
cache
.put("a".to_string(), "1".to_string())
.expect("should succeed");
cache
.put("b".to_string(), "2".to_string())
.expect("should succeed");
cache
.put("c".to_string(), "3".to_string())
.expect("should succeed");
let _ = cache.get(&"a".to_string()).expect("should succeed");
cache
.put("d".to_string(), "4".to_string())
.expect("should succeed");
assert_eq!(
cache.get(&"b".to_string()).expect("should succeed"),
None,
"b should be evicted"
);
assert!(
cache
.get(&"a".to_string())
.expect("should succeed")
.is_some(),
"a should survive"
);
assert!(
cache
.get(&"d".to_string())
.expect("should succeed")
.is_some(),
"d should be present"
);
}
#[test]
fn test_remove() {
let cache = small_cache(10, 3600);
cache
.put("k".to_string(), "v".to_string())
.expect("should succeed");
let removed = cache.remove(&"k".to_string()).expect("should succeed");
assert_eq!(removed, Some("v".to_string()));
assert_eq!(cache.get(&"k".to_string()).expect("should succeed"), None);
}
#[test]
fn test_evict_expired_batch() {
let cache = QueryCache::new(QueryCacheConfig {
capacity: NonZeroUsize::new(20).expect("should succeed"),
default_ttl: Duration::from_millis(50),
min_ttl: Duration::from_millis(1),
max_ttl: Duration::from_secs(3600),
});
for i in 0..5u32 {
cache
.put(format!("k{}", i), format!("v{}", i))
.expect("should succeed");
}
thread::sleep(Duration::from_millis(100));
let evicted = cache.evict_expired().expect("should succeed");
assert_eq!(evicted, 5);
assert_eq!(cache.len().expect("should succeed"), 0);
}
#[test]
fn test_stats_hit_rate() {
let cache = small_cache(10, 3600);
cache
.put("x".to_string(), "1".to_string())
.expect("should succeed");
let _ = cache.get(&"x".to_string()).expect("should succeed"); let _ = cache.get(&"y".to_string()).expect("should succeed");
let stats = cache.stats().expect("should succeed");
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 0.5).abs() < 1e-9);
}
#[test]
fn test_put_with_explicit_ttl() {
let cache = QueryCache::new(QueryCacheConfig {
capacity: NonZeroUsize::new(10).expect("should succeed"),
default_ttl: Duration::from_secs(3600),
min_ttl: Duration::from_millis(1),
max_ttl: Duration::from_secs(86_400),
});
cache
.put_with_ttl("k".to_string(), "v".to_string(), Duration::from_millis(50))
.expect("should succeed");
assert!(cache
.get(&"k".to_string())
.expect("should succeed")
.is_some());
thread::sleep(Duration::from_millis(100));
assert_eq!(cache.get(&"k".to_string()).expect("should succeed"), None);
}
#[test]
fn test_clear() {
let cache = small_cache(10, 3600);
cache
.put("a".to_string(), "1".to_string())
.expect("should succeed");
cache
.put("b".to_string(), "2".to_string())
.expect("should succeed");
cache.clear().expect("should succeed");
assert_eq!(cache.len().expect("should succeed"), 0);
}
#[test]
fn test_thread_safe_concurrent_access() {
let cache: QueryCache<String, usize> = QueryCache::new(QueryCacheConfig {
capacity: NonZeroUsize::new(256).expect("should succeed"),
default_ttl: Duration::from_secs(60),
min_ttl: Duration::from_millis(1),
max_ttl: Duration::from_secs(3600),
});
let handles: Vec<_> = (0..8_usize)
.map(|t| {
let c = cache.clone();
thread::spawn(move || {
for i in 0..32_usize {
let key = format!("t{}k{}", t, i);
c.put(key.clone(), t * 100 + i).expect("put failed");
let _ = c.get(&key).expect("get failed");
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
let stats = cache.stats().expect("should succeed");
assert!(stats.hits >= 256, "expected hits ≥256, got {}", stats.hits);
}
#[test]
fn test_peek_entry_metadata() {
let cache = small_cache(10, 3600);
cache
.put("k".to_string(), "v".to_string())
.expect("should succeed");
let hit_count = cache
.peek_entry(&"k".to_string(), |e| e.hit_count)
.expect("should succeed");
assert_eq!(hit_count, Some(0)); let _ = cache.get(&"k".to_string()).expect("should succeed");
let hit_count2 = cache
.peek_entry(&"k".to_string(), |e| e.hit_count)
.expect("should succeed");
assert_eq!(hit_count2, Some(1));
}
#[test]
fn test_ttl_clamping() {
let cache = QueryCache::new(QueryCacheConfig {
capacity: NonZeroUsize::new(10).expect("should succeed"),
default_ttl: Duration::from_secs(60),
min_ttl: Duration::from_secs(10),
max_ttl: Duration::from_secs(120),
});
cache
.put_with_ttl("k".to_string(), "v".to_string(), Duration::from_millis(1))
.expect("should succeed");
let result = cache
.peek_entry(&"k".to_string(), |e| e.ttl)
.expect("should succeed");
assert_eq!(result, Some(Duration::from_secs(10)));
}
}