use alloc::collections::{BTreeMap, VecDeque};
use alloc::string::String;
use alloc::vec::Vec;
use spg_sql::ast::{Expr, FromClause, FromJoin, SelectItem, SelectStatement, Statement, TableRef};
use spg_storage::ColumnSchema;
pub(crate) const PLAN_CACHE_MAX_ENTRIES: usize = 256;
#[derive(Debug, Clone)]
pub struct PreparedPlan {
pub stmt: Statement,
pub statistics_version: u64,
pub source_tables: Vec<String>,
pub describe_columns: Vec<ColumnSchema>,
}
#[derive(Debug, Clone)]
pub struct PlanCache {
entries: BTreeMap<String, PreparedPlan>,
lru: VecDeque<String>,
max_entries: usize,
}
impl Default for PlanCache {
fn default() -> Self {
Self {
entries: BTreeMap::new(),
lru: VecDeque::new(),
max_entries: PLAN_CACHE_MAX_ENTRIES,
}
}
}
impl PlanCache {
pub fn new() -> Self {
Self::default()
}
pub fn set_max_entries(&mut self, n: usize) {
self.max_entries = n.max(1).min(PLAN_CACHE_MAX_ENTRIES);
}
pub fn max_entries(&self) -> usize {
self.max_entries
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn get_snapshot(&self, sql: &str) -> Option<&PreparedPlan> {
self.entries.get(sql)
}
pub fn get(&mut self, sql: &str) -> Option<&PreparedPlan> {
if !self.entries.contains_key(sql) {
return None;
}
if let Some(idx) = self.lru.iter().position(|k| k == sql) {
let key = self.lru.remove(idx).expect("idx came from position()");
self.lru.push_back(key);
}
self.entries.get(sql)
}
pub fn insert(&mut self, sql: String, plan: PreparedPlan) {
if self.entries.contains_key(&sql) {
if let Some(idx) = self.lru.iter().position(|k| k == &sql) {
let key = self.lru.remove(idx).expect("idx came from position()");
self.lru.push_back(key);
}
self.entries.insert(sql, plan);
return;
}
if self.entries.len() >= self.max_entries {
if let Some(oldest) = self.lru.pop_front() {
self.entries.remove(&oldest);
}
}
self.lru.push_back(sql.clone());
self.entries.insert(sql, plan);
}
pub fn clear(&mut self) {
self.entries.clear();
self.lru.clear();
}
pub fn evict(&mut self, sql: &str) -> Option<PreparedPlan> {
let plan = self.entries.remove(sql)?;
if let Some(idx) = self.lru.iter().position(|k| k == sql) {
self.lru.remove(idx);
}
Some(plan)
}
pub fn evict_referencing(&mut self, table: &str) -> usize {
let to_evict: Vec<String> = self
.entries
.iter()
.filter_map(|(k, p)| {
if p.source_tables.iter().any(|t| t == table) {
Some(k.clone())
} else {
None
}
})
.collect();
let n = to_evict.len();
for k in to_evict {
self.entries.remove(&k);
if let Some(idx) = self.lru.iter().position(|x| x == &k) {
self.lru.remove(idx);
}
}
n
}
}
pub fn collect_source_tables(stmt: &Statement) -> Vec<String> {
let mut out: Vec<String> = Vec::new();
match stmt {
Statement::Select(s) => collect_from_select(s, &mut out),
Statement::Insert(s) => push_unique(&mut out, &s.table),
Statement::Update(s) => {
push_unique(&mut out, &s.table);
if let Some(w) = &s.where_ {
collect_expr(w, &mut out);
}
}
Statement::Delete(s) => {
push_unique(&mut out, &s.table);
if let Some(w) = &s.where_ {
collect_expr(w, &mut out);
}
}
Statement::Explain(inner) => {
collect_from_select(&inner.inner, &mut out);
}
_ => {}
}
out.sort();
out.dedup();
out
}
fn collect_from_select(s: &SelectStatement, out: &mut Vec<String>) {
if let Some(from) = &s.from {
collect_from_clause(from, out);
}
if let Some(w) = &s.where_ {
collect_expr(w, out);
}
if let Some(h) = &s.having {
collect_expr(h, out);
}
for item in &s.items {
if let SelectItem::Expr { expr, .. } = item {
collect_expr(expr, out);
}
}
for (_, peer) in &s.unions {
collect_from_select(peer, out);
}
}
fn collect_from_clause(from: &FromClause, out: &mut Vec<String>) {
collect_table_ref(&from.primary, out);
for j in &from.joins {
collect_from_join(j, out);
}
}
fn collect_from_join(j: &FromJoin, out: &mut Vec<String>) {
collect_table_ref(&j.table, out);
if let Some(on) = &j.on {
collect_expr(on, out);
}
}
fn collect_table_ref(t: &TableRef, out: &mut Vec<String>) {
push_unique(out, &t.name);
}
fn collect_expr(e: &Expr, out: &mut Vec<String>) {
match e {
Expr::ScalarSubquery(inner) => collect_from_select(inner, out),
Expr::Exists { subquery, .. } => collect_from_select(subquery, out),
Expr::InSubquery { expr, subquery, .. } => {
collect_expr(expr, out);
collect_from_select(subquery, out);
}
Expr::Binary { lhs, rhs, .. } => {
collect_expr(lhs, out);
collect_expr(rhs, out);
}
Expr::Unary { expr, .. } => collect_expr(expr, out),
Expr::Cast { expr, .. } => collect_expr(expr, out),
Expr::IsNull { expr, .. } => collect_expr(expr, out),
Expr::Like { expr, pattern, .. } => {
collect_expr(expr, out);
collect_expr(pattern, out);
}
Expr::FunctionCall { args, .. } => {
for a in args {
collect_expr(a, out);
}
}
Expr::WindowFunction { args, partition_by, order_by, .. } => {
for a in args {
collect_expr(a, out);
}
for p in partition_by {
collect_expr(p, out);
}
for (o, _) in order_by {
collect_expr(o, out);
}
}
Expr::Extract { source, .. } => collect_expr(source, out),
Expr::Array(items) => {
for elem in items {
collect_expr(elem, out);
}
}
Expr::ArraySubscript { target, index } => {
collect_expr(target, out);
collect_expr(index, out);
}
Expr::AnyAll { expr, array, .. } => {
collect_expr(expr, out);
collect_expr(array, out);
}
Expr::Literal(_) | Expr::Column(_) | Expr::Placeholder(_) => {}
}
}
fn push_unique(out: &mut Vec<String>, s: &str) {
if !out.iter().any(|x| x == s) {
out.push(String::from(s));
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;
use spg_sql::parser::parse_statement;
fn dummy_plan(version: u64, tables: &[&str]) -> PreparedPlan {
let stmt = parse_statement("SELECT 1").expect("trivial SELECT parses");
PreparedPlan {
stmt,
statistics_version: version,
source_tables: tables.iter().map(|s| s.to_string()).collect(),
describe_columns: Vec::new(),
}
}
#[test]
fn new_cache_is_empty() {
let cache = PlanCache::new();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[test]
fn insert_then_get_returns_the_plan() {
let mut cache = PlanCache::new();
cache.insert("SELECT 1".into(), dummy_plan(0, &["t"]));
assert_eq!(cache.len(), 1);
let plan = cache.get("SELECT 1").expect("hit");
assert_eq!(plan.source_tables, alloc::vec!["t".to_string()]);
}
#[test]
fn miss_returns_none() {
let mut cache = PlanCache::new();
cache.insert("SELECT 1".into(), dummy_plan(0, &[]));
assert!(cache.get("SELECT 2").is_none());
}
#[test]
fn replace_overwrites_existing_entry() {
let mut cache = PlanCache::new();
cache.insert("SELECT 1".into(), dummy_plan(1, &["a"]));
cache.insert("SELECT 1".into(), dummy_plan(2, &["b"]));
assert_eq!(cache.len(), 1);
let plan = cache.get("SELECT 1").expect("hit");
assert_eq!(plan.statistics_version, 2);
}
#[test]
fn lru_evicts_oldest_at_cap() {
let mut cache = PlanCache::new();
for i in 0..PLAN_CACHE_MAX_ENTRIES {
cache.insert(alloc::format!("SELECT {i}"), dummy_plan(i as u64, &[]));
}
assert_eq!(cache.len(), PLAN_CACHE_MAX_ENTRIES);
cache.insert("SELECT new".into(), dummy_plan(999, &[]));
assert_eq!(cache.len(), PLAN_CACHE_MAX_ENTRIES);
assert!(cache.get("SELECT 0").is_none());
assert!(cache.get("SELECT new").is_some());
}
#[test]
fn get_promotes_lru_position() {
let mut cache = PlanCache::new();
cache.insert("a".into(), dummy_plan(0, &[]));
cache.insert("b".into(), dummy_plan(0, &[]));
cache.insert("c".into(), dummy_plan(0, &[]));
let _ = cache.get("a");
for i in 0..(PLAN_CACHE_MAX_ENTRIES - 3) {
cache.insert(alloc::format!("filler{i}"), dummy_plan(0, &[]));
}
cache.insert("trigger".into(), dummy_plan(0, &[]));
assert!(cache.get("a").is_some(), "a was MRU after get(); should survive");
assert!(cache.get("b").is_none(), "b should be evicted");
}
#[test]
fn clear_drops_everything() {
let mut cache = PlanCache::new();
cache.insert("a".into(), dummy_plan(0, &[]));
cache.insert("b".into(), dummy_plan(0, &[]));
cache.clear();
assert!(cache.is_empty());
assert!(cache.get("a").is_none());
}
#[test]
fn evict_referencing_drops_only_matching_plans() {
let mut cache = PlanCache::new();
cache.insert("a".into(), dummy_plan(0, &["users"]));
cache.insert("b".into(), dummy_plan(0, &["orders"]));
cache.insert("c".into(), dummy_plan(0, &["users", "orders"]));
let n = cache.evict_referencing("users");
assert_eq!(n, 2);
assert!(cache.get("a").is_none());
assert!(cache.get("b").is_some());
assert!(cache.get("c").is_none());
}
#[test]
fn collect_source_tables_from_simple_select() {
let stmt = parse_statement("SELECT a, b FROM t1 WHERE x = 1").expect("parses");
let tables = collect_source_tables(&stmt);
assert_eq!(tables, alloc::vec!["t1".to_string()]);
}
#[test]
fn collect_source_tables_from_join() {
let stmt = parse_statement(
"SELECT * FROM t1 JOIN t2 ON t1.a = t2.b JOIN t3 ON t2.c = t3.d",
)
.expect("parses");
let tables = collect_source_tables(&stmt);
assert_eq!(
tables,
alloc::vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]
);
}
#[test]
fn collect_source_tables_dedupes_self_join() {
let stmt = parse_statement("SELECT * FROM t1 a JOIN t1 b ON a.x = b.y").expect("parses");
let tables = collect_source_tables(&stmt);
assert_eq!(tables, alloc::vec!["t1".to_string()]);
}
}