use super::types::{RerankRequest, RerankResponse};
use std::collections::HashMap;
use std::time::{Duration, Instant};
pub struct RerankCache {
entries: tokio::sync::RwLock<HashMap<String, CacheEntry>>,
max_size: usize,
default_ttl: Duration,
}
struct CacheEntry {
response: RerankResponse,
created_at: Instant,
ttl: Duration,
}
impl RerankCache {
pub fn new(max_size: usize, default_ttl: Duration) -> Self {
Self {
entries: tokio::sync::RwLock::new(HashMap::new()),
max_size,
default_ttl,
}
}
fn cache_key(request: &RerankRequest) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
request.model.hash(&mut hasher);
request.query.hash(&mut hasher);
for doc in &request.documents {
doc.get_text().hash(&mut hasher);
}
request.top_n.hash(&mut hasher);
format!("rerank:{:x}", hasher.finish())
}
pub async fn get(&self, request: &RerankRequest) -> Option<RerankResponse> {
let key = Self::cache_key(request);
let entries = self.entries.read().await;
if let Some(entry) = entries.get(&key)
&& entry.created_at.elapsed() < entry.ttl
{
return Some(entry.response.clone());
}
None
}
pub async fn set(&self, request: &RerankRequest, response: &RerankResponse) {
let key = Self::cache_key(request);
let mut entries = self.entries.write().await;
if entries.len() >= self.max_size {
entries.retain(|_, entry| entry.created_at.elapsed() < entry.ttl);
if entries.len() >= self.max_size
&& let Some(key_to_remove) = entries.keys().next().cloned()
{
entries.remove(&key_to_remove);
}
}
entries.insert(
key,
CacheEntry {
response: response.clone(),
created_at: Instant::now(),
ttl: self.default_ttl,
},
);
}
pub async fn clear(&self) {
self.entries.write().await.clear();
}
pub async fn stats(&self) -> RerankCacheStats {
let entries = self.entries.read().await;
let valid_entries = entries
.values()
.filter(|e| e.created_at.elapsed() < e.ttl)
.count();
RerankCacheStats {
total_entries: entries.len(),
valid_entries,
max_size: self.max_size,
}
}
}
#[derive(Debug, Clone)]
pub struct RerankCacheStats {
pub total_entries: usize,
pub valid_entries: usize,
pub max_size: usize,
}