use crate::Vector;
use blake3::Hasher;
use lru::LruCache;
use parking_lot::RwLock;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Clone, Debug)]
struct CachedResult {
results: Vec<(String, f32)>,
cached_at: Instant,
hit_count: usize,
}
impl CachedResult {
fn new(results: Vec<(String, f32)>) -> Self {
Self {
results,
cached_at: Instant::now(),
hit_count: 0,
}
}
fn is_expired(&self, ttl: Duration) -> bool {
self.cached_at.elapsed() > ttl
}
fn record_hit(&mut self) {
self.hit_count += 1;
}
}
#[derive(Debug, Clone)]
pub struct QueryCacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub enable_fuzzy_matching: bool,
pub fuzzy_threshold: f32,
pub enable_stats: bool,
}
impl Default for QueryCacheConfig {
fn default() -> Self {
Self {
max_entries: 10000,
ttl: Duration::from_secs(300), enable_fuzzy_matching: false,
fuzzy_threshold: 0.95,
enable_stats: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct QueryCacheStats {
pub total_queries: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub evictions: u64,
pub expirations: u64,
}
impl QueryCacheStats {
pub fn hit_rate(&self) -> f64 {
if self.total_queries == 0 {
0.0
} else {
self.cache_hits as f64 / self.total_queries as f64
}
}
}
pub struct QueryCache {
cache: Arc<RwLock<LruCache<u64, CachedResult>>>,
config: QueryCacheConfig,
stats: Arc<RwLock<QueryCacheStats>>,
}
impl QueryCache {
pub fn new(config: QueryCacheConfig) -> Self {
let capacity =
NonZeroUsize::new(config.max_entries).expect("cache max_entries must be non-zero");
Self {
cache: Arc::new(RwLock::new(LruCache::new(capacity))),
config,
stats: Arc::new(RwLock::new(QueryCacheStats::default())),
}
}
fn generate_key(&self, query: &Vector, k: usize) -> u64 {
let mut hasher = Hasher::new();
let query_f32 = query.as_f32();
for &val in &query_f32 {
hasher.update(&val.to_le_bytes());
}
hasher.update(&k.to_le_bytes());
let hash = hasher.finalize();
let hash_bytes = hash.as_bytes();
u64::from_le_bytes([
hash_bytes[0],
hash_bytes[1],
hash_bytes[2],
hash_bytes[3],
hash_bytes[4],
hash_bytes[5],
hash_bytes[6],
hash_bytes[7],
])
}
pub fn get(&self, query: &Vector, k: usize) -> Option<Vec<(String, f32)>> {
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.total_queries += 1;
}
let key = self.generate_key(query, k);
let mut cache = self.cache.write();
if let Some(cached) = cache.get_mut(&key) {
if cached.is_expired(self.config.ttl) {
cache.pop(&key);
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.expirations += 1;
stats.cache_misses += 1;
}
return None;
}
cached.record_hit();
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.cache_hits += 1;
}
return Some(cached.results.clone());
}
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.cache_misses += 1;
}
None
}
pub fn put(&self, query: &Vector, k: usize, results: Vec<(String, f32)>) {
let key = self.generate_key(query, k);
let mut cache = self.cache.write();
let cached_result = CachedResult::new(results);
if cache.len() >= self.config.max_entries && self.config.enable_stats {
let mut stats = self.stats.write();
stats.evictions += 1;
}
cache.put(key, cached_result);
}
pub fn clear(&self) {
let mut cache = self.cache.write();
cache.clear();
}
pub fn get_stats(&self) -> QueryCacheStats {
self.stats.read().clone()
}
pub fn reset_stats(&self) {
let mut stats = self.stats.write();
*stats = QueryCacheStats::default();
}
pub fn len(&self) -> usize {
self.cache.read().len()
}
pub fn is_empty(&self) -> bool {
self.cache.read().is_empty()
}
pub fn cleanup_expired(&self) -> usize {
let mut cache = self.cache.write();
let mut expired_keys = Vec::new();
for (key, cached) in cache.iter() {
if cached.is_expired(self.config.ttl) {
expired_keys.push(*key);
}
}
let count = expired_keys.len();
for key in expired_keys {
cache.pop(&key);
}
if self.config.enable_stats && count > 0 {
let mut stats = self.stats.write();
stats.expirations += count as u64;
}
count
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
#[test]
fn test_query_cache_basic() -> Result<()> {
let config = QueryCacheConfig::default();
let cache = QueryCache::new(config);
let query = Vector::new(vec![1.0, 2.0, 3.0]);
let results = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
assert!(cache.get(&query, 5).is_none());
cache.put(&query, 5, results.clone());
let cached = cache.get(&query, 5).expect("cache should have results");
assert_eq!(cached.len(), 2);
assert_eq!(cached[0].0, "uri1");
assert_eq!(cached[0].1, 0.9);
Ok(())
}
#[test]
fn test_query_cache_expiration() {
let config = QueryCacheConfig {
ttl: Duration::from_millis(100),
..Default::default()
};
let cache = QueryCache::new(config);
let query = Vector::new(vec![1.0, 2.0, 3.0]);
let results = vec![("uri1".to_string(), 0.9)];
cache.put(&query, 5, results);
assert!(cache.get(&query, 5).is_some());
std::thread::sleep(Duration::from_millis(150));
assert!(cache.get(&query, 5).is_none());
}
#[test]
fn test_query_cache_stats() {
let config = QueryCacheConfig::default();
let cache = QueryCache::new(config);
let query = Vector::new(vec![1.0, 2.0, 3.0]);
let results = vec![("uri1".to_string(), 0.9)];
cache.get(&query, 5);
cache.put(&query, 5, results);
cache.get(&query, 5);
cache.get(&query, 5);
let stats = cache.get_stats();
assert_eq!(stats.total_queries, 3);
assert_eq!(stats.cache_hits, 2);
assert_eq!(stats.cache_misses, 1);
assert_eq!(stats.hit_rate(), 2.0 / 3.0);
}
#[test]
fn test_query_cache_cleanup() {
let config = QueryCacheConfig {
ttl: Duration::from_millis(100),
..Default::default()
};
let cache = QueryCache::new(config);
for i in 0..5 {
let query = Vector::new(vec![i as f32, 0.0, 0.0]);
let results = vec![(format!("uri{}", i), 0.9)];
cache.put(&query, 5, results);
}
assert_eq!(cache.len(), 5);
std::thread::sleep(Duration::from_millis(150));
let expired = cache.cleanup_expired();
assert_eq!(expired, 5);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_query_cache_different_k() -> Result<()> {
let config = QueryCacheConfig::default();
let cache = QueryCache::new(config);
let query = Vector::new(vec![1.0, 2.0, 3.0]);
let results_k5 = vec![("uri1".to_string(), 0.9)];
let results_k10 = vec![("uri1".to_string(), 0.9), ("uri2".to_string(), 0.8)];
cache.put(&query, 5, results_k5);
cache.put(&query, 10, results_k10);
let cached_k5 = cache.get(&query, 5).expect("cache k5 should have results");
let cached_k10 = cache
.get(&query, 10)
.expect("cache k10 should have results");
assert_eq!(cached_k5.len(), 1);
assert_eq!(cached_k10.len(), 2);
Ok(())
}
}