use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::RwLock;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::retrievers::BaseRetriever;
#[derive(Debug)]
pub struct CacheEntry {
pub documents: Vec<Document>,
pub created_at: Instant,
pub last_accessed: Instant,
pub hit_count: AtomicUsize,
}
impl CacheEntry {
fn new(documents: Vec<Document>) -> Self {
let now = Instant::now();
Self {
documents,
created_at: now,
last_accessed: now,
hit_count: AtomicUsize::new(0),
}
}
fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub normalize_queries: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
ttl: Duration::from_secs(300),
normalize_queries: true,
}
}
}
impl CacheConfig {
pub fn with_max_entries(mut self, max_entries: usize) -> Self {
self.max_entries = max_entries;
self
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_normalize_queries(mut self, normalize: bool) -> Self {
self.normalize_queries = normalize;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub evictions: usize,
pub size: usize,
}
pub struct CachingRetriever {
inner: Arc<dyn BaseRetriever>,
cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
config: CacheConfig,
stats: Arc<RwLock<CacheStats>>,
}
impl CachingRetriever {
pub fn new(inner: Arc<dyn BaseRetriever>, config: CacheConfig) -> Self {
Self {
inner,
cache: Arc::new(RwLock::new(HashMap::new())),
config,
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
pub fn with_defaults(inner: Arc<dyn BaseRetriever>) -> Self {
Self::new(inner, CacheConfig::default())
}
fn normalize_query(&self, query: &str) -> String {
if self.config.normalize_queries {
query.trim().to_lowercase()
} else {
query.to_string()
}
}
pub async fn invalidate(&self, query: &str) {
let key = self.normalize_query(query);
let mut cache = self.cache.write().await;
if cache.remove(&key).is_some() {
let mut stats = self.stats.write().await;
stats.evictions += 1;
}
}
pub async fn clear_cache(&self) {
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
stats.evictions += cache.len();
cache.clear();
}
pub async fn cache_stats(&self) -> CacheStats {
let stats = self.stats.read().await;
let cache = self.cache.read().await;
CacheStats {
hits: stats.hits,
misses: stats.misses,
evictions: stats.evictions,
size: cache.len(),
}
}
async fn evict_expired(&self) {
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
let before = cache.len();
cache.retain(|_, entry| !entry.is_expired(self.config.ttl));
let removed = before - cache.len();
stats.evictions += removed;
}
async fn evict_lru(&self) {
let mut cache = self.cache.write().await;
if cache.len() <= self.config.max_entries {
return;
}
let mut stats = self.stats.write().await;
while cache.len() > self.config.max_entries {
if let Some(lru_key) = cache
.iter()
.min_by_key(|(_, entry)| entry.last_accessed)
.map(|(key, _)| key.clone())
{
cache.remove(&lru_key);
stats.evictions += 1;
} else {
break;
}
}
}
}
#[async_trait]
impl BaseRetriever for CachingRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
let key = self.normalize_query(query);
{
let cache = self.cache.read().await;
if let Some(entry) = cache.get(&key) {
if !entry.is_expired(self.config.ttl) {
entry.hit_count.fetch_add(1, Ordering::Relaxed);
let docs = entry.documents.clone();
drop(cache);
let mut stats = self.stats.write().await;
stats.hits += 1;
return Ok(docs);
}
}
}
let docs = self.inner.get_relevant_documents(query).await?;
self.evict_expired().await;
{
let mut cache = self.cache.write().await;
cache.insert(key, CacheEntry::new(docs.clone()));
}
self.evict_lru().await;
{
let mut stats = self.stats.write().await;
stats.misses += 1;
}
Ok(docs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize as StdAtomicUsize;
struct CountingRetriever {
docs: Vec<Document>,
call_count: StdAtomicUsize,
}
impl CountingRetriever {
fn new(docs: Vec<Document>) -> Self {
Self {
docs,
call_count: StdAtomicUsize::new(0),
}
}
fn calls(&self) -> usize {
self.call_count.load(Ordering::Relaxed)
}
}
#[async_trait]
impl BaseRetriever for CountingRetriever {
async fn get_relevant_documents(&self, _query: &str) -> Result<Vec<Document>> {
self.call_count.fetch_add(1, Ordering::Relaxed);
Ok(self.docs.clone())
}
}
fn make_docs(contents: &[&str]) -> Vec<Document> {
contents.iter().map(|c| Document::new(*c)).collect()
}
#[tokio::test]
async fn test_cache_hit_avoids_inner_call() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let caching = CachingRetriever::with_defaults(inner.clone());
let docs1 = caching.get_relevant_documents("hello").await.unwrap();
let docs2 = caching.get_relevant_documents("hello").await.unwrap();
assert_eq!(docs1, docs2);
assert_eq!(inner.calls(), 1); }
#[tokio::test]
async fn test_different_queries_cause_separate_calls() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let caching = CachingRetriever::with_defaults(inner.clone());
caching.get_relevant_documents("query_a").await.unwrap();
caching.get_relevant_documents("query_b").await.unwrap();
assert_eq!(inner.calls(), 2);
}
#[tokio::test]
async fn test_query_normalization_lowercase() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let config = CacheConfig::default().with_normalize_queries(true);
let caching = CachingRetriever::new(inner.clone(), config);
caching.get_relevant_documents("Hello World").await.unwrap();
caching.get_relevant_documents("hello world").await.unwrap();
assert_eq!(inner.calls(), 1);
}
#[tokio::test]
async fn test_query_normalization_trim() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let config = CacheConfig::default().with_normalize_queries(true);
let caching = CachingRetriever::new(inner.clone(), config);
caching.get_relevant_documents(" hello ").await.unwrap();
caching.get_relevant_documents("hello").await.unwrap();
assert_eq!(inner.calls(), 1);
}
#[tokio::test]
async fn test_normalization_disabled() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let config = CacheConfig::default().with_normalize_queries(false);
let caching = CachingRetriever::new(inner.clone(), config);
caching.get_relevant_documents("Hello").await.unwrap();
caching.get_relevant_documents("hello").await.unwrap();
assert_eq!(inner.calls(), 2); }
#[tokio::test]
async fn test_ttl_expiration() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let config = CacheConfig::default().with_ttl(Duration::from_millis(50));
let caching = CachingRetriever::new(inner.clone(), config);
caching.get_relevant_documents("query").await.unwrap();
assert_eq!(inner.calls(), 1);
tokio::time::sleep(Duration::from_millis(100)).await;
caching.get_relevant_documents("query").await.unwrap();
assert_eq!(inner.calls(), 2); }
#[tokio::test]
async fn test_invalidate_removes_entry() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let caching = CachingRetriever::with_defaults(inner.clone());
caching.get_relevant_documents("query").await.unwrap();
assert_eq!(inner.calls(), 1);
caching.invalidate("query").await;
caching.get_relevant_documents("query").await.unwrap();
assert_eq!(inner.calls(), 2);
}
#[tokio::test]
async fn test_clear_cache() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let caching = CachingRetriever::with_defaults(inner.clone());
caching.get_relevant_documents("a").await.unwrap();
caching.get_relevant_documents("b").await.unwrap();
assert_eq!(inner.calls(), 2);
caching.clear_cache().await;
caching.get_relevant_documents("a").await.unwrap();
caching.get_relevant_documents("b").await.unwrap();
assert_eq!(inner.calls(), 4);
}
#[tokio::test]
async fn test_cache_stats_hits_and_misses() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let caching = CachingRetriever::with_defaults(inner.clone());
caching.get_relevant_documents("a").await.unwrap(); caching.get_relevant_documents("a").await.unwrap(); caching.get_relevant_documents("b").await.unwrap(); caching.get_relevant_documents("a").await.unwrap();
let stats = caching.cache_stats().await;
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 2);
assert_eq!(stats.size, 2);
}
#[tokio::test]
async fn test_lru_eviction() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let config = CacheConfig::default().with_max_entries(2);
let caching = CachingRetriever::new(inner.clone(), config);
caching.get_relevant_documents("a").await.unwrap();
caching.get_relevant_documents("b").await.unwrap();
caching.get_relevant_documents("c").await.unwrap();
let stats = caching.cache_stats().await;
assert_eq!(stats.size, 2);
assert!(stats.evictions >= 1);
}
#[tokio::test]
async fn test_cache_returns_correct_documents() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["alpha", "beta"])));
let caching = CachingRetriever::with_defaults(inner.clone());
let docs = caching.get_relevant_documents("test").await.unwrap();
assert_eq!(docs.len(), 2);
assert_eq!(docs[0].page_content, "alpha");
assert_eq!(docs[1].page_content, "beta");
let docs2 = caching.get_relevant_documents("test").await.unwrap();
assert_eq!(docs, docs2);
}
#[tokio::test]
async fn test_empty_results_are_cached() {
let inner = Arc::new(CountingRetriever::new(vec![]));
let caching = CachingRetriever::with_defaults(inner.clone());
let docs1 = caching.get_relevant_documents("empty").await.unwrap();
let docs2 = caching.get_relevant_documents("empty").await.unwrap();
assert!(docs1.is_empty());
assert!(docs2.is_empty());
assert_eq!(inner.calls(), 1); }
#[tokio::test]
async fn test_invalidate_nonexistent_key_is_noop() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let caching = CachingRetriever::with_defaults(inner.clone());
caching.invalidate("nonexistent").await;
let stats = caching.cache_stats().await;
assert_eq!(stats.evictions, 0);
}
#[tokio::test]
async fn test_evict_expired_cleans_up() {
let inner = Arc::new(CountingRetriever::new(make_docs(&["doc1"])));
let config = CacheConfig::default().with_ttl(Duration::from_millis(50));
let caching = CachingRetriever::new(inner.clone(), config);
caching.get_relevant_documents("a").await.unwrap();
caching.get_relevant_documents("b").await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
caching.get_relevant_documents("c").await.unwrap();
let stats = caching.cache_stats().await;
assert_eq!(stats.size, 1);
assert!(stats.evictions >= 2);
}
}