use crate::executor::scan::RecordBatch;
use crate::parser::ast::Statement;
use blake3::Hash;
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub struct QueryCache {
entries: DashMap<Hash, CacheEntry>,
config: CacheConfig,
stats: Arc<RwLock<CacheStatistics>>,
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_size_bytes: usize,
pub ttl: Duration,
pub enabled: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_size_bytes: 1024 * 1024 * 1024, ttl: Duration::from_secs(300), enabled: true,
}
}
}
#[derive(Clone)]
struct CacheEntry {
result: Arc<Vec<RecordBatch>>,
created_at: Instant,
size_bytes: usize,
access_count: usize,
}
impl CacheEntry {
fn new(result: Vec<RecordBatch>) -> Self {
let size_bytes = Self::estimate_size(&result);
Self {
result: Arc::new(result),
created_at: Instant::now(),
size_bytes,
access_count: 0,
}
}
fn estimate_size(batches: &[RecordBatch]) -> usize {
batches
.iter()
.map(|batch| batch.num_rows * 100)
.sum::<usize>()
}
fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
}
impl QueryCache {
pub fn new(config: CacheConfig) -> Self {
Self {
entries: DashMap::new(),
config,
stats: Arc::new(RwLock::new(CacheStatistics::default())),
}
}
pub fn get(&self, query: &Statement) -> Option<Vec<RecordBatch>> {
if !self.config.enabled {
return None;
}
let key = self.compute_key(query);
if let Some(mut entry) = self.entries.get_mut(&key) {
if entry.is_expired(self.config.ttl) {
drop(entry);
self.entries.remove(&key);
self.stats.write().misses += 1;
return None;
}
entry.access_count += 1;
let result = (*entry.result).clone();
self.stats.write().hits += 1;
Some(result)
} else {
self.stats.write().misses += 1;
None
}
}
pub fn put(&self, query: &Statement, result: Vec<RecordBatch>) {
if !self.config.enabled {
return;
}
let key = self.compute_key(query);
let entry = CacheEntry::new(result);
self.evict_if_needed(entry.size_bytes);
self.entries.insert(key, entry);
self.stats.write().inserts += 1;
}
pub fn invalidate(&self, query: &Statement) {
let key = self.compute_key(query);
self.entries.remove(&key);
}
pub fn clear(&self) {
self.entries.clear();
self.stats.write().clears += 1;
}
pub fn statistics(&self) -> CacheStatistics {
*self.stats.read()
}
fn compute_key(&self, query: &Statement) -> Hash {
let query_string = format!("{:?}", query);
blake3::hash(query_string.as_bytes())
}
fn evict_if_needed(&self, incoming_size: usize) {
let mut current_size: usize = self
.entries
.iter()
.map(|entry| entry.value().size_bytes)
.sum();
if current_size + incoming_size <= self.config.max_size_bytes {
return;
}
let mut entries: Vec<_> = self
.entries
.iter()
.map(|entry| {
(
*entry.key(),
entry.value().created_at,
entry.value().access_count,
entry.value().size_bytes,
)
})
.collect();
entries.sort_by_key(|(_, created, access_count, _)| {
(created.elapsed().as_secs(), *access_count)
});
for (key, _, _, size) in entries {
self.entries.remove(&key);
current_size -= size;
self.stats.write().evictions += 1;
if current_size + incoming_size <= self.config.max_size_bytes {
break;
}
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct CacheStatistics {
pub hits: u64,
pub misses: u64,
pub inserts: u64,
pub evictions: u64,
pub clears: u64,
}
impl CacheStatistics {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn miss_rate(&self) -> f64 {
1.0 - self.hit_rate()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::executor::scan::{ColumnData, DataType, Field, Schema};
use crate::parser::sql::parse_sql;
#[test]
fn test_cache_put_get() {
let config = CacheConfig::default();
let cache = QueryCache::new(config);
let query = parse_sql("SELECT * FROM test").ok().unwrap_or_else(|| {
Statement::Select(crate::parser::ast::SelectStatement {
projection: vec![],
from: None,
selection: None,
group_by: vec![],
having: None,
order_by: vec![],
limit: None,
offset: None,
})
});
let schema = Arc::new(Schema::new(vec![Field::new(
"id".to_string(),
DataType::Int64,
false,
)]));
let columns = vec![ColumnData::Int64(vec![Some(1), Some(2)])];
let batch = RecordBatch::new(schema, columns, 2).ok();
if let Some(batch) = batch {
let result = vec![batch];
cache.put(&query, result.clone());
let cached = cache.get(&query);
assert!(cached.is_some());
}
}
#[test]
fn test_cache_statistics() {
let config = CacheConfig::default();
let cache = QueryCache::new(config);
let query = parse_sql("SELECT * FROM test").ok().unwrap_or_else(|| {
Statement::Select(crate::parser::ast::SelectStatement {
projection: vec![],
from: None,
selection: None,
group_by: vec![],
having: None,
order_by: vec![],
limit: None,
offset: None,
})
});
let _ = cache.get(&query);
let stats = cache.statistics();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 1);
}
}