1use anyhow::Result;
42use serde::{Deserialize, Serialize};
43use std::collections::{HashMap, VecDeque};
44use std::sync::{Arc, RwLock};
45use std::time::{Duration, SystemTime};
46
47use crate::cache::CacheCoordinator;
49
50pub struct QueryResultCache {
52 entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
54 lru_queue: Arc<RwLock<VecDeque<String>>>,
56 config: CacheConfig,
58 stats: Arc<RwLock<CacheStatistics>>,
60 invalidation_coordinator: Option<Arc<CacheCoordinator>>,
62 invalidated_entries: Arc<RwLock<std::collections::HashSet<String>>>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct CacheConfig {
69 pub max_entries: usize,
71 pub ttl: Duration,
73 pub enable_compression: bool,
75 pub max_result_size: usize,
77 pub enable_stats: bool,
79 pub eviction_batch_size: usize,
81}
82
83impl Default for CacheConfig {
84 fn default() -> Self {
85 Self {
86 max_entries: 10_000,
87 ttl: Duration::from_secs(3600), enable_compression: false,
89 max_result_size: 10 * 1024 * 1024, enable_stats: true,
91 eviction_batch_size: 100,
92 }
93 }
94}
95
96impl CacheConfig {
97 pub fn with_max_entries(mut self, max: usize) -> Self {
99 self.max_entries = max;
100 self
101 }
102
103 pub fn with_ttl(mut self, ttl: Duration) -> Self {
105 self.ttl = ttl;
106 self
107 }
108
109 pub fn with_compression(mut self, enabled: bool) -> Self {
111 self.enable_compression = enabled;
112 self
113 }
114
115 pub fn with_max_result_size(mut self, size: usize) -> Self {
117 self.max_result_size = size;
118 self
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124struct CacheEntry {
125 fingerprint_hash: String,
127 results: Vec<u8>,
129 original_size: usize,
131 created_at: SystemTime,
133 last_accessed: SystemTime,
135 access_count: u64,
137 is_compressed: bool,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
143pub struct CacheStatistics {
144 pub hits: u64,
146 pub misses: u64,
148 pub puts: u64,
150 pub evictions: u64,
152 pub invalidations: u64,
154 pub size_bytes: usize,
156 pub entry_count: usize,
158 pub hit_rate: f64,
160 pub avg_result_size: usize,
162 pub compression_ratio: f64,
164}
165
166impl CacheStatistics {
167 fn calculate_hit_rate(&mut self) {
169 let total = self.hits + self.misses;
170 self.hit_rate = if total > 0 {
171 self.hits as f64 / total as f64
172 } else {
173 0.0
174 };
175 }
176}
177
178impl QueryResultCache {
179 pub fn new(config: CacheConfig) -> Self {
181 Self {
182 entries: Arc::new(RwLock::new(HashMap::new())),
183 lru_queue: Arc::new(RwLock::new(VecDeque::new())),
184 config,
185 stats: Arc::new(RwLock::new(CacheStatistics::default())),
186 invalidation_coordinator: None,
187 invalidated_entries: Arc::new(RwLock::new(std::collections::HashSet::new())),
188 }
189 }
190
191 pub fn with_invalidation_coordinator(
193 config: CacheConfig,
194 coordinator: Arc<CacheCoordinator>,
195 ) -> Self {
196 Self {
197 entries: Arc::new(RwLock::new(HashMap::new())),
198 lru_queue: Arc::new(RwLock::new(VecDeque::new())),
199 config,
200 stats: Arc::new(RwLock::new(CacheStatistics::default())),
201 invalidation_coordinator: Some(coordinator),
202 invalidated_entries: Arc::new(RwLock::new(std::collections::HashSet::new())),
203 }
204 }
205
206 pub fn attach_coordinator(&mut self, coordinator: Arc<CacheCoordinator>) {
208 self.invalidation_coordinator = Some(coordinator);
209 }
210
211 pub fn put(&self, fingerprint_hash: String, results: Vec<u8>) -> Result<()> {
213 if results.len() > self.config.max_result_size {
215 return Ok(()); }
217
218 let mut entries = self.entries.write().expect("lock poisoned");
219 let mut lru = self.lru_queue.write().expect("lock poisoned");
220
221 if entries.len() >= self.config.max_entries {
223 self.evict_lru(&mut entries, &mut lru)?;
224 }
225
226 let (stored_results, is_compressed) = if self.config.enable_compression {
228 match self.compress_results(&results) {
229 Ok(compressed) => (compressed, true),
230 Err(_) => (results.clone(), false),
231 }
232 } else {
233 (results.clone(), false)
234 };
235
236 let entry = CacheEntry {
237 fingerprint_hash: fingerprint_hash.clone(),
238 results: stored_results.clone(),
239 original_size: results.len(),
240 created_at: SystemTime::now(),
241 last_accessed: SystemTime::now(),
242 access_count: 0,
243 is_compressed,
244 };
245
246 entries.insert(fingerprint_hash.clone(), entry);
248 lru.push_back(fingerprint_hash);
249
250 if self.config.enable_stats {
252 let mut stats = self.stats.write().expect("lock poisoned");
253 stats.puts += 1;
254 stats.entry_count = entries.len();
255 stats.size_bytes += stored_results.len();
256 stats.avg_result_size = if stats.entry_count > 0 {
257 stats.size_bytes / stats.entry_count
258 } else {
259 0
260 };
261 }
262
263 Ok(())
264 }
265
266 pub fn get(&self, fingerprint_hash: &str) -> Option<Vec<u8>> {
268 {
270 let invalidated = self.invalidated_entries.read().expect("lock poisoned");
271 if invalidated.contains(fingerprint_hash) {
272 if self.config.enable_stats {
274 let mut stats = self.stats.write().expect("lock poisoned");
275 stats.misses += 1;
276 stats.invalidations += 1;
277 stats.calculate_hit_rate();
278 }
279 return None;
280 }
281 }
282
283 let mut entries = self.entries.write().expect("lock poisoned");
284 let mut lru = self.lru_queue.write().expect("lock poisoned");
285
286 if let Some(entry) = entries.get_mut(fingerprint_hash) {
287 if let Ok(elapsed) = entry.created_at.elapsed() {
289 if elapsed > self.config.ttl {
290 entries.remove(fingerprint_hash);
292 lru.retain(|k| k != fingerprint_hash);
293
294 if self.config.enable_stats {
296 let mut stats = self.stats.write().expect("lock poisoned");
297 stats.misses += 1;
298 stats.evictions += 1;
299 stats.calculate_hit_rate();
300 }
301 return None;
302 }
303 }
304
305 entry.last_accessed = SystemTime::now();
307 entry.access_count += 1;
308
309 lru.retain(|k| k != fingerprint_hash);
311 lru.push_back(fingerprint_hash.to_string());
312
313 let results = if entry.is_compressed {
315 self.decompress_results(&entry.results).ok()?
316 } else {
317 entry.results.clone()
318 };
319
320 if self.config.enable_stats {
322 let mut stats = self.stats.write().expect("lock poisoned");
323 stats.hits += 1;
324 stats.calculate_hit_rate();
325 }
326
327 Some(results)
328 } else {
329 if self.config.enable_stats {
331 let mut stats = self.stats.write().expect("lock poisoned");
332 stats.misses += 1;
333 stats.calculate_hit_rate();
334 }
335 None
336 }
337 }
338
339 pub fn invalidate(&self, fingerprint_hash: &str) -> Result<()> {
341 {
343 let mut invalidated = self.invalidated_entries.write().expect("lock poisoned");
344 invalidated.insert(fingerprint_hash.to_string());
345 }
346
347 let mut entries = self.entries.write().expect("lock poisoned");
348 let mut lru = self.lru_queue.write().expect("lock poisoned");
349
350 if entries.remove(fingerprint_hash).is_some() {
351 lru.retain(|k| k != fingerprint_hash);
352
353 if self.config.enable_stats {
354 let mut stats = self.stats.write().expect("lock poisoned");
355 stats.invalidations += 1;
356 stats.entry_count = entries.len();
357 }
358 }
359
360 Ok(())
361 }
362
363 pub fn mark_invalidated(&self, fingerprint_hash: &str) -> Result<()> {
365 let mut invalidated = self.invalidated_entries.write().expect("lock poisoned");
366 invalidated.insert(fingerprint_hash.to_string());
367
368 if self.config.enable_stats {
369 let mut stats = self.stats.write().expect("lock poisoned");
370 stats.invalidations += 1;
371 }
372
373 Ok(())
374 }
375
376 pub fn invalidate_all(&self) -> Result<()> {
378 let mut entries = self.entries.write().expect("lock poisoned");
379 let mut lru = self.lru_queue.write().expect("lock poisoned");
380 let mut invalidated = self.invalidated_entries.write().expect("lock poisoned");
381
382 let count = entries.len();
383
384 for key in entries.keys() {
386 invalidated.insert(key.clone());
387 }
388
389 entries.clear();
390 lru.clear();
391
392 if self.config.enable_stats {
393 let mut stats = self.stats.write().expect("lock poisoned");
394 stats.invalidations += count as u64;
395 stats.entry_count = 0;
396 stats.size_bytes = 0;
397 }
398
399 Ok(())
400 }
401
402 pub fn statistics(&self) -> CacheStatistics {
404 self.stats.read().expect("lock poisoned").clone()
405 }
406
407 pub fn size(&self) -> usize {
409 self.entries.read().expect("lock poisoned").len()
410 }
411
412 pub fn contains(&self, fingerprint_hash: &str) -> bool {
414 self.entries
415 .read()
416 .expect("lock poisoned")
417 .contains_key(fingerprint_hash)
418 }
419
420 fn evict_lru(
422 &self,
423 entries: &mut HashMap<String, CacheEntry>,
424 lru: &mut VecDeque<String>,
425 ) -> Result<()> {
426 let batch_size = self.config.eviction_batch_size.min(entries.len() / 10 + 1);
427
428 for _ in 0..batch_size {
429 if let Some(oldest) = lru.pop_front() {
430 if let Some(entry) = entries.remove(&oldest) {
431 if self.config.enable_stats {
432 let mut stats = self.stats.write().expect("lock poisoned");
433 stats.evictions += 1;
434 stats.size_bytes = stats.size_bytes.saturating_sub(entry.results.len());
435 stats.entry_count = entries.len();
436 }
437 }
438 }
439 }
440
441 Ok(())
442 }
443
444 fn compress_results(&self, results: &[u8]) -> Result<Vec<u8>> {
446 use flate2::write::GzEncoder;
447 use flate2::Compression;
448 use std::io::Write;
449
450 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
451 encoder.write_all(results)?;
452 Ok(encoder.finish()?)
453 }
454
455 fn decompress_results(&self, compressed: &[u8]) -> Result<Vec<u8>> {
457 use flate2::read::GzDecoder;
458 use std::io::Read;
459
460 let mut decoder = GzDecoder::new(compressed);
461 let mut decompressed = Vec::new();
462 decoder.read_to_end(&mut decompressed)?;
463 Ok(decompressed)
464 }
465}
466
467pub struct QueryResultCacheBuilder {
469 config: CacheConfig,
470}
471
472impl QueryResultCacheBuilder {
473 pub fn new() -> Self {
475 Self {
476 config: CacheConfig::default(),
477 }
478 }
479
480 pub fn max_entries(mut self, max: usize) -> Self {
482 self.config.max_entries = max;
483 self
484 }
485
486 pub fn ttl(mut self, ttl: Duration) -> Self {
488 self.config.ttl = ttl;
489 self
490 }
491
492 pub fn compression(mut self, enabled: bool) -> Self {
494 self.config.enable_compression = enabled;
495 self
496 }
497
498 pub fn build(self) -> QueryResultCache {
500 QueryResultCache::new(self.config)
501 }
502}
503
504impl Default for QueryResultCacheBuilder {
505 fn default() -> Self {
506 Self::new()
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_cache_basic_operations() {
516 let cache = QueryResultCache::new(CacheConfig::default());
517
518 let hash = "test_hash_123".to_string();
519 let results = vec![1, 2, 3, 4, 5];
520
521 cache.put(hash.clone(), results.clone()).unwrap();
523 let retrieved = cache.get(&hash).unwrap();
524 assert_eq!(results, retrieved);
525
526 let stats = cache.statistics();
528 assert_eq!(stats.puts, 1);
529 assert_eq!(stats.hits, 1);
530 assert_eq!(stats.misses, 0);
531 }
532
533 #[test]
534 fn test_cache_miss() {
535 let cache = QueryResultCache::new(CacheConfig::default());
536
537 let result = cache.get("nonexistent");
538 assert!(result.is_none());
539
540 let stats = cache.statistics();
541 assert_eq!(stats.misses, 1);
542 }
543
544 #[test]
545 fn test_cache_invalidation() {
546 let cache = QueryResultCache::new(CacheConfig::default());
547
548 let hash = "test_hash".to_string();
549 let results = vec![1, 2, 3];
550
551 cache.put(hash.clone(), results).unwrap();
552 assert!(cache.contains(&hash));
553
554 cache.invalidate(&hash).unwrap();
555 assert!(!cache.contains(&hash));
556
557 let stats = cache.statistics();
558 assert_eq!(stats.invalidations, 1);
559 }
560
561 #[test]
562 fn test_lru_eviction() {
563 let config = CacheConfig::default().with_max_entries(3);
564 let cache = QueryResultCache::new(config);
565
566 cache.put("hash1".to_string(), vec![1]).unwrap();
568 cache.put("hash2".to_string(), vec![2]).unwrap();
569 cache.put("hash3".to_string(), vec![3]).unwrap();
570
571 cache.put("hash4".to_string(), vec![4]).unwrap();
573
574 assert!(!cache.contains("hash1"));
576 assert!(cache.contains("hash4"));
577 }
578
579 #[test]
580 fn test_cache_compression() {
581 let config = CacheConfig::default().with_compression(true);
582 let cache = QueryResultCache::new(config);
583
584 let hash = "compressed_hash".to_string();
585 let large_results = vec![0u8; 10_000]; cache.put(hash.clone(), large_results.clone()).unwrap();
588 let retrieved = cache.get(&hash).unwrap();
589 assert_eq!(large_results, retrieved);
590
591 let stats = cache.statistics();
592 assert!(stats.compression_ratio > 1.0 || stats.size_bytes < large_results.len());
593 }
594
595 #[test]
596 fn test_cache_ttl_expiration() {
597 use std::thread;
598
599 let config = CacheConfig::default().with_ttl(Duration::from_millis(100));
600 let cache = QueryResultCache::new(config);
601
602 let hash = "expiring_hash".to_string();
603 cache.put(hash.clone(), vec![1, 2, 3]).unwrap();
604
605 assert!(cache.get(&hash).is_some());
607
608 thread::sleep(Duration::from_millis(150));
610
611 assert!(cache.get(&hash).is_none());
613 }
614
615 #[test]
616 fn test_cache_builder() {
617 let cache = QueryResultCacheBuilder::new()
618 .max_entries(5000)
619 .ttl(Duration::from_secs(1800))
620 .compression(true)
621 .build();
622
623 assert_eq!(cache.config.max_entries, 5000);
624 assert_eq!(cache.config.ttl, Duration::from_secs(1800));
625 assert!(cache.config.enable_compression);
626 }
627
628 #[test]
629 fn test_cache_statistics_accuracy() {
630 let cache = QueryResultCache::new(CacheConfig::default());
631
632 cache.put("h1".to_string(), vec![1]).unwrap();
634 cache.put("h2".to_string(), vec![2]).unwrap();
635 cache.get("h1"); cache.get("h3"); cache.invalidate("h1").unwrap();
638
639 let stats = cache.statistics();
640 assert_eq!(stats.puts, 2);
641 assert_eq!(stats.hits, 1);
642 assert_eq!(stats.misses, 1);
643 assert_eq!(stats.invalidations, 1);
644 assert_eq!(stats.hit_rate, 0.5);
645 }
646
647 #[test]
648 fn test_cache_max_result_size() {
649 let config = CacheConfig::default().with_max_result_size(100);
650 let cache = QueryResultCache::new(config);
651
652 cache.put("small".to_string(), vec![1; 50]).unwrap();
654 assert!(cache.contains("small"));
655
656 cache.put("large".to_string(), vec![1; 200]).unwrap();
658 assert!(!cache.contains("large"));
659 }
660
661 #[test]
662 fn test_cache_access_tracking() {
663 let cache = QueryResultCache::new(CacheConfig::default());
664
665 let hash = "tracked".to_string();
666 cache.put(hash.clone(), vec![1, 2, 3]).unwrap();
667
668 for _ in 0..5 {
670 cache.get(&hash);
671 }
672
673 let entries = cache.entries.read().unwrap();
674 let entry = entries.get(&hash).unwrap();
675 assert_eq!(entry.access_count, 5);
676 }
677}