use crate::common::time_compat::Instant;
use std::borrow::Cow;
use std::sync::{Arc, RwLock};
use crate::common::SmartString;
use rustc_hash::FxHashMap;
use crate::common::CompactArc;
use crate::core::Schema;
use crate::parser::ast::Statement;
#[inline]
fn to_lowercase_cow(s: &str) -> Cow<'_, str> {
if s.bytes().all(|b| !b.is_ascii_uppercase()) {
Cow::Borrowed(s)
} else {
Cow::Owned(s.to_lowercase())
}
}
#[derive(Debug, Clone)]
pub enum PkValueSource {
Parameter(usize),
Literal(i64),
NamedParameter(SmartString),
}
#[derive(Debug, Clone)]
pub struct CompiledPkLookup {
pub table_name: SmartString,
pub schema: CompactArc<Schema>,
pub column_names: CompactArc<Vec<String>>,
pub pk_value_source: PkValueSource,
pub cached_epoch: u64,
}
#[derive(Debug, Clone)]
pub struct CompiledUpdateColumn {
pub column_idx: usize,
pub column_type: crate::core::DataType,
pub value_source: UpdateValueSource,
}
#[derive(Debug, Clone)]
pub enum UpdateValueSource {
Literal(crate::core::Value),
Parameter(usize),
NamedParameter(SmartString),
}
#[derive(Debug, Clone)]
pub struct CompiledPkUpdate {
pub table_name: SmartString,
pub schema: CompactArc<Schema>,
pub pk_column_name: SmartString,
pub pk_value_source: PkValueSource,
pub updates: Vec<CompiledUpdateColumn>,
pub cached_epoch: u64,
}
#[derive(Debug, Clone)]
pub struct CompiledPkDelete {
pub table_name: SmartString,
pub schema: CompactArc<Schema>,
pub pk_column_name: SmartString,
pub pk_value_source: PkValueSource,
pub cached_epoch: u64,
}
#[derive(Debug, Clone)]
pub struct CompiledInsert {
pub table_name: SmartString,
pub column_indices: Arc<Vec<usize>>,
pub column_types: Arc<Vec<crate::core::DataType>>,
pub column_vector_dims: Arc<Vec<u16>>,
pub column_names: Arc<Vec<SmartString>>,
pub all_column_types: Arc<Vec<crate::core::DataType>>,
pub default_row_template: Arc<Vec<crate::core::Value>>,
pub check_exprs: Arc<Vec<(usize, SmartString, SmartString)>>,
pub cached_epoch: u64,
}
#[derive(Debug, Clone)]
pub struct CompiledCountDistinct {
pub table_name: SmartString,
pub column_name: SmartString,
pub result_column_name: String,
pub cached_epoch: u64,
}
#[derive(Debug, Clone)]
pub struct CompiledCountStar {
pub table_name: SmartString,
pub result_column_name: String,
pub cached_epoch: u64,
}
#[derive(Debug, Clone, Default)]
pub enum CompiledExecution {
#[default]
Unknown,
NotOptimizable(u64),
PkLookup(CompiledPkLookup),
PkUpdate(CompiledPkUpdate),
PkDelete(CompiledPkDelete),
Insert(CompiledInsert),
CountDistinct(CompiledCountDistinct),
CountStar(CompiledCountStar),
}
pub const DEFAULT_CACHE_SIZE: usize = 1000;
#[derive(Debug, Clone)]
pub struct CachedPlanRef {
pub statement: Arc<Statement>,
pub has_params: bool,
pub param_count: usize,
pub compiled: Arc<RwLock<CompiledExecution>>,
}
#[derive(Debug, Clone)]
pub struct CachedQueryPlan {
pub statement: Arc<Statement>,
pub query_text: SmartString,
pub last_used: Instant,
pub usage_count: u64,
pub has_params: bool,
pub param_count: usize,
pub normalized_query: SmartString,
pub compiled: Arc<RwLock<CompiledExecution>>,
}
impl CachedQueryPlan {
pub fn new(
statement: Arc<Statement>,
query_text: SmartString,
has_params: bool,
param_count: usize,
normalized_query: SmartString,
) -> Self {
Self {
statement,
query_text,
last_used: Instant::now(),
usage_count: 1,
has_params,
param_count,
normalized_query,
compiled: Arc::new(RwLock::new(CompiledExecution::Unknown)),
}
}
}
pub struct QueryCache {
plans: RwLock<FxHashMap<SmartString, CachedQueryPlan>>,
max_size: usize,
prune_factor: f64,
}
impl QueryCache {
pub fn new(max_size: usize) -> Self {
Self {
plans: RwLock::new(FxHashMap::default()),
max_size,
prune_factor: 0.2, }
}
pub fn default_sized() -> Self {
Self::new(DEFAULT_CACHE_SIZE)
}
pub fn get(&self, query: &str) -> Option<CachedPlanRef> {
let normalized = normalize_query(query);
let plans = self.plans.read().ok()?;
let plan = plans.get(normalized.as_ref())?;
Some(CachedPlanRef {
statement: plan.statement.clone(),
has_params: plan.has_params,
param_count: plan.param_count,
compiled: plan.compiled.clone(), })
}
pub fn put(
&self,
query: &str,
statement: Arc<Statement>,
has_params: bool,
param_count: usize,
) -> CachedPlanRef {
let normalized = normalize_query(query);
let normalized_key: SmartString = match normalized {
Cow::Borrowed(s) => SmartString::new(s),
Cow::Owned(s) => SmartString::new(&s),
};
let compiled = Arc::new(RwLock::new(CompiledExecution::Unknown));
if let Ok(mut plans) = self.plans.write() {
if plans.len() >= self.max_size {
self.prune_cache(&mut plans);
}
let key_for_insert = normalized_key.clone();
plans.insert(
key_for_insert,
CachedQueryPlan {
statement: statement.clone(),
query_text: SmartString::new(query),
last_used: Instant::now(),
usage_count: 1,
has_params,
param_count,
normalized_query: normalized_key, compiled: compiled.clone(), },
);
}
CachedPlanRef {
statement,
has_params,
param_count,
compiled,
}
}
pub fn clear(&self) {
if let Ok(mut plans) = self.plans.write() {
plans.clear();
}
}
pub fn invalidate_table(&self, table_name: &str) {
let table_lower = to_lowercase_cow(table_name);
if let Ok(mut plans) = self.plans.write() {
plans.retain(|_key, plan| {
if let Ok(compiled) = plan.compiled.read() {
match &*compiled {
CompiledExecution::PkLookup(lookup) => {
if lookup.table_name == *table_lower {
return false; }
}
CompiledExecution::CountDistinct(cd) => {
if cd.table_name == *table_lower {
return false; }
}
CompiledExecution::CountStar(cs) => {
if cs.table_name == *table_lower {
return false; }
}
_ => {}
}
}
let query_lower = to_lowercase_cow(&plan.query_text);
!query_lower.contains(&format!(" {} ", &*table_lower))
&& !query_lower.contains(&format!(" {}\n", &*table_lower))
&& !query_lower.contains(&format!(" {};", &*table_lower))
&& !query_lower.contains(&format!("from {}", &*table_lower))
&& !query_lower.contains(&format!("join {}", &*table_lower))
&& !query_lower.contains(&format!("into {}", &*table_lower))
&& !query_lower.contains(&format!("update {}", &*table_lower))
});
}
}
pub fn size(&self) -> usize {
self.plans.read().map(|p| p.len()).unwrap_or(0)
}
pub fn stats(&self) -> CacheStats {
let plans = match self.plans.read() {
Ok(p) => p,
Err(_) => {
return CacheStats {
size: 0,
max_size: self.max_size,
total_usage: 0,
avg_usage: 0.0,
}
}
};
let size = plans.len();
let total_usage: u64 = plans.values().map(|p| p.usage_count).sum();
let avg_usage = if size > 0 {
total_usage as f64 / size as f64
} else {
0.0
};
CacheStats {
size,
max_size: self.max_size,
total_usage,
avg_usage,
}
}
fn prune_cache(&self, plans: &mut FxHashMap<SmartString, CachedQueryPlan>) {
let num_to_remove = ((self.max_size as f64) * self.prune_factor).ceil() as usize;
let num_to_remove = num_to_remove.max(1);
if plans.len() <= num_to_remove {
return;
}
let mut entries: Vec<(&SmartString, Instant, u64)> = plans
.iter()
.map(|(k, p)| (k, p.last_used, p.usage_count))
.collect();
entries.sort_unstable_by(|a, b| a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2)));
let keys_to_remove: Vec<SmartString> = entries
.into_iter()
.take(num_to_remove)
.map(|(k, _, _)| k.clone())
.collect();
for key in keys_to_remove {
plans.remove(&key);
}
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::default_sized()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub size: usize,
pub max_size: usize,
pub total_usage: u64,
pub avg_usage: f64,
}
#[inline]
fn normalize_query(query: &str) -> std::borrow::Cow<'_, str> {
use std::borrow::Cow;
let bytes = query.as_bytes();
let len = bytes.len();
let start = bytes
.iter()
.position(|&b| !b.is_ascii_whitespace())
.unwrap_or(len);
if start == len {
return Cow::Borrowed("");
}
let end = bytes
.iter()
.rposition(|&b| !b.is_ascii_whitespace())
.map(|p| p + 1)
.unwrap_or(start);
let trimmed = &bytes[start..end];
let mut prev_ws = false;
let mut needs_normalization = false;
for &b in trimmed {
let is_ws = b.is_ascii_whitespace();
if is_ws {
if prev_ws || b != b' ' {
needs_normalization = true;
break;
}
}
prev_ws = is_ws;
}
if !needs_normalization {
return Cow::Borrowed(unsafe { std::str::from_utf8_unchecked(trimmed) });
}
let trimmed_str = unsafe { std::str::from_utf8_unchecked(trimmed) };
let mut result = String::with_capacity(trimmed.len());
let mut last_was_space = false;
for c in trimmed_str.chars() {
if c.is_ascii_whitespace() {
if !last_was_space {
result.push(' ');
last_was_space = true;
}
} else {
result.push(c);
last_was_space = false;
}
}
Cow::Owned(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::ast::{Expression, GroupByClause, SelectStatement, StarExpression};
use crate::parser::token::{Position, Token, TokenType};
fn dummy_token() -> Token {
Token::new(TokenType::Keyword, "SELECT", Position::new(0, 1, 1))
}
fn star_token() -> Token {
Token::new(TokenType::Operator, "*", Position::new(0, 1, 1))
}
fn create_test_statement() -> Arc<Statement> {
Arc::new(Statement::Select(SelectStatement {
token: dummy_token(),
with: None,
distinct: false,
distinct_on: vec![],
columns: vec![Expression::Star(StarExpression {
token: star_token(),
})],
table_expr: None,
where_clause: None,
group_by: GroupByClause::default(),
having: None,
window_defs: vec![],
order_by: vec![],
limit: None,
offset: None,
set_operations: vec![],
}))
}
#[test]
fn test_cache_put_get() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt.clone(), false, 0);
assert_eq!(cache.size(), 1);
let plan = cache.get("SELECT * FROM users");
assert!(plan.is_some());
let plan = plan.unwrap();
assert!(!plan.has_params);
assert_eq!(plan.param_count, 0);
}
#[test]
fn test_cache_miss() {
let cache = QueryCache::new(100);
let plan = cache.get("SELECT * FROM users");
assert!(plan.is_none());
}
#[test]
fn test_cache_usage_count() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt, false, 0);
for _ in 0..5 {
cache.get("SELECT * FROM users");
}
let stats = cache.stats();
assert_eq!(stats.total_usage, 1); }
#[test]
fn test_cache_clear() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt, false, 0);
assert_eq!(cache.size(), 1);
cache.clear();
assert_eq!(cache.size(), 0);
}
#[test]
fn test_cache_pruning() {
let cache = QueryCache::new(5);
let stmt = create_test_statement();
for i in 0..10 {
let query = format!("SELECT * FROM table{}", i);
cache.put(&query, stmt.clone(), false, 0);
}
assert!(cache.size() <= 5);
}
#[test]
fn test_normalize_query() {
assert_eq!(
normalize_query(" SELECT * FROM users "),
"SELECT * FROM users"
);
assert_eq!(
normalize_query("SELECT\n*\nFROM\nusers"),
"SELECT * FROM users"
);
assert_eq!(
normalize_query("SELECT\t*\t\tFROM users"),
"SELECT * FROM users"
);
}
#[test]
fn test_normalize_query_utf8() {
assert_eq!(
normalize_query("SELECT * FROM t WHERE name = '日本語'"),
"SELECT * FROM t WHERE name = '日本語'"
);
assert_eq!(
normalize_query("SELECT * FROM t WHERE name = '日本語'"),
"SELECT * FROM t WHERE name = '日本語'"
);
assert_eq!(
normalize_query("SELECT\t*\tFROM t WHERE city = '東京' AND country = '中国'"),
"SELECT * FROM t WHERE city = '東京' AND country = '中国'"
);
assert_eq!(
normalize_query("SELECT * FROM t WHERE emoji = '🎉'"),
"SELECT * FROM t WHERE emoji = '🎉'"
);
}
#[test]
fn test_normalized_cache_hit() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt, false, 0);
let plan = cache.get(" SELECT * FROM users ");
assert!(plan.is_some());
}
#[test]
fn test_parameterized_query() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT * FROM users WHERE id = $1", stmt, true, 1);
let plan = cache.get("SELECT * FROM users WHERE id = $1").unwrap();
assert!(plan.has_params);
assert_eq!(plan.param_count, 1);
}
#[test]
fn test_cache_stats() {
let cache = QueryCache::new(100);
let stmt = create_test_statement();
cache.put("SELECT 1", stmt.clone(), false, 0);
cache.put("SELECT 2", stmt.clone(), false, 0);
for _ in 0..5 {
cache.get("SELECT 1");
}
let stats = cache.stats();
assert_eq!(stats.size, 2);
assert_eq!(stats.max_size, 100);
assert_eq!(stats.total_usage, 2); }
#[test]
fn test_cache_thread_safety() {
use std::sync::Arc;
use std::thread;
let cache = Arc::new(QueryCache::new(1000));
let stmt = create_test_statement();
cache.put("SELECT * FROM users", stmt.clone(), false, 0);
let mut handles = vec![];
for _ in 0..10 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for _ in 0..100 {
cache.get("SELECT * FROM users");
}
}));
}
for i in 0..5 {
let cache = Arc::clone(&cache);
let stmt = stmt.clone();
handles.push(thread::spawn(move || {
for j in 0..20 {
let query = format!("SELECT * FROM table{}_{}", i, j);
cache.put(&query, stmt.clone(), false, 0);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert!(cache.get("SELECT * FROM users").is_some());
}
}