use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use tracing::{debug, info};
use uuid::Uuid;
use crate::content_vectorizer::SemanticSearchResult;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QueryCacheConfig {
pub max_cache_size: usize,
pub ttl_minutes: i64,
pub similarity_threshold: f32,
pub enable_prefetching: bool,
pub max_prefetch_variations: usize,
pub enable_stats: bool,
}
impl Default for QueryCacheConfig {
fn default() -> Self {
Self {
max_cache_size: 1000,
ttl_minutes: 30,
similarity_threshold: 0.85,
enable_prefetching: true,
max_prefetch_variations: 5,
enable_stats: true,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CachedQuery {
pub id: Uuid,
pub query_text: String,
pub query_vector: Vec<f32>,
pub results: Vec<SemanticSearchResult>,
pub cached_at: DateTime<Utc>,
last_accessed: AtomicU64,
access_count: AtomicU64,
pub params_hash: u64,
pub session_id: Option<Uuid>,
efficiency_score_bits: AtomicU64,
}
impl CachedQuery {
pub fn new(
query_text: String,
query_vector: Vec<f32>,
results: Vec<SemanticSearchResult>,
params_hash: u64,
session_id: Option<Uuid>,
) -> Self {
let now = Utc::now();
let now_timestamp = now.timestamp() as u64;
Self {
id: Uuid::new_v4(),
query_text,
query_vector,
results,
cached_at: now,
last_accessed: AtomicU64::new(now_timestamp),
access_count: AtomicU64::new(0),
params_hash,
session_id,
efficiency_score_bits: AtomicU64::new(1.0f32.to_bits() as u64),
}
}
pub fn is_expired(&self, ttl_minutes: i64) -> bool {
let ttl_duration = Duration::minutes(ttl_minutes);
Utc::now() - self.cached_at > ttl_duration
}
pub fn mark_accessed(&self) {
let now = Utc::now();
let now_timestamp = now.timestamp() as u64;
self.last_accessed.store(now_timestamp, Ordering::Relaxed);
self.access_count.fetch_add(1, Ordering::Relaxed);
let hours_since_cached = (now - self.cached_at).num_hours().max(1) as f32;
let recency_factor = 1.0 / (1.0 + hours_since_cached / 24.0);
loop {
let old_bits = self.efficiency_score_bits.load(Ordering::Relaxed);
let count = self.access_count.load(Ordering::Relaxed);
let frequency_factor = (count as f32).ln().max(1.0);
let score = recency_factor * frequency_factor;
if self
.efficiency_score_bits
.compare_exchange_weak(
old_bits,
score.to_bits() as u64,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
}
pub fn efficiency_score(&self) -> f32 {
f32::from_bits(self.efficiency_score_bits.load(Ordering::Relaxed) as u32)
}
pub fn similarity_with(&self, other_vector: &[f32]) -> f32 {
if self.query_vector.len() != other_vector.len() {
return 0.0;
}
let dot_product: f32 = self
.query_vector
.iter()
.zip(other_vector.iter())
.map(|(a, b)| a * b)
.sum();
let norm_a: f32 = self.query_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = other_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot_product / (norm_a * norm_b)
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct QueryCacheStats {
pub total_queries: AtomicU64,
pub cache_hits: AtomicU64,
pub cache_misses: AtomicU64,
pub expired_removed: AtomicU64,
pub evicted_entries: AtomicU64,
avg_hit_similarity_bits: AtomicU64,
pub current_cache_size: AtomicUsize,
pub estimated_memory_bytes: AtomicUsize,
pub hit_rate: f32,
pub avg_time_saved_ms: f32,
}
impl Default for QueryCacheStats {
fn default() -> Self {
Self::new()
}
}
impl QueryCacheStats {
pub fn new() -> Self {
Self {
total_queries: AtomicU64::new(0),
cache_hits: AtomicU64::new(0),
cache_misses: AtomicU64::new(0),
expired_removed: AtomicU64::new(0),
evicted_entries: AtomicU64::new(0),
avg_hit_similarity_bits: AtomicU64::new(0.0f32.to_bits() as u64),
current_cache_size: AtomicUsize::new(0),
estimated_memory_bytes: AtomicUsize::new(0),
hit_rate: 0.0,
avg_time_saved_ms: 0.0,
}
}
pub fn record_hit(&self, similarity: f32) {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
loop {
let old_bits = self.avg_hit_similarity_bits.load(Ordering::Relaxed);
let hits = self.cache_hits.load(Ordering::Relaxed);
let current_avg = f32::from_bits(old_bits as u32);
let new_avg = if hits == 1 {
similarity
} else {
((current_avg * (hits as f32 - 1.0)) + similarity) / hits as f32
};
if self
.avg_hit_similarity_bits
.compare_exchange_weak(
old_bits,
new_avg.to_bits() as u64,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
}
pub fn record_miss(&self) {
self.cache_misses.fetch_add(1, Ordering::Relaxed);
}
pub fn avg_hit_similarity(&self) -> f32 {
f32::from_bits(self.avg_hit_similarity_bits.load(Ordering::Relaxed) as u32)
}
pub fn snapshot(&self) -> QueryCacheStatsSnapshot {
let total = self.total_queries.load(Ordering::Relaxed);
let hits = self.cache_hits.load(Ordering::Relaxed);
let hit_rate = if total > 0 {
(hits as f32 / total as f32) * 100.0
} else {
0.0
};
QueryCacheStatsSnapshot {
total_queries: total,
cache_hits: hits,
cache_misses: self.cache_misses.load(Ordering::Relaxed),
expired_removed: self.expired_removed.load(Ordering::Relaxed),
evicted_entries: self.evicted_entries.load(Ordering::Relaxed),
avg_hit_similarity: self.avg_hit_similarity(),
current_cache_size: self.current_cache_size.load(Ordering::Relaxed),
estimated_memory_bytes: self.estimated_memory_bytes.load(Ordering::Relaxed),
hit_rate,
avg_time_saved_ms: 150.0, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryCacheStatsSnapshot {
pub total_queries: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub expired_removed: u64,
pub evicted_entries: u64,
pub avg_hit_similarity: f32,
pub current_cache_size: usize,
pub estimated_memory_bytes: usize,
pub hit_rate: f32,
pub avg_time_saved_ms: f32,
}
#[derive(Debug)]
struct QueryPattern {
frequency: AtomicU64,
last_seen: AtomicU64,
}
pub struct QueryCache {
cache: Arc<DashMap<Uuid, CachedQuery>>,
config: QueryCacheConfig,
stats: Arc<QueryCacheStats>,
patterns: Arc<DashMap<String, QueryPattern>>,
recent_queries: Arc<DashMap<String, AtomicU64>>,
}
impl QueryCache {
pub fn new(config: QueryCacheConfig) -> Self {
info!(
"Initializing query cache with max size: {}",
config.max_cache_size
);
Self {
cache: Arc::new(DashMap::new()),
config,
stats: Arc::new(QueryCacheStats::new()),
patterns: Arc::new(DashMap::new()),
recent_queries: Arc::new(DashMap::new()),
}
}
pub fn search(
&self,
query_text: &str,
query_vector: &[f32],
params_hash: u64,
) -> Option<Vec<SemanticSearchResult>> {
self.stats.total_queries.fetch_add(1, Ordering::Relaxed);
if let Some(results) = self.find_exact_match(params_hash) {
self.stats.record_hit(1.0);
return Some(results);
}
if let Some((results, similarity)) = self.find_similar_query(query_vector, params_hash) {
self.stats.record_hit(similarity);
return Some(results);
}
self.stats.record_miss();
if self.config.enable_prefetching {
self.update_query_patterns(query_text);
}
None
}
pub fn cache_results(
&self,
query_text: String,
query_vector: Vec<f32>,
results: Vec<SemanticSearchResult>,
params_hash: u64,
session_id: Option<Uuid>,
) -> Result<()> {
self.cleanup_expired()?;
let cached_query = CachedQuery::new(
query_text.clone(),
query_vector,
results,
params_hash,
session_id,
);
let query_id = cached_query.id;
let current_size = self.cache.len();
if current_size >= self.config.max_cache_size {
self.evict_least_efficient();
}
self.cache.insert(query_id, cached_query);
self.stats
.current_cache_size
.store(self.cache.len(), Ordering::Relaxed);
self.stats
.estimated_memory_bytes
.store(self.estimate_memory_usage(), Ordering::Relaxed);
debug!("Cached query results for: {}", query_text);
Ok(())
}
fn find_exact_match(&self, params_hash: u64) -> Option<Vec<SemanticSearchResult>> {
for entry in self.cache.iter() {
let cached_query = entry.value();
if cached_query.params_hash == params_hash
&& !cached_query.is_expired(self.config.ttl_minutes)
{
cached_query.mark_accessed();
return Some(cached_query.results.clone());
}
}
None
}
fn find_similar_query(
&self,
query_vector: &[f32],
params_hash: u64,
) -> Option<(Vec<SemanticSearchResult>, f32)> {
let mut best_match: Option<(Vec<SemanticSearchResult>, f32, Uuid)> = None;
let mut best_similarity = 0.0f32;
for entry in self.cache.iter() {
let cached_query = entry.value();
if cached_query.is_expired(self.config.ttl_minutes) {
continue;
}
if cached_query.params_hash != params_hash {
continue;
}
let similarity = cached_query.similarity_with(query_vector);
if similarity >= self.config.similarity_threshold && similarity > best_similarity {
best_similarity = similarity;
best_match = Some((cached_query.results.clone(), similarity, cached_query.id));
}
}
if let Some((results, similarity, query_id)) = best_match {
if let Some(entry) = self.cache.get(&query_id) {
entry.mark_accessed();
}
Some((results, similarity))
} else {
None
}
}
fn cleanup_expired(&self) -> Result<()> {
let mut removed_count = 0;
self.cache.retain(|_, cached_query| {
let expired = cached_query.is_expired(self.config.ttl_minutes);
if expired {
removed_count += 1;
}
!expired
});
if removed_count > 0 {
self.stats
.expired_removed
.fetch_add(removed_count as u64, Ordering::Relaxed);
self.stats
.current_cache_size
.store(self.cache.len(), Ordering::Relaxed);
self.stats
.estimated_memory_bytes
.store(self.estimate_memory_usage(), Ordering::Relaxed);
debug!("Removed {} expired cache entries", removed_count);
}
Ok(())
}
fn evict_least_efficient(&self) {
if self.cache.is_empty() {
return;
}
let mut worst_id: Option<Uuid> = None;
let mut worst_score = f32::INFINITY;
for entry in self.cache.iter() {
let score = entry.value().efficiency_score();
if score < worst_score {
worst_score = score;
worst_id = Some(*entry.key());
}
}
if let Some(id) = worst_id {
if self.cache.remove(&id).is_some() {
self.stats.evicted_entries.fetch_add(1, Ordering::Relaxed);
debug!(
"Evicted cache entry with efficiency score: {:.3}",
worst_score
);
}
}
}
fn update_query_patterns(&self, query_text: &str) {
let now = Utc::now().timestamp() as u64;
let query_lower = query_text.to_lowercase();
self.recent_queries
.insert(query_lower.clone(), AtomicU64::new(now));
self.patterns
.entry(query_lower.clone())
.and_modify(|pattern| {
pattern.frequency.fetch_add(1, Ordering::Relaxed);
pattern.last_seen.store(now, Ordering::Relaxed);
})
.or_insert_with(|| QueryPattern {
frequency: AtomicU64::new(1),
last_seen: AtomicU64::new(now),
});
}
pub fn get_stats(&self) -> QueryCacheStatsSnapshot {
self.stats.snapshot()
}
pub fn clear(&self) -> Result<()> {
let old_size = self.cache.len();
self.cache.clear();
self.stats.current_cache_size.store(0, Ordering::Relaxed);
self.stats
.estimated_memory_bytes
.store(0, Ordering::Relaxed);
info!("Query cache cleared ({} entries)", old_size);
Ok(())
}
pub fn invalidate_session(&self, session_id: Uuid) -> Result<()> {
let mut invalidated_count = 0;
let mut keys_to_remove = Vec::new();
for entry in self.cache.iter() {
if entry.value().session_id == Some(session_id) {
keys_to_remove.push(*entry.key());
}
}
for key in keys_to_remove {
if self.cache.remove(&key).is_some() {
invalidated_count += 1;
}
}
let new_size = self.cache.len();
self.stats
.current_cache_size
.store(new_size, Ordering::Relaxed);
self.stats
.estimated_memory_bytes
.store(self.estimate_memory_usage(), Ordering::Relaxed);
if invalidated_count > 0 {
debug!(
"Invalidated {} cache entries for session {} (remaining: {})",
invalidated_count, session_id, new_size
);
}
Ok(())
}
fn estimate_memory_usage(&self) -> usize {
self.cache
.iter()
.map(|entry| {
let query = entry.value();
let text_size = query.query_text.len();
let vector_size = query.query_vector.len() * std::mem::size_of::<f32>();
let results_size = query.results.len() * 200; text_size + vector_size + results_size + 100 })
.sum()
}
pub fn get_efficiency_metrics(&self) -> HashMap<String, f32> {
let stats = self.get_stats();
let cache_size = self.cache.len();
let mut metrics = HashMap::new();
metrics.insert("hit_rate".to_string(), stats.hit_rate);
metrics.insert("avg_hit_similarity".to_string(), stats.avg_hit_similarity);
metrics.insert(
"cache_utilization".to_string(),
cache_size as f32 / self.config.max_cache_size as f32 * 100.0,
);
metrics.insert("avg_time_saved_ms".to_string(), stats.avg_time_saved_ms);
metrics
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new(QueryCacheConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_cache_creation() {
let config = QueryCacheConfig::default();
let cache = QueryCache::new(config);
let stats = cache.get_stats();
assert_eq!(stats.total_queries, 0);
assert_eq!(stats.cache_hits, 0);
}
#[test]
fn test_cache_and_retrieve() {
let cache = QueryCache::default();
let query_text = "test query".to_string();
let query_vector = vec![0.1, 0.2, 0.3];
let results = vec![];
let params_hash = 12345u64;
cache
.cache_results(
query_text.clone(),
query_vector.clone(),
results,
params_hash,
None,
)
.unwrap();
let cached_results = cache.search(&query_text, &query_vector, params_hash);
assert!(cached_results.is_some());
}
#[test]
fn test_similarity_matching() {
let cache = QueryCache::default();
let query_vector1 = vec![1.0, 0.0, 0.0];
let query_vector2 = vec![0.9, 0.1, 0.0]; let results = vec![];
let params_hash = 123;
cache
.cache_results(
"query1".to_string(),
query_vector1,
results,
params_hash,
None,
)
.unwrap();
let cached_results = cache.search("query2", &query_vector2, params_hash);
assert!(cached_results.is_some());
let different_params = cache.search("query2", &query_vector2, 456);
assert!(
different_params.is_none(),
"Bug fix verification: different params_hash should not return cached results"
);
}
#[test]
fn test_cache_expiration() {
let config = QueryCacheConfig {
ttl_minutes: 0, ..Default::default()
};
let cache = QueryCache::new(config);
cache
.cache_results("test".to_string(), vec![1.0, 0.0], vec![], 123, None)
.unwrap();
let cached_results = cache.search("test", &[1.0, 0.0], 123);
assert!(cached_results.is_none());
}
}