Skip to main content

citadel_sql/
planner.rs

1//! Query planner: chooses between seq scan, PK lookup, or index scan.
2
3use crate::encoding::encode_composite_key;
4use crate::parser::{BinOp, Expr};
5use crate::types::{IndexDef, TableSchema, Value};
6
7#[derive(Debug, Clone)]
8pub enum ScanPlan {
9    SeqScan,
10    PkLookup {
11        pk_values: Vec<Value>,
12    },
13    PkRangeScan {
14        start_key: Vec<u8>,
15        range_conds: Vec<(BinOp, Value)>,
16        num_pk_cols: usize,
17    },
18    IndexScan {
19        index_name: String,
20        idx_table: Vec<u8>,
21        prefix: Vec<u8>,
22        num_prefix_cols: usize,
23        range_conds: Vec<(BinOp, Value)>,
24        is_unique: bool,
25        index_columns: Vec<u16>,
26    },
27}
28
29struct SimplePredicate {
30    col_idx: usize,
31    op: BinOp,
32    value: Value,
33}
34
35fn flatten_and(expr: &Expr) -> Vec<&Expr> {
36    match expr {
37        Expr::BinaryOp {
38            left,
39            op: BinOp::And,
40            right,
41        } => {
42            let mut v = flatten_and(left);
43            v.extend(flatten_and(right));
44            v
45        }
46        _ => vec![expr],
47    }
48}
49
50fn is_comparison(op: BinOp) -> bool {
51    matches!(
52        op,
53        BinOp::Eq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq
54    )
55}
56
57fn is_range_op(op: BinOp) -> bool {
58    matches!(op, BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq)
59}
60
61fn flip_op(op: BinOp) -> BinOp {
62    match op {
63        BinOp::Lt => BinOp::Gt,
64        BinOp::LtEq => BinOp::GtEq,
65        BinOp::Gt => BinOp::Lt,
66        BinOp::GtEq => BinOp::LtEq,
67        other => other,
68    }
69}
70
71fn resolve_column_name(expr: &Expr) -> Option<&str> {
72    match expr {
73        Expr::Column(name) => Some(name.as_str()),
74        Expr::QualifiedColumn { column, .. } => Some(column.as_str()),
75        _ => None,
76    }
77}
78
79fn resolve_literal(expr: &Expr) -> Option<Value> {
80    match expr {
81        Expr::Literal(v) => Some(v.clone()),
82        Expr::Parameter(n) => crate::eval::resolve_scoped_param(*n).ok(),
83        _ => None,
84    }
85}
86
87fn extract_simple_predicate(expr: &Expr, schema: &TableSchema) -> Option<SimplePredicate> {
88    match expr {
89        Expr::BinaryOp { left, op, right } if is_comparison(*op) => {
90            if let (Some(name), Some(val)) = (resolve_column_name(left), resolve_literal(right)) {
91                let col_idx = schema.column_index(name)?;
92                return Some(SimplePredicate {
93                    col_idx,
94                    op: *op,
95                    value: val,
96                });
97            }
98            if let (Some(val), Some(name)) = (resolve_literal(left), resolve_column_name(right)) {
99                let col_idx = schema.column_index(name)?;
100                return Some(SimplePredicate {
101                    col_idx,
102                    op: flip_op(*op),
103                    value: val,
104                });
105            }
106            None
107        }
108        _ => None,
109    }
110}
111
112/// Decompose BETWEEN into two range predicates for planner use.
113fn flatten_between(expr: &Expr, schema: &TableSchema, out: &mut Vec<SimplePredicate>) {
114    match expr {
115        Expr::Between {
116            expr: col_expr,
117            low,
118            high,
119            negated: false,
120        } => {
121            if let (Some(name), Some(lo), Some(hi)) = (
122                resolve_column_name(col_expr),
123                resolve_literal(low),
124                resolve_literal(high),
125            ) {
126                if let Some(col_idx) = schema.column_index(name) {
127                    out.push(SimplePredicate {
128                        col_idx,
129                        op: BinOp::GtEq,
130                        value: lo,
131                    });
132                    out.push(SimplePredicate {
133                        col_idx,
134                        op: BinOp::LtEq,
135                        value: hi,
136                    });
137                }
138            }
139        }
140        Expr::BinaryOp {
141            left,
142            op: BinOp::And,
143            right,
144        } => {
145            flatten_between(left, schema, out);
146            flatten_between(right, schema, out);
147        }
148        _ => {}
149    }
150}
151
152pub fn plan_select(schema: &TableSchema, where_clause: &Option<Expr>) -> ScanPlan {
153    let where_expr = match where_clause {
154        Some(e) => e,
155        None => return ScanPlan::SeqScan,
156    };
157
158    let predicates = flatten_and(where_expr);
159    let simple: Vec<Option<SimplePredicate>> = predicates
160        .iter()
161        .map(|p| extract_simple_predicate(p, schema))
162        .collect();
163
164    if let Some(plan) = try_pk_lookup(schema, &simple) {
165        return plan;
166    }
167
168    let mut range_preds: Vec<SimplePredicate> = simple
169        .iter()
170        .filter_map(|p| {
171            let p = p.as_ref()?;
172            if is_range_op(p.op) {
173                Some(SimplePredicate {
174                    col_idx: p.col_idx,
175                    op: p.op,
176                    value: p.value.clone(),
177                })
178            } else {
179                None
180            }
181        })
182        .collect();
183    flatten_between(where_expr, schema, &mut range_preds);
184
185    if let Some(plan) = try_pk_range_scan(schema, &range_preds) {
186        return plan;
187    }
188
189    if let Some(plan) = try_best_index(schema, where_expr, &simple) {
190        return plan;
191    }
192
193    ScanPlan::SeqScan
194}
195
196fn try_pk_range_scan(schema: &TableSchema, range_preds: &[SimplePredicate]) -> Option<ScanPlan> {
197    if schema.primary_key_columns.len() != 1 {
198        return None; // Only single-column PK for now
199    }
200    let pk_col = schema.primary_key_columns[0] as usize;
201    let conds: Vec<(BinOp, Value)> = range_preds
202        .iter()
203        .filter(|p| p.col_idx == pk_col)
204        .map(|p| (p.op, p.value.clone()))
205        .collect();
206    if conds.is_empty() {
207        return None;
208    }
209    let start_key = conds
210        .iter()
211        .filter(|(op, _)| matches!(op, BinOp::GtEq | BinOp::Gt))
212        .map(|(_, v)| encode_composite_key(std::slice::from_ref(v)))
213        .min_by(|a, b| a.cmp(b))
214        .unwrap_or_default();
215    Some(ScanPlan::PkRangeScan {
216        start_key,
217        range_conds: conds,
218        num_pk_cols: 1,
219    })
220}
221
222fn try_pk_lookup(schema: &TableSchema, predicates: &[Option<SimplePredicate>]) -> Option<ScanPlan> {
223    let pk_cols = &schema.primary_key_columns;
224    let mut pk_values: Vec<Option<Value>> = vec![None; pk_cols.len()];
225
226    for pred in predicates.iter().flatten() {
227        if pred.op == BinOp::Eq {
228            if let Some(pk_pos) = pk_cols.iter().position(|&c| c == pred.col_idx as u16) {
229                pk_values[pk_pos] = Some(pred.value.clone());
230            }
231        }
232    }
233
234    if pk_values.iter().all(|v| v.is_some()) {
235        let values: Vec<Value> = pk_values.into_iter().map(|v| v.unwrap()).collect();
236        Some(ScanPlan::PkLookup { pk_values: values })
237    } else {
238        None
239    }
240}
241
242#[derive(PartialEq, Eq, PartialOrd, Ord)]
243struct IndexScore {
244    num_equality: usize,
245    has_range: bool,
246    is_unique: bool,
247}
248
249fn try_best_index(
250    schema: &TableSchema,
251    where_expr: &Expr,
252    predicates: &[Option<SimplePredicate>],
253) -> Option<ScanPlan> {
254    let mut best_score: Option<IndexScore> = None;
255    let mut best_plan: Option<ScanPlan> = None;
256
257    let conjuncts = flatten_and(where_expr);
258    for idx in &schema.indices {
259        if !partial_predicate_implied(idx, where_expr, &conjuncts) {
260            continue;
261        }
262        if let Some((score, plan)) = try_index_scan(schema, idx, predicates) {
263            if best_score.is_none() || score > *best_score.as_ref().unwrap() {
264                best_score = Some(score);
265                best_plan = Some(plan);
266            }
267        }
268    }
269
270    best_plan
271}
272
273fn partial_predicate_implied(idx: &IndexDef, where_expr: &Expr, conjuncts: &[&Expr]) -> bool {
274    let Some(pred) = idx.predicate_expr.as_ref() else {
275        return true;
276    };
277    if expr_structurally_eq(pred, where_expr) {
278        return true;
279    }
280    if conjuncts.iter().any(|c| expr_structurally_eq(pred, c)) {
281        return true;
282    }
283    if let Expr::IsNotNull(target) = pred {
284        if let Expr::Column(col) = target.as_ref() {
285            return conjuncts.iter().any(|c| conjunct_proves_not_null(c, col));
286        }
287    }
288    false
289}
290
291fn expr_structurally_eq(a: &Expr, b: &Expr) -> bool {
292    format!("{a:?}") == format!("{b:?}")
293}
294
295fn conjunct_proves_not_null(expr: &Expr, col: &str) -> bool {
296    let mentions = |e: &Expr| matches!(e, Expr::Column(n) if n.eq_ignore_ascii_case(col));
297    match expr {
298        Expr::BinaryOp {
299            left,
300            op: BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq,
301            right,
302        } => mentions(left) || mentions(right),
303        Expr::IsNotNull(inner) => mentions(inner),
304        _ => false,
305    }
306}
307
308fn try_index_scan(
309    schema: &TableSchema,
310    idx: &IndexDef,
311    predicates: &[Option<SimplePredicate>],
312) -> Option<(IndexScore, ScanPlan)> {
313    let mut used = Vec::new();
314    let mut equality_values: Vec<Value> = Vec::new();
315    let mut range_conds: Vec<(BinOp, Value)> = Vec::new();
316
317    for &col_idx in &idx.columns {
318        let mut found_eq = false;
319        for (i, pred) in predicates.iter().enumerate() {
320            if used.contains(&i) {
321                continue;
322            }
323            if let Some(sp) = pred {
324                if sp.col_idx == col_idx as usize && sp.op == BinOp::Eq {
325                    equality_values.push(sp.value.clone());
326                    used.push(i);
327                    found_eq = true;
328                    break;
329                }
330            }
331        }
332        if !found_eq {
333            for (i, pred) in predicates.iter().enumerate() {
334                if used.contains(&i) {
335                    continue;
336                }
337                if let Some(sp) = pred {
338                    if sp.col_idx == col_idx as usize && is_range_op(sp.op) {
339                        range_conds.push((sp.op, sp.value.clone()));
340                        used.push(i);
341                    }
342                }
343            }
344            break;
345        }
346    }
347
348    if equality_values.is_empty() && range_conds.is_empty() {
349        return None;
350    }
351
352    let score = IndexScore {
353        num_equality: equality_values.len(),
354        has_range: !range_conds.is_empty(),
355        is_unique: idx.unique,
356    };
357
358    let prefix = encode_composite_key(&equality_values);
359    let idx_table = TableSchema::index_table_name(&schema.name, &idx.name);
360
361    Some((
362        score,
363        ScanPlan::IndexScan {
364            index_name: idx.name.clone(),
365            idx_table,
366            prefix,
367            num_prefix_cols: equality_values.len(),
368            range_conds,
369            is_unique: idx.unique,
370            index_columns: idx.columns.clone(),
371        },
372    ))
373}
374
375pub fn describe_plan(plan: &ScanPlan, table_schema: &TableSchema) -> String {
376    match plan {
377        ScanPlan::SeqScan => String::new(),
378
379        ScanPlan::PkLookup { pk_values } => {
380            let pk_cols: Vec<&str> = table_schema
381                .primary_key_columns
382                .iter()
383                .map(|&idx| table_schema.columns[idx as usize].name.as_str())
384                .collect();
385            let conditions: Vec<String> = pk_cols
386                .iter()
387                .zip(pk_values.iter())
388                .map(|(col, val)| format!("{col} = {}", format_value(val)))
389                .collect();
390            format!("USING PRIMARY KEY ({})", conditions.join(", "))
391        }
392
393        ScanPlan::PkRangeScan { range_conds, .. } => {
394            let pk_col = &table_schema.columns[table_schema.primary_key_columns[0] as usize].name;
395            let conditions: Vec<String> = range_conds
396                .iter()
397                .map(|(op, val)| format!("{pk_col} {} {}", op_symbol(*op), format_value(val)))
398                .collect();
399            format!("USING PRIMARY KEY RANGE ({})", conditions.join(", "))
400        }
401
402        ScanPlan::IndexScan {
403            index_name,
404            num_prefix_cols,
405            range_conds,
406            index_columns,
407            ..
408        } => {
409            let mut conditions = Vec::new();
410            for &col in index_columns.iter().take(*num_prefix_cols) {
411                let col_idx = col as usize;
412                let col_name = &table_schema.columns[col_idx].name;
413                conditions.push(format!("{col_name} = ?"));
414            }
415            if !range_conds.is_empty() && *num_prefix_cols < index_columns.len() {
416                let col_idx = index_columns[*num_prefix_cols] as usize;
417                let col_name = &table_schema.columns[col_idx].name;
418                for (op, _) in range_conds {
419                    conditions.push(format!("{col_name} {} ?", op_symbol(*op)));
420                }
421            }
422            if conditions.is_empty() {
423                format!("USING INDEX {index_name}")
424            } else {
425                format!("USING INDEX {index_name} ({})", conditions.join(", "))
426            }
427        }
428    }
429}
430
431fn format_value(val: &Value) -> String {
432    match val {
433        Value::Null => "NULL".into(),
434        Value::Integer(i) => i.to_string(),
435        Value::Real(f) => format!("{f}"),
436        Value::Text(s) => format!("'{s}'"),
437        Value::Blob(_) => "BLOB".into(),
438        Value::Boolean(b) => b.to_string(),
439        Value::Date(d) => format!("DATE '{}'", crate::datetime::format_date(*d)),
440        Value::Time(t) => format!("TIME '{}'", crate::datetime::format_time(*t)),
441        Value::Timestamp(t) => format!("TIMESTAMP '{}'", crate::datetime::format_timestamp(*t)),
442        Value::Interval {
443            months,
444            days,
445            micros,
446        } => format!(
447            "INTERVAL '{}'",
448            crate::datetime::format_interval(*months, *days, *micros)
449        ),
450    }
451}
452
453fn op_symbol(op: BinOp) -> &'static str {
454    match op {
455        BinOp::Lt => "<",
456        BinOp::LtEq => "<=",
457        BinOp::Gt => ">",
458        BinOp::GtEq => ">=",
459        BinOp::Eq => "=",
460        BinOp::NotEq => "!=",
461        _ => "?",
462    }
463}
464
465#[cfg(test)]
466#[path = "planner_tests.rs"]
467mod tests;