use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};
use crate::cache::CacheCoordinator;
pub struct QueryResultCache {
entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
lru_queue: Arc<RwLock<VecDeque<String>>>,
config: CacheConfig,
stats: Arc<RwLock<CacheStatistics>>,
invalidation_coordinator: Option<Arc<CacheCoordinator>>,
invalidated_entries: Arc<RwLock<std::collections::HashSet<String>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub enable_compression: bool,
pub max_result_size: usize,
pub enable_stats: bool,
pub eviction_batch_size: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
ttl: Duration::from_secs(3600), enable_compression: false,
max_result_size: 10 * 1024 * 1024, enable_stats: true,
eviction_batch_size: 100,
}
}
}
impl CacheConfig {
pub fn with_max_entries(mut self, max: usize) -> Self {
self.max_entries = max;
self
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_compression(mut self, enabled: bool) -> Self {
self.enable_compression = enabled;
self
}
pub fn with_max_result_size(mut self, size: usize) -> Self {
self.max_result_size = size;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CacheEntry {
fingerprint_hash: String,
results: Vec<u8>,
original_size: usize,
created_at: SystemTime,
last_accessed: SystemTime,
access_count: u64,
is_compressed: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStatistics {
pub hits: u64,
pub misses: u64,
pub puts: u64,
pub evictions: u64,
pub invalidations: u64,
pub size_bytes: usize,
pub entry_count: usize,
pub hit_rate: f64,
pub avg_result_size: usize,
pub compression_ratio: f64,
}
impl CacheStatistics {
fn calculate_hit_rate(&mut self) {
let total = self.hits + self.misses;
self.hit_rate = if total > 0 {
self.hits as f64 / total as f64
} else {
0.0
};
}
}
impl QueryResultCache {
pub fn new(config: CacheConfig) -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
lru_queue: Arc::new(RwLock::new(VecDeque::new())),
config,
stats: Arc::new(RwLock::new(CacheStatistics::default())),
invalidation_coordinator: None,
invalidated_entries: Arc::new(RwLock::new(std::collections::HashSet::new())),
}
}
pub fn with_invalidation_coordinator(
config: CacheConfig,
coordinator: Arc<CacheCoordinator>,
) -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
lru_queue: Arc::new(RwLock::new(VecDeque::new())),
config,
stats: Arc::new(RwLock::new(CacheStatistics::default())),
invalidation_coordinator: Some(coordinator),
invalidated_entries: Arc::new(RwLock::new(std::collections::HashSet::new())),
}
}
pub fn attach_coordinator(&mut self, coordinator: Arc<CacheCoordinator>) {
self.invalidation_coordinator = Some(coordinator);
}
pub fn put(&self, fingerprint_hash: String, results: Vec<u8>) -> Result<()> {
if results.len() > self.config.max_result_size {
return Ok(()); }
let mut entries = self.entries.write().expect("lock poisoned");
let mut lru = self.lru_queue.write().expect("lock poisoned");
if entries.len() >= self.config.max_entries {
self.evict_lru(&mut entries, &mut lru)?;
}
let (stored_results, is_compressed) = if self.config.enable_compression {
match self.compress_results(&results) {
Ok(compressed) => (compressed, true),
Err(_) => (results.clone(), false),
}
} else {
(results.clone(), false)
};
let entry = CacheEntry {
fingerprint_hash: fingerprint_hash.clone(),
results: stored_results.clone(),
original_size: results.len(),
created_at: SystemTime::now(),
last_accessed: SystemTime::now(),
access_count: 0,
is_compressed,
};
entries.insert(fingerprint_hash.clone(), entry);
lru.push_back(fingerprint_hash);
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.puts += 1;
stats.entry_count = entries.len();
stats.size_bytes += stored_results.len();
stats.avg_result_size = stats.size_bytes.checked_div(stats.entry_count).unwrap_or(0);
}
Ok(())
}
pub fn get(&self, fingerprint_hash: &str) -> Option<Vec<u8>> {
{
let invalidated = self.invalidated_entries.read().expect("lock poisoned");
if invalidated.contains(fingerprint_hash) {
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.misses += 1;
stats.invalidations += 1;
stats.calculate_hit_rate();
}
return None;
}
}
let mut entries = self.entries.write().expect("lock poisoned");
let mut lru = self.lru_queue.write().expect("lock poisoned");
if let Some(entry) = entries.get_mut(fingerprint_hash) {
if let Ok(elapsed) = entry.created_at.elapsed() {
if elapsed > self.config.ttl {
entries.remove(fingerprint_hash);
lru.retain(|k| k != fingerprint_hash);
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.misses += 1;
stats.evictions += 1;
stats.calculate_hit_rate();
}
return None;
}
}
entry.last_accessed = SystemTime::now();
entry.access_count += 1;
lru.retain(|k| k != fingerprint_hash);
lru.push_back(fingerprint_hash.to_string());
let results = if entry.is_compressed {
self.decompress_results(&entry.results).ok()?
} else {
entry.results.clone()
};
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.hits += 1;
stats.calculate_hit_rate();
}
Some(results)
} else {
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.misses += 1;
stats.calculate_hit_rate();
}
None
}
}
pub fn invalidate(&self, fingerprint_hash: &str) -> Result<()> {
{
let mut invalidated = self.invalidated_entries.write().expect("lock poisoned");
invalidated.insert(fingerprint_hash.to_string());
}
let mut entries = self.entries.write().expect("lock poisoned");
let mut lru = self.lru_queue.write().expect("lock poisoned");
if entries.remove(fingerprint_hash).is_some() {
lru.retain(|k| k != fingerprint_hash);
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.invalidations += 1;
stats.entry_count = entries.len();
}
}
Ok(())
}
pub fn mark_invalidated(&self, fingerprint_hash: &str) -> Result<()> {
let mut invalidated = self.invalidated_entries.write().expect("lock poisoned");
invalidated.insert(fingerprint_hash.to_string());
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.invalidations += 1;
}
Ok(())
}
pub fn invalidate_all(&self) -> Result<()> {
let mut entries = self.entries.write().expect("lock poisoned");
let mut lru = self.lru_queue.write().expect("lock poisoned");
let mut invalidated = self.invalidated_entries.write().expect("lock poisoned");
let count = entries.len();
for key in entries.keys() {
invalidated.insert(key.clone());
}
entries.clear();
lru.clear();
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.invalidations += count as u64;
stats.entry_count = 0;
stats.size_bytes = 0;
}
Ok(())
}
pub fn statistics(&self) -> CacheStatistics {
self.stats.read().expect("lock poisoned").clone()
}
pub fn size(&self) -> usize {
self.entries.read().expect("lock poisoned").len()
}
pub fn contains(&self, fingerprint_hash: &str) -> bool {
self.entries
.read()
.expect("lock poisoned")
.contains_key(fingerprint_hash)
}
fn evict_lru(
&self,
entries: &mut HashMap<String, CacheEntry>,
lru: &mut VecDeque<String>,
) -> Result<()> {
let batch_size = self.config.eviction_batch_size.min(entries.len() / 10 + 1);
for _ in 0..batch_size {
if let Some(oldest) = lru.pop_front() {
if let Some(entry) = entries.remove(&oldest) {
if self.config.enable_stats {
let mut stats = self.stats.write().expect("lock poisoned");
stats.evictions += 1;
stats.size_bytes = stats.size_bytes.saturating_sub(entry.results.len());
stats.entry_count = entries.len();
}
}
}
}
Ok(())
}
fn compress_results(&self, results: &[u8]) -> Result<Vec<u8>> {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
encoder.write_all(results)?;
Ok(encoder.finish()?)
}
fn decompress_results(&self, compressed: &[u8]) -> Result<Vec<u8>> {
use flate2::read::GzDecoder;
use std::io::Read;
let mut decoder = GzDecoder::new(compressed);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
Ok(decompressed)
}
}
pub struct QueryResultCacheBuilder {
config: CacheConfig,
}
impl QueryResultCacheBuilder {
pub fn new() -> Self {
Self {
config: CacheConfig::default(),
}
}
pub fn max_entries(mut self, max: usize) -> Self {
self.config.max_entries = max;
self
}
pub fn ttl(mut self, ttl: Duration) -> Self {
self.config.ttl = ttl;
self
}
pub fn compression(mut self, enabled: bool) -> Self {
self.config.enable_compression = enabled;
self
}
pub fn build(self) -> QueryResultCache {
QueryResultCache::new(self.config)
}
}
impl Default for QueryResultCacheBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_basic_operations() {
let cache = QueryResultCache::new(CacheConfig::default());
let hash = "test_hash_123".to_string();
let results = vec![1, 2, 3, 4, 5];
cache.put(hash.clone(), results.clone()).unwrap();
let retrieved = cache.get(&hash).unwrap();
assert_eq!(results, retrieved);
let stats = cache.statistics();
assert_eq!(stats.puts, 1);
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_cache_miss() {
let cache = QueryResultCache::new(CacheConfig::default());
let result = cache.get("nonexistent");
assert!(result.is_none());
let stats = cache.statistics();
assert_eq!(stats.misses, 1);
}
#[test]
fn test_cache_invalidation() {
let cache = QueryResultCache::new(CacheConfig::default());
let hash = "test_hash".to_string();
let results = vec![1, 2, 3];
cache.put(hash.clone(), results).unwrap();
assert!(cache.contains(&hash));
cache.invalidate(&hash).unwrap();
assert!(!cache.contains(&hash));
let stats = cache.statistics();
assert_eq!(stats.invalidations, 1);
}
#[test]
fn test_lru_eviction() {
let config = CacheConfig::default().with_max_entries(3);
let cache = QueryResultCache::new(config);
cache.put("hash1".to_string(), vec![1]).unwrap();
cache.put("hash2".to_string(), vec![2]).unwrap();
cache.put("hash3".to_string(), vec![3]).unwrap();
cache.put("hash4".to_string(), vec![4]).unwrap();
assert!(!cache.contains("hash1"));
assert!(cache.contains("hash4"));
}
#[test]
fn test_cache_compression() {
let config = CacheConfig::default().with_compression(true);
let cache = QueryResultCache::new(config);
let hash = "compressed_hash".to_string();
let large_results = vec![0u8; 10_000];
cache.put(hash.clone(), large_results.clone()).unwrap();
let retrieved = cache.get(&hash).unwrap();
assert_eq!(large_results, retrieved);
let stats = cache.statistics();
assert!(stats.compression_ratio > 1.0 || stats.size_bytes < large_results.len());
}
#[test]
fn test_cache_ttl_expiration() {
use std::thread;
let config = CacheConfig::default().with_ttl(Duration::from_millis(100));
let cache = QueryResultCache::new(config);
let hash = "expiring_hash".to_string();
cache.put(hash.clone(), vec![1, 2, 3]).unwrap();
assert!(cache.get(&hash).is_some());
thread::sleep(Duration::from_millis(150));
assert!(cache.get(&hash).is_none());
}
#[test]
fn test_cache_builder() {
let cache = QueryResultCacheBuilder::new()
.max_entries(5000)
.ttl(Duration::from_secs(1800))
.compression(true)
.build();
assert_eq!(cache.config.max_entries, 5000);
assert_eq!(cache.config.ttl, Duration::from_secs(1800));
assert!(cache.config.enable_compression);
}
#[test]
fn test_cache_statistics_accuracy() {
let cache = QueryResultCache::new(CacheConfig::default());
cache.put("h1".to_string(), vec![1]).unwrap();
cache.put("h2".to_string(), vec![2]).unwrap();
cache.get("h1"); cache.get("h3"); cache.invalidate("h1").unwrap();
let stats = cache.statistics();
assert_eq!(stats.puts, 2);
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.invalidations, 1);
assert_eq!(stats.hit_rate, 0.5);
}
#[test]
fn test_cache_max_result_size() {
let config = CacheConfig::default().with_max_result_size(100);
let cache = QueryResultCache::new(config);
cache.put("small".to_string(), vec![1; 50]).unwrap();
assert!(cache.contains("small"));
cache.put("large".to_string(), vec![1; 200]).unwrap();
assert!(!cache.contains("large"));
}
#[test]
fn test_cache_access_tracking() {
let cache = QueryResultCache::new(CacheConfig::default());
let hash = "tracked".to_string();
cache.put(hash.clone(), vec![1, 2, 3]).unwrap();
for _ in 0..5 {
cache.get(&hash);
}
let entries = cache.entries.read().unwrap();
let entry = entries.get(&hash).unwrap();
assert_eq!(entry.access_count, 5);
}
}