#![allow(clippy::doc_markdown)]
use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::vec::Vec;
use spg_sql::ast::{ColumnName, Expr, FromClause, FromJoin, JoinKind, SelectStatement, TableRef};
use crate::selectivity;
use crate::statistics::Statistics;
use spg_storage::Catalog;
pub const FULL_ENUM_MAX: usize = 4;
pub fn choose_order_for_test(
stmt: &SelectStatement,
catalog: &Catalog,
stats: &Statistics,
) -> Option<Vec<usize>> {
let mut clone = stmt.clone();
choose_order_inner(&mut clone, catalog, stats)
}
fn choose_order_inner(
stmt: &mut SelectStatement,
catalog: &Catalog,
stats: &Statistics,
) -> Option<Vec<usize>> {
let from = stmt.from.as_mut()?;
if from.joins.is_empty() {
return None;
}
if from
.joins
.iter()
.any(|j| !matches!(j.kind, JoinKind::Inner))
{
return None;
}
let mut tables: Vec<TableRef> = Vec::with_capacity(1 + from.joins.len());
tables.push(from.primary.clone());
for j in &from.joins {
tables.push(j.table.clone());
}
let n = tables.len();
let mut alias_to_idx: BTreeMap<String, usize> = BTreeMap::new();
for (i, t) in tables.iter().enumerate() {
let key = t.alias.clone().unwrap_or_else(|| t.name.clone());
alias_to_idx.insert(key, i);
if t.alias.is_some() {
alias_to_idx.entry(t.name.clone()).or_insert(i);
}
}
let mut edges: Vec<Edge> = Vec::new();
for j in &from.joins {
let on = j.on.as_ref()?;
for sub in split_and_conjunctions(on) {
let mut endpoint_set: Vec<usize> = Vec::new();
if !collect_referenced_tables(sub, &alias_to_idx, &mut endpoint_set) {
return None;
}
endpoint_set.sort_unstable();
endpoint_set.dedup();
edges.push(Edge {
endpoints: endpoint_set,
predicate: sub.clone(),
selectivity: estimate_edge_selectivity(sub, &tables, catalog, stats),
});
}
}
let mut sizes: Vec<u64> = Vec::with_capacity(n);
for t in &tables {
let table = catalog.get(&t.name)?;
sizes.push(table.rows().len() as u64);
}
Some(if n <= FULL_ENUM_MAX {
best_order_brute(n, &sizes, &edges)
} else {
best_order_greedy(n, &sizes, &edges)
})
}
pub fn reorder_joins(stmt: &mut SelectStatement, catalog: &Catalog, stats: &Statistics) {
let Some(from) = stmt.from.as_mut() else {
return;
};
if from.joins.is_empty() {
return;
}
if from
.joins
.iter()
.any(|j| !matches!(j.kind, JoinKind::Inner))
{
return;
}
if stats.is_empty() {
return;
}
let mut tables: Vec<TableRef> = Vec::with_capacity(1 + from.joins.len());
tables.push(from.primary.clone());
for j in &from.joins {
tables.push(j.table.clone());
}
let n = tables.len();
let mut alias_to_idx: BTreeMap<String, usize> = BTreeMap::new();
for (i, t) in tables.iter().enumerate() {
let key = t.alias.clone().unwrap_or_else(|| t.name.clone());
alias_to_idx.insert(key, i);
if t.alias.is_some() {
alias_to_idx.entry(t.name.clone()).or_insert(i);
}
}
let mut edges: Vec<Edge> = Vec::new();
for j in &from.joins {
let Some(on) = j.on.as_ref() else {
return;
};
for sub in split_and_conjunctions(on) {
let mut endpoint_set: Vec<usize> = Vec::new();
if !collect_referenced_tables(sub, &alias_to_idx, &mut endpoint_set) {
return;
}
endpoint_set.sort_unstable();
endpoint_set.dedup();
edges.push(Edge {
endpoints: endpoint_set,
predicate: sub.clone(),
selectivity: estimate_edge_selectivity(sub, &tables, catalog, stats),
});
}
}
let mut sizes: Vec<u64> = Vec::with_capacity(n);
for t in &tables {
let Some(table) = catalog.get(&t.name) else {
return;
};
sizes.push(table.rows().len() as u64);
}
let order: Vec<usize> = if n <= FULL_ENUM_MAX {
best_order_brute(n, &sizes, &edges)
} else {
best_order_greedy(n, &sizes, &edges)
};
if order.iter().enumerate().all(|(i, &j)| i == j) {
return;
}
rewrite_from(from, &tables, &edges, &order);
}
struct Edge {
endpoints: Vec<usize>,
predicate: Expr,
selectivity: f64,
}
fn split_and_conjunctions(expr: &Expr) -> Vec<&Expr> {
use spg_sql::ast::BinOp;
let mut out: Vec<&Expr> = Vec::new();
let mut stack: Vec<&Expr> = alloc::vec![expr];
while let Some(e) = stack.pop() {
if let Expr::Binary {
op: BinOp::And,
lhs,
rhs,
} = e
{
stack.push(rhs);
stack.push(lhs);
} else {
out.push(e);
}
}
out
}
fn collect_referenced_tables(
expr: &Expr,
alias_to_idx: &BTreeMap<String, usize>,
out: &mut Vec<usize>,
) -> bool {
match expr {
Expr::Column(ColumnName {
qualifier: Some(q),
..
}) => {
if let Some(&i) = alias_to_idx.get(q) {
out.push(i);
true
} else {
false
}
}
Expr::Column(_) => {
false
}
Expr::Literal(_) | Expr::Placeholder(_) => true,
Expr::Binary { lhs, rhs, .. } => {
collect_referenced_tables(lhs, alias_to_idx, out)
&& collect_referenced_tables(rhs, alias_to_idx, out)
}
Expr::Unary { expr, .. } => collect_referenced_tables(expr, alias_to_idx, out),
Expr::FunctionCall { args, .. } => args
.iter()
.all(|a| collect_referenced_tables(a, alias_to_idx, out)),
Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
collect_referenced_tables(expr, alias_to_idx, out)
}
Expr::Like {
expr: e, pattern, ..
} => {
collect_referenced_tables(e, alias_to_idx, out)
&& collect_referenced_tables(pattern, alias_to_idx, out)
}
_ => false,
}
}
fn estimate_edge_selectivity(
on: &Expr,
tables: &[TableRef],
catalog: &Catalog,
stats: &Statistics,
) -> f64 {
use spg_sql::ast::BinOp;
let Expr::Binary {
op: BinOp::Eq,
lhs,
rhs,
} = on
else {
return selectivity::DEFAULT_RANGE;
};
let lhs_col = column_ref(lhs);
let rhs_col = column_ref(rhs);
let (Some(lhs_col), Some(rhs_col)) = (lhs_col, rhs_col) else {
return selectivity::DEFAULT_RANGE;
};
let lhs_distinct = column_n_distinct(&lhs_col, tables, catalog, stats);
let rhs_distinct = column_n_distinct(&rhs_col, tables, catalog, stats);
let max_distinct = lhs_distinct.max(rhs_distinct).max(1);
1.0 / max_distinct as f64
}
fn column_ref(expr: &Expr) -> Option<(Option<String>, String)> {
if let Expr::Column(ColumnName { qualifier, name }) = expr {
Some((qualifier.clone(), name.clone()))
} else {
None
}
}
fn column_n_distinct(
col: &(Option<String>, String),
tables: &[TableRef],
catalog: &Catalog,
stats: &Statistics,
) -> u64 {
let Some(alias) = col.0.as_ref() else {
return 0;
};
let Some(table_name) = tables
.iter()
.find(|t| t.alias.as_deref() == Some(alias.as_str()) || t.name == *alias)
.map(|t| t.name.clone())
else {
return 0;
};
if let Some(s) = stats.get(&table_name, &col.1) {
return s.n_distinct.max(1);
}
catalog
.get(&table_name)
.map_or(1, |t| (t.rows().len() as u64).max(1))
}
fn best_order_brute(n: usize, sizes: &[u64], edges: &[Edge]) -> Vec<usize> {
let mut indices: Vec<usize> = (0..n).collect();
let mut best_cost = f64::INFINITY;
let mut best_order = indices.clone();
permute(&mut indices, 0, &mut |perm| {
let c = plan_cost(perm, sizes, edges);
if c < best_cost {
best_cost = c;
best_order = perm.to_vec();
}
});
best_order
}
fn permute<F: FnMut(&[usize])>(arr: &mut Vec<usize>, k: usize, visit: &mut F) {
if k >= arr.len() {
visit(arr);
return;
}
for i in k..arr.len() {
arr.swap(i, k);
permute(arr, k + 1, visit);
arr.swap(i, k);
}
}
fn best_order_greedy(n: usize, sizes: &[u64], edges: &[Edge]) -> Vec<usize> {
let mut chosen: Vec<usize> = Vec::with_capacity(n);
let mut remaining: Vec<usize> = (0..n).collect();
let &first = remaining
.iter()
.min_by_key(|&&i| sizes[i])
.expect("n > 0");
chosen.push(first);
remaining.retain(|&x| x != first);
while !remaining.is_empty() {
let mut best_cand = remaining[0];
let mut best_cost = f64::INFINITY;
for &cand in &remaining {
let mut probe = chosen.clone();
probe.push(cand);
let c = plan_cost(&probe, sizes, edges);
if c < best_cost {
best_cost = c;
best_cand = cand;
}
}
chosen.push(best_cand);
remaining.retain(|&x| x != best_cand);
}
chosen
}
fn plan_cost(order: &[usize], sizes: &[u64], edges: &[Edge]) -> f64 {
let mut running = sizes[order[0]] as f64;
let mut cost = 0.0_f64;
let mut in_prefix: Vec<bool> = alloc::vec![false; sizes.len()];
in_prefix[order[0]] = true;
for &table_idx in &order[1..] {
let right = sizes[table_idx] as f64;
cost += running * right;
in_prefix[table_idx] = true;
let mut step_output = running * right;
for edge in edges {
if edge
.endpoints
.iter()
.all(|&e| in_prefix[e])
{
if edge.endpoints.contains(&table_idx) {
step_output *= edge.selectivity;
}
}
}
running = step_output.max(1.0);
}
cost
}
fn rewrite_from(
from: &mut FromClause,
tables: &[TableRef],
edges: &[Edge],
order: &[usize],
) {
from.primary = tables[order[0]].clone();
from.joins.clear();
let mut in_prefix: Vec<bool> = alloc::vec![false; tables.len()];
in_prefix[order[0]] = true;
let mut edges_used: Vec<bool> = alloc::vec![false; edges.len()];
for &table_idx in &order[1..] {
in_prefix[table_idx] = true;
let mut combined: Option<Expr> = None;
for (ei, edge) in edges.iter().enumerate() {
if edges_used[ei] {
continue;
}
if edge.endpoints.contains(&table_idx)
&& edge.endpoints.iter().all(|&e| in_prefix[e])
{
edges_used[ei] = true;
combined = Some(match combined {
None => edge.predicate.clone(),
Some(prev) => Expr::Binary {
op: spg_sql::ast::BinOp::And,
lhs: alloc::boxed::Box::new(prev),
rhs: alloc::boxed::Box::new(edge.predicate.clone()),
},
});
}
}
let on = combined.unwrap_or_else(|| {
Expr::Literal(spg_sql::ast::Literal::Bool(true))
});
from.joins.push(FromJoin {
kind: JoinKind::Inner,
table: tables[table_idx].clone(),
on: Some(on),
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use spg_sql::parser;
#[test]
fn no_joins_is_noop() {
let mut stmt = match parser::parse_statement("SELECT * FROM users").unwrap() {
spg_sql::ast::Statement::Select(s) => s,
_ => panic!(),
};
let cat = Catalog::new();
let stats = Statistics::new();
let snap = stmt.clone();
reorder_joins(&mut stmt, &cat, &stats);
assert_eq!(stmt, snap);
}
#[test]
fn five_table_star_picks_fact_first() {
let mut e = crate::Engine::new();
e.execute("CREATE TABLE fact (id INT NOT NULL, k1 INT NOT NULL, k2 INT NOT NULL, k3 INT NOT NULL, k4 INT NOT NULL)").unwrap();
for tag in ["big1", "big2", "big3", "big4"] {
e.execute(&alloc::format!("CREATE TABLE {tag} (k INT NOT NULL)"))
.unwrap();
}
for i in 0..3 {
e.execute(&alloc::format!(
"INSERT INTO fact VALUES ({i}, {i}, {i}, {i}, {i})"
))
.unwrap();
}
for tag in ["big1", "big2", "big3", "big4"] {
for i in 0..40 {
e.execute(&alloc::format!("INSERT INTO {tag} VALUES ({i})"))
.unwrap();
}
}
e.execute("ANALYZE").unwrap();
let stmt = e.prepare(
"SELECT fact.id FROM big1 \
INNER JOIN big2 ON 1 = 1 \
INNER JOIN big3 ON 1 = 1 \
INNER JOIN big4 ON 1 = 1 \
INNER JOIN fact ON fact.k1 = big1.k AND fact.k2 = big2.k AND fact.k3 = big3.k AND fact.k4 = big4.k",
)
.unwrap();
let spg_sql::ast::Statement::Select(sel) = stmt else { panic!() };
let from = sel.from.unwrap();
assert_eq!(
from.primary.name, "fact",
"reorder must put fact first; got primary={:?}",
from.primary.name
);
}
#[test]
fn left_join_is_skipped() {
let mut stmt = match parser::parse_statement(
"SELECT * FROM a LEFT JOIN b ON a.id = b.id LEFT JOIN c ON b.id = c.id",
)
.unwrap()
{
spg_sql::ast::Statement::Select(s) => s,
_ => panic!(),
};
let cat = Catalog::new();
let stats = Statistics::new();
let snap = stmt.clone();
reorder_joins(&mut stmt, &cat, &stats);
assert_eq!(stmt, snap);
}
}