use std::sync::{Arc, RwLock};
use std::time::Instant;
use crate::parser::ast::Statement;
pub const DEFAULT_CACHE_SIZE: usize = 1000;
#[derive(Debug, Clone)]
pub struct CachedPlanRef {
pub statement: Arc<Statement>,
pub has_params: bool,
pub param_count: usize,
}
#[derive(Debug, Clone)]
pub struct CachedQueryPlan {
pub statement: Arc<Statement>,
pub query_text: String,
pub last_used: Instant,
pub usage_count: u64,
pub has_params: bool,
pub param_count: usize,
pub normalized_query: String,
}
impl CachedQueryPlan {
pub fn new(
statement: Arc<Statement>,
query_text: String,
has_params: bool,
param_count: usize,
normalized_query: String,
) -> Self {
Self {
statement,
query_text,
last_used: Instant::now(),
usage_count: 1,
has_params,
param_count,
normalized_query,
}
}
}
pub struct QueryCache {
plans: RwLock<std::collections::HashMap<String, CachedQueryPlan>>,
max_size: usize,
prune_factor: f64,
}
impl QueryCache {
pub fn new(max_size: usize) -> Self {
Self {
plans: RwLock::new(std::collections::HashMap::new()),
max_size,
prune_factor: 0.2, }
}
pub fn default_sized() -> Self {
Self::new(DEFAULT_CACHE_SIZE)
}
pub fn get(&self, query: &str) -> Option<CachedPlanRef> {
let normalized = normalize_query(query);
let plans = self.plans.read().ok()?;
let plan = plans.get(normalized.as_ref())?;
Some(CachedPlanRef {
statement: plan.statement.clone(),
has_params: plan.has_params,
param_count: plan.param_count,
})
}
pub fn put(
&self,
query: &str,
statement: Arc<Statement>,
has_params: bool,
param_count: usize,
) -> CachedQueryPlan {
let normalized = normalize_query(query);
let normalized_key = normalized.into_owned();
let plan = CachedQueryPlan::new(
statement,
query.to_string(),
has_params,
param_count,
normalized_key.clone(),
);
if let Ok(mut plans) = self.plans.write() {
if plans.len() >= self.max_size {
self.prune_cache(&mut plans);
}
plans.insert(normalized_key, plan.clone());
}
plan
}
pub fn clear(&self) {
if let Ok(mut plans) = self.plans.write() {
plans.clear();
}
}
pub fn size(&self) -> usize {
self.plans.read().map(|p| p.len()).unwrap_or(0)
}
pub fn stats(&self) -> CacheStats {
let plans = match self.plans.read() {
Ok(p) => p,
Err(_) => {
return CacheStats {
size: 0,
max_size: self.max_size,
total_usage: 0,
avg_usage: 0.0,
}
}
};
let size = plans.len();
let total_usage: u64 = plans.values().map(|p| p.usage_count).sum();
let avg_usage = if size > 0 {
total_usage as f64 / size as f64
} else {
0.0
};
CacheStats {
size,
max_size: self.max_size,
total_usage,
avg_usage,
}
}
fn prune_cache(&self, plans: &mut std::collections::HashMap<String, CachedQueryPlan>) {
let num_to_remove = ((self.max_size as f64) * self.prune_factor).ceil() as usize;
let num_to_remove = num_to_remove.max(1);
if plans.len() <= num_to_remove {
return;
}
let mut entries: Vec<(String, Instant, u64)> = plans
.iter()
.map(|(k, p)| (k.clone(), p.last_used, p.usage_count))
.collect();
entries.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2)));
for (key, _, _) in entries.into_iter().take(num_to_remove) {
plans.remove(&key);
}
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::default_sized()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub size: usize,
pub max_size: usize,
pub total_usage: u64,
pub avg_usage: f64,
}
fn normalize_query(query: &str) -> std::borrow::Cow<'_, str> {
use std::borrow::Cow;
let trimmed = query.trim();
let needs_normalization = trimmed
.as_bytes()
.windows(2)
.any(|w| w[0].is_ascii_whitespace() && w[1].is_ascii_whitespace())
|| trimmed
.bytes()
.any(|b| b == b'\n' || b == b'\t' || b == b'\r');
if !needs_normalization {
return Cow::Borrowed(trimmed);
}
let mut result = String::with_capacity(trimmed.len());
let mut last_was_space = false;
for c in trimmed.chars() {
if c.is_whitespace() {
if !last_was_space {
result.push(' ');
last_was_space = true;
}
} else {
result.push(c);
last_was_space = false;
}
}
Cow::Owned(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::ast::{Expression, GroupByClause, SelectStatement, StarExpression};
use crate::parser::token::{Position, Token, TokenType};
fn dummy_token() -> Token {
Token::new(TokenType::Keyword, "SELECT", Position::new(0, 1, 1))
}
fn star_token() -> Token {
Token::new(TokenType::Operator, "*", Position::new(0, 1, 1))
}
fn create_test_statement() -> Arc<Statement> {
Arc::new(Statement::Select(SelectStatement {
token: dummy_token(),
with: None,
distinct: false,
columns: vec![Expression::Star(StarExpression {
token: star_token(),
})],
table_expr: None,
where_clause: None,
group_by: GroupByClause::default(),
having: None,
window_defs: vec![],
order_by: vec![],
limit: None,
offset: None,
set_operations: vec![],
}))
}
#[test]
fn test_cache_put_get() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt.clone(), false, 0);
assert_eq!(cache.size(), 1);
let plan = cache.get("SELECT * FROM users");
assert!(plan.is_some());
let plan = plan.unwrap();
assert!(!plan.has_params);
assert_eq!(plan.param_count, 0);
}
#[test]
fn test_cache_miss() {
let cache = QueryCache::new(100);
let plan = cache.get("SELECT * FROM users");
assert!(plan.is_none());
}
#[test]
fn test_cache_usage_count() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt, false, 0);
for _ in 0..5 {
cache.get("SELECT * FROM users");
}
let stats = cache.stats();
assert_eq!(stats.total_usage, 1); }
#[test]
fn test_cache_clear() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt, false, 0);
assert_eq!(cache.size(), 1);
cache.clear();
assert_eq!(cache.size(), 0);
}
#[test]
fn test_cache_pruning() {
let cache = QueryCache::new(5);
let stmt = create_test_statement();
for i in 0..10 {
let query = format!("SELECT * FROM table{}", i);
cache.put(&query, stmt.clone(), false, 0);
}
assert!(cache.size() <= 5);
}
#[test]
fn test_normalize_query() {
assert_eq!(
normalize_query(" SELECT * FROM users "),
"SELECT * FROM users"
);
assert_eq!(
normalize_query("SELECT\n*\nFROM\nusers"),
"SELECT * FROM users"
);
assert_eq!(
normalize_query("SELECT\t*\t\tFROM users"),
"SELECT * FROM users"
);
}
#[test]
fn test_normalized_cache_hit() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt, false, 0);
let plan = cache.get(" SELECT * FROM users ");
assert!(plan.is_some());
}
#[test]
fn test_parameterized_query() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users WHERE id = $1", stmt, true, 1);
let plan = cache.get("SELECT * FROM users WHERE id = $1").unwrap();
assert!(plan.has_params);
assert_eq!(plan.param_count, 1);
}
#[test]
fn test_cache_stats() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT 1", stmt.clone(), false, 0);
cache.put("SELECT 2", stmt.clone(), false, 0);
for _ in 0..5 {
cache.get("SELECT 1");
}
let stats = cache.stats();
assert_eq!(stats.size, 2);
assert_eq!(stats.max_size, 100);
assert_eq!(stats.total_usage, 2); }
#[test]
fn test_cache_thread_safety() {
use std::sync::Arc;
use std::thread;
let cache = Arc::new(QueryCache::new(1000));
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt.clone(), false, 0);
let mut handles = vec![];
for _ in 0..10 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for _ in 0..100 {
cache.get("SELECT * FROM users");
}
}));
}
for i in 0..5 {
let cache = Arc::clone(&cache);
let stmt = stmt.clone();
handles.push(thread::spawn(move || {
for j in 0..20 {
let query = format!("SELECT * FROM table{}_{}", i, j);
cache.put(&query, stmt.clone(), false, 0);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert!(cache.get("SELECT * FROM users").is_some());
}
}