Skip to main content

powdb_query/
planner.rs

1use crate::ast::*;
2use crate::parser::{parse, ParseError};
3use crate::plan::*;
4
5/// (column_name, lower_bound, upper_bound) — used by range-index extraction.
6type RangeBound = (String, Option<(Expr, bool)>, Option<(Expr, bool)>);
7
8/// Plan-phase error — wraps ParseError for the full lex→parse→plan chain.
9#[derive(Debug)]
10pub enum PlanError {
11    /// Error originated in the parser (or lexer, via ParseError::Lex).
12    Parse(ParseError),
13}
14
15impl PlanError {
16    /// Convenience: human-readable message for any variant.
17    pub fn message(&self) -> String {
18        self.to_string()
19    }
20}
21
22impl std::fmt::Display for PlanError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Parse(e) => write!(f, "{e}"),
26        }
27    }
28}
29
30impl std::error::Error for PlanError {
31    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32        match self {
33            Self::Parse(e) => Some(e),
34        }
35    }
36}
37
38impl From<ParseError> for PlanError {
39    fn from(e: ParseError) -> Self {
40        PlanError::Parse(e)
41    }
42}
43
44pub fn plan(input: &str) -> Result<PlanNode, PlanError> {
45    let stmt = parse(input)?;
46    plan_statement(stmt)
47}
48
49pub fn plan_statement(stmt: Statement) -> Result<PlanNode, PlanError> {
50    match stmt {
51        Statement::Query(q) => plan_query(q),
52        Statement::Insert(ins) => plan_insert(ins),
53        Statement::UpdateQuery(upd) => plan_update(upd),
54        Statement::DeleteQuery(del) => plan_delete(del),
55        Statement::CreateType(ct) => plan_create_type(ct),
56        Statement::AlterTable(at) => Ok(PlanNode::AlterTable {
57            table: at.table,
58            action: at.action,
59        }),
60        Statement::DropTable(dt) => Ok(PlanNode::DropTable { name: dt.table }),
61        Statement::CreateView(cv) => Ok(PlanNode::CreateView {
62            name: cv.name,
63            query_text: cv.query_text,
64        }),
65        Statement::RefreshView(rv) => Ok(PlanNode::RefreshView { name: rv.name }),
66        Statement::DropView(dv) => Ok(PlanNode::DropView { name: dv.name }),
67        Statement::Union(u) => {
68            let left = plan_statement(*u.left)?;
69            let right = plan_statement(*u.right)?;
70            Ok(PlanNode::Union {
71                left: Box::new(left),
72                right: Box::new(right),
73                all: u.all,
74            })
75        }
76        Statement::Upsert(ups) => plan_upsert(ups),
77        Statement::Begin => Ok(PlanNode::Begin),
78        Statement::Commit => Ok(PlanNode::Commit),
79        Statement::Rollback => Ok(PlanNode::Rollback),
80        Statement::Explain(inner) => {
81            let inner_plan = plan_statement(*inner)?;
82            Ok(PlanNode::Explain {
83                input: Box::new(inner_plan),
84            })
85        }
86    }
87}
88
89fn plan_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
90    // Mission E1.2: if the query has joins, build a left-deep nested-loop
91    // plan. Correctness first — hash-join optimization is E1.3. We also
92    // don't try to fold an IndexScan under a joined query yet (the
93    // leaf-level fast paths all match on `PlanNode::SeqScan { .. }`
94    // literally, so mixing them into a join plan would silently break).
95    if !q.joins.is_empty() {
96        return plan_joined_query(q);
97    }
98    // Try to fold `filter .col = literal` into an IndexScan. The executor
99    // decides at run time whether the column actually has an index — if not,
100    // it transparently falls back to a sequential scan with the same predicate,
101    // so this rewrite is always safe.
102    //
103    // We only rewrite the *simple* eq case: `filter .col = literal`. Conjunctions
104    // like `filter .col = 1 and .other > 5` fall through to SeqScan + Filter.
105    // Extending this to split conjunctions is a future optimization.
106    let (source, filter) = match q.filter {
107        Some(pred) => match try_extract_eq_index_key(&q.source, &pred) {
108            Some(index_scan) => (index_scan, None),
109            None => match try_extract_range_index_keys(&q.source, &pred) {
110                Some(range_scan) => (range_scan, None),
111                None => (
112                    PlanNode::SeqScan {
113                        table: q.source.clone(),
114                    },
115                    Some(pred),
116                ),
117            },
118        },
119        None => (
120            PlanNode::SeqScan {
121                table: q.source.clone(),
122            },
123            None,
124        ),
125    };
126    let mut node = source;
127
128    if let Some(pred) = filter {
129        node = PlanNode::Filter {
130            input: Box::new(node),
131            predicate: pred,
132        };
133    }
134
135    // Mission E2b: GROUP BY path — insert GroupBy + Project before
136    // order/limit/offset/distinct.
137    if let Some(group) = q.group_by {
138        let mut proj_fields: Vec<ProjectField> = q
139            .projection
140            .map(|proj| {
141                proj.into_iter()
142                    .map(|pf| ProjectField {
143                        alias: pf.alias,
144                        expr: pf.expr,
145                    })
146                    .collect()
147            })
148            .unwrap_or_default();
149        let mut having = group.having;
150        let aggregates = extract_aggregates(&mut proj_fields, &mut having);
151
152        node = PlanNode::GroupBy {
153            input: Box::new(node),
154            keys: group.keys,
155            aggregates,
156            having,
157        };
158
159        if !proj_fields.is_empty() {
160            node = PlanNode::Project {
161                input: Box::new(node),
162                fields: proj_fields,
163            };
164        }
165
166        if let Some(order) = q.order {
167            node = PlanNode::Sort {
168                input: Box::new(node),
169                keys: order
170                    .keys
171                    .into_iter()
172                    .map(|k| SortKey {
173                        field: k.field,
174                        descending: k.descending,
175                    })
176                    .collect(),
177            };
178        }
179        // Offset must be applied *before* Limit: skip M rows, then take N.
180        // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
181        // and Limit wraps it (outer).
182        if let Some(off) = q.offset {
183            node = PlanNode::Offset {
184                input: Box::new(node),
185                count: off,
186            };
187        }
188        if let Some(lim) = q.limit {
189            node = PlanNode::Limit {
190                input: Box::new(node),
191                count: lim,
192            };
193        }
194        if q.distinct {
195            node = PlanNode::Distinct {
196                input: Box::new(node),
197            };
198        }
199        return Ok(node);
200    }
201
202    if let Some(order) = q.order {
203        node = PlanNode::Sort {
204            input: Box::new(node),
205            keys: order
206                .keys
207                .into_iter()
208                .map(|k| SortKey {
209                    field: k.field,
210                    descending: k.descending,
211                })
212                .collect(),
213        };
214    }
215
216    // Offset must be applied *before* Limit: skip M rows, then take N.
217    // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
218    // and Limit wraps it (outer).
219    if let Some(off) = q.offset {
220        node = PlanNode::Offset {
221            input: Box::new(node),
222            count: off,
223        };
224    }
225
226    if let Some(lim) = q.limit {
227        node = PlanNode::Limit {
228            input: Box::new(node),
229            count: lim,
230        };
231    }
232
233    if let Some(proj) = q.projection {
234        let mut fields: Vec<ProjectField> = proj
235            .into_iter()
236            .map(|pf| ProjectField {
237                alias: pf.alias,
238                expr: pf.expr,
239            })
240            .collect();
241        let windows = extract_windows(&mut fields);
242        if !windows.is_empty() {
243            node = PlanNode::Window {
244                input: Box::new(node),
245                windows,
246            };
247        }
248        node = PlanNode::Project {
249            input: Box::new(node),
250            fields,
251        };
252    }
253
254    if q.distinct {
255        node = PlanNode::Distinct {
256            input: Box::new(node),
257        };
258    }
259
260    if let Some(agg) = q.aggregation {
261        node = PlanNode::Aggregate {
262            input: Box::new(node),
263            function: agg.function,
264            field: agg.field,
265        };
266    }
267
268    Ok(node)
269}
270
271/// Build a left-deep nested-loop join plan for a query with 1+ join clauses.
272///
273/// The plan shape for `T1 as a [inner|left|cross] join T2 as b on <pred> ...` is:
274///
275///   Project? (optional, from q.projection)
276///   └─ Offset? / Limit? / Sort?
277///      └─ Filter? (the top-level q.filter, using qualified columns)
278///         └─ NestedLoopJoin { kind, on }
279///            ├─ AliasScan { T1, a }
280///            └─ AliasScan { T2, b }
281///
282/// Multi-join chains extend left-deep: a third join adds a second
283/// `NestedLoopJoin` on top, with the first join's output as its `left`.
284///
285/// Aliases default to the source table name when the query didn't write
286/// `as <name>` explicitly — that way users can always write `T.field`
287/// without being forced to alias every source.
288///
289/// RightOuter is rewritten into LeftOuter with inputs swapped — the two
290/// differ only in which side survives non-matching rows, and swapping
291/// inputs lets the executor ship a single LeftOuter path.
292fn plan_joined_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
293    let primary_alias = q.alias.clone().unwrap_or_else(|| q.source.clone());
294    let mut node = PlanNode::AliasScan {
295        table: q.source.clone(),
296        alias: primary_alias,
297    };
298
299    for join in q.joins {
300        let right_alias = join.alias.unwrap_or_else(|| join.source.clone());
301        let right = PlanNode::AliasScan {
302            table: join.source,
303            alias: right_alias,
304        };
305        match join.kind {
306            JoinKind::Inner | JoinKind::LeftOuter | JoinKind::Cross => {
307                node = PlanNode::NestedLoopJoin {
308                    left: Box::new(node),
309                    right: Box::new(right),
310                    on: join.on,
311                    kind: join.kind,
312                };
313            }
314            JoinKind::RightOuter => {
315                // `a RIGHT OUTER JOIN b ON <p>` ≡ `b LEFT OUTER JOIN a ON <p>`.
316                node = PlanNode::NestedLoopJoin {
317                    left: Box::new(right),
318                    right: Box::new(node),
319                    on: join.on,
320                    kind: JoinKind::LeftOuter,
321                };
322            }
323        }
324    }
325
326    if let Some(pred) = q.filter {
327        node = PlanNode::Filter {
328            input: Box::new(node),
329            predicate: pred,
330        };
331    }
332
333    if let Some(order) = q.order {
334        node = PlanNode::Sort {
335            input: Box::new(node),
336            keys: order
337                .keys
338                .into_iter()
339                .map(|k| SortKey {
340                    field: k.field,
341                    descending: k.descending,
342                })
343                .collect(),
344        };
345    }
346
347    // Offset must be applied *before* Limit: skip M rows, then take N.
348    // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
349    // and Limit wraps it (outer).
350    if let Some(off) = q.offset {
351        node = PlanNode::Offset {
352            input: Box::new(node),
353            count: off,
354        };
355    }
356
357    if let Some(lim) = q.limit {
358        node = PlanNode::Limit {
359            input: Box::new(node),
360            count: lim,
361        };
362    }
363
364    // Mission E2b: GROUP BY path for joined queries.
365    if let Some(group) = q.group_by {
366        let mut proj_fields: Vec<ProjectField> = q
367            .projection
368            .map(|proj| {
369                proj.into_iter()
370                    .map(|pf| ProjectField {
371                        alias: pf.alias,
372                        expr: pf.expr,
373                    })
374                    .collect()
375            })
376            .unwrap_or_default();
377        let mut having = group.having;
378        let aggregates = extract_aggregates(&mut proj_fields, &mut having);
379
380        node = PlanNode::GroupBy {
381            input: Box::new(node),
382            keys: group.keys,
383            aggregates,
384            having,
385        };
386
387        if !proj_fields.is_empty() {
388            node = PlanNode::Project {
389                input: Box::new(node),
390                fields: proj_fields,
391            };
392        }
393        if q.distinct {
394            node = PlanNode::Distinct {
395                input: Box::new(node),
396            };
397        }
398        return Ok(node);
399    }
400
401    if let Some(proj) = q.projection {
402        let mut fields: Vec<ProjectField> = proj
403            .into_iter()
404            .map(|pf| ProjectField {
405                alias: pf.alias,
406                expr: pf.expr,
407            })
408            .collect();
409        let windows = extract_windows(&mut fields);
410        if !windows.is_empty() {
411            node = PlanNode::Window {
412                input: Box::new(node),
413                windows,
414            };
415        }
416        node = PlanNode::Project {
417            input: Box::new(node),
418            fields,
419        };
420    }
421
422    if q.distinct {
423        node = PlanNode::Distinct {
424            input: Box::new(node),
425        };
426    }
427
428    if let Some(agg) = q.aggregation {
429        node = PlanNode::Aggregate {
430            input: Box::new(node),
431            function: agg.function,
432            field: agg.field,
433        };
434    }
435
436    Ok(node)
437}
438
439fn plan_insert(ins: InsertExpr) -> Result<PlanNode, PlanError> {
440    Ok(PlanNode::Insert {
441        table: ins.target,
442        assignments: ins.assignments,
443    })
444}
445
446fn plan_update(upd: UpdateExpr) -> Result<PlanNode, PlanError> {
447    // Mirror the read-side IndexScan fold: when the update filter is a simple
448    // `.col = literal`, emit `Update(IndexScan)` so the executor's index-lookup
449    // mutation fast path fires. The executor falls back to a scan if the
450    // column happens to lack an index, so this is always safe.
451    let source = match upd.filter {
452        Some(pred) => match try_extract_eq_index_key(&upd.source, &pred) {
453            Some(index_scan) => index_scan,
454            None => match try_extract_range_index_keys(&upd.source, &pred) {
455                Some(range_scan) => range_scan,
456                None => PlanNode::Filter {
457                    input: Box::new(PlanNode::SeqScan {
458                        table: upd.source.clone(),
459                    }),
460                    predicate: pred,
461                },
462            },
463        },
464        None => PlanNode::SeqScan {
465            table: upd.source.clone(),
466        },
467    };
468    Ok(PlanNode::Update {
469        input: Box::new(source),
470        table: upd.source,
471        assignments: upd.assignments,
472    })
473}
474
475fn plan_delete(del: DeleteExpr) -> Result<PlanNode, PlanError> {
476    let source = match del.filter {
477        Some(pred) => match try_extract_eq_index_key(&del.source, &pred) {
478            Some(index_scan) => index_scan,
479            None => match try_extract_range_index_keys(&del.source, &pred) {
480                Some(range_scan) => range_scan,
481                None => PlanNode::Filter {
482                    input: Box::new(PlanNode::SeqScan {
483                        table: del.source.clone(),
484                    }),
485                    predicate: pred,
486                },
487            },
488        },
489        None => PlanNode::SeqScan {
490            table: del.source.clone(),
491        },
492    };
493    Ok(PlanNode::Delete {
494        input: Box::new(source),
495        table: del.source,
496    })
497}
498
499fn plan_upsert(ups: UpsertExpr) -> Result<PlanNode, PlanError> {
500    Ok(PlanNode::Upsert {
501        table: ups.target,
502        key_column: ups.key_column,
503        assignments: ups.assignments,
504        on_conflict: ups.on_conflict,
505    })
506}
507
508fn plan_create_type(ct: CreateTypeExpr) -> Result<PlanNode, PlanError> {
509    let fields = ct
510        .fields
511        .into_iter()
512        .map(|f| (f.name, f.type_name, f.required))
513        .collect();
514    Ok(PlanNode::CreateTable {
515        name: ct.name,
516        fields,
517    })
518}
519
520/// If the predicate is a simple `.field = literal` (or `literal = .field`),
521/// return a corresponding IndexScan plan node. Otherwise return None so the
522/// caller can fall through to SeqScan + Filter.
523///
524/// The executor decides at run time whether the named column actually has a
525/// B-tree index — if not, IndexScan transparently falls back to a scan +
526/// equality filter on that column. That means this rewrite is always safe
527/// regardless of schema/index state; it just unlocks the fast path when an
528/// index happens to exist.
529fn try_extract_eq_index_key(table: &str, pred: &Expr) -> Option<PlanNode> {
530    let (lhs, op, rhs) = match pred {
531        Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
532        _ => return None,
533    };
534    if op != BinOp::Eq {
535        return None;
536    }
537    let (column, key) = match (lhs, rhs) {
538        (Expr::Field(name), Expr::Literal(_)) => (name.clone(), rhs.clone()),
539        (Expr::Literal(_), Expr::Field(name)) => (name.clone(), lhs.clone()),
540        _ => return None,
541    };
542    Some(PlanNode::IndexScan {
543        table: table.to_string(),
544        column,
545        key,
546    })
547}
548
549/// Extract a single range bound from a simple inequality predicate.
550/// Returns `(column, lower_bound, upper_bound)` where at most one bound is set.
551fn extract_single_bound(pred: &Expr) -> Option<RangeBound> {
552    let (lhs, op, rhs) = match pred {
553        Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
554        _ => return None,
555    };
556    match op {
557        // .col > literal  →  lower=(literal, exclusive)
558        BinOp::Gt => match (lhs, rhs) {
559            (Expr::Field(name), Expr::Literal(_)) => {
560                Some((name.clone(), Some((rhs.clone(), false)), None))
561            }
562            (Expr::Literal(_), Expr::Field(name)) => {
563                // literal > .col  →  col < literal  →  upper=(literal, exclusive)
564                Some((name.clone(), None, Some((lhs.clone(), false))))
565            }
566            _ => None,
567        },
568        // .col >= literal  →  lower=(literal, inclusive)
569        BinOp::Gte => match (lhs, rhs) {
570            (Expr::Field(name), Expr::Literal(_)) => {
571                Some((name.clone(), Some((rhs.clone(), true)), None))
572            }
573            (Expr::Literal(_), Expr::Field(name)) => {
574                Some((name.clone(), None, Some((lhs.clone(), true))))
575            }
576            _ => None,
577        },
578        // .col < literal  →  upper=(literal, exclusive)
579        BinOp::Lt => match (lhs, rhs) {
580            (Expr::Field(name), Expr::Literal(_)) => {
581                Some((name.clone(), None, Some((rhs.clone(), false))))
582            }
583            (Expr::Literal(_), Expr::Field(name)) => {
584                Some((name.clone(), Some((lhs.clone(), false)), None))
585            }
586            _ => None,
587        },
588        // .col <= literal  →  upper=(literal, inclusive)
589        BinOp::Lte => match (lhs, rhs) {
590            (Expr::Field(name), Expr::Literal(_)) => {
591                Some((name.clone(), None, Some((rhs.clone(), true))))
592            }
593            (Expr::Literal(_), Expr::Field(name)) => {
594                Some((name.clone(), Some((lhs.clone(), true)), None))
595            }
596            _ => None,
597        },
598        _ => None,
599    }
600}
601
602/// If the predicate is an inequality or a conjunction of two inequalities
603/// on the same indexed column, return a RangeScan plan node.
604/// Handles: `.col > lit`, `.col >= lit`, `.col < lit`, `.col <= lit`,
605/// and AND-conjunctions like `.col >= low AND .col <= high` (BETWEEN pattern).
606fn try_extract_range_index_keys(table: &str, pred: &Expr) -> Option<PlanNode> {
607    // Case 1: AND conjunction — try to merge two bounds on the same column.
608    if let Expr::BinaryOp(lhs, BinOp::And, rhs) = pred {
609        if let (Some((col1, s1, e1)), Some((col2, s2, e2))) =
610            (extract_single_bound(lhs), extract_single_bound(rhs))
611        {
612            if col1 == col2 {
613                let start = s1.or(s2);
614                let end = e1.or(e2);
615                if start.is_some() || end.is_some() {
616                    return Some(PlanNode::RangeScan {
617                        table: table.to_string(),
618                        column: col1,
619                        start,
620                        end,
621                    });
622                }
623            }
624        }
625    }
626
627    // Case 2: single inequality.
628    if let Some((col, start, end)) = extract_single_bound(pred) {
629        return Some(PlanNode::RangeScan {
630            table: table.to_string(),
631            column: col,
632            start,
633            end,
634        });
635    }
636
637    None
638}
639
640/// Walk projection fields, replacing every `Expr::Window { .. }` with
641/// `Expr::Field("__win_N")` and collecting the corresponding `WindowDef`
642/// descriptors. Returns the list of window definitions to insert as a
643/// `PlanNode::Window` before the `Project` node.
644fn extract_windows(proj_fields: &mut [ProjectField]) -> Vec<WindowDef> {
645    let mut defs = Vec::new();
646    let mut counter = 0usize;
647    for f in proj_fields.iter_mut() {
648        if let Expr::Window {
649            function,
650            args,
651            partition_by,
652            order_by,
653        } = &f.expr
654        {
655            let output_name = format!("__win_{counter}");
656            defs.push(WindowDef {
657                function: *function,
658                args: args.clone(),
659                partition_by: partition_by.clone(),
660                order_by: order_by
661                    .iter()
662                    .map(|k| SortKey {
663                        field: k.field.clone(),
664                        descending: k.descending,
665                    })
666                    .collect(),
667                output_name: output_name.clone(),
668            });
669            f.expr = Expr::Field(output_name);
670            counter += 1;
671        }
672    }
673    defs
674}
675
676/// Walk projection fields and HAVING expression, replacing every
677/// `Expr::FunctionCall(func, Field(col))` with `Expr::Field("__agg_N")`
678/// and collecting the corresponding `GroupAgg` descriptors. Deduplicates:
679/// if the same (func, field) pair appears in both projection and HAVING,
680/// they share a single `GroupAgg` entry.
681fn extract_aggregates(
682    proj_fields: &mut [ProjectField],
683    having: &mut Option<Expr>,
684) -> Vec<GroupAgg> {
685    let mut aggs: Vec<GroupAgg> = Vec::new();
686    let mut counter = 0usize;
687    for f in proj_fields.iter_mut() {
688        rewrite_agg_expr(&mut f.expr, &mut aggs, &mut counter);
689    }
690    if let Some(h) = having {
691        rewrite_agg_expr(h, &mut aggs, &mut counter);
692    }
693    aggs
694}
695
696fn rewrite_agg_expr(expr: &mut Expr, aggs: &mut Vec<GroupAgg>, counter: &mut usize) {
697    match expr {
698        Expr::FunctionCall(func, inner) => {
699            if let Expr::Field(name) = inner.as_ref() {
700                let output = find_or_insert_agg(aggs, *func, name, counter);
701                *expr = Expr::Field(output);
702            }
703        }
704        Expr::BinaryOp(l, _, r) => {
705            rewrite_agg_expr(l, aggs, counter);
706            rewrite_agg_expr(r, aggs, counter);
707        }
708        Expr::UnaryOp(_, inner) => rewrite_agg_expr(inner, aggs, counter),
709        Expr::Coalesce(l, r) => {
710            rewrite_agg_expr(l, aggs, counter);
711            rewrite_agg_expr(r, aggs, counter);
712        }
713        Expr::InList { expr: e, list, .. } => {
714            rewrite_agg_expr(e, aggs, counter);
715            for item in list {
716                rewrite_agg_expr(item, aggs, counter);
717            }
718        }
719        Expr::InSubquery { expr: e, .. } => {
720            rewrite_agg_expr(e, aggs, counter);
721        }
722        _ => {}
723    }
724}
725
726fn find_or_insert_agg(
727    aggs: &mut Vec<GroupAgg>,
728    func: AggFunc,
729    field: &str,
730    counter: &mut usize,
731) -> String {
732    for existing in aggs.iter() {
733        if existing.function == func && existing.field == field {
734            return existing.output_name.clone();
735        }
736    }
737    let output_name = format!("__agg_{counter}");
738    aggs.push(GroupAgg {
739        function: func,
740        field: field.to_string(),
741        output_name: output_name.clone(),
742    });
743    *counter += 1;
744    output_name
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750    use crate::plan::PlanNode;
751
752    #[test]
753    fn test_plan_simple_scan() {
754        let plan = plan("User").unwrap();
755        assert!(matches!(plan, PlanNode::SeqScan { table } if table == "User"));
756    }
757
758    #[test]
759    fn test_plan_filter() {
760        let plan = plan("User filter .age > 30").unwrap();
761        assert!(matches!(plan, PlanNode::RangeScan { .. }));
762    }
763
764    #[test]
765    fn test_plan_filter_with_projection() {
766        let plan = plan("User filter .age > 30 { name, email }").unwrap();
767        assert!(matches!(plan, PlanNode::Project { .. }));
768    }
769
770    #[test]
771    fn test_plan_insert() {
772        let plan = plan(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
773        assert!(matches!(plan, PlanNode::Insert { .. }));
774    }
775
776    #[test]
777    fn test_plan_order_limit() {
778        let plan = plan("User order .name limit 10").unwrap();
779        match plan {
780            PlanNode::Limit { input, .. } => {
781                assert!(matches!(*input, PlanNode::Sort { .. }));
782            }
783            _ => panic!("expected Limit(Sort(SeqScan))"),
784        }
785    }
786
787    #[test]
788    fn test_plan_count() {
789        let plan = plan("count(User)").unwrap();
790        assert!(matches!(plan, PlanNode::Aggregate { .. }));
791    }
792
793    #[test]
794    fn test_plan_eq_becomes_index_scan() {
795        // `filter .col = literal` should fold into an IndexScan — the executor
796        // falls back to a scan if the column happens to lack an index.
797        let plan = plan("User filter .id = 42").unwrap();
798        match plan {
799            PlanNode::IndexScan { table, column, key } => {
800                assert_eq!(table, "User");
801                assert_eq!(column, "id");
802                assert!(matches!(key, Expr::Literal(Literal::Int(42))));
803            }
804            other => panic!("expected IndexScan, got {other:?}"),
805        }
806    }
807
808    #[test]
809    fn test_plan_eq_reversed_becomes_index_scan() {
810        // Literal-on-the-left form should fold the same way.
811        let plan = plan(r#"User filter "NYC" = .city"#).unwrap();
812        assert!(matches!(plan, PlanNode::IndexScan { .. }));
813    }
814
815    #[test]
816    fn test_plan_non_eq_stays_filter() {
817        // `>` now emits a RangeScan instead of SeqScan+Filter.
818        let plan = plan("User filter .age > 30").unwrap();
819        match plan {
820            PlanNode::RangeScan {
821                column, start, end, ..
822            } => {
823                assert_eq!(column, "age");
824                assert!(start.is_some(), "expected lower bound");
825                assert!(end.is_none(), "expected no upper bound");
826                let (_, inclusive) = start.unwrap();
827                assert!(!inclusive, "expected exclusive lower bound for >");
828            }
829            other => panic!("expected RangeScan, got {other:?}"),
830        }
831    }
832
833    #[test]
834    fn test_plan_index_scan_with_projection() {
835        // Projection on top of an IndexScan should layer correctly.
836        let plan = plan("User filter .id = 1 { .name }").unwrap();
837        match plan {
838            PlanNode::Project { input, .. } => {
839                assert!(matches!(*input, PlanNode::IndexScan { .. }));
840            }
841            other => panic!("expected Project(IndexScan), got {other:?}"),
842        }
843    }
844
845    #[test]
846    fn test_plan_update_by_pk_becomes_index_scan() {
847        // `.id = literal` update should fold to Update(IndexScan), not
848        // Update(Filter(SeqScan)).
849        let plan = plan("User filter .id = 42 update { age := 31 }").unwrap();
850        match plan {
851            PlanNode::Update { input, .. } => {
852                assert!(
853                    matches!(*input, PlanNode::IndexScan { .. }),
854                    "expected Update(IndexScan), got {input:?}"
855                );
856            }
857            other => panic!("expected Update, got {other:?}"),
858        }
859    }
860
861    #[test]
862    fn test_plan_update_range_stays_range_scan() {
863        let plan = plan("User filter .age > 30 update { age := 31 }").unwrap();
864        match plan {
865            PlanNode::Update { input, .. } => {
866                assert!(
867                    matches!(*input, PlanNode::RangeScan { .. }),
868                    "expected Update(RangeScan), got {input:?}"
869                );
870            }
871            other => panic!("expected Update, got {other:?}"),
872        }
873    }
874
875    #[test]
876    fn test_plan_delete_by_pk_becomes_index_scan() {
877        let plan = plan("User filter .id = 7 delete").unwrap();
878        match plan {
879            PlanNode::Delete { input, .. } => {
880                assert!(matches!(*input, PlanNode::IndexScan { .. }));
881            }
882            other => panic!("expected Delete, got {other:?}"),
883        }
884    }
885
886    #[test]
887    fn test_plan_inner_join_builds_nested_loop() {
888        // Mission E1.2: a join query should plan to NestedLoopJoin with
889        // AliasScan leaves on both sides.
890        let plan = plan("User as u join Order as o on u.id = o.user_id").unwrap();
891        match plan {
892            PlanNode::NestedLoopJoin {
893                left,
894                right,
895                on,
896                kind,
897            } => {
898                assert_eq!(kind, JoinKind::Inner);
899                assert!(on.is_some());
900                assert!(matches!(*left, PlanNode::AliasScan { .. }));
901                assert!(matches!(*right, PlanNode::AliasScan { .. }));
902            }
903            other => panic!("expected NestedLoopJoin, got {other:?}"),
904        }
905    }
906
907    #[test]
908    fn test_plan_right_join_rewritten_as_left_with_swapped_inputs() {
909        let plan = plan("User as u right join Order as o on u.id = o.user_id").unwrap();
910        match plan {
911            PlanNode::NestedLoopJoin {
912                left, right, kind, ..
913            } => {
914                assert_eq!(kind, JoinKind::LeftOuter);
915                // Swapped: Order is now on the left, User on the right.
916                match *left {
917                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "Order"),
918                    other => panic!("expected AliasScan(Order), got {other:?}"),
919                }
920                match *right {
921                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "User"),
922                    other => panic!("expected AliasScan(User), got {other:?}"),
923                }
924            }
925            other => panic!("expected NestedLoopJoin, got {other:?}"),
926        }
927    }
928
929    #[test]
930    fn test_plan_multi_join_is_left_deep() {
931        // Three sources → two NestedLoopJoins, left-deep.
932        let plan = plan(
933            "User as u join Order as o on u.id = o.user_id \
934             join Product as p on o.product_id = p.id",
935        )
936        .unwrap();
937        match plan {
938            PlanNode::NestedLoopJoin { left, right, .. } => {
939                // Outer (Product) join: right is AliasScan(Product)
940                match *right {
941                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "Product"),
942                    other => panic!("expected AliasScan(Product), got {other:?}"),
943                }
944                // Outer.left is inner (Order) NestedLoopJoin
945                assert!(matches!(*left, PlanNode::NestedLoopJoin { .. }));
946            }
947            other => panic!("expected NestedLoopJoin, got {other:?}"),
948        }
949    }
950
951    #[test]
952    fn test_plan_join_with_filter_tail_wraps_filter_on_top() {
953        let plan =
954            plan("User as u join Order as o on u.id = o.user_id filter o.total > 100").unwrap();
955        match plan {
956            PlanNode::Filter { input, .. } => {
957                assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
958            }
959            other => panic!("expected Filter(NestedLoopJoin), got {other:?}"),
960        }
961    }
962
963    #[test]
964    fn test_plan_group_by_builds_groupby_node() {
965        let plan = plan("User group .status { .status, n: count(.name) }").unwrap();
966        // Should be Project(GroupBy(SeqScan)).
967        match plan {
968            PlanNode::Project { input, fields } => {
969                assert_eq!(fields.len(), 2);
970                match *input {
971                    PlanNode::GroupBy {
972                        input: inner,
973                        keys,
974                        aggregates,
975                        having,
976                    } => {
977                        assert!(matches!(*inner, PlanNode::SeqScan { .. }));
978                        assert_eq!(keys, vec!["status"]);
979                        assert_eq!(aggregates.len(), 1);
980                        assert_eq!(aggregates[0].function, AggFunc::Count);
981                        assert_eq!(aggregates[0].field, "name");
982                        assert!(having.is_none());
983                    }
984                    other => panic!("expected GroupBy, got {other:?}"),
985                }
986            }
987            other => panic!("expected Project, got {other:?}"),
988        }
989    }
990
991    #[test]
992    fn test_plan_group_by_having_rewrites_agg_in_having() {
993        let plan = plan("User group .status having count(.name) > 1 { .status }").unwrap();
994        match plan {
995            PlanNode::Project { input, .. } => {
996                match *input {
997                    PlanNode::GroupBy {
998                        having, aggregates, ..
999                    } => {
1000                        // The planner should have extracted count(.name) into
1001                        // aggregates and rewritten the HAVING to reference __agg_0.
1002                        assert_eq!(aggregates.len(), 1);
1003                        assert_eq!(aggregates[0].output_name, "__agg_0");
1004                        let h = having.expect("having should be Some");
1005                        match h {
1006                            Expr::BinaryOp(l, BinOp::Gt, _) => {
1007                                assert!(
1008                                    matches!(*l, Expr::Field(ref name) if name == "__agg_0"),
1009                                    "expected Field(__agg_0), got {l:?}"
1010                                );
1011                            }
1012                            other => panic!("expected BinaryOp, got {other:?}"),
1013                        }
1014                    }
1015                    other => panic!("expected GroupBy, got {other:?}"),
1016                }
1017            }
1018            other => panic!("expected Project, got {other:?}"),
1019        }
1020    }
1021
1022    #[test]
1023    fn test_plan_window_inserts_window_node_before_project() {
1024        let plan = plan("User { .name, rn: row_number() over (order .age) }").unwrap();
1025        // Expected shape: Project(Window(SeqScan))
1026        match plan {
1027            PlanNode::Project { input, fields } => {
1028                assert_eq!(fields.len(), 2);
1029                // The window expr should have been replaced with Field("__win_0")
1030                assert!(
1031                    matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"),
1032                    "expected Field(__win_0), got {:?}",
1033                    fields[1].expr
1034                );
1035                match *input {
1036                    PlanNode::Window {
1037                        input: inner,
1038                        windows,
1039                    } => {
1040                        assert_eq!(windows.len(), 1);
1041                        assert_eq!(windows[0].output_name, "__win_0");
1042                        assert!(matches!(*inner, PlanNode::SeqScan { .. }));
1043                    }
1044                    other => panic!("expected Window, got {other:?}"),
1045                }
1046            }
1047            other => panic!("expected Project, got {other:?}"),
1048        }
1049    }
1050
1051    #[test]
1052    fn test_plan_multiple_windows() {
1053        let plan = plan(
1054            "User { .name, rn: row_number() over (order .age), s: sum(.salary) over (partition .dept order .salary) }"
1055        ).unwrap();
1056        match plan {
1057            PlanNode::Project { input, fields } => {
1058                assert_eq!(fields.len(), 3);
1059                assert!(matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"));
1060                assert!(matches!(&fields[2].expr, Expr::Field(name) if name == "__win_1"));
1061                match *input {
1062                    PlanNode::Window { windows, .. } => {
1063                        assert_eq!(windows.len(), 2);
1064                        assert_eq!(windows[0].output_name, "__win_0");
1065                        assert_eq!(windows[1].output_name, "__win_1");
1066                    }
1067                    other => panic!("expected Window, got {other:?}"),
1068                }
1069            }
1070            other => panic!("expected Project, got {other:?}"),
1071        }
1072    }
1073
1074    #[test]
1075    fn test_plan_no_window_without_over() {
1076        // Plain aggregate in projection should not create a Window node.
1077        let plan = plan("User group .dept { .dept, total: sum(.salary) }").unwrap();
1078        match plan {
1079            PlanNode::Project { input, .. } => {
1080                // Input should be GroupBy, not Window.
1081                assert!(
1082                    matches!(*input, PlanNode::GroupBy { .. }),
1083                    "expected GroupBy under Project, got {:?}",
1084                    input
1085                );
1086            }
1087            other => panic!("expected Project, got {other:?}"),
1088        }
1089    }
1090
1091    #[test]
1092    fn test_plan_explain_wraps_inner() {
1093        let plan = plan("explain User filter .age > 30").unwrap();
1094        match plan {
1095            PlanNode::Explain { input } => {
1096                assert!(
1097                    matches!(*input, PlanNode::RangeScan { .. }),
1098                    "expected Explain(RangeScan), got {:?}",
1099                    input
1100                );
1101            }
1102            other => panic!("expected Explain, got {other:?}"),
1103        }
1104    }
1105
1106    #[test]
1107    fn test_plan_explain_simple_scan() {
1108        let plan = plan("explain User").unwrap();
1109        match plan {
1110            PlanNode::Explain { input } => {
1111                assert!(matches!(*input, PlanNode::SeqScan { .. }));
1112            }
1113            other => panic!("expected Explain(SeqScan), got {other:?}"),
1114        }
1115    }
1116
1117    #[test]
1118    fn test_plan_explain_join() {
1119        let plan = plan("explain User as u join Order as o on u.id = o.user_id").unwrap();
1120        match plan {
1121            PlanNode::Explain { input } => {
1122                assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
1123            }
1124            other => panic!("expected Explain(NestedLoopJoin), got {other:?}"),
1125        }
1126    }
1127}