1use crate::ast::*;
2use crate::parser::{parse, ParseError};
3use crate::plan::*;
4
5type RangeBound = (String, Option<(Expr, bool)>, Option<(Expr, bool)>);
7
8#[derive(Debug)]
10pub enum PlanError {
11 Parse(ParseError),
13}
14
15impl PlanError {
16 pub fn message(&self) -> String {
18 self.to_string()
19 }
20}
21
22impl std::fmt::Display for PlanError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::Parse(e) => write!(f, "{e}"),
26 }
27 }
28}
29
30impl std::error::Error for PlanError {
31 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32 match self {
33 Self::Parse(e) => Some(e),
34 }
35 }
36}
37
38impl From<ParseError> for PlanError {
39 fn from(e: ParseError) -> Self {
40 PlanError::Parse(e)
41 }
42}
43
44pub fn plan(input: &str) -> Result<PlanNode, PlanError> {
45 let stmt = parse(input)?;
46 plan_statement(stmt)
47}
48
49pub fn plan_statement(stmt: Statement) -> Result<PlanNode, PlanError> {
50 match stmt {
51 Statement::Query(q) => plan_query(q),
52 Statement::Insert(ins) => plan_insert(ins),
53 Statement::UpdateQuery(upd) => plan_update(upd),
54 Statement::DeleteQuery(del) => plan_delete(del),
55 Statement::CreateType(ct) => plan_create_type(ct),
56 Statement::AlterTable(at) => Ok(PlanNode::AlterTable {
57 table: at.table,
58 action: at.action,
59 }),
60 Statement::DropTable(dt) => Ok(PlanNode::DropTable { name: dt.table }),
61 Statement::CreateView(cv) => Ok(PlanNode::CreateView {
62 name: cv.name,
63 query_text: cv.query_text,
64 }),
65 Statement::RefreshView(rv) => Ok(PlanNode::RefreshView { name: rv.name }),
66 Statement::DropView(dv) => Ok(PlanNode::DropView { name: dv.name }),
67 Statement::Union(u) => {
68 let left = plan_statement(*u.left)?;
69 let right = plan_statement(*u.right)?;
70 Ok(PlanNode::Union {
71 left: Box::new(left),
72 right: Box::new(right),
73 all: u.all,
74 })
75 }
76 Statement::Upsert(ups) => plan_upsert(ups),
77 Statement::Begin => Ok(PlanNode::Begin),
78 Statement::Commit => Ok(PlanNode::Commit),
79 Statement::Rollback => Ok(PlanNode::Rollback),
80 Statement::Explain(inner) => {
81 let inner_plan = plan_statement(*inner)?;
82 Ok(PlanNode::Explain {
83 input: Box::new(inner_plan),
84 })
85 }
86 }
87}
88
89fn plan_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
90 if !q.joins.is_empty() {
96 return plan_joined_query(q);
97 }
98 let (source, filter) = match q.filter {
107 Some(pred) => match try_extract_eq_index_key(&q.source, &pred) {
108 Some(index_scan) => (index_scan, None),
109 None => match try_extract_range_index_keys(&q.source, &pred) {
110 Some(range_scan) => (range_scan, None),
111 None => (
112 PlanNode::SeqScan {
113 table: q.source.clone(),
114 },
115 Some(pred),
116 ),
117 },
118 },
119 None => (
120 PlanNode::SeqScan {
121 table: q.source.clone(),
122 },
123 None,
124 ),
125 };
126 let mut node = source;
127
128 if let Some(pred) = filter {
129 node = PlanNode::Filter {
130 input: Box::new(node),
131 predicate: pred,
132 };
133 }
134
135 if let Some(group) = q.group_by {
138 let mut proj_fields: Vec<ProjectField> = q
139 .projection
140 .map(|proj| {
141 proj.into_iter()
142 .map(|pf| ProjectField {
143 alias: pf.alias,
144 expr: pf.expr,
145 })
146 .collect()
147 })
148 .unwrap_or_default();
149 let mut having = group.having;
150 let aggregates = extract_aggregates(&mut proj_fields, &mut having);
151
152 node = PlanNode::GroupBy {
153 input: Box::new(node),
154 keys: group.keys,
155 aggregates,
156 having,
157 };
158
159 if !proj_fields.is_empty() {
160 node = PlanNode::Project {
161 input: Box::new(node),
162 fields: proj_fields,
163 };
164 }
165
166 if let Some(order) = q.order {
167 node = PlanNode::Sort {
168 input: Box::new(node),
169 keys: order
170 .keys
171 .into_iter()
172 .map(|k| SortKey {
173 field: k.field,
174 descending: k.descending,
175 })
176 .collect(),
177 };
178 }
179 if let Some(off) = q.offset {
183 node = PlanNode::Offset {
184 input: Box::new(node),
185 count: off,
186 };
187 }
188 if let Some(lim) = q.limit {
189 node = PlanNode::Limit {
190 input: Box::new(node),
191 count: lim,
192 };
193 }
194 if q.distinct {
195 node = PlanNode::Distinct {
196 input: Box::new(node),
197 };
198 }
199 return Ok(node);
200 }
201
202 if let Some(order) = q.order {
203 node = PlanNode::Sort {
204 input: Box::new(node),
205 keys: order
206 .keys
207 .into_iter()
208 .map(|k| SortKey {
209 field: k.field,
210 descending: k.descending,
211 })
212 .collect(),
213 };
214 }
215
216 if let Some(off) = q.offset {
220 node = PlanNode::Offset {
221 input: Box::new(node),
222 count: off,
223 };
224 }
225
226 if let Some(lim) = q.limit {
227 node = PlanNode::Limit {
228 input: Box::new(node),
229 count: lim,
230 };
231 }
232
233 if let Some(proj) = q.projection {
234 let mut fields: Vec<ProjectField> = proj
235 .into_iter()
236 .map(|pf| ProjectField {
237 alias: pf.alias,
238 expr: pf.expr,
239 })
240 .collect();
241 let windows = extract_windows(&mut fields);
242 if !windows.is_empty() {
243 node = PlanNode::Window {
244 input: Box::new(node),
245 windows,
246 };
247 }
248 node = PlanNode::Project {
249 input: Box::new(node),
250 fields,
251 };
252 }
253
254 if q.distinct {
255 node = PlanNode::Distinct {
256 input: Box::new(node),
257 };
258 }
259
260 if let Some(agg) = q.aggregation {
261 node = PlanNode::Aggregate {
262 input: Box::new(node),
263 function: agg.function,
264 field: agg.field,
265 };
266 }
267
268 Ok(node)
269}
270
271fn plan_joined_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
293 let primary_alias = q.alias.clone().unwrap_or_else(|| q.source.clone());
294 let mut node = PlanNode::AliasScan {
295 table: q.source.clone(),
296 alias: primary_alias,
297 };
298
299 for join in q.joins {
300 let right_alias = join.alias.unwrap_or_else(|| join.source.clone());
301 let right = PlanNode::AliasScan {
302 table: join.source,
303 alias: right_alias,
304 };
305 match join.kind {
306 JoinKind::Inner | JoinKind::LeftOuter | JoinKind::Cross => {
307 node = PlanNode::NestedLoopJoin {
308 left: Box::new(node),
309 right: Box::new(right),
310 on: join.on,
311 kind: join.kind,
312 };
313 }
314 JoinKind::RightOuter => {
315 node = PlanNode::NestedLoopJoin {
317 left: Box::new(right),
318 right: Box::new(node),
319 on: join.on,
320 kind: JoinKind::LeftOuter,
321 };
322 }
323 }
324 }
325
326 if let Some(pred) = q.filter {
327 node = PlanNode::Filter {
328 input: Box::new(node),
329 predicate: pred,
330 };
331 }
332
333 if let Some(order) = q.order {
334 node = PlanNode::Sort {
335 input: Box::new(node),
336 keys: order
337 .keys
338 .into_iter()
339 .map(|k| SortKey {
340 field: k.field,
341 descending: k.descending,
342 })
343 .collect(),
344 };
345 }
346
347 if let Some(off) = q.offset {
351 node = PlanNode::Offset {
352 input: Box::new(node),
353 count: off,
354 };
355 }
356
357 if let Some(lim) = q.limit {
358 node = PlanNode::Limit {
359 input: Box::new(node),
360 count: lim,
361 };
362 }
363
364 if let Some(group) = q.group_by {
366 let mut proj_fields: Vec<ProjectField> = q
367 .projection
368 .map(|proj| {
369 proj.into_iter()
370 .map(|pf| ProjectField {
371 alias: pf.alias,
372 expr: pf.expr,
373 })
374 .collect()
375 })
376 .unwrap_or_default();
377 let mut having = group.having;
378 let aggregates = extract_aggregates(&mut proj_fields, &mut having);
379
380 node = PlanNode::GroupBy {
381 input: Box::new(node),
382 keys: group.keys,
383 aggregates,
384 having,
385 };
386
387 if !proj_fields.is_empty() {
388 node = PlanNode::Project {
389 input: Box::new(node),
390 fields: proj_fields,
391 };
392 }
393 if q.distinct {
394 node = PlanNode::Distinct {
395 input: Box::new(node),
396 };
397 }
398 return Ok(node);
399 }
400
401 if let Some(proj) = q.projection {
402 let mut fields: Vec<ProjectField> = proj
403 .into_iter()
404 .map(|pf| ProjectField {
405 alias: pf.alias,
406 expr: pf.expr,
407 })
408 .collect();
409 let windows = extract_windows(&mut fields);
410 if !windows.is_empty() {
411 node = PlanNode::Window {
412 input: Box::new(node),
413 windows,
414 };
415 }
416 node = PlanNode::Project {
417 input: Box::new(node),
418 fields,
419 };
420 }
421
422 if q.distinct {
423 node = PlanNode::Distinct {
424 input: Box::new(node),
425 };
426 }
427
428 if let Some(agg) = q.aggregation {
429 node = PlanNode::Aggregate {
430 input: Box::new(node),
431 function: agg.function,
432 field: agg.field,
433 };
434 }
435
436 Ok(node)
437}
438
439fn plan_insert(ins: InsertExpr) -> Result<PlanNode, PlanError> {
440 Ok(PlanNode::Insert {
441 table: ins.target,
442 rows: ins.rows,
443 })
444}
445
446fn plan_update(upd: UpdateExpr) -> Result<PlanNode, PlanError> {
447 let source = match upd.filter {
452 Some(pred) => match try_extract_eq_index_key(&upd.source, &pred) {
453 Some(index_scan) => index_scan,
454 None => match try_extract_range_index_keys(&upd.source, &pred) {
455 Some(range_scan) => range_scan,
456 None => PlanNode::Filter {
457 input: Box::new(PlanNode::SeqScan {
458 table: upd.source.clone(),
459 }),
460 predicate: pred,
461 },
462 },
463 },
464 None => PlanNode::SeqScan {
465 table: upd.source.clone(),
466 },
467 };
468 Ok(PlanNode::Update {
469 input: Box::new(source),
470 table: upd.source,
471 assignments: upd.assignments,
472 })
473}
474
475fn plan_delete(del: DeleteExpr) -> Result<PlanNode, PlanError> {
476 let source = match del.filter {
477 Some(pred) => match try_extract_eq_index_key(&del.source, &pred) {
478 Some(index_scan) => index_scan,
479 None => match try_extract_range_index_keys(&del.source, &pred) {
480 Some(range_scan) => range_scan,
481 None => PlanNode::Filter {
482 input: Box::new(PlanNode::SeqScan {
483 table: del.source.clone(),
484 }),
485 predicate: pred,
486 },
487 },
488 },
489 None => PlanNode::SeqScan {
490 table: del.source.clone(),
491 },
492 };
493 Ok(PlanNode::Delete {
494 input: Box::new(source),
495 table: del.source,
496 })
497}
498
499fn plan_upsert(ups: UpsertExpr) -> Result<PlanNode, PlanError> {
500 Ok(PlanNode::Upsert {
501 table: ups.target,
502 key_column: ups.key_column,
503 assignments: ups.assignments,
504 on_conflict: ups.on_conflict,
505 })
506}
507
508fn plan_create_type(ct: CreateTypeExpr) -> Result<PlanNode, PlanError> {
509 let fields = ct
510 .fields
511 .into_iter()
512 .map(|f| crate::plan::CreateField {
513 name: f.name,
514 type_name: f.type_name,
515 required: f.required,
516 unique: f.unique,
517 })
518 .collect();
519 Ok(PlanNode::CreateTable {
520 name: ct.name,
521 fields,
522 })
523}
524
525fn try_extract_eq_index_key(table: &str, pred: &Expr) -> Option<PlanNode> {
535 let (lhs, op, rhs) = match pred {
536 Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
537 _ => return None,
538 };
539 if op != BinOp::Eq {
540 return None;
541 }
542 let (column, key) = match (lhs, rhs) {
543 (Expr::Field(name), Expr::Literal(_)) => (name.clone(), rhs.clone()),
544 (Expr::Literal(_), Expr::Field(name)) => (name.clone(), lhs.clone()),
545 _ => return None,
546 };
547 Some(PlanNode::IndexScan {
548 table: table.to_string(),
549 column,
550 key,
551 })
552}
553
554fn extract_single_bound(pred: &Expr) -> Option<RangeBound> {
557 let (lhs, op, rhs) = match pred {
558 Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
559 _ => return None,
560 };
561 match op {
562 BinOp::Gt => match (lhs, rhs) {
564 (Expr::Field(name), Expr::Literal(_)) => {
565 Some((name.clone(), Some((rhs.clone(), false)), None))
566 }
567 (Expr::Literal(_), Expr::Field(name)) => {
568 Some((name.clone(), None, Some((lhs.clone(), false))))
570 }
571 _ => None,
572 },
573 BinOp::Gte => match (lhs, rhs) {
575 (Expr::Field(name), Expr::Literal(_)) => {
576 Some((name.clone(), Some((rhs.clone(), true)), None))
577 }
578 (Expr::Literal(_), Expr::Field(name)) => {
579 Some((name.clone(), None, Some((lhs.clone(), true))))
580 }
581 _ => None,
582 },
583 BinOp::Lt => match (lhs, rhs) {
585 (Expr::Field(name), Expr::Literal(_)) => {
586 Some((name.clone(), None, Some((rhs.clone(), false))))
587 }
588 (Expr::Literal(_), Expr::Field(name)) => {
589 Some((name.clone(), Some((lhs.clone(), false)), None))
590 }
591 _ => None,
592 },
593 BinOp::Lte => match (lhs, rhs) {
595 (Expr::Field(name), Expr::Literal(_)) => {
596 Some((name.clone(), None, Some((rhs.clone(), true))))
597 }
598 (Expr::Literal(_), Expr::Field(name)) => {
599 Some((name.clone(), Some((lhs.clone(), true)), None))
600 }
601 _ => None,
602 },
603 _ => None,
604 }
605}
606
607fn try_extract_range_index_keys(table: &str, pred: &Expr) -> Option<PlanNode> {
612 if let Expr::BinaryOp(lhs, BinOp::And, rhs) = pred {
614 if let (Some((col1, s1, e1)), Some((col2, s2, e2))) =
615 (extract_single_bound(lhs), extract_single_bound(rhs))
616 {
617 if col1 == col2 {
618 let start = s1.or(s2);
619 let end = e1.or(e2);
620 if start.is_some() || end.is_some() {
621 return Some(PlanNode::RangeScan {
622 table: table.to_string(),
623 column: col1,
624 start,
625 end,
626 });
627 }
628 }
629 }
630 }
631
632 if let Some((col, start, end)) = extract_single_bound(pred) {
634 return Some(PlanNode::RangeScan {
635 table: table.to_string(),
636 column: col,
637 start,
638 end,
639 });
640 }
641
642 None
643}
644
645fn extract_windows(proj_fields: &mut [ProjectField]) -> Vec<WindowDef> {
650 let mut defs = Vec::new();
651 let mut counter = 0usize;
652 for f in proj_fields.iter_mut() {
653 if let Expr::Window {
654 function,
655 args,
656 partition_by,
657 order_by,
658 } = &f.expr
659 {
660 let output_name = format!("__win_{counter}");
661 defs.push(WindowDef {
662 function: *function,
663 args: args.clone(),
664 partition_by: partition_by.clone(),
665 order_by: order_by
666 .iter()
667 .map(|k| SortKey {
668 field: k.field.clone(),
669 descending: k.descending,
670 })
671 .collect(),
672 output_name: output_name.clone(),
673 });
674 f.expr = Expr::Field(output_name);
675 counter += 1;
676 }
677 }
678 defs
679}
680
681fn extract_aggregates(
687 proj_fields: &mut [ProjectField],
688 having: &mut Option<Expr>,
689) -> Vec<GroupAgg> {
690 let mut aggs: Vec<GroupAgg> = Vec::new();
691 let mut counter = 0usize;
692 for f in proj_fields.iter_mut() {
693 rewrite_agg_expr(&mut f.expr, &mut aggs, &mut counter);
694 }
695 if let Some(h) = having {
696 rewrite_agg_expr(h, &mut aggs, &mut counter);
697 }
698 aggs
699}
700
701fn rewrite_agg_expr(expr: &mut Expr, aggs: &mut Vec<GroupAgg>, counter: &mut usize) {
702 match expr {
703 Expr::FunctionCall(func, inner) => {
704 if let Expr::Field(name) = inner.as_ref() {
705 let output = find_or_insert_agg(aggs, *func, name, counter);
706 *expr = Expr::Field(output);
707 }
708 }
709 Expr::BinaryOp(l, _, r) => {
710 rewrite_agg_expr(l, aggs, counter);
711 rewrite_agg_expr(r, aggs, counter);
712 }
713 Expr::UnaryOp(_, inner) => rewrite_agg_expr(inner, aggs, counter),
714 Expr::Coalesce(l, r) => {
715 rewrite_agg_expr(l, aggs, counter);
716 rewrite_agg_expr(r, aggs, counter);
717 }
718 Expr::InList { expr: e, list, .. } => {
719 rewrite_agg_expr(e, aggs, counter);
720 for item in list {
721 rewrite_agg_expr(item, aggs, counter);
722 }
723 }
724 Expr::InSubquery { expr: e, .. } => {
725 rewrite_agg_expr(e, aggs, counter);
726 }
727 _ => {}
728 }
729}
730
731fn find_or_insert_agg(
732 aggs: &mut Vec<GroupAgg>,
733 func: AggFunc,
734 field: &str,
735 counter: &mut usize,
736) -> String {
737 for existing in aggs.iter() {
738 if existing.function == func && existing.field == field {
739 return existing.output_name.clone();
740 }
741 }
742 let output_name = format!("__agg_{counter}");
743 aggs.push(GroupAgg {
744 function: func,
745 field: field.to_string(),
746 output_name: output_name.clone(),
747 });
748 *counter += 1;
749 output_name
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755 use crate::plan::PlanNode;
756
757 #[test]
758 fn test_plan_simple_scan() {
759 let plan = plan("User").unwrap();
760 assert!(matches!(plan, PlanNode::SeqScan { table } if table == "User"));
761 }
762
763 #[test]
764 fn test_plan_filter() {
765 let plan = plan("User filter .age > 30").unwrap();
766 assert!(matches!(plan, PlanNode::RangeScan { .. }));
767 }
768
769 #[test]
770 fn test_plan_filter_with_projection() {
771 let plan = plan("User filter .age > 30 { name, email }").unwrap();
772 assert!(matches!(plan, PlanNode::Project { .. }));
773 }
774
775 #[test]
776 fn test_plan_insert() {
777 let plan = plan(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
778 assert!(matches!(plan, PlanNode::Insert { .. }));
779 }
780
781 #[test]
782 fn test_plan_order_limit() {
783 let plan = plan("User order .name limit 10").unwrap();
784 match plan {
785 PlanNode::Limit { input, .. } => {
786 assert!(matches!(*input, PlanNode::Sort { .. }));
787 }
788 _ => panic!("expected Limit(Sort(SeqScan))"),
789 }
790 }
791
792 #[test]
793 fn test_plan_count() {
794 let plan = plan("count(User)").unwrap();
795 assert!(matches!(plan, PlanNode::Aggregate { .. }));
796 }
797
798 #[test]
799 fn test_plan_eq_becomes_index_scan() {
800 let plan = plan("User filter .id = 42").unwrap();
803 match plan {
804 PlanNode::IndexScan { table, column, key } => {
805 assert_eq!(table, "User");
806 assert_eq!(column, "id");
807 assert!(matches!(key, Expr::Literal(Literal::Int(42))));
808 }
809 other => panic!("expected IndexScan, got {other:?}"),
810 }
811 }
812
813 #[test]
814 fn test_plan_eq_reversed_becomes_index_scan() {
815 let plan = plan(r#"User filter "NYC" = .city"#).unwrap();
817 assert!(matches!(plan, PlanNode::IndexScan { .. }));
818 }
819
820 #[test]
821 fn test_plan_non_eq_stays_filter() {
822 let plan = plan("User filter .age > 30").unwrap();
824 match plan {
825 PlanNode::RangeScan {
826 column, start, end, ..
827 } => {
828 assert_eq!(column, "age");
829 assert!(start.is_some(), "expected lower bound");
830 assert!(end.is_none(), "expected no upper bound");
831 let (_, inclusive) = start.unwrap();
832 assert!(!inclusive, "expected exclusive lower bound for >");
833 }
834 other => panic!("expected RangeScan, got {other:?}"),
835 }
836 }
837
838 #[test]
839 fn test_plan_index_scan_with_projection() {
840 let plan = plan("User filter .id = 1 { .name }").unwrap();
842 match plan {
843 PlanNode::Project { input, .. } => {
844 assert!(matches!(*input, PlanNode::IndexScan { .. }));
845 }
846 other => panic!("expected Project(IndexScan), got {other:?}"),
847 }
848 }
849
850 #[test]
851 fn test_plan_update_by_pk_becomes_index_scan() {
852 let plan = plan("User filter .id = 42 update { age := 31 }").unwrap();
855 match plan {
856 PlanNode::Update { input, .. } => {
857 assert!(
858 matches!(*input, PlanNode::IndexScan { .. }),
859 "expected Update(IndexScan), got {input:?}"
860 );
861 }
862 other => panic!("expected Update, got {other:?}"),
863 }
864 }
865
866 #[test]
867 fn test_plan_update_range_stays_range_scan() {
868 let plan = plan("User filter .age > 30 update { age := 31 }").unwrap();
869 match plan {
870 PlanNode::Update { input, .. } => {
871 assert!(
872 matches!(*input, PlanNode::RangeScan { .. }),
873 "expected Update(RangeScan), got {input:?}"
874 );
875 }
876 other => panic!("expected Update, got {other:?}"),
877 }
878 }
879
880 #[test]
881 fn test_plan_delete_by_pk_becomes_index_scan() {
882 let plan = plan("User filter .id = 7 delete").unwrap();
883 match plan {
884 PlanNode::Delete { input, .. } => {
885 assert!(matches!(*input, PlanNode::IndexScan { .. }));
886 }
887 other => panic!("expected Delete, got {other:?}"),
888 }
889 }
890
891 #[test]
892 fn test_plan_inner_join_builds_nested_loop() {
893 let plan = plan("User as u join Order as o on u.id = o.user_id").unwrap();
896 match plan {
897 PlanNode::NestedLoopJoin {
898 left,
899 right,
900 on,
901 kind,
902 } => {
903 assert_eq!(kind, JoinKind::Inner);
904 assert!(on.is_some());
905 assert!(matches!(*left, PlanNode::AliasScan { .. }));
906 assert!(matches!(*right, PlanNode::AliasScan { .. }));
907 }
908 other => panic!("expected NestedLoopJoin, got {other:?}"),
909 }
910 }
911
912 #[test]
913 fn test_plan_right_join_rewritten_as_left_with_swapped_inputs() {
914 let plan = plan("User as u right join Order as o on u.id = o.user_id").unwrap();
915 match plan {
916 PlanNode::NestedLoopJoin {
917 left, right, kind, ..
918 } => {
919 assert_eq!(kind, JoinKind::LeftOuter);
920 match *left {
922 PlanNode::AliasScan { table, .. } => assert_eq!(table, "Order"),
923 other => panic!("expected AliasScan(Order), got {other:?}"),
924 }
925 match *right {
926 PlanNode::AliasScan { table, .. } => assert_eq!(table, "User"),
927 other => panic!("expected AliasScan(User), got {other:?}"),
928 }
929 }
930 other => panic!("expected NestedLoopJoin, got {other:?}"),
931 }
932 }
933
934 #[test]
935 fn test_plan_multi_join_is_left_deep() {
936 let plan = plan(
938 "User as u join Order as o on u.id = o.user_id \
939 join Product as p on o.product_id = p.id",
940 )
941 .unwrap();
942 match plan {
943 PlanNode::NestedLoopJoin { left, right, .. } => {
944 match *right {
946 PlanNode::AliasScan { table, .. } => assert_eq!(table, "Product"),
947 other => panic!("expected AliasScan(Product), got {other:?}"),
948 }
949 assert!(matches!(*left, PlanNode::NestedLoopJoin { .. }));
951 }
952 other => panic!("expected NestedLoopJoin, got {other:?}"),
953 }
954 }
955
956 #[test]
957 fn test_plan_join_with_filter_tail_wraps_filter_on_top() {
958 let plan =
959 plan("User as u join Order as o on u.id = o.user_id filter o.total > 100").unwrap();
960 match plan {
961 PlanNode::Filter { input, .. } => {
962 assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
963 }
964 other => panic!("expected Filter(NestedLoopJoin), got {other:?}"),
965 }
966 }
967
968 #[test]
969 fn test_plan_group_by_builds_groupby_node() {
970 let plan = plan("User group .status { .status, n: count(.name) }").unwrap();
971 match plan {
973 PlanNode::Project { input, fields } => {
974 assert_eq!(fields.len(), 2);
975 match *input {
976 PlanNode::GroupBy {
977 input: inner,
978 keys,
979 aggregates,
980 having,
981 } => {
982 assert!(matches!(*inner, PlanNode::SeqScan { .. }));
983 assert_eq!(keys, vec!["status"]);
984 assert_eq!(aggregates.len(), 1);
985 assert_eq!(aggregates[0].function, AggFunc::Count);
986 assert_eq!(aggregates[0].field, "name");
987 assert!(having.is_none());
988 }
989 other => panic!("expected GroupBy, got {other:?}"),
990 }
991 }
992 other => panic!("expected Project, got {other:?}"),
993 }
994 }
995
996 #[test]
997 fn test_plan_group_by_having_rewrites_agg_in_having() {
998 let plan = plan("User group .status having count(.name) > 1 { .status }").unwrap();
999 match plan {
1000 PlanNode::Project { input, .. } => {
1001 match *input {
1002 PlanNode::GroupBy {
1003 having, aggregates, ..
1004 } => {
1005 assert_eq!(aggregates.len(), 1);
1008 assert_eq!(aggregates[0].output_name, "__agg_0");
1009 let h = having.expect("having should be Some");
1010 match h {
1011 Expr::BinaryOp(l, BinOp::Gt, _) => {
1012 assert!(
1013 matches!(*l, Expr::Field(ref name) if name == "__agg_0"),
1014 "expected Field(__agg_0), got {l:?}"
1015 );
1016 }
1017 other => panic!("expected BinaryOp, got {other:?}"),
1018 }
1019 }
1020 other => panic!("expected GroupBy, got {other:?}"),
1021 }
1022 }
1023 other => panic!("expected Project, got {other:?}"),
1024 }
1025 }
1026
1027 #[test]
1028 fn test_plan_window_inserts_window_node_before_project() {
1029 let plan = plan("User { .name, rn: row_number() over (order .age) }").unwrap();
1030 match plan {
1032 PlanNode::Project { input, fields } => {
1033 assert_eq!(fields.len(), 2);
1034 assert!(
1036 matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"),
1037 "expected Field(__win_0), got {:?}",
1038 fields[1].expr
1039 );
1040 match *input {
1041 PlanNode::Window {
1042 input: inner,
1043 windows,
1044 } => {
1045 assert_eq!(windows.len(), 1);
1046 assert_eq!(windows[0].output_name, "__win_0");
1047 assert!(matches!(*inner, PlanNode::SeqScan { .. }));
1048 }
1049 other => panic!("expected Window, got {other:?}"),
1050 }
1051 }
1052 other => panic!("expected Project, got {other:?}"),
1053 }
1054 }
1055
1056 #[test]
1057 fn test_plan_multiple_windows() {
1058 let plan = plan(
1059 "User { .name, rn: row_number() over (order .age), s: sum(.salary) over (partition .dept order .salary) }"
1060 ).unwrap();
1061 match plan {
1062 PlanNode::Project { input, fields } => {
1063 assert_eq!(fields.len(), 3);
1064 assert!(matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"));
1065 assert!(matches!(&fields[2].expr, Expr::Field(name) if name == "__win_1"));
1066 match *input {
1067 PlanNode::Window { windows, .. } => {
1068 assert_eq!(windows.len(), 2);
1069 assert_eq!(windows[0].output_name, "__win_0");
1070 assert_eq!(windows[1].output_name, "__win_1");
1071 }
1072 other => panic!("expected Window, got {other:?}"),
1073 }
1074 }
1075 other => panic!("expected Project, got {other:?}"),
1076 }
1077 }
1078
1079 #[test]
1080 fn test_plan_no_window_without_over() {
1081 let plan = plan("User group .dept { .dept, total: sum(.salary) }").unwrap();
1083 match plan {
1084 PlanNode::Project { input, .. } => {
1085 assert!(
1087 matches!(*input, PlanNode::GroupBy { .. }),
1088 "expected GroupBy under Project, got {:?}",
1089 input
1090 );
1091 }
1092 other => panic!("expected Project, got {other:?}"),
1093 }
1094 }
1095
1096 #[test]
1097 fn test_plan_explain_wraps_inner() {
1098 let plan = plan("explain User filter .age > 30").unwrap();
1099 match plan {
1100 PlanNode::Explain { input } => {
1101 assert!(
1102 matches!(*input, PlanNode::RangeScan { .. }),
1103 "expected Explain(RangeScan), got {:?}",
1104 input
1105 );
1106 }
1107 other => panic!("expected Explain, got {other:?}"),
1108 }
1109 }
1110
1111 #[test]
1112 fn test_plan_explain_simple_scan() {
1113 let plan = plan("explain User").unwrap();
1114 match plan {
1115 PlanNode::Explain { input } => {
1116 assert!(matches!(*input, PlanNode::SeqScan { .. }));
1117 }
1118 other => panic!("expected Explain(SeqScan), got {other:?}"),
1119 }
1120 }
1121
1122 #[test]
1123 fn test_plan_explain_join() {
1124 let plan = plan("explain User as u join Order as o on u.id = o.user_id").unwrap();
1125 match plan {
1126 PlanNode::Explain { input } => {
1127 assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
1128 }
1129 other => panic!("expected Explain(NestedLoopJoin), got {other:?}"),
1130 }
1131 }
1132}