use dashmap::DashMap;
use sqlparser::ast::Statement;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
pub struct PlanCache {
l1: Arc<DashMap<String, CachedPlan>>,
l2_dir: Option<PathBuf>,
max_l1_size: usize,
stats: CacheStats,
}
#[derive(Clone)]
pub struct CachedPlan {
pub statement: Statement,
pub hit_count: u64,
}
pub struct CacheStats {
pub hits: AtomicU64,
pub misses: AtomicU64,
pub evictions: AtomicU64,
pub l2_hits: AtomicU64,
}
impl CacheStats {
fn new() -> Self {
Self {
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
evictions: AtomicU64::new(0),
l2_hits: AtomicU64::new(0),
}
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let total = hits + self.misses.load(Ordering::Relaxed);
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
pub fn total(&self) -> u64 {
self.hits.load(Ordering::Relaxed) + self.misses.load(Ordering::Relaxed)
}
}
impl PlanCache {
pub fn new(max_size: usize) -> Self {
Self {
l1: Arc::new(DashMap::with_capacity(max_size)),
l2_dir: None,
max_l1_size: max_size,
stats: CacheStats::new(),
}
}
pub fn with_default_size() -> Self {
Self::new(1_000)
}
pub fn with_l2_cache(mut self, dir: PathBuf) -> Self {
if !dir.exists() {
let _ = std::fs::create_dir_all(&dir);
}
self.l2_dir = Some(dir);
self
}
pub fn get(&self, sql: &str) -> Option<Statement> {
if let Some(mut entry) = self.l1.get_mut(sql) {
entry.hit_count += 1;
self.stats.hits.fetch_add(1, Ordering::Relaxed);
return Some(entry.statement.clone());
}
if let Some(stmt) = self.get_l2(sql) {
self.stats.l2_hits.fetch_add(1, Ordering::Relaxed);
self.stats.hits.fetch_add(1, Ordering::Relaxed);
self.insert_l1(sql.to_string(), stmt.clone());
return Some(stmt);
}
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub fn insert(&self, sql: String, statement: Statement) {
self.insert_l1(sql.clone(), statement.clone());
self.put_l2(&sql, &statement);
}
fn insert_l1(&self, sql: String, statement: Statement) {
if self.l1.len() >= self.max_l1_size
&& let Some(lru_key) = self.find_lfu_key()
{
self.l1.remove(&lru_key);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
self.l1.insert(
sql,
CachedPlan {
statement,
hit_count: 0,
},
);
}
fn find_lfu_key(&self) -> Option<String> {
let mut min_hits = u64::MAX;
let mut lfu_key = None;
for entry in self.l1.iter() {
if entry.value().hit_count < min_hits {
min_hits = entry.value().hit_count;
lfu_key = Some(entry.key().clone());
}
}
lfu_key
}
fn put_l2(&self, sql: &str, statement: &Statement) {
if let Some(dir) = &self.l2_dir {
let hash = Self::hash_sql(sql);
let path = dir.join(format!("{hash}.plan"));
let data = format!("{statement:?}");
let _ = std::fs::write(path, data.as_bytes());
}
}
fn get_l2(&self, sql: &str) -> Option<Statement> {
let dir = self.l2_dir.as_ref()?;
let hash = Self::hash_sql(sql);
let path = dir.join(format!("{hash}.plan"));
if path.exists() {
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
let dialect = GenericDialect {};
Parser::parse_sql(&dialect, sql).ok()?.into_iter().next()
} else {
None
}
}
fn hash_sql(sql: &str) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for byte in sql.bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}
pub fn clear(&self) {
self.l1.clear();
if let Some(dir) = &self.l2_dir {
let _ = std::fs::remove_dir_all(dir);
let _ = std::fs::create_dir_all(dir);
}
}
pub fn len(&self) -> usize {
self.l1.len()
}
pub fn is_empty(&self) -> bool {
self.l1.is_empty()
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn contains(&self, sql: &str) -> bool {
self.l1.contains_key(sql)
}
}
#[cfg(test)]
mod tests {
use super::*;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
fn parse_one(sql: &str) -> Statement {
let dialect = GenericDialect {};
Parser::parse_sql(&dialect, sql)
.unwrap()
.into_iter()
.next()
.unwrap()
}
#[test]
fn test_plan_cache_basic() {
let cache = PlanCache::new(10);
let sql = "SELECT * FROM users";
let stmt = parse_one(sql);
cache.insert(sql.to_string(), stmt.clone());
let cached = cache.get(sql);
assert!(cached.is_some());
assert_eq!(cache.stats().hits.load(Ordering::Relaxed), 1);
assert_eq!(cache.stats().misses.load(Ordering::Relaxed), 0);
}
#[test]
fn test_plan_cache_eviction() {
let cache = PlanCache::new(2);
let sql1 = "SELECT * FROM users";
let sql2 = "SELECT * FROM orders";
let sql3 = "SELECT * FROM products";
cache.insert(sql1.to_string(), parse_one(sql1));
cache.insert(sql2.to_string(), parse_one(sql2));
cache.insert(sql3.to_string(), parse_one(sql3));
assert_eq!(cache.len(), 2);
assert_eq!(cache.stats().evictions.load(Ordering::Relaxed), 1);
}
#[test]
fn test_plan_cache_hit_rate() {
let cache = PlanCache::new(10);
let sql = "SELECT * FROM users";
cache.insert(sql.to_string(), parse_one(sql));
cache.get(sql); cache.get(sql); cache.get("SELECT 1");
assert_eq!(cache.stats().hits.load(Ordering::Relaxed), 2);
assert_eq!(cache.stats().misses.load(Ordering::Relaxed), 1);
assert!((cache.stats().hit_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_plan_cache_l2_disk() {
let tmp_dir = std::env::temp_dir().join("dbx_plan_cache_test");
let _ = std::fs::remove_dir_all(&tmp_dir);
let cache = PlanCache::new(1).with_l2_cache(tmp_dir.clone());
let sql1 = "SELECT * FROM users";
let sql2 = "SELECT * FROM orders";
cache.insert(sql1.to_string(), parse_one(sql1));
cache.insert(sql2.to_string(), parse_one(sql2));
let result = cache.get(sql1);
assert!(result.is_some());
assert_eq!(cache.stats().l2_hits.load(Ordering::Relaxed), 1);
let _ = std::fs::remove_dir_all(&tmp_dir);
}
#[test]
fn test_plan_cache_contains() {
let cache = PlanCache::new(10);
let sql = "SELECT * FROM users";
assert!(!cache.contains(sql));
cache.insert(sql.to_string(), parse_one(sql));
assert!(cache.contains(sql));
}
#[test]
fn test_plan_cache_concurrent_access() {
use std::thread;
let cache = Arc::new(PlanCache::new(100));
let mut handles = vec![];
for i in 0..8 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
let sql = format!("SELECT * FROM table_{i}");
let stmt = parse_one(&sql);
cache.insert(sql.clone(), stmt);
assert!(cache.get(&sql).is_some());
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(cache.len(), 8);
}
}