use crate::ast::{Assignment, Expr, Literal};
use crate::plan::PlanNode;
use rustc_hash::FxHashMap;
pub struct PlanCache {
cache: FxHashMap<u64, PlanNode>,
capacity: usize,
pub hits: u64,
pub misses: u64,
}
impl PlanCache {
pub fn new(capacity: usize) -> Self {
PlanCache {
cache: FxHashMap::default(),
capacity,
hits: 0,
misses: 0,
}
}
pub fn insert(&mut self, hash: u64, plan: PlanNode) {
if self.cache.len() >= self.capacity && !self.cache.contains_key(&hash) {
self.cache.clear();
}
self.cache.insert(hash, plan);
}
pub fn get_with_substitution(&mut self, hash: u64, literals: &[Literal]) -> Option<PlanNode> {
match self.cache.get(&hash) {
Some(template) => {
self.hits += 1;
let mut plan = template.clone();
let mut idx = 0usize;
substitute_plan(&mut plan, literals, &mut idx);
debug_assert_eq!(
idx,
literals.len(),
"plan substitution consumed {idx} literals but query had {}",
literals.len(),
);
Some(plan)
}
None => {
self.misses += 1;
None
}
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
}
pub(crate) fn substitute_plan(plan: &mut PlanNode, literals: &[Literal], idx: &mut usize) {
match plan {
PlanNode::SeqScan { .. } => {}
PlanNode::AliasScan { .. } => {}
PlanNode::IndexScan { key, .. } => {
substitute_expr(key, literals, idx);
}
PlanNode::RangeScan { start, end, .. } => {
if let Some((expr, _)) = start {
substitute_expr(expr, literals, idx);
}
if let Some((expr, _)) = end {
substitute_expr(expr, literals, idx);
}
}
PlanNode::Filter { input, predicate } => {
substitute_plan(input, literals, idx);
substitute_expr(predicate, literals, idx);
}
PlanNode::Project { input, fields } => {
substitute_plan(input, literals, idx);
for f in fields {
substitute_expr(&mut f.expr, literals, idx);
}
}
PlanNode::Sort { input, .. } => substitute_plan(input, literals, idx),
PlanNode::AlterTable { .. } => {}
PlanNode::DropTable { .. } => {}
PlanNode::Limit { input, count } => {
if let PlanNode::Offset {
input: inner,
count: off_count,
} = input.as_mut()
{
substitute_plan(inner, literals, idx);
substitute_expr(count, literals, idx);
substitute_expr(off_count, literals, idx);
} else {
substitute_plan(input, literals, idx);
substitute_expr(count, literals, idx);
}
}
PlanNode::Offset { input, count } => {
substitute_plan(input, literals, idx);
substitute_expr(count, literals, idx);
}
PlanNode::Aggregate { input, .. } => {
substitute_plan(input, literals, idx);
}
PlanNode::NestedLoopJoin {
left, right, on, ..
} => {
substitute_plan(left, literals, idx);
substitute_plan(right, literals, idx);
if let Some(pred) = on {
substitute_expr(pred, literals, idx);
}
}
PlanNode::Distinct { input } => {
substitute_plan(input, literals, idx);
}
PlanNode::GroupBy { input, having, .. } => {
substitute_plan(input, literals, idx);
if let Some(pred) = having {
substitute_expr(pred, literals, idx);
}
}
PlanNode::Insert { assignments, .. } => {
substitute_assignments(assignments, literals, idx);
}
PlanNode::Upsert {
assignments,
on_conflict,
..
} => {
substitute_assignments(assignments, literals, idx);
substitute_assignments(on_conflict, literals, idx);
}
PlanNode::Update {
input, assignments, ..
} => {
substitute_plan(input, literals, idx);
substitute_assignments(assignments, literals, idx);
}
PlanNode::Delete { input, .. } => {
substitute_plan(input, literals, idx);
}
PlanNode::CreateTable { .. } => {}
PlanNode::CreateView { .. } => {}
PlanNode::RefreshView { .. } => {}
PlanNode::DropView { .. } => {}
PlanNode::Window { input, windows } => {
substitute_plan(input, literals, idx);
for w in windows {
for arg in &mut w.args {
substitute_expr(arg, literals, idx);
}
}
}
PlanNode::Union { left, right, .. } => {
substitute_plan(left, literals, idx);
substitute_plan(right, literals, idx);
}
PlanNode::Explain { input } => {
substitute_plan(input, literals, idx);
}
}
}
fn substitute_assignments(assignments: &mut [Assignment], literals: &[Literal], idx: &mut usize) {
for a in assignments {
substitute_expr(&mut a.value, literals, idx);
}
}
pub(crate) fn count_literal_slots(plan: &PlanNode) -> usize {
let mut n = 0usize;
count_plan(plan, &mut n);
n
}
fn count_plan(plan: &PlanNode, n: &mut usize) {
match plan {
PlanNode::SeqScan { .. } => {}
PlanNode::AliasScan { .. } => {}
PlanNode::IndexScan { key, .. } => count_expr(key, n),
PlanNode::RangeScan { start, end, .. } => {
if let Some((expr, _)) = start {
count_expr(expr, n);
}
if let Some((expr, _)) = end {
count_expr(expr, n);
}
}
PlanNode::Filter { input, predicate } => {
count_plan(input, n);
count_expr(predicate, n);
}
PlanNode::Project { input, fields } => {
count_plan(input, n);
for f in fields {
count_expr(&f.expr, n);
}
}
PlanNode::Sort { input, .. } => count_plan(input, n),
PlanNode::Limit { input, count } => {
if let PlanNode::Offset {
input: inner,
count: off_count,
} = input.as_ref()
{
count_plan(inner, n);
count_expr(count, n);
count_expr(off_count, n);
} else {
count_plan(input, n);
count_expr(count, n);
}
}
PlanNode::Offset { input, count } => {
count_plan(input, n);
count_expr(count, n);
}
PlanNode::Aggregate { input, .. } => count_plan(input, n),
PlanNode::NestedLoopJoin {
left, right, on, ..
} => {
count_plan(left, n);
count_plan(right, n);
if let Some(pred) = on {
count_expr(pred, n);
}
}
PlanNode::Distinct { input } => count_plan(input, n),
PlanNode::GroupBy { input, having, .. } => {
count_plan(input, n);
if let Some(pred) = having {
count_expr(pred, n);
}
}
PlanNode::Insert { assignments, .. } => {
for a in assignments {
count_expr(&a.value, n);
}
}
PlanNode::Upsert {
assignments,
on_conflict,
..
} => {
for a in assignments {
count_expr(&a.value, n);
}
for a in on_conflict {
count_expr(&a.value, n);
}
}
PlanNode::Update {
input, assignments, ..
} => {
count_plan(input, n);
for a in assignments {
count_expr(&a.value, n);
}
}
PlanNode::Delete { input, .. } => count_plan(input, n),
PlanNode::CreateTable { .. } => {}
PlanNode::AlterTable { .. } => {}
PlanNode::DropTable { .. } => {}
PlanNode::CreateView { .. } => {}
PlanNode::RefreshView { .. } => {}
PlanNode::DropView { .. } => {}
PlanNode::Window { input, windows } => {
count_plan(input, n);
for w in windows {
for arg in &w.args {
count_expr(arg, n);
}
}
}
PlanNode::Union { left, right, .. } => {
count_plan(left, n);
count_plan(right, n);
}
PlanNode::Explain { input } => {
count_plan(input, n);
}
}
}
fn count_expr(expr: &Expr, n: &mut usize) {
match expr {
Expr::Literal(_) => *n += 1,
Expr::Field(_) | Expr::QualifiedField { .. } | Expr::Param(_) => {}
Expr::BinaryOp(l, _, r) => {
count_expr(l, n);
count_expr(r, n);
}
Expr::UnaryOp(_, inner) => count_expr(inner, n),
Expr::FunctionCall(_, inner) => count_expr(inner, n),
Expr::Coalesce(l, r) => {
count_expr(l, n);
count_expr(r, n);
}
Expr::InList { expr, list, .. } => {
count_expr(expr, n);
for item in list {
count_expr(item, n);
}
}
Expr::ScalarFunc(_, args) => {
for a in args {
count_expr(a, n);
}
}
Expr::Cast(inner, _) => count_expr(inner, n),
Expr::Case { whens, else_expr } => {
for (cond, result) in whens {
count_expr(cond, n);
count_expr(result, n);
}
if let Some(e) = else_expr {
count_expr(e, n);
}
}
Expr::InSubquery { expr, .. } => {
count_expr(expr, n);
}
Expr::ExistsSubquery { .. } => {
}
Expr::Window { args, .. } => {
for a in args {
count_expr(a, n);
}
}
}
}
fn substitute_expr(expr: &mut Expr, literals: &[Literal], idx: &mut usize) {
match expr {
Expr::Literal(_) => {
*expr = Expr::Literal(literals[*idx].clone());
*idx += 1;
}
Expr::Field(_) | Expr::QualifiedField { .. } | Expr::Param(_) => {}
Expr::BinaryOp(l, _, r) => {
substitute_expr(l, literals, idx);
substitute_expr(r, literals, idx);
}
Expr::UnaryOp(_, inner) => {
substitute_expr(inner, literals, idx);
}
Expr::FunctionCall(_, inner) => {
substitute_expr(inner, literals, idx);
}
Expr::Coalesce(l, r) => {
substitute_expr(l, literals, idx);
substitute_expr(r, literals, idx);
}
Expr::InList { expr, list, .. } => {
substitute_expr(expr, literals, idx);
for item in list {
substitute_expr(item, literals, idx);
}
}
Expr::ScalarFunc(_, args) => {
for a in args {
substitute_expr(a, literals, idx);
}
}
Expr::Cast(inner, _) => substitute_expr(inner, literals, idx),
Expr::Case { whens, else_expr } => {
for (cond, result) in whens {
substitute_expr(cond, literals, idx);
substitute_expr(result, literals, idx);
}
if let Some(e) = else_expr {
substitute_expr(e, literals, idx);
}
}
Expr::InSubquery { expr, .. } => {
substitute_expr(expr, literals, idx);
}
Expr::ExistsSubquery { .. } => {
}
Expr::Window { args, .. } => {
for a in args {
substitute_expr(a, literals, idx);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::canonicalize::canonicalize;
use crate::planner;
#[test]
fn test_cache_hit_substitutes_literal() {
let mut cache = PlanCache::new(100);
let q1 = "User filter .id = 42";
let (h1, lits1) = canonicalize(q1).unwrap();
let p1 = planner::plan(q1).unwrap();
cache.insert(h1, p1);
let q2 = "User filter .id = 99";
let (h2, lits2) = canonicalize(q2).unwrap();
assert_eq!(h1, h2, "different literals must hash the same");
let plan = cache.get_with_substitution(h2, &lits2).expect("hit");
match plan {
PlanNode::IndexScan { key, .. } => {
assert_eq!(key, Expr::Literal(Literal::Int(99)));
}
other => panic!("expected IndexScan, got {other:?}"),
}
assert_eq!(lits1, vec![Literal::Int(42)]);
assert_eq!(cache.hits, 1);
assert_eq!(cache.misses, 0);
}
#[test]
fn test_cache_miss_returns_none_and_bumps_counter() {
let mut cache = PlanCache::new(100);
assert!(cache.get_with_substitution(99999, &[]).is_none());
assert_eq!(cache.misses, 1);
assert_eq!(cache.hits, 0);
}
#[test]
fn test_multi_literal_filter_substitution() {
let mut cache = PlanCache::new(100);
let q1 = r#"User filter .age > 30 and .status = "active" { .name }"#;
let (h1, _) = canonicalize(q1).unwrap();
cache.insert(h1, planner::plan(q1).unwrap());
let q2 = r#"User filter .age > 50 and .status = "pending" { .name }"#;
let (h2, lits2) = canonicalize(q2).unwrap();
let plan = cache.get_with_substitution(h2, &lits2).expect("hit");
let mut found = Vec::new();
collect_literals_for_test(&plan, &mut found);
assert_eq!(
found,
vec![Literal::Int(50), Literal::String("pending".into()),]
);
}
#[test]
fn test_update_by_pk_substitution() {
let mut cache = PlanCache::new(100);
let q1 = "User filter .id = 1 update { age := 100 }";
let (h1, _) = canonicalize(q1).unwrap();
cache.insert(h1, planner::plan(q1).unwrap());
let q2 = "User filter .id = 7 update { age := 200 }";
let (h2, lits2) = canonicalize(q2).unwrap();
let plan = cache.get_with_substitution(h2, &lits2).expect("hit");
let mut found = Vec::new();
collect_literals_for_test(&plan, &mut found);
assert_eq!(found, vec![Literal::Int(7), Literal::Int(200)]);
}
#[test]
fn test_insert_substitution() {
let mut cache = PlanCache::new(100);
let q1 = r#"insert User { id := 1, name := "Alice", age := 20 }"#;
let (h1, _) = canonicalize(q1).unwrap();
cache.insert(h1, planner::plan(q1).unwrap());
let q2 = r#"insert User { id := 2, name := "Bob", age := 30 }"#;
let (h2, lits2) = canonicalize(q2).unwrap();
let plan = cache.get_with_substitution(h2, &lits2).expect("hit");
let mut found = Vec::new();
collect_literals_for_test(&plan, &mut found);
assert_eq!(
found,
vec![
Literal::Int(2),
Literal::String("Bob".into()),
Literal::Int(30),
]
);
}
#[test]
fn test_eviction_on_capacity() {
let mut cache = PlanCache::new(2);
let q1 = "User";
let q2 = "User filter .age > 1";
let _q3 = "User filter .age > 2";
let q3_distinct = "User filter .id = 5";
let (h1, _) = canonicalize(q1).unwrap();
let (h2, _) = canonicalize(q2).unwrap();
let (h3, _) = canonicalize(q3_distinct).unwrap();
cache.insert(h1, planner::plan(q1).unwrap());
cache.insert(h2, planner::plan(q2).unwrap());
cache.insert(h3, planner::plan(q3_distinct).unwrap());
assert!(cache.cache.contains_key(&h3));
assert_eq!(cache.cache.len(), 1);
}
fn collect_literals_for_test(plan: &PlanNode, out: &mut Vec<Literal>) {
match plan {
PlanNode::SeqScan { .. } => {}
PlanNode::AliasScan { .. } => {}
PlanNode::IndexScan { key, .. } => collect_expr_literals(key, out),
PlanNode::RangeScan { start, end, .. } => {
if let Some((expr, _)) = start {
collect_expr_literals(expr, out);
}
if let Some((expr, _)) = end {
collect_expr_literals(expr, out);
}
}
PlanNode::Filter { input, predicate } => {
collect_literals_for_test(input, out);
collect_expr_literals(predicate, out);
}
PlanNode::Project { input, fields } => {
collect_literals_for_test(input, out);
for f in fields {
collect_expr_literals(&f.expr, out);
}
}
PlanNode::Sort { input, .. } => collect_literals_for_test(input, out),
PlanNode::Limit { input, count } => {
collect_literals_for_test(input, out);
collect_expr_literals(count, out);
}
PlanNode::Offset { input, count } => {
collect_literals_for_test(input, out);
collect_expr_literals(count, out);
}
PlanNode::Aggregate { input, .. } => collect_literals_for_test(input, out),
PlanNode::NestedLoopJoin {
left, right, on, ..
} => {
collect_literals_for_test(left, out);
collect_literals_for_test(right, out);
if let Some(pred) = on {
collect_expr_literals(pred, out);
}
}
PlanNode::Insert { assignments, .. } => {
for a in assignments {
collect_expr_literals(&a.value, out);
}
}
PlanNode::Upsert {
assignments,
on_conflict,
..
} => {
for a in assignments {
collect_expr_literals(&a.value, out);
}
for a in on_conflict {
collect_expr_literals(&a.value, out);
}
}
PlanNode::Update {
input, assignments, ..
} => {
collect_literals_for_test(input, out);
for a in assignments {
collect_expr_literals(&a.value, out);
}
}
PlanNode::Distinct { input } => collect_literals_for_test(input, out),
PlanNode::GroupBy { input, having, .. } => {
collect_literals_for_test(input, out);
if let Some(pred) = having {
collect_expr_literals(pred, out);
}
}
PlanNode::Delete { input, .. } => collect_literals_for_test(input, out),
PlanNode::CreateTable { .. } => {}
PlanNode::AlterTable { .. } => {}
PlanNode::DropTable { .. } => {}
PlanNode::CreateView { .. } => {}
PlanNode::RefreshView { .. } => {}
PlanNode::DropView { .. } => {}
PlanNode::Window { input, windows } => {
collect_literals_for_test(input, out);
for w in windows {
for arg in &w.args {
collect_expr_literals(arg, out);
}
}
}
PlanNode::Union { left, right, .. } => {
collect_literals_for_test(left, out);
collect_literals_for_test(right, out);
}
PlanNode::Explain { input } => {
collect_literals_for_test(input, out);
}
}
}
fn collect_expr_literals(expr: &Expr, out: &mut Vec<Literal>) {
match expr {
Expr::Literal(l) => out.push(l.clone()),
Expr::Field(_) | Expr::QualifiedField { .. } | Expr::Param(_) => {}
Expr::BinaryOp(l, _, r) => {
collect_expr_literals(l, out);
collect_expr_literals(r, out);
}
Expr::UnaryOp(_, inner) => collect_expr_literals(inner, out),
Expr::FunctionCall(_, inner) => collect_expr_literals(inner, out),
Expr::Coalesce(l, r) => {
collect_expr_literals(l, out);
collect_expr_literals(r, out);
}
Expr::InList { expr, list, .. } => {
collect_expr_literals(expr, out);
for item in list {
collect_expr_literals(item, out);
}
}
Expr::ScalarFunc(_, args) => {
for a in args {
collect_expr_literals(a, out);
}
}
Expr::Cast(inner, _) => collect_expr_literals(inner, out),
Expr::Case { whens, else_expr } => {
for (cond, result) in whens {
collect_expr_literals(cond, out);
collect_expr_literals(result, out);
}
if let Some(e) = else_expr {
collect_expr_literals(e, out);
}
}
Expr::InSubquery { expr, .. } => {
collect_expr_literals(expr, out);
}
Expr::ExistsSubquery { .. } => {}
Expr::Window { args, .. } => {
for a in args {
collect_expr_literals(a, out);
}
}
}
}
}