Skip to main content

powdb_query/
planner.rs

1use crate::ast::*;
2use crate::parser::{parse, ParseError};
3use crate::plan::*;
4
5/// (column_name, lower_bound, upper_bound) — used by range-index extraction.
6type RangeBound = (String, Option<(Expr, bool)>, Option<(Expr, bool)>);
7
8/// Plan-phase error — wraps ParseError for the full lex→parse→plan chain.
9#[derive(Debug)]
10pub enum PlanError {
11    /// Error originated in the parser (or lexer, via ParseError::Lex).
12    Parse(ParseError),
13}
14
15impl PlanError {
16    /// Convenience: human-readable message for any variant.
17    pub fn message(&self) -> String {
18        self.to_string()
19    }
20}
21
22impl std::fmt::Display for PlanError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Parse(e) => write!(f, "{e}"),
26        }
27    }
28}
29
30impl std::error::Error for PlanError {
31    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32        match self {
33            Self::Parse(e) => Some(e),
34        }
35    }
36}
37
38impl From<ParseError> for PlanError {
39    fn from(e: ParseError) -> Self {
40        PlanError::Parse(e)
41    }
42}
43
44pub fn plan(input: &str) -> Result<PlanNode, PlanError> {
45    let stmt = parse(input)?;
46    plan_statement(stmt)
47}
48
49pub fn plan_statement(stmt: Statement) -> Result<PlanNode, PlanError> {
50    match stmt {
51        Statement::Query(q) => plan_query(q),
52        Statement::Insert(ins) => plan_insert(ins),
53        Statement::UpdateQuery(upd) => plan_update(upd),
54        Statement::DeleteQuery(del) => plan_delete(del),
55        Statement::CreateType(ct) => plan_create_type(ct),
56        Statement::AlterTable(at) => Ok(PlanNode::AlterTable {
57            table: at.table,
58            action: at.action,
59        }),
60        Statement::DropTable(dt) => Ok(PlanNode::DropTable { name: dt.table }),
61        Statement::CreateView(cv) => Ok(PlanNode::CreateView {
62            name: cv.name,
63            query_text: cv.query_text,
64        }),
65        Statement::RefreshView(rv) => Ok(PlanNode::RefreshView { name: rv.name }),
66        Statement::DropView(dv) => Ok(PlanNode::DropView { name: dv.name }),
67        Statement::Union(u) => {
68            let left = plan_statement(*u.left)?;
69            let right = plan_statement(*u.right)?;
70            Ok(PlanNode::Union {
71                left: Box::new(left),
72                right: Box::new(right),
73                all: u.all,
74            })
75        }
76        Statement::Upsert(ups) => plan_upsert(ups),
77        Statement::Begin => Ok(PlanNode::Begin),
78        Statement::Commit => Ok(PlanNode::Commit),
79        Statement::Rollback => Ok(PlanNode::Rollback),
80        Statement::Explain(inner) => {
81            let inner_plan = plan_statement(*inner)?;
82            Ok(PlanNode::Explain {
83                input: Box::new(inner_plan),
84            })
85        }
86    }
87}
88
89fn plan_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
90    // Mission E1.2: if the query has joins, build a left-deep nested-loop
91    // plan. Correctness first — hash-join optimization is E1.3. We also
92    // don't try to fold an IndexScan under a joined query yet (the
93    // leaf-level fast paths all match on `PlanNode::SeqScan { .. }`
94    // literally, so mixing them into a join plan would silently break).
95    if !q.joins.is_empty() {
96        return plan_joined_query(q);
97    }
98    // Try to fold `filter .col = literal` into an IndexScan. The executor
99    // decides at run time whether the column actually has an index — if not,
100    // it transparently falls back to a sequential scan with the same predicate,
101    // so this rewrite is always safe.
102    //
103    // We only rewrite the *simple* eq case: `filter .col = literal`. Conjunctions
104    // like `filter .col = 1 and .other > 5` fall through to SeqScan + Filter.
105    // Extending this to split conjunctions is a future optimization.
106    let (source, filter) = match q.filter {
107        Some(pred) => match try_extract_eq_index_key(&q.source, &pred) {
108            Some(index_scan) => (index_scan, None),
109            None => match try_extract_range_index_keys(&q.source, &pred) {
110                Some(range_scan) => (range_scan, None),
111                None => (
112                    PlanNode::SeqScan {
113                        table: q.source.clone(),
114                    },
115                    Some(pred),
116                ),
117            },
118        },
119        None => (
120            PlanNode::SeqScan {
121                table: q.source.clone(),
122            },
123            None,
124        ),
125    };
126    let mut node = source;
127
128    if let Some(pred) = filter {
129        node = PlanNode::Filter {
130            input: Box::new(node),
131            predicate: pred,
132        };
133    }
134
135    // Mission E2b: GROUP BY path — insert GroupBy + Project before
136    // order/limit/offset/distinct.
137    if let Some(group) = q.group_by {
138        let mut proj_fields: Vec<ProjectField> = q
139            .projection
140            .map(|proj| {
141                proj.into_iter()
142                    .map(|pf| ProjectField {
143                        alias: pf.alias,
144                        expr: pf.expr,
145                    })
146                    .collect()
147            })
148            .unwrap_or_default();
149        let mut having = group.having;
150        let aggregates = extract_aggregates(&mut proj_fields, &mut having);
151
152        node = PlanNode::GroupBy {
153            input: Box::new(node),
154            keys: group.keys,
155            aggregates,
156            having,
157        };
158
159        if !proj_fields.is_empty() {
160            node = PlanNode::Project {
161                input: Box::new(node),
162                fields: proj_fields,
163            };
164        }
165
166        if let Some(order) = q.order {
167            node = PlanNode::Sort {
168                input: Box::new(node),
169                keys: order
170                    .keys
171                    .into_iter()
172                    .map(|k| SortKey {
173                        field: k.field,
174                        descending: k.descending,
175                    })
176                    .collect(),
177            };
178        }
179        // Offset must be applied *before* Limit: skip M rows, then take N.
180        // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
181        // and Limit wraps it (outer).
182        if let Some(off) = q.offset {
183            node = PlanNode::Offset {
184                input: Box::new(node),
185                count: off,
186            };
187        }
188        if let Some(lim) = q.limit {
189            node = PlanNode::Limit {
190                input: Box::new(node),
191                count: lim,
192            };
193        }
194        if q.distinct {
195            node = PlanNode::Distinct {
196                input: Box::new(node),
197            };
198        }
199        return Ok(node);
200    }
201
202    if let Some(order) = q.order {
203        node = PlanNode::Sort {
204            input: Box::new(node),
205            keys: order
206                .keys
207                .into_iter()
208                .map(|k| SortKey {
209                    field: k.field,
210                    descending: k.descending,
211                })
212                .collect(),
213        };
214    }
215
216    // Offset must be applied *before* Limit: skip M rows, then take N.
217    // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
218    // and Limit wraps it (outer).
219    if let Some(off) = q.offset {
220        node = PlanNode::Offset {
221            input: Box::new(node),
222            count: off,
223        };
224    }
225
226    if let Some(lim) = q.limit {
227        node = PlanNode::Limit {
228            input: Box::new(node),
229            count: lim,
230        };
231    }
232
233    if let Some(proj) = q.projection {
234        let mut fields: Vec<ProjectField> = proj
235            .into_iter()
236            .map(|pf| ProjectField {
237                alias: pf.alias,
238                expr: pf.expr,
239            })
240            .collect();
241        let windows = extract_windows(&mut fields);
242        if !windows.is_empty() {
243            node = PlanNode::Window {
244                input: Box::new(node),
245                windows,
246            };
247        }
248        node = PlanNode::Project {
249            input: Box::new(node),
250            fields,
251        };
252    }
253
254    if q.distinct {
255        node = PlanNode::Distinct {
256            input: Box::new(node),
257        };
258    }
259
260    if let Some(agg) = q.aggregation {
261        node = PlanNode::Aggregate {
262            input: Box::new(node),
263            function: agg.function,
264            field: agg.field,
265        };
266    }
267
268    Ok(node)
269}
270
271/// Build a left-deep nested-loop join plan for a query with 1+ join clauses.
272///
273/// The plan shape for `T1 as a [inner|left|cross] join T2 as b on <pred> ...` is:
274///
275///   Project? (optional, from q.projection)
276///   └─ Offset? / Limit? / Sort?
277///      └─ Filter? (the top-level q.filter, using qualified columns)
278///         └─ NestedLoopJoin { kind, on }
279///            ├─ AliasScan { T1, a }
280///            └─ AliasScan { T2, b }
281///
282/// Multi-join chains extend left-deep: a third join adds a second
283/// `NestedLoopJoin` on top, with the first join's output as its `left`.
284///
285/// Aliases default to the source table name when the query didn't write
286/// `as <name>` explicitly — that way users can always write `T.field`
287/// without being forced to alias every source.
288///
289/// RightOuter is rewritten into LeftOuter with inputs swapped — the two
290/// differ only in which side survives non-matching rows, and swapping
291/// inputs lets the executor ship a single LeftOuter path.
292fn plan_joined_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
293    let primary_alias = q.alias.clone().unwrap_or_else(|| q.source.clone());
294    let mut node = PlanNode::AliasScan {
295        table: q.source.clone(),
296        alias: primary_alias,
297    };
298
299    for join in q.joins {
300        let right_alias = join.alias.unwrap_or_else(|| join.source.clone());
301        let right = PlanNode::AliasScan {
302            table: join.source,
303            alias: right_alias,
304        };
305        match join.kind {
306            JoinKind::Inner | JoinKind::LeftOuter | JoinKind::Cross => {
307                node = PlanNode::NestedLoopJoin {
308                    left: Box::new(node),
309                    right: Box::new(right),
310                    on: join.on,
311                    kind: join.kind,
312                };
313            }
314            JoinKind::RightOuter => {
315                // `a RIGHT OUTER JOIN b ON <p>` ≡ `b LEFT OUTER JOIN a ON <p>`.
316                node = PlanNode::NestedLoopJoin {
317                    left: Box::new(right),
318                    right: Box::new(node),
319                    on: join.on,
320                    kind: JoinKind::LeftOuter,
321                };
322            }
323        }
324    }
325
326    if let Some(pred) = q.filter {
327        node = PlanNode::Filter {
328            input: Box::new(node),
329            predicate: pred,
330        };
331    }
332
333    if let Some(order) = q.order {
334        node = PlanNode::Sort {
335            input: Box::new(node),
336            keys: order
337                .keys
338                .into_iter()
339                .map(|k| SortKey {
340                    field: k.field,
341                    descending: k.descending,
342                })
343                .collect(),
344        };
345    }
346
347    // Offset must be applied *before* Limit: skip M rows, then take N.
348    // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
349    // and Limit wraps it (outer).
350    if let Some(off) = q.offset {
351        node = PlanNode::Offset {
352            input: Box::new(node),
353            count: off,
354        };
355    }
356
357    if let Some(lim) = q.limit {
358        node = PlanNode::Limit {
359            input: Box::new(node),
360            count: lim,
361        };
362    }
363
364    // Mission E2b: GROUP BY path for joined queries.
365    if let Some(group) = q.group_by {
366        let mut proj_fields: Vec<ProjectField> = q
367            .projection
368            .map(|proj| {
369                proj.into_iter()
370                    .map(|pf| ProjectField {
371                        alias: pf.alias,
372                        expr: pf.expr,
373                    })
374                    .collect()
375            })
376            .unwrap_or_default();
377        let mut having = group.having;
378        let aggregates = extract_aggregates(&mut proj_fields, &mut having);
379
380        node = PlanNode::GroupBy {
381            input: Box::new(node),
382            keys: group.keys,
383            aggregates,
384            having,
385        };
386
387        if !proj_fields.is_empty() {
388            node = PlanNode::Project {
389                input: Box::new(node),
390                fields: proj_fields,
391            };
392        }
393        if q.distinct {
394            node = PlanNode::Distinct {
395                input: Box::new(node),
396            };
397        }
398        return Ok(node);
399    }
400
401    if let Some(proj) = q.projection {
402        let mut fields: Vec<ProjectField> = proj
403            .into_iter()
404            .map(|pf| ProjectField {
405                alias: pf.alias,
406                expr: pf.expr,
407            })
408            .collect();
409        let windows = extract_windows(&mut fields);
410        if !windows.is_empty() {
411            node = PlanNode::Window {
412                input: Box::new(node),
413                windows,
414            };
415        }
416        node = PlanNode::Project {
417            input: Box::new(node),
418            fields,
419        };
420    }
421
422    if q.distinct {
423        node = PlanNode::Distinct {
424            input: Box::new(node),
425        };
426    }
427
428    if let Some(agg) = q.aggregation {
429        node = PlanNode::Aggregate {
430            input: Box::new(node),
431            function: agg.function,
432            field: agg.field,
433        };
434    }
435
436    Ok(node)
437}
438
439fn plan_insert(ins: InsertExpr) -> Result<PlanNode, PlanError> {
440    Ok(PlanNode::Insert {
441        table: ins.target,
442        rows: ins.rows,
443    })
444}
445
446fn plan_update(upd: UpdateExpr) -> Result<PlanNode, PlanError> {
447    // Mirror the read-side IndexScan fold: when the update filter is a simple
448    // `.col = literal`, emit `Update(IndexScan)` so the executor's index-lookup
449    // mutation fast path fires. The executor falls back to a scan if the
450    // column happens to lack an index, so this is always safe.
451    let source = match upd.filter {
452        Some(pred) => match try_extract_eq_index_key(&upd.source, &pred) {
453            Some(index_scan) => index_scan,
454            None => match try_extract_range_index_keys(&upd.source, &pred) {
455                Some(range_scan) => range_scan,
456                None => PlanNode::Filter {
457                    input: Box::new(PlanNode::SeqScan {
458                        table: upd.source.clone(),
459                    }),
460                    predicate: pred,
461                },
462            },
463        },
464        None => PlanNode::SeqScan {
465            table: upd.source.clone(),
466        },
467    };
468    Ok(PlanNode::Update {
469        input: Box::new(source),
470        table: upd.source,
471        assignments: upd.assignments,
472    })
473}
474
475fn plan_delete(del: DeleteExpr) -> Result<PlanNode, PlanError> {
476    let source = match del.filter {
477        Some(pred) => match try_extract_eq_index_key(&del.source, &pred) {
478            Some(index_scan) => index_scan,
479            None => match try_extract_range_index_keys(&del.source, &pred) {
480                Some(range_scan) => range_scan,
481                None => PlanNode::Filter {
482                    input: Box::new(PlanNode::SeqScan {
483                        table: del.source.clone(),
484                    }),
485                    predicate: pred,
486                },
487            },
488        },
489        None => PlanNode::SeqScan {
490            table: del.source.clone(),
491        },
492    };
493    Ok(PlanNode::Delete {
494        input: Box::new(source),
495        table: del.source,
496    })
497}
498
499fn plan_upsert(ups: UpsertExpr) -> Result<PlanNode, PlanError> {
500    Ok(PlanNode::Upsert {
501        table: ups.target,
502        key_column: ups.key_column,
503        assignments: ups.assignments,
504        on_conflict: ups.on_conflict,
505    })
506}
507
508fn plan_create_type(ct: CreateTypeExpr) -> Result<PlanNode, PlanError> {
509    let fields = ct
510        .fields
511        .into_iter()
512        .map(|f| crate::plan::CreateField {
513            name: f.name,
514            type_name: f.type_name,
515            required: f.required,
516            unique: f.unique,
517        })
518        .collect();
519    Ok(PlanNode::CreateTable {
520        name: ct.name,
521        fields,
522    })
523}
524
525/// If the predicate is a simple `.field = literal` (or `literal = .field`),
526/// return a corresponding IndexScan plan node. Otherwise return None so the
527/// caller can fall through to SeqScan + Filter.
528///
529/// The executor decides at run time whether the named column actually has a
530/// B-tree index — if not, IndexScan transparently falls back to a scan +
531/// equality filter on that column. That means this rewrite is always safe
532/// regardless of schema/index state; it just unlocks the fast path when an
533/// index happens to exist.
534fn try_extract_eq_index_key(table: &str, pred: &Expr) -> Option<PlanNode> {
535    let (lhs, op, rhs) = match pred {
536        Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
537        _ => return None,
538    };
539    if op != BinOp::Eq {
540        return None;
541    }
542    let (column, key) = match (lhs, rhs) {
543        (Expr::Field(name), Expr::Literal(_)) => (name.clone(), rhs.clone()),
544        (Expr::Literal(_), Expr::Field(name)) => (name.clone(), lhs.clone()),
545        _ => return None,
546    };
547    Some(PlanNode::IndexScan {
548        table: table.to_string(),
549        column,
550        key,
551    })
552}
553
554/// Extract a single range bound from a simple inequality predicate.
555/// Returns `(column, lower_bound, upper_bound)` where at most one bound is set.
556fn extract_single_bound(pred: &Expr) -> Option<RangeBound> {
557    let (lhs, op, rhs) = match pred {
558        Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
559        _ => return None,
560    };
561    match op {
562        // .col > literal  →  lower=(literal, exclusive)
563        BinOp::Gt => match (lhs, rhs) {
564            (Expr::Field(name), Expr::Literal(_)) => {
565                Some((name.clone(), Some((rhs.clone(), false)), None))
566            }
567            (Expr::Literal(_), Expr::Field(name)) => {
568                // literal > .col  →  col < literal  →  upper=(literal, exclusive)
569                Some((name.clone(), None, Some((lhs.clone(), false))))
570            }
571            _ => None,
572        },
573        // .col >= literal  →  lower=(literal, inclusive)
574        BinOp::Gte => match (lhs, rhs) {
575            (Expr::Field(name), Expr::Literal(_)) => {
576                Some((name.clone(), Some((rhs.clone(), true)), None))
577            }
578            (Expr::Literal(_), Expr::Field(name)) => {
579                Some((name.clone(), None, Some((lhs.clone(), true))))
580            }
581            _ => None,
582        },
583        // .col < literal  →  upper=(literal, exclusive)
584        BinOp::Lt => match (lhs, rhs) {
585            (Expr::Field(name), Expr::Literal(_)) => {
586                Some((name.clone(), None, Some((rhs.clone(), false))))
587            }
588            (Expr::Literal(_), Expr::Field(name)) => {
589                Some((name.clone(), Some((lhs.clone(), false)), None))
590            }
591            _ => None,
592        },
593        // .col <= literal  →  upper=(literal, inclusive)
594        BinOp::Lte => match (lhs, rhs) {
595            (Expr::Field(name), Expr::Literal(_)) => {
596                Some((name.clone(), None, Some((rhs.clone(), true))))
597            }
598            (Expr::Literal(_), Expr::Field(name)) => {
599                Some((name.clone(), Some((lhs.clone(), true)), None))
600            }
601            _ => None,
602        },
603        _ => None,
604    }
605}
606
607/// If the predicate is an inequality or a conjunction of two inequalities
608/// on the same indexed column, return a RangeScan plan node.
609/// Handles: `.col > lit`, `.col >= lit`, `.col < lit`, `.col <= lit`,
610/// and AND-conjunctions like `.col >= low AND .col <= high` (BETWEEN pattern).
611fn try_extract_range_index_keys(table: &str, pred: &Expr) -> Option<PlanNode> {
612    // Case 1: AND conjunction — try to merge two bounds on the same column.
613    if let Expr::BinaryOp(lhs, BinOp::And, rhs) = pred {
614        if let (Some((col1, s1, e1)), Some((col2, s2, e2))) =
615            (extract_single_bound(lhs), extract_single_bound(rhs))
616        {
617            if col1 == col2 {
618                let start = s1.or(s2);
619                let end = e1.or(e2);
620                if start.is_some() || end.is_some() {
621                    return Some(PlanNode::RangeScan {
622                        table: table.to_string(),
623                        column: col1,
624                        start,
625                        end,
626                    });
627                }
628            }
629        }
630    }
631
632    // Case 2: single inequality.
633    if let Some((col, start, end)) = extract_single_bound(pred) {
634        return Some(PlanNode::RangeScan {
635            table: table.to_string(),
636            column: col,
637            start,
638            end,
639        });
640    }
641
642    None
643}
644
645/// Walk projection fields, replacing every `Expr::Window { .. }` with
646/// `Expr::Field("__win_N")` and collecting the corresponding `WindowDef`
647/// descriptors. Returns the list of window definitions to insert as a
648/// `PlanNode::Window` before the `Project` node.
649fn extract_windows(proj_fields: &mut [ProjectField]) -> Vec<WindowDef> {
650    let mut defs = Vec::new();
651    let mut counter = 0usize;
652    for f in proj_fields.iter_mut() {
653        if let Expr::Window {
654            function,
655            args,
656            partition_by,
657            order_by,
658        } = &f.expr
659        {
660            let output_name = format!("__win_{counter}");
661            defs.push(WindowDef {
662                function: *function,
663                args: args.clone(),
664                partition_by: partition_by.clone(),
665                order_by: order_by
666                    .iter()
667                    .map(|k| SortKey {
668                        field: k.field.clone(),
669                        descending: k.descending,
670                    })
671                    .collect(),
672                output_name: output_name.clone(),
673            });
674            f.expr = Expr::Field(output_name);
675            counter += 1;
676        }
677    }
678    defs
679}
680
681/// Walk projection fields and HAVING expression, replacing every
682/// `Expr::FunctionCall(func, Field(col))` with `Expr::Field("__agg_N")`
683/// and collecting the corresponding `GroupAgg` descriptors. Deduplicates:
684/// if the same (func, field) pair appears in both projection and HAVING,
685/// they share a single `GroupAgg` entry.
686fn extract_aggregates(
687    proj_fields: &mut [ProjectField],
688    having: &mut Option<Expr>,
689) -> Vec<GroupAgg> {
690    let mut aggs: Vec<GroupAgg> = Vec::new();
691    let mut counter = 0usize;
692    for f in proj_fields.iter_mut() {
693        rewrite_agg_expr(&mut f.expr, &mut aggs, &mut counter);
694    }
695    if let Some(h) = having {
696        rewrite_agg_expr(h, &mut aggs, &mut counter);
697    }
698    aggs
699}
700
701fn rewrite_agg_expr(expr: &mut Expr, aggs: &mut Vec<GroupAgg>, counter: &mut usize) {
702    match expr {
703        Expr::FunctionCall(func, inner) => {
704            if let Expr::Field(name) = inner.as_ref() {
705                let output = find_or_insert_agg(aggs, *func, name, counter);
706                *expr = Expr::Field(output);
707            }
708        }
709        Expr::BinaryOp(l, _, r) => {
710            rewrite_agg_expr(l, aggs, counter);
711            rewrite_agg_expr(r, aggs, counter);
712        }
713        Expr::UnaryOp(_, inner) => rewrite_agg_expr(inner, aggs, counter),
714        Expr::Coalesce(l, r) => {
715            rewrite_agg_expr(l, aggs, counter);
716            rewrite_agg_expr(r, aggs, counter);
717        }
718        Expr::InList { expr: e, list, .. } => {
719            rewrite_agg_expr(e, aggs, counter);
720            for item in list {
721                rewrite_agg_expr(item, aggs, counter);
722            }
723        }
724        Expr::InSubquery { expr: e, .. } => {
725            rewrite_agg_expr(e, aggs, counter);
726        }
727        _ => {}
728    }
729}
730
731fn find_or_insert_agg(
732    aggs: &mut Vec<GroupAgg>,
733    func: AggFunc,
734    field: &str,
735    counter: &mut usize,
736) -> String {
737    for existing in aggs.iter() {
738        if existing.function == func && existing.field == field {
739            return existing.output_name.clone();
740        }
741    }
742    let output_name = format!("__agg_{counter}");
743    aggs.push(GroupAgg {
744        function: func,
745        field: field.to_string(),
746        output_name: output_name.clone(),
747    });
748    *counter += 1;
749    output_name
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use crate::plan::PlanNode;
756
757    #[test]
758    fn test_plan_simple_scan() {
759        let plan = plan("User").unwrap();
760        assert!(matches!(plan, PlanNode::SeqScan { table } if table == "User"));
761    }
762
763    #[test]
764    fn test_plan_filter() {
765        let plan = plan("User filter .age > 30").unwrap();
766        assert!(matches!(plan, PlanNode::RangeScan { .. }));
767    }
768
769    #[test]
770    fn test_plan_filter_with_projection() {
771        let plan = plan("User filter .age > 30 { name, email }").unwrap();
772        assert!(matches!(plan, PlanNode::Project { .. }));
773    }
774
775    #[test]
776    fn test_plan_insert() {
777        let plan = plan(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
778        assert!(matches!(plan, PlanNode::Insert { .. }));
779    }
780
781    #[test]
782    fn test_plan_order_limit() {
783        let plan = plan("User order .name limit 10").unwrap();
784        match plan {
785            PlanNode::Limit { input, .. } => {
786                assert!(matches!(*input, PlanNode::Sort { .. }));
787            }
788            _ => panic!("expected Limit(Sort(SeqScan))"),
789        }
790    }
791
792    #[test]
793    fn test_plan_count() {
794        let plan = plan("count(User)").unwrap();
795        assert!(matches!(plan, PlanNode::Aggregate { .. }));
796    }
797
798    #[test]
799    fn test_plan_eq_becomes_index_scan() {
800        // `filter .col = literal` should fold into an IndexScan — the executor
801        // falls back to a scan if the column happens to lack an index.
802        let plan = plan("User filter .id = 42").unwrap();
803        match plan {
804            PlanNode::IndexScan { table, column, key } => {
805                assert_eq!(table, "User");
806                assert_eq!(column, "id");
807                assert!(matches!(key, Expr::Literal(Literal::Int(42))));
808            }
809            other => panic!("expected IndexScan, got {other:?}"),
810        }
811    }
812
813    #[test]
814    fn test_plan_eq_reversed_becomes_index_scan() {
815        // Literal-on-the-left form should fold the same way.
816        let plan = plan(r#"User filter "NYC" = .city"#).unwrap();
817        assert!(matches!(plan, PlanNode::IndexScan { .. }));
818    }
819
820    #[test]
821    fn test_plan_non_eq_stays_filter() {
822        // `>` now emits a RangeScan instead of SeqScan+Filter.
823        let plan = plan("User filter .age > 30").unwrap();
824        match plan {
825            PlanNode::RangeScan {
826                column, start, end, ..
827            } => {
828                assert_eq!(column, "age");
829                assert!(start.is_some(), "expected lower bound");
830                assert!(end.is_none(), "expected no upper bound");
831                let (_, inclusive) = start.unwrap();
832                assert!(!inclusive, "expected exclusive lower bound for >");
833            }
834            other => panic!("expected RangeScan, got {other:?}"),
835        }
836    }
837
838    #[test]
839    fn test_plan_index_scan_with_projection() {
840        // Projection on top of an IndexScan should layer correctly.
841        let plan = plan("User filter .id = 1 { .name }").unwrap();
842        match plan {
843            PlanNode::Project { input, .. } => {
844                assert!(matches!(*input, PlanNode::IndexScan { .. }));
845            }
846            other => panic!("expected Project(IndexScan), got {other:?}"),
847        }
848    }
849
850    #[test]
851    fn test_plan_update_by_pk_becomes_index_scan() {
852        // `.id = literal` update should fold to Update(IndexScan), not
853        // Update(Filter(SeqScan)).
854        let plan = plan("User filter .id = 42 update { age := 31 }").unwrap();
855        match plan {
856            PlanNode::Update { input, .. } => {
857                assert!(
858                    matches!(*input, PlanNode::IndexScan { .. }),
859                    "expected Update(IndexScan), got {input:?}"
860                );
861            }
862            other => panic!("expected Update, got {other:?}"),
863        }
864    }
865
866    #[test]
867    fn test_plan_update_range_stays_range_scan() {
868        let plan = plan("User filter .age > 30 update { age := 31 }").unwrap();
869        match plan {
870            PlanNode::Update { input, .. } => {
871                assert!(
872                    matches!(*input, PlanNode::RangeScan { .. }),
873                    "expected Update(RangeScan), got {input:?}"
874                );
875            }
876            other => panic!("expected Update, got {other:?}"),
877        }
878    }
879
880    #[test]
881    fn test_plan_delete_by_pk_becomes_index_scan() {
882        let plan = plan("User filter .id = 7 delete").unwrap();
883        match plan {
884            PlanNode::Delete { input, .. } => {
885                assert!(matches!(*input, PlanNode::IndexScan { .. }));
886            }
887            other => panic!("expected Delete, got {other:?}"),
888        }
889    }
890
891    #[test]
892    fn test_plan_inner_join_builds_nested_loop() {
893        // Mission E1.2: a join query should plan to NestedLoopJoin with
894        // AliasScan leaves on both sides.
895        let plan = plan("User as u join Order as o on u.id = o.user_id").unwrap();
896        match plan {
897            PlanNode::NestedLoopJoin {
898                left,
899                right,
900                on,
901                kind,
902            } => {
903                assert_eq!(kind, JoinKind::Inner);
904                assert!(on.is_some());
905                assert!(matches!(*left, PlanNode::AliasScan { .. }));
906                assert!(matches!(*right, PlanNode::AliasScan { .. }));
907            }
908            other => panic!("expected NestedLoopJoin, got {other:?}"),
909        }
910    }
911
912    #[test]
913    fn test_plan_right_join_rewritten_as_left_with_swapped_inputs() {
914        let plan = plan("User as u right join Order as o on u.id = o.user_id").unwrap();
915        match plan {
916            PlanNode::NestedLoopJoin {
917                left, right, kind, ..
918            } => {
919                assert_eq!(kind, JoinKind::LeftOuter);
920                // Swapped: Order is now on the left, User on the right.
921                match *left {
922                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "Order"),
923                    other => panic!("expected AliasScan(Order), got {other:?}"),
924                }
925                match *right {
926                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "User"),
927                    other => panic!("expected AliasScan(User), got {other:?}"),
928                }
929            }
930            other => panic!("expected NestedLoopJoin, got {other:?}"),
931        }
932    }
933
934    #[test]
935    fn test_plan_multi_join_is_left_deep() {
936        // Three sources → two NestedLoopJoins, left-deep.
937        let plan = plan(
938            "User as u join Order as o on u.id = o.user_id \
939             join Product as p on o.product_id = p.id",
940        )
941        .unwrap();
942        match plan {
943            PlanNode::NestedLoopJoin { left, right, .. } => {
944                // Outer (Product) join: right is AliasScan(Product)
945                match *right {
946                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "Product"),
947                    other => panic!("expected AliasScan(Product), got {other:?}"),
948                }
949                // Outer.left is inner (Order) NestedLoopJoin
950                assert!(matches!(*left, PlanNode::NestedLoopJoin { .. }));
951            }
952            other => panic!("expected NestedLoopJoin, got {other:?}"),
953        }
954    }
955
956    #[test]
957    fn test_plan_join_with_filter_tail_wraps_filter_on_top() {
958        let plan =
959            plan("User as u join Order as o on u.id = o.user_id filter o.total > 100").unwrap();
960        match plan {
961            PlanNode::Filter { input, .. } => {
962                assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
963            }
964            other => panic!("expected Filter(NestedLoopJoin), got {other:?}"),
965        }
966    }
967
968    #[test]
969    fn test_plan_group_by_builds_groupby_node() {
970        let plan = plan("User group .status { .status, n: count(.name) }").unwrap();
971        // Should be Project(GroupBy(SeqScan)).
972        match plan {
973            PlanNode::Project { input, fields } => {
974                assert_eq!(fields.len(), 2);
975                match *input {
976                    PlanNode::GroupBy {
977                        input: inner,
978                        keys,
979                        aggregates,
980                        having,
981                    } => {
982                        assert!(matches!(*inner, PlanNode::SeqScan { .. }));
983                        assert_eq!(keys, vec!["status"]);
984                        assert_eq!(aggregates.len(), 1);
985                        assert_eq!(aggregates[0].function, AggFunc::Count);
986                        assert_eq!(aggregates[0].field, "name");
987                        assert!(having.is_none());
988                    }
989                    other => panic!("expected GroupBy, got {other:?}"),
990                }
991            }
992            other => panic!("expected Project, got {other:?}"),
993        }
994    }
995
996    #[test]
997    fn test_plan_group_by_having_rewrites_agg_in_having() {
998        let plan = plan("User group .status having count(.name) > 1 { .status }").unwrap();
999        match plan {
1000            PlanNode::Project { input, .. } => {
1001                match *input {
1002                    PlanNode::GroupBy {
1003                        having, aggregates, ..
1004                    } => {
1005                        // The planner should have extracted count(.name) into
1006                        // aggregates and rewritten the HAVING to reference __agg_0.
1007                        assert_eq!(aggregates.len(), 1);
1008                        assert_eq!(aggregates[0].output_name, "__agg_0");
1009                        let h = having.expect("having should be Some");
1010                        match h {
1011                            Expr::BinaryOp(l, BinOp::Gt, _) => {
1012                                assert!(
1013                                    matches!(*l, Expr::Field(ref name) if name == "__agg_0"),
1014                                    "expected Field(__agg_0), got {l:?}"
1015                                );
1016                            }
1017                            other => panic!("expected BinaryOp, got {other:?}"),
1018                        }
1019                    }
1020                    other => panic!("expected GroupBy, got {other:?}"),
1021                }
1022            }
1023            other => panic!("expected Project, got {other:?}"),
1024        }
1025    }
1026
1027    #[test]
1028    fn test_plan_window_inserts_window_node_before_project() {
1029        let plan = plan("User { .name, rn: row_number() over (order .age) }").unwrap();
1030        // Expected shape: Project(Window(SeqScan))
1031        match plan {
1032            PlanNode::Project { input, fields } => {
1033                assert_eq!(fields.len(), 2);
1034                // The window expr should have been replaced with Field("__win_0")
1035                assert!(
1036                    matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"),
1037                    "expected Field(__win_0), got {:?}",
1038                    fields[1].expr
1039                );
1040                match *input {
1041                    PlanNode::Window {
1042                        input: inner,
1043                        windows,
1044                    } => {
1045                        assert_eq!(windows.len(), 1);
1046                        assert_eq!(windows[0].output_name, "__win_0");
1047                        assert!(matches!(*inner, PlanNode::SeqScan { .. }));
1048                    }
1049                    other => panic!("expected Window, got {other:?}"),
1050                }
1051            }
1052            other => panic!("expected Project, got {other:?}"),
1053        }
1054    }
1055
1056    #[test]
1057    fn test_plan_multiple_windows() {
1058        let plan = plan(
1059            "User { .name, rn: row_number() over (order .age), s: sum(.salary) over (partition .dept order .salary) }"
1060        ).unwrap();
1061        match plan {
1062            PlanNode::Project { input, fields } => {
1063                assert_eq!(fields.len(), 3);
1064                assert!(matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"));
1065                assert!(matches!(&fields[2].expr, Expr::Field(name) if name == "__win_1"));
1066                match *input {
1067                    PlanNode::Window { windows, .. } => {
1068                        assert_eq!(windows.len(), 2);
1069                        assert_eq!(windows[0].output_name, "__win_0");
1070                        assert_eq!(windows[1].output_name, "__win_1");
1071                    }
1072                    other => panic!("expected Window, got {other:?}"),
1073                }
1074            }
1075            other => panic!("expected Project, got {other:?}"),
1076        }
1077    }
1078
1079    #[test]
1080    fn test_plan_no_window_without_over() {
1081        // Plain aggregate in projection should not create a Window node.
1082        let plan = plan("User group .dept { .dept, total: sum(.salary) }").unwrap();
1083        match plan {
1084            PlanNode::Project { input, .. } => {
1085                // Input should be GroupBy, not Window.
1086                assert!(
1087                    matches!(*input, PlanNode::GroupBy { .. }),
1088                    "expected GroupBy under Project, got {:?}",
1089                    input
1090                );
1091            }
1092            other => panic!("expected Project, got {other:?}"),
1093        }
1094    }
1095
1096    #[test]
1097    fn test_plan_explain_wraps_inner() {
1098        let plan = plan("explain User filter .age > 30").unwrap();
1099        match plan {
1100            PlanNode::Explain { input } => {
1101                assert!(
1102                    matches!(*input, PlanNode::RangeScan { .. }),
1103                    "expected Explain(RangeScan), got {:?}",
1104                    input
1105                );
1106            }
1107            other => panic!("expected Explain, got {other:?}"),
1108        }
1109    }
1110
1111    #[test]
1112    fn test_plan_explain_simple_scan() {
1113        let plan = plan("explain User").unwrap();
1114        match plan {
1115            PlanNode::Explain { input } => {
1116                assert!(matches!(*input, PlanNode::SeqScan { .. }));
1117            }
1118            other => panic!("expected Explain(SeqScan), got {other:?}"),
1119        }
1120    }
1121
1122    #[test]
1123    fn test_plan_explain_join() {
1124        let plan = plan("explain User as u join Order as o on u.id = o.user_id").unwrap();
1125        match plan {
1126            PlanNode::Explain { input } => {
1127                assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
1128            }
1129            other => panic!("expected Explain(NestedLoopJoin), got {other:?}"),
1130        }
1131    }
1132}