use crate::executor::pipeline::RowBatch;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub max_memory_bytes: usize,
pub ttl_seconds: u64,
}
impl CacheConfig {
pub fn new(max_entries: usize, max_memory_bytes: usize, ttl_seconds: u64) -> Self {
Self {
max_entries,
max_memory_bytes,
ttl_seconds,
}
}
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
max_memory_bytes: 100 * 1024 * 1024, ttl_seconds: 300, }
}
}
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub results: Vec<RowBatch>,
pub created_at: Instant,
pub last_accessed: Instant,
pub size_bytes: usize,
pub access_count: u64,
}
impl CacheEntry {
pub fn new(results: Vec<RowBatch>) -> Self {
let size_bytes = Self::estimate_size(&results);
let now = Instant::now();
Self {
results,
created_at: now,
last_accessed: now,
size_bytes,
access_count: 0,
}
}
fn estimate_size(results: &[RowBatch]) -> usize {
results
.iter()
.map(|batch| {
batch.len() * batch.schema.columns.len() * 8 + 1024
})
.sum()
}
pub fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
pub fn mark_accessed(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
}
pub struct QueryCache {
entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
lru_order: Arc<RwLock<Vec<String>>>,
config: CacheConfig,
memory_used: Arc<RwLock<usize>>,
stats: Arc<RwLock<CacheStats>>,
}
impl QueryCache {
pub fn new(config: CacheConfig) -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
lru_order: Arc::new(RwLock::new(Vec::new())),
config,
memory_used: Arc::new(RwLock::new(0)),
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
pub fn get(&self, key: &str) -> Option<CacheEntry> {
let mut entries = self.entries.write().ok()?;
let mut lru = self.lru_order.write().ok()?;
let mut stats = self.stats.write().ok()?;
if let Some(entry) = entries.get_mut(key) {
if entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
stats.misses += 1;
return None;
}
if let Some(pos) = lru.iter().position(|k| k == key) {
lru.remove(pos);
}
lru.push(key.to_string());
entry.mark_accessed();
stats.hits += 1;
Some(entry.clone())
} else {
stats.misses += 1;
None
}
}
pub fn insert(&self, key: String, results: Vec<RowBatch>) {
let entry = CacheEntry::new(results);
let entry_size = entry.size_bytes;
let mut entries = self.entries.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
let mut memory = self.memory_used.write().unwrap();
let mut stats = self.stats.write().unwrap();
while (entries.len() >= self.config.max_entries
|| *memory + entry_size > self.config.max_memory_bytes)
&& !lru.is_empty()
{
if let Some(old_key) = lru.first().cloned() {
if let Some(old_entry) = entries.remove(&old_key) {
*memory = memory.saturating_sub(old_entry.size_bytes);
stats.evictions += 1;
}
lru.remove(0);
}
}
entries.insert(key.clone(), entry);
lru.push(key);
*memory += entry_size;
stats.inserts += 1;
}
pub fn remove(&self, key: &str) -> bool {
let mut entries = self.entries.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
let mut memory = self.memory_used.write().unwrap();
if let Some(entry) = entries.remove(key) {
*memory = memory.saturating_sub(entry.size_bytes);
if let Some(pos) = lru.iter().position(|k| k == key) {
lru.remove(pos);
}
true
} else {
false
}
}
pub fn clear(&self) {
let mut entries = self.entries.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
let mut memory = self.memory_used.write().unwrap();
entries.clear();
lru.clear();
*memory = 0;
}
pub fn stats(&self) -> CacheStats {
self.stats.read().unwrap().clone()
}
pub fn memory_used(&self) -> usize {
*self.memory_used.read().unwrap()
}
pub fn len(&self) -> usize {
self.entries.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.entries.read().unwrap().is_empty()
}
pub fn clean_expired(&self) {
let ttl = Duration::from_secs(self.config.ttl_seconds);
let mut entries = self.entries.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
let mut memory = self.memory_used.write().unwrap();
let mut stats = self.stats.write().unwrap();
let expired_keys: Vec<_> = entries
.iter()
.filter(|(_, entry)| entry.is_expired(ttl))
.map(|(key, _)| key.clone())
.collect();
for key in expired_keys {
if let Some(entry) = entries.remove(&key) {
*memory = memory.saturating_sub(entry.size_bytes);
if let Some(pos) = lru.iter().position(|k| k == &key) {
lru.remove(pos);
}
stats.evictions += 1;
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub inserts: u64,
pub evictions: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn reset(&mut self) {
self.hits = 0;
self.misses = 0;
self.inserts = 0;
self.evictions = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::executor::plan::{ColumnDef, DataType, QuerySchema};
fn create_test_batch() -> RowBatch {
let schema = QuerySchema::new(vec![ColumnDef {
name: "id".to_string(),
data_type: DataType::Int64,
nullable: false,
}]);
RowBatch::new(schema)
}
#[test]
fn test_cache_insert_and_get() {
let cache = QueryCache::new(CacheConfig::default());
let batch = create_test_batch();
cache.insert("test_key".to_string(), vec![batch.clone()]);
assert_eq!(cache.len(), 1);
let cached = cache.get("test_key");
assert!(cached.is_some());
}
#[test]
fn test_cache_miss() {
let cache = QueryCache::new(CacheConfig::default());
let result = cache.get("nonexistent");
assert!(result.is_none());
let stats = cache.stats();
assert_eq!(stats.misses, 1);
}
#[test]
fn test_cache_eviction() {
let config = CacheConfig {
max_entries: 2,
max_memory_bytes: 1024 * 1024,
ttl_seconds: 300,
};
let cache = QueryCache::new(config);
let batch = create_test_batch();
cache.insert("key1".to_string(), vec![batch.clone()]);
cache.insert("key2".to_string(), vec![batch.clone()]);
cache.insert("key3".to_string(), vec![batch.clone()]);
assert_eq!(cache.len(), 2);
assert!(cache.get("key1").is_none());
}
#[test]
fn test_cache_clear() {
let cache = QueryCache::new(CacheConfig::default());
let batch = create_test_batch();
cache.insert("key1".to_string(), vec![batch.clone()]);
cache.insert("key2".to_string(), vec![batch.clone()]);
cache.clear();
assert_eq!(cache.len(), 0);
assert_eq!(cache.memory_used(), 0);
}
#[test]
fn test_hit_rate() {
let mut stats = CacheStats::default();
stats.hits = 7;
stats.misses = 3;
assert!((stats.hit_rate() - 0.7).abs() < 0.001);
}
}