Skip to main content

powdb_query/
planner.rs

1use crate::ast::*;
2use crate::parser::{parse, ParseError};
3use crate::plan::*;
4
5/// (column_name, lower_bound, upper_bound) — used by range-index extraction.
6type RangeBound = (String, Option<(Expr, bool)>, Option<(Expr, bool)>);
7
8/// Plan-phase error — wraps ParseError for the full lex→parse→plan chain.
9#[derive(Debug)]
10pub enum PlanError {
11    /// Error originated in the parser (or lexer, via ParseError::Lex).
12    Parse(ParseError),
13}
14
15impl PlanError {
16    /// Convenience: human-readable message for any variant.
17    pub fn message(&self) -> String {
18        self.to_string()
19    }
20}
21
22impl std::fmt::Display for PlanError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Parse(e) => write!(f, "{e}"),
26        }
27    }
28}
29
30impl std::error::Error for PlanError {
31    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32        match self {
33            Self::Parse(e) => Some(e),
34        }
35    }
36}
37
38impl From<ParseError> for PlanError {
39    fn from(e: ParseError) -> Self {
40        PlanError::Parse(e)
41    }
42}
43
44pub fn plan(input: &str) -> Result<PlanNode, PlanError> {
45    let stmt = parse(input)?;
46    plan_statement(stmt)
47}
48
49pub fn plan_statement(stmt: Statement) -> Result<PlanNode, PlanError> {
50    match stmt {
51        Statement::Query(q) => plan_query(q),
52        Statement::Insert(ins) => plan_insert(ins),
53        Statement::UpdateQuery(upd) => plan_update(upd),
54        Statement::DeleteQuery(del) => plan_delete(del),
55        Statement::CreateType(ct) => plan_create_type(ct),
56        Statement::AlterTable(at) => Ok(PlanNode::AlterTable {
57            table: at.table,
58            action: at.action,
59        }),
60        Statement::DropTable(dt) => Ok(PlanNode::DropTable { name: dt.table }),
61        Statement::CreateView(cv) => Ok(PlanNode::CreateView {
62            name: cv.name,
63            query_text: cv.query_text,
64        }),
65        Statement::RefreshView(rv) => Ok(PlanNode::RefreshView { name: rv.name }),
66        Statement::DropView(dv) => Ok(PlanNode::DropView { name: dv.name }),
67        Statement::Union(u) => {
68            let left = plan_statement(*u.left)?;
69            let right = plan_statement(*u.right)?;
70            Ok(PlanNode::Union {
71                left: Box::new(left),
72                right: Box::new(right),
73                all: u.all,
74            })
75        }
76        Statement::Upsert(ups) => plan_upsert(ups),
77        Statement::Explain(inner) => {
78            let inner_plan = plan_statement(*inner)?;
79            Ok(PlanNode::Explain {
80                input: Box::new(inner_plan),
81            })
82        }
83    }
84}
85
86fn plan_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
87    // Mission E1.2: if the query has joins, build a left-deep nested-loop
88    // plan. Correctness first — hash-join optimization is E1.3. We also
89    // don't try to fold an IndexScan under a joined query yet (the
90    // leaf-level fast paths all match on `PlanNode::SeqScan { .. }`
91    // literally, so mixing them into a join plan would silently break).
92    if !q.joins.is_empty() {
93        return plan_joined_query(q);
94    }
95    // Try to fold `filter .col = literal` into an IndexScan. The executor
96    // decides at run time whether the column actually has an index — if not,
97    // it transparently falls back to a sequential scan with the same predicate,
98    // so this rewrite is always safe.
99    //
100    // We only rewrite the *simple* eq case: `filter .col = literal`. Conjunctions
101    // like `filter .col = 1 and .other > 5` fall through to SeqScan + Filter.
102    // Extending this to split conjunctions is a future optimization.
103    let (source, filter) = match q.filter {
104        Some(pred) => match try_extract_eq_index_key(&q.source, &pred) {
105            Some(index_scan) => (index_scan, None),
106            None => match try_extract_range_index_keys(&q.source, &pred) {
107                Some(range_scan) => (range_scan, None),
108                None => (
109                    PlanNode::SeqScan {
110                        table: q.source.clone(),
111                    },
112                    Some(pred),
113                ),
114            },
115        },
116        None => (
117            PlanNode::SeqScan {
118                table: q.source.clone(),
119            },
120            None,
121        ),
122    };
123    let mut node = source;
124
125    if let Some(pred) = filter {
126        node = PlanNode::Filter {
127            input: Box::new(node),
128            predicate: pred,
129        };
130    }
131
132    // Mission E2b: GROUP BY path — insert GroupBy + Project before
133    // order/limit/offset/distinct.
134    if let Some(group) = q.group_by {
135        let mut proj_fields: Vec<ProjectField> = q
136            .projection
137            .map(|proj| {
138                proj.into_iter()
139                    .map(|pf| ProjectField {
140                        alias: pf.alias,
141                        expr: pf.expr,
142                    })
143                    .collect()
144            })
145            .unwrap_or_default();
146        let mut having = group.having;
147        let aggregates = extract_aggregates(&mut proj_fields, &mut having);
148
149        node = PlanNode::GroupBy {
150            input: Box::new(node),
151            keys: group.keys,
152            aggregates,
153            having,
154        };
155
156        if !proj_fields.is_empty() {
157            node = PlanNode::Project {
158                input: Box::new(node),
159                fields: proj_fields,
160            };
161        }
162
163        if let Some(order) = q.order {
164            node = PlanNode::Sort {
165                input: Box::new(node),
166                keys: order
167                    .keys
168                    .into_iter()
169                    .map(|k| SortKey {
170                        field: k.field,
171                        descending: k.descending,
172                    })
173                    .collect(),
174            };
175        }
176        // Offset must be applied *before* Limit: skip M rows, then take N.
177        // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
178        // and Limit wraps it (outer).
179        if let Some(off) = q.offset {
180            node = PlanNode::Offset {
181                input: Box::new(node),
182                count: off,
183            };
184        }
185        if let Some(lim) = q.limit {
186            node = PlanNode::Limit {
187                input: Box::new(node),
188                count: lim,
189            };
190        }
191        if q.distinct {
192            node = PlanNode::Distinct {
193                input: Box::new(node),
194            };
195        }
196        return Ok(node);
197    }
198
199    if let Some(order) = q.order {
200        node = PlanNode::Sort {
201            input: Box::new(node),
202            keys: order
203                .keys
204                .into_iter()
205                .map(|k| SortKey {
206                    field: k.field,
207                    descending: k.descending,
208                })
209                .collect(),
210        };
211    }
212
213    // Offset must be applied *before* Limit: skip M rows, then take N.
214    // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
215    // and Limit wraps it (outer).
216    if let Some(off) = q.offset {
217        node = PlanNode::Offset {
218            input: Box::new(node),
219            count: off,
220        };
221    }
222
223    if let Some(lim) = q.limit {
224        node = PlanNode::Limit {
225            input: Box::new(node),
226            count: lim,
227        };
228    }
229
230    if let Some(proj) = q.projection {
231        let mut fields: Vec<ProjectField> = proj
232            .into_iter()
233            .map(|pf| ProjectField {
234                alias: pf.alias,
235                expr: pf.expr,
236            })
237            .collect();
238        let windows = extract_windows(&mut fields);
239        if !windows.is_empty() {
240            node = PlanNode::Window {
241                input: Box::new(node),
242                windows,
243            };
244        }
245        node = PlanNode::Project {
246            input: Box::new(node),
247            fields,
248        };
249    }
250
251    if q.distinct {
252        node = PlanNode::Distinct {
253            input: Box::new(node),
254        };
255    }
256
257    if let Some(agg) = q.aggregation {
258        node = PlanNode::Aggregate {
259            input: Box::new(node),
260            function: agg.function,
261            field: agg.field,
262        };
263    }
264
265    Ok(node)
266}
267
268/// Build a left-deep nested-loop join plan for a query with 1+ join clauses.
269///
270/// The plan shape for `T1 as a [inner|left|cross] join T2 as b on <pred> ...` is:
271///
272///   Project? (optional, from q.projection)
273///   └─ Offset? / Limit? / Sort?
274///      └─ Filter? (the top-level q.filter, using qualified columns)
275///         └─ NestedLoopJoin { kind, on }
276///            ├─ AliasScan { T1, a }
277///            └─ AliasScan { T2, b }
278///
279/// Multi-join chains extend left-deep: a third join adds a second
280/// `NestedLoopJoin` on top, with the first join's output as its `left`.
281///
282/// Aliases default to the source table name when the query didn't write
283/// `as <name>` explicitly — that way users can always write `T.field`
284/// without being forced to alias every source.
285///
286/// RightOuter is rewritten into LeftOuter with inputs swapped — the two
287/// differ only in which side survives non-matching rows, and swapping
288/// inputs lets the executor ship a single LeftOuter path.
289fn plan_joined_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
290    let primary_alias = q.alias.clone().unwrap_or_else(|| q.source.clone());
291    let mut node = PlanNode::AliasScan {
292        table: q.source.clone(),
293        alias: primary_alias,
294    };
295
296    for join in q.joins {
297        let right_alias = join.alias.unwrap_or_else(|| join.source.clone());
298        let right = PlanNode::AliasScan {
299            table: join.source,
300            alias: right_alias,
301        };
302        match join.kind {
303            JoinKind::Inner | JoinKind::LeftOuter | JoinKind::Cross => {
304                node = PlanNode::NestedLoopJoin {
305                    left: Box::new(node),
306                    right: Box::new(right),
307                    on: join.on,
308                    kind: join.kind,
309                };
310            }
311            JoinKind::RightOuter => {
312                // `a RIGHT OUTER JOIN b ON <p>` ≡ `b LEFT OUTER JOIN a ON <p>`.
313                node = PlanNode::NestedLoopJoin {
314                    left: Box::new(right),
315                    right: Box::new(node),
316                    on: join.on,
317                    kind: JoinKind::LeftOuter,
318                };
319            }
320        }
321    }
322
323    if let Some(pred) = q.filter {
324        node = PlanNode::Filter {
325            input: Box::new(node),
326            predicate: pred,
327        };
328    }
329
330    if let Some(order) = q.order {
331        node = PlanNode::Sort {
332            input: Box::new(node),
333            keys: order
334                .keys
335                .into_iter()
336                .map(|k| SortKey {
337                    field: k.field,
338                    descending: k.descending,
339                })
340                .collect(),
341        };
342    }
343
344    // Offset must be applied *before* Limit: skip M rows, then take N.
345    // Plan shape is Limit(Offset(...)), so Offset is built first (inner)
346    // and Limit wraps it (outer).
347    if let Some(off) = q.offset {
348        node = PlanNode::Offset {
349            input: Box::new(node),
350            count: off,
351        };
352    }
353
354    if let Some(lim) = q.limit {
355        node = PlanNode::Limit {
356            input: Box::new(node),
357            count: lim,
358        };
359    }
360
361    // Mission E2b: GROUP BY path for joined queries.
362    if let Some(group) = q.group_by {
363        let mut proj_fields: Vec<ProjectField> = q
364            .projection
365            .map(|proj| {
366                proj.into_iter()
367                    .map(|pf| ProjectField {
368                        alias: pf.alias,
369                        expr: pf.expr,
370                    })
371                    .collect()
372            })
373            .unwrap_or_default();
374        let mut having = group.having;
375        let aggregates = extract_aggregates(&mut proj_fields, &mut having);
376
377        node = PlanNode::GroupBy {
378            input: Box::new(node),
379            keys: group.keys,
380            aggregates,
381            having,
382        };
383
384        if !proj_fields.is_empty() {
385            node = PlanNode::Project {
386                input: Box::new(node),
387                fields: proj_fields,
388            };
389        }
390        if q.distinct {
391            node = PlanNode::Distinct {
392                input: Box::new(node),
393            };
394        }
395        return Ok(node);
396    }
397
398    if let Some(proj) = q.projection {
399        let mut fields: Vec<ProjectField> = proj
400            .into_iter()
401            .map(|pf| ProjectField {
402                alias: pf.alias,
403                expr: pf.expr,
404            })
405            .collect();
406        let windows = extract_windows(&mut fields);
407        if !windows.is_empty() {
408            node = PlanNode::Window {
409                input: Box::new(node),
410                windows,
411            };
412        }
413        node = PlanNode::Project {
414            input: Box::new(node),
415            fields,
416        };
417    }
418
419    if q.distinct {
420        node = PlanNode::Distinct {
421            input: Box::new(node),
422        };
423    }
424
425    if let Some(agg) = q.aggregation {
426        node = PlanNode::Aggregate {
427            input: Box::new(node),
428            function: agg.function,
429            field: agg.field,
430        };
431    }
432
433    Ok(node)
434}
435
436fn plan_insert(ins: InsertExpr) -> Result<PlanNode, PlanError> {
437    Ok(PlanNode::Insert {
438        table: ins.target,
439        assignments: ins.assignments,
440    })
441}
442
443fn plan_update(upd: UpdateExpr) -> Result<PlanNode, PlanError> {
444    // Mirror the read-side IndexScan fold: when the update filter is a simple
445    // `.col = literal`, emit `Update(IndexScan)` so the executor's index-lookup
446    // mutation fast path fires. The executor falls back to a scan if the
447    // column happens to lack an index, so this is always safe.
448    let source = match upd.filter {
449        Some(pred) => match try_extract_eq_index_key(&upd.source, &pred) {
450            Some(index_scan) => index_scan,
451            None => match try_extract_range_index_keys(&upd.source, &pred) {
452                Some(range_scan) => range_scan,
453                None => PlanNode::Filter {
454                    input: Box::new(PlanNode::SeqScan {
455                        table: upd.source.clone(),
456                    }),
457                    predicate: pred,
458                },
459            },
460        },
461        None => PlanNode::SeqScan {
462            table: upd.source.clone(),
463        },
464    };
465    Ok(PlanNode::Update {
466        input: Box::new(source),
467        table: upd.source,
468        assignments: upd.assignments,
469    })
470}
471
472fn plan_delete(del: DeleteExpr) -> Result<PlanNode, PlanError> {
473    let source = match del.filter {
474        Some(pred) => match try_extract_eq_index_key(&del.source, &pred) {
475            Some(index_scan) => index_scan,
476            None => match try_extract_range_index_keys(&del.source, &pred) {
477                Some(range_scan) => range_scan,
478                None => PlanNode::Filter {
479                    input: Box::new(PlanNode::SeqScan {
480                        table: del.source.clone(),
481                    }),
482                    predicate: pred,
483                },
484            },
485        },
486        None => PlanNode::SeqScan {
487            table: del.source.clone(),
488        },
489    };
490    Ok(PlanNode::Delete {
491        input: Box::new(source),
492        table: del.source,
493    })
494}
495
496fn plan_upsert(ups: UpsertExpr) -> Result<PlanNode, PlanError> {
497    Ok(PlanNode::Upsert {
498        table: ups.target,
499        key_column: ups.key_column,
500        assignments: ups.assignments,
501        on_conflict: ups.on_conflict,
502    })
503}
504
505fn plan_create_type(ct: CreateTypeExpr) -> Result<PlanNode, PlanError> {
506    let fields = ct
507        .fields
508        .into_iter()
509        .map(|f| (f.name, f.type_name, f.required))
510        .collect();
511    Ok(PlanNode::CreateTable {
512        name: ct.name,
513        fields,
514    })
515}
516
517/// If the predicate is a simple `.field = literal` (or `literal = .field`),
518/// return a corresponding IndexScan plan node. Otherwise return None so the
519/// caller can fall through to SeqScan + Filter.
520///
521/// The executor decides at run time whether the named column actually has a
522/// B-tree index — if not, IndexScan transparently falls back to a scan +
523/// equality filter on that column. That means this rewrite is always safe
524/// regardless of schema/index state; it just unlocks the fast path when an
525/// index happens to exist.
526fn try_extract_eq_index_key(table: &str, pred: &Expr) -> Option<PlanNode> {
527    let (lhs, op, rhs) = match pred {
528        Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
529        _ => return None,
530    };
531    if op != BinOp::Eq {
532        return None;
533    }
534    let (column, key) = match (lhs, rhs) {
535        (Expr::Field(name), Expr::Literal(_)) => (name.clone(), rhs.clone()),
536        (Expr::Literal(_), Expr::Field(name)) => (name.clone(), lhs.clone()),
537        _ => return None,
538    };
539    Some(PlanNode::IndexScan {
540        table: table.to_string(),
541        column,
542        key,
543    })
544}
545
546/// Extract a single range bound from a simple inequality predicate.
547/// Returns `(column, lower_bound, upper_bound)` where at most one bound is set.
548fn extract_single_bound(pred: &Expr) -> Option<RangeBound> {
549    let (lhs, op, rhs) = match pred {
550        Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
551        _ => return None,
552    };
553    match op {
554        // .col > literal  →  lower=(literal, exclusive)
555        BinOp::Gt => match (lhs, rhs) {
556            (Expr::Field(name), Expr::Literal(_)) => {
557                Some((name.clone(), Some((rhs.clone(), false)), None))
558            }
559            (Expr::Literal(_), Expr::Field(name)) => {
560                // literal > .col  →  col < literal  →  upper=(literal, exclusive)
561                Some((name.clone(), None, Some((lhs.clone(), false))))
562            }
563            _ => None,
564        },
565        // .col >= literal  →  lower=(literal, inclusive)
566        BinOp::Gte => match (lhs, rhs) {
567            (Expr::Field(name), Expr::Literal(_)) => {
568                Some((name.clone(), Some((rhs.clone(), true)), None))
569            }
570            (Expr::Literal(_), Expr::Field(name)) => {
571                Some((name.clone(), None, Some((lhs.clone(), true))))
572            }
573            _ => None,
574        },
575        // .col < literal  →  upper=(literal, exclusive)
576        BinOp::Lt => match (lhs, rhs) {
577            (Expr::Field(name), Expr::Literal(_)) => {
578                Some((name.clone(), None, Some((rhs.clone(), false))))
579            }
580            (Expr::Literal(_), Expr::Field(name)) => {
581                Some((name.clone(), Some((lhs.clone(), false)), None))
582            }
583            _ => None,
584        },
585        // .col <= literal  →  upper=(literal, inclusive)
586        BinOp::Lte => match (lhs, rhs) {
587            (Expr::Field(name), Expr::Literal(_)) => {
588                Some((name.clone(), None, Some((rhs.clone(), true))))
589            }
590            (Expr::Literal(_), Expr::Field(name)) => {
591                Some((name.clone(), Some((lhs.clone(), true)), None))
592            }
593            _ => None,
594        },
595        _ => None,
596    }
597}
598
599/// If the predicate is an inequality or a conjunction of two inequalities
600/// on the same indexed column, return a RangeScan plan node.
601/// Handles: `.col > lit`, `.col >= lit`, `.col < lit`, `.col <= lit`,
602/// and AND-conjunctions like `.col >= low AND .col <= high` (BETWEEN pattern).
603fn try_extract_range_index_keys(table: &str, pred: &Expr) -> Option<PlanNode> {
604    // Case 1: AND conjunction — try to merge two bounds on the same column.
605    if let Expr::BinaryOp(lhs, BinOp::And, rhs) = pred {
606        if let (Some((col1, s1, e1)), Some((col2, s2, e2))) =
607            (extract_single_bound(lhs), extract_single_bound(rhs))
608        {
609            if col1 == col2 {
610                let start = s1.or(s2);
611                let end = e1.or(e2);
612                if start.is_some() || end.is_some() {
613                    return Some(PlanNode::RangeScan {
614                        table: table.to_string(),
615                        column: col1,
616                        start,
617                        end,
618                    });
619                }
620            }
621        }
622    }
623
624    // Case 2: single inequality.
625    if let Some((col, start, end)) = extract_single_bound(pred) {
626        return Some(PlanNode::RangeScan {
627            table: table.to_string(),
628            column: col,
629            start,
630            end,
631        });
632    }
633
634    None
635}
636
637/// Walk projection fields, replacing every `Expr::Window { .. }` with
638/// `Expr::Field("__win_N")` and collecting the corresponding `WindowDef`
639/// descriptors. Returns the list of window definitions to insert as a
640/// `PlanNode::Window` before the `Project` node.
641fn extract_windows(proj_fields: &mut [ProjectField]) -> Vec<WindowDef> {
642    let mut defs = Vec::new();
643    let mut counter = 0usize;
644    for f in proj_fields.iter_mut() {
645        if let Expr::Window {
646            function,
647            args,
648            partition_by,
649            order_by,
650        } = &f.expr
651        {
652            let output_name = format!("__win_{counter}");
653            defs.push(WindowDef {
654                function: *function,
655                args: args.clone(),
656                partition_by: partition_by.clone(),
657                order_by: order_by
658                    .iter()
659                    .map(|k| SortKey {
660                        field: k.field.clone(),
661                        descending: k.descending,
662                    })
663                    .collect(),
664                output_name: output_name.clone(),
665            });
666            f.expr = Expr::Field(output_name);
667            counter += 1;
668        }
669    }
670    defs
671}
672
673/// Walk projection fields and HAVING expression, replacing every
674/// `Expr::FunctionCall(func, Field(col))` with `Expr::Field("__agg_N")`
675/// and collecting the corresponding `GroupAgg` descriptors. Deduplicates:
676/// if the same (func, field) pair appears in both projection and HAVING,
677/// they share a single `GroupAgg` entry.
678fn extract_aggregates(
679    proj_fields: &mut [ProjectField],
680    having: &mut Option<Expr>,
681) -> Vec<GroupAgg> {
682    let mut aggs: Vec<GroupAgg> = Vec::new();
683    let mut counter = 0usize;
684    for f in proj_fields.iter_mut() {
685        rewrite_agg_expr(&mut f.expr, &mut aggs, &mut counter);
686    }
687    if let Some(h) = having {
688        rewrite_agg_expr(h, &mut aggs, &mut counter);
689    }
690    aggs
691}
692
693fn rewrite_agg_expr(expr: &mut Expr, aggs: &mut Vec<GroupAgg>, counter: &mut usize) {
694    match expr {
695        Expr::FunctionCall(func, inner) => {
696            if let Expr::Field(name) = inner.as_ref() {
697                let output = find_or_insert_agg(aggs, *func, name, counter);
698                *expr = Expr::Field(output);
699            }
700        }
701        Expr::BinaryOp(l, _, r) => {
702            rewrite_agg_expr(l, aggs, counter);
703            rewrite_agg_expr(r, aggs, counter);
704        }
705        Expr::UnaryOp(_, inner) => rewrite_agg_expr(inner, aggs, counter),
706        Expr::Coalesce(l, r) => {
707            rewrite_agg_expr(l, aggs, counter);
708            rewrite_agg_expr(r, aggs, counter);
709        }
710        Expr::InList { expr: e, list, .. } => {
711            rewrite_agg_expr(e, aggs, counter);
712            for item in list {
713                rewrite_agg_expr(item, aggs, counter);
714            }
715        }
716        Expr::InSubquery { expr: e, .. } => {
717            rewrite_agg_expr(e, aggs, counter);
718        }
719        _ => {}
720    }
721}
722
723fn find_or_insert_agg(
724    aggs: &mut Vec<GroupAgg>,
725    func: AggFunc,
726    field: &str,
727    counter: &mut usize,
728) -> String {
729    for existing in aggs.iter() {
730        if existing.function == func && existing.field == field {
731            return existing.output_name.clone();
732        }
733    }
734    let output_name = format!("__agg_{counter}");
735    aggs.push(GroupAgg {
736        function: func,
737        field: field.to_string(),
738        output_name: output_name.clone(),
739    });
740    *counter += 1;
741    output_name
742}
743
744#[cfg(test)]
745mod tests {
746    use super::*;
747    use crate::plan::PlanNode;
748
749    #[test]
750    fn test_plan_simple_scan() {
751        let plan = plan("User").unwrap();
752        assert!(matches!(plan, PlanNode::SeqScan { table } if table == "User"));
753    }
754
755    #[test]
756    fn test_plan_filter() {
757        let plan = plan("User filter .age > 30").unwrap();
758        assert!(matches!(plan, PlanNode::RangeScan { .. }));
759    }
760
761    #[test]
762    fn test_plan_filter_with_projection() {
763        let plan = plan("User filter .age > 30 { name, email }").unwrap();
764        assert!(matches!(plan, PlanNode::Project { .. }));
765    }
766
767    #[test]
768    fn test_plan_insert() {
769        let plan = plan(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
770        assert!(matches!(plan, PlanNode::Insert { .. }));
771    }
772
773    #[test]
774    fn test_plan_order_limit() {
775        let plan = plan("User order .name limit 10").unwrap();
776        match plan {
777            PlanNode::Limit { input, .. } => {
778                assert!(matches!(*input, PlanNode::Sort { .. }));
779            }
780            _ => panic!("expected Limit(Sort(SeqScan))"),
781        }
782    }
783
784    #[test]
785    fn test_plan_count() {
786        let plan = plan("count(User)").unwrap();
787        assert!(matches!(plan, PlanNode::Aggregate { .. }));
788    }
789
790    #[test]
791    fn test_plan_eq_becomes_index_scan() {
792        // `filter .col = literal` should fold into an IndexScan — the executor
793        // falls back to a scan if the column happens to lack an index.
794        let plan = plan("User filter .id = 42").unwrap();
795        match plan {
796            PlanNode::IndexScan { table, column, key } => {
797                assert_eq!(table, "User");
798                assert_eq!(column, "id");
799                assert!(matches!(key, Expr::Literal(Literal::Int(42))));
800            }
801            other => panic!("expected IndexScan, got {other:?}"),
802        }
803    }
804
805    #[test]
806    fn test_plan_eq_reversed_becomes_index_scan() {
807        // Literal-on-the-left form should fold the same way.
808        let plan = plan(r#"User filter "NYC" = .city"#).unwrap();
809        assert!(matches!(plan, PlanNode::IndexScan { .. }));
810    }
811
812    #[test]
813    fn test_plan_non_eq_stays_filter() {
814        // `>` now emits a RangeScan instead of SeqScan+Filter.
815        let plan = plan("User filter .age > 30").unwrap();
816        match plan {
817            PlanNode::RangeScan {
818                column, start, end, ..
819            } => {
820                assert_eq!(column, "age");
821                assert!(start.is_some(), "expected lower bound");
822                assert!(end.is_none(), "expected no upper bound");
823                let (_, inclusive) = start.unwrap();
824                assert!(!inclusive, "expected exclusive lower bound for >");
825            }
826            other => panic!("expected RangeScan, got {other:?}"),
827        }
828    }
829
830    #[test]
831    fn test_plan_index_scan_with_projection() {
832        // Projection on top of an IndexScan should layer correctly.
833        let plan = plan("User filter .id = 1 { .name }").unwrap();
834        match plan {
835            PlanNode::Project { input, .. } => {
836                assert!(matches!(*input, PlanNode::IndexScan { .. }));
837            }
838            other => panic!("expected Project(IndexScan), got {other:?}"),
839        }
840    }
841
842    #[test]
843    fn test_plan_update_by_pk_becomes_index_scan() {
844        // `.id = literal` update should fold to Update(IndexScan), not
845        // Update(Filter(SeqScan)).
846        let plan = plan("User filter .id = 42 update { age := 31 }").unwrap();
847        match plan {
848            PlanNode::Update { input, .. } => {
849                assert!(
850                    matches!(*input, PlanNode::IndexScan { .. }),
851                    "expected Update(IndexScan), got {input:?}"
852                );
853            }
854            other => panic!("expected Update, got {other:?}"),
855        }
856    }
857
858    #[test]
859    fn test_plan_update_range_stays_range_scan() {
860        let plan = plan("User filter .age > 30 update { age := 31 }").unwrap();
861        match plan {
862            PlanNode::Update { input, .. } => {
863                assert!(
864                    matches!(*input, PlanNode::RangeScan { .. }),
865                    "expected Update(RangeScan), got {input:?}"
866                );
867            }
868            other => panic!("expected Update, got {other:?}"),
869        }
870    }
871
872    #[test]
873    fn test_plan_delete_by_pk_becomes_index_scan() {
874        let plan = plan("User filter .id = 7 delete").unwrap();
875        match plan {
876            PlanNode::Delete { input, .. } => {
877                assert!(matches!(*input, PlanNode::IndexScan { .. }));
878            }
879            other => panic!("expected Delete, got {other:?}"),
880        }
881    }
882
883    #[test]
884    fn test_plan_inner_join_builds_nested_loop() {
885        // Mission E1.2: a join query should plan to NestedLoopJoin with
886        // AliasScan leaves on both sides.
887        let plan = plan("User as u join Order as o on u.id = o.user_id").unwrap();
888        match plan {
889            PlanNode::NestedLoopJoin {
890                left,
891                right,
892                on,
893                kind,
894            } => {
895                assert_eq!(kind, JoinKind::Inner);
896                assert!(on.is_some());
897                assert!(matches!(*left, PlanNode::AliasScan { .. }));
898                assert!(matches!(*right, PlanNode::AliasScan { .. }));
899            }
900            other => panic!("expected NestedLoopJoin, got {other:?}"),
901        }
902    }
903
904    #[test]
905    fn test_plan_right_join_rewritten_as_left_with_swapped_inputs() {
906        let plan = plan("User as u right join Order as o on u.id = o.user_id").unwrap();
907        match plan {
908            PlanNode::NestedLoopJoin {
909                left, right, kind, ..
910            } => {
911                assert_eq!(kind, JoinKind::LeftOuter);
912                // Swapped: Order is now on the left, User on the right.
913                match *left {
914                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "Order"),
915                    other => panic!("expected AliasScan(Order), got {other:?}"),
916                }
917                match *right {
918                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "User"),
919                    other => panic!("expected AliasScan(User), got {other:?}"),
920                }
921            }
922            other => panic!("expected NestedLoopJoin, got {other:?}"),
923        }
924    }
925
926    #[test]
927    fn test_plan_multi_join_is_left_deep() {
928        // Three sources → two NestedLoopJoins, left-deep.
929        let plan = plan(
930            "User as u join Order as o on u.id = o.user_id \
931             join Product as p on o.product_id = p.id",
932        )
933        .unwrap();
934        match plan {
935            PlanNode::NestedLoopJoin { left, right, .. } => {
936                // Outer (Product) join: right is AliasScan(Product)
937                match *right {
938                    PlanNode::AliasScan { table, .. } => assert_eq!(table, "Product"),
939                    other => panic!("expected AliasScan(Product), got {other:?}"),
940                }
941                // Outer.left is inner (Order) NestedLoopJoin
942                assert!(matches!(*left, PlanNode::NestedLoopJoin { .. }));
943            }
944            other => panic!("expected NestedLoopJoin, got {other:?}"),
945        }
946    }
947
948    #[test]
949    fn test_plan_join_with_filter_tail_wraps_filter_on_top() {
950        let plan =
951            plan("User as u join Order as o on u.id = o.user_id filter o.total > 100").unwrap();
952        match plan {
953            PlanNode::Filter { input, .. } => {
954                assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
955            }
956            other => panic!("expected Filter(NestedLoopJoin), got {other:?}"),
957        }
958    }
959
960    #[test]
961    fn test_plan_group_by_builds_groupby_node() {
962        let plan = plan("User group .status { .status, n: count(.name) }").unwrap();
963        // Should be Project(GroupBy(SeqScan)).
964        match plan {
965            PlanNode::Project { input, fields } => {
966                assert_eq!(fields.len(), 2);
967                match *input {
968                    PlanNode::GroupBy {
969                        input: inner,
970                        keys,
971                        aggregates,
972                        having,
973                    } => {
974                        assert!(matches!(*inner, PlanNode::SeqScan { .. }));
975                        assert_eq!(keys, vec!["status"]);
976                        assert_eq!(aggregates.len(), 1);
977                        assert_eq!(aggregates[0].function, AggFunc::Count);
978                        assert_eq!(aggregates[0].field, "name");
979                        assert!(having.is_none());
980                    }
981                    other => panic!("expected GroupBy, got {other:?}"),
982                }
983            }
984            other => panic!("expected Project, got {other:?}"),
985        }
986    }
987
988    #[test]
989    fn test_plan_group_by_having_rewrites_agg_in_having() {
990        let plan = plan("User group .status having count(.name) > 1 { .status }").unwrap();
991        match plan {
992            PlanNode::Project { input, .. } => {
993                match *input {
994                    PlanNode::GroupBy {
995                        having, aggregates, ..
996                    } => {
997                        // The planner should have extracted count(.name) into
998                        // aggregates and rewritten the HAVING to reference __agg_0.
999                        assert_eq!(aggregates.len(), 1);
1000                        assert_eq!(aggregates[0].output_name, "__agg_0");
1001                        let h = having.expect("having should be Some");
1002                        match h {
1003                            Expr::BinaryOp(l, BinOp::Gt, _) => {
1004                                assert!(
1005                                    matches!(*l, Expr::Field(ref name) if name == "__agg_0"),
1006                                    "expected Field(__agg_0), got {l:?}"
1007                                );
1008                            }
1009                            other => panic!("expected BinaryOp, got {other:?}"),
1010                        }
1011                    }
1012                    other => panic!("expected GroupBy, got {other:?}"),
1013                }
1014            }
1015            other => panic!("expected Project, got {other:?}"),
1016        }
1017    }
1018
1019    #[test]
1020    fn test_plan_window_inserts_window_node_before_project() {
1021        let plan = plan("User { .name, rn: row_number() over (order .age) }").unwrap();
1022        // Expected shape: Project(Window(SeqScan))
1023        match plan {
1024            PlanNode::Project { input, fields } => {
1025                assert_eq!(fields.len(), 2);
1026                // The window expr should have been replaced with Field("__win_0")
1027                assert!(
1028                    matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"),
1029                    "expected Field(__win_0), got {:?}",
1030                    fields[1].expr
1031                );
1032                match *input {
1033                    PlanNode::Window {
1034                        input: inner,
1035                        windows,
1036                    } => {
1037                        assert_eq!(windows.len(), 1);
1038                        assert_eq!(windows[0].output_name, "__win_0");
1039                        assert!(matches!(*inner, PlanNode::SeqScan { .. }));
1040                    }
1041                    other => panic!("expected Window, got {other:?}"),
1042                }
1043            }
1044            other => panic!("expected Project, got {other:?}"),
1045        }
1046    }
1047
1048    #[test]
1049    fn test_plan_multiple_windows() {
1050        let plan = plan(
1051            "User { .name, rn: row_number() over (order .age), s: sum(.salary) over (partition .dept order .salary) }"
1052        ).unwrap();
1053        match plan {
1054            PlanNode::Project { input, fields } => {
1055                assert_eq!(fields.len(), 3);
1056                assert!(matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"));
1057                assert!(matches!(&fields[2].expr, Expr::Field(name) if name == "__win_1"));
1058                match *input {
1059                    PlanNode::Window { windows, .. } => {
1060                        assert_eq!(windows.len(), 2);
1061                        assert_eq!(windows[0].output_name, "__win_0");
1062                        assert_eq!(windows[1].output_name, "__win_1");
1063                    }
1064                    other => panic!("expected Window, got {other:?}"),
1065                }
1066            }
1067            other => panic!("expected Project, got {other:?}"),
1068        }
1069    }
1070
1071    #[test]
1072    fn test_plan_no_window_without_over() {
1073        // Plain aggregate in projection should not create a Window node.
1074        let plan = plan("User group .dept { .dept, total: sum(.salary) }").unwrap();
1075        match plan {
1076            PlanNode::Project { input, .. } => {
1077                // Input should be GroupBy, not Window.
1078                assert!(
1079                    matches!(*input, PlanNode::GroupBy { .. }),
1080                    "expected GroupBy under Project, got {:?}",
1081                    input
1082                );
1083            }
1084            other => panic!("expected Project, got {other:?}"),
1085        }
1086    }
1087
1088    #[test]
1089    fn test_plan_explain_wraps_inner() {
1090        let plan = plan("explain User filter .age > 30").unwrap();
1091        match plan {
1092            PlanNode::Explain { input } => {
1093                assert!(
1094                    matches!(*input, PlanNode::RangeScan { .. }),
1095                    "expected Explain(RangeScan), got {:?}",
1096                    input
1097                );
1098            }
1099            other => panic!("expected Explain, got {other:?}"),
1100        }
1101    }
1102
1103    #[test]
1104    fn test_plan_explain_simple_scan() {
1105        let plan = plan("explain User").unwrap();
1106        match plan {
1107            PlanNode::Explain { input } => {
1108                assert!(matches!(*input, PlanNode::SeqScan { .. }));
1109            }
1110            other => panic!("expected Explain(SeqScan), got {other:?}"),
1111        }
1112    }
1113
1114    #[test]
1115    fn test_plan_explain_join() {
1116        let plan = plan("explain User as u join Order as o on u.id = o.user_id").unwrap();
1117        match plan {
1118            PlanNode::Explain { input } => {
1119                assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
1120            }
1121            other => panic!("expected Explain(NestedLoopJoin), got {other:?}"),
1122        }
1123    }
1124}