1use crate::ast::*;
2use crate::parser::{parse, ParseError};
3use crate::plan::*;
4
5type RangeBound = (String, Option<(Expr, bool)>, Option<(Expr, bool)>);
7
8#[derive(Debug)]
10pub enum PlanError {
11 Parse(ParseError),
13}
14
15impl PlanError {
16 pub fn message(&self) -> String {
18 self.to_string()
19 }
20}
21
22impl std::fmt::Display for PlanError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::Parse(e) => write!(f, "{e}"),
26 }
27 }
28}
29
30impl std::error::Error for PlanError {
31 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32 match self {
33 Self::Parse(e) => Some(e),
34 }
35 }
36}
37
38impl From<ParseError> for PlanError {
39 fn from(e: ParseError) -> Self {
40 PlanError::Parse(e)
41 }
42}
43
44pub fn plan(input: &str) -> Result<PlanNode, PlanError> {
45 let stmt = parse(input)?;
46 plan_statement(stmt)
47}
48
49pub fn plan_statement(stmt: Statement) -> Result<PlanNode, PlanError> {
50 match stmt {
51 Statement::Query(q) => plan_query(q),
52 Statement::Insert(ins) => plan_insert(ins),
53 Statement::UpdateQuery(upd) => plan_update(upd),
54 Statement::DeleteQuery(del) => plan_delete(del),
55 Statement::CreateType(ct) => plan_create_type(ct),
56 Statement::AlterTable(at) => Ok(PlanNode::AlterTable {
57 table: at.table,
58 action: at.action,
59 }),
60 Statement::DropTable(dt) => Ok(PlanNode::DropTable { name: dt.table }),
61 Statement::CreateView(cv) => Ok(PlanNode::CreateView {
62 name: cv.name,
63 query_text: cv.query_text,
64 }),
65 Statement::RefreshView(rv) => Ok(PlanNode::RefreshView { name: rv.name }),
66 Statement::DropView(dv) => Ok(PlanNode::DropView { name: dv.name }),
67 Statement::Union(u) => {
68 let left = plan_statement(*u.left)?;
69 let right = plan_statement(*u.right)?;
70 Ok(PlanNode::Union {
71 left: Box::new(left),
72 right: Box::new(right),
73 all: u.all,
74 })
75 }
76 Statement::Upsert(ups) => plan_upsert(ups),
77 Statement::Begin => Ok(PlanNode::Begin),
78 Statement::Commit => Ok(PlanNode::Commit),
79 Statement::Rollback => Ok(PlanNode::Rollback),
80 Statement::Explain(inner) => {
81 let inner_plan = plan_statement(*inner)?;
82 Ok(PlanNode::Explain {
83 input: Box::new(inner_plan),
84 })
85 }
86 }
87}
88
89fn plan_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
90 if !q.joins.is_empty() {
96 return plan_joined_query(q);
97 }
98 let (source, filter) = match q.filter {
107 Some(pred) => match try_extract_eq_index_key(&q.source, &pred) {
108 Some(index_scan) => (index_scan, None),
109 None => match try_extract_range_index_keys(&q.source, &pred) {
110 Some(range_scan) => (range_scan, None),
111 None => (
112 PlanNode::SeqScan {
113 table: q.source.clone(),
114 },
115 Some(pred),
116 ),
117 },
118 },
119 None => (
120 PlanNode::SeqScan {
121 table: q.source.clone(),
122 },
123 None,
124 ),
125 };
126 let mut node = source;
127
128 if let Some(pred) = filter {
129 node = PlanNode::Filter {
130 input: Box::new(node),
131 predicate: pred,
132 };
133 }
134
135 if let Some(group) = q.group_by {
138 let mut proj_fields: Vec<ProjectField> = q
139 .projection
140 .map(|proj| {
141 proj.into_iter()
142 .map(|pf| ProjectField {
143 alias: pf.alias,
144 expr: pf.expr,
145 })
146 .collect()
147 })
148 .unwrap_or_default();
149 let mut having = group.having;
150 let aggregates = extract_aggregates(&mut proj_fields, &mut having);
151
152 node = PlanNode::GroupBy {
153 input: Box::new(node),
154 keys: group.keys,
155 aggregates,
156 having,
157 };
158
159 if !proj_fields.is_empty() {
160 node = PlanNode::Project {
161 input: Box::new(node),
162 fields: proj_fields,
163 };
164 }
165
166 if let Some(order) = q.order {
167 node = PlanNode::Sort {
168 input: Box::new(node),
169 keys: order
170 .keys
171 .into_iter()
172 .map(|k| SortKey {
173 field: k.field,
174 descending: k.descending,
175 })
176 .collect(),
177 };
178 }
179 if let Some(off) = q.offset {
183 node = PlanNode::Offset {
184 input: Box::new(node),
185 count: off,
186 };
187 }
188 if let Some(lim) = q.limit {
189 node = PlanNode::Limit {
190 input: Box::new(node),
191 count: lim,
192 };
193 }
194 if q.distinct {
195 node = PlanNode::Distinct {
196 input: Box::new(node),
197 };
198 }
199 return Ok(node);
200 }
201
202 if let Some(order) = q.order {
203 node = PlanNode::Sort {
204 input: Box::new(node),
205 keys: order
206 .keys
207 .into_iter()
208 .map(|k| SortKey {
209 field: k.field,
210 descending: k.descending,
211 })
212 .collect(),
213 };
214 }
215
216 if let Some(off) = q.offset {
220 node = PlanNode::Offset {
221 input: Box::new(node),
222 count: off,
223 };
224 }
225
226 if let Some(lim) = q.limit {
227 node = PlanNode::Limit {
228 input: Box::new(node),
229 count: lim,
230 };
231 }
232
233 if let Some(proj) = q.projection {
234 let mut fields: Vec<ProjectField> = proj
235 .into_iter()
236 .map(|pf| ProjectField {
237 alias: pf.alias,
238 expr: pf.expr,
239 })
240 .collect();
241 let windows = extract_windows(&mut fields);
242 if !windows.is_empty() {
243 node = PlanNode::Window {
244 input: Box::new(node),
245 windows,
246 };
247 }
248 node = PlanNode::Project {
249 input: Box::new(node),
250 fields,
251 };
252 }
253
254 if q.distinct {
255 node = PlanNode::Distinct {
256 input: Box::new(node),
257 };
258 }
259
260 if let Some(agg) = q.aggregation {
261 node = PlanNode::Aggregate {
262 input: Box::new(node),
263 function: agg.function,
264 field: agg.field,
265 };
266 }
267
268 Ok(node)
269}
270
271fn plan_joined_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
293 let primary_alias = q.alias.clone().unwrap_or_else(|| q.source.clone());
294 let mut node = PlanNode::AliasScan {
295 table: q.source.clone(),
296 alias: primary_alias,
297 };
298
299 for join in q.joins {
300 let right_alias = join.alias.unwrap_or_else(|| join.source.clone());
301 let right = PlanNode::AliasScan {
302 table: join.source,
303 alias: right_alias,
304 };
305 match join.kind {
306 JoinKind::Inner | JoinKind::LeftOuter | JoinKind::Cross => {
307 node = PlanNode::NestedLoopJoin {
308 left: Box::new(node),
309 right: Box::new(right),
310 on: join.on,
311 kind: join.kind,
312 };
313 }
314 JoinKind::RightOuter => {
315 node = PlanNode::NestedLoopJoin {
317 left: Box::new(right),
318 right: Box::new(node),
319 on: join.on,
320 kind: JoinKind::LeftOuter,
321 };
322 }
323 }
324 }
325
326 if let Some(pred) = q.filter {
327 node = PlanNode::Filter {
328 input: Box::new(node),
329 predicate: pred,
330 };
331 }
332
333 if let Some(order) = q.order {
334 node = PlanNode::Sort {
335 input: Box::new(node),
336 keys: order
337 .keys
338 .into_iter()
339 .map(|k| SortKey {
340 field: k.field,
341 descending: k.descending,
342 })
343 .collect(),
344 };
345 }
346
347 if let Some(off) = q.offset {
351 node = PlanNode::Offset {
352 input: Box::new(node),
353 count: off,
354 };
355 }
356
357 if let Some(lim) = q.limit {
358 node = PlanNode::Limit {
359 input: Box::new(node),
360 count: lim,
361 };
362 }
363
364 if let Some(group) = q.group_by {
366 let mut proj_fields: Vec<ProjectField> = q
367 .projection
368 .map(|proj| {
369 proj.into_iter()
370 .map(|pf| ProjectField {
371 alias: pf.alias,
372 expr: pf.expr,
373 })
374 .collect()
375 })
376 .unwrap_or_default();
377 let mut having = group.having;
378 let aggregates = extract_aggregates(&mut proj_fields, &mut having);
379
380 node = PlanNode::GroupBy {
381 input: Box::new(node),
382 keys: group.keys,
383 aggregates,
384 having,
385 };
386
387 if !proj_fields.is_empty() {
388 node = PlanNode::Project {
389 input: Box::new(node),
390 fields: proj_fields,
391 };
392 }
393 if q.distinct {
394 node = PlanNode::Distinct {
395 input: Box::new(node),
396 };
397 }
398 return Ok(node);
399 }
400
401 if let Some(proj) = q.projection {
402 let mut fields: Vec<ProjectField> = proj
403 .into_iter()
404 .map(|pf| ProjectField {
405 alias: pf.alias,
406 expr: pf.expr,
407 })
408 .collect();
409 let windows = extract_windows(&mut fields);
410 if !windows.is_empty() {
411 node = PlanNode::Window {
412 input: Box::new(node),
413 windows,
414 };
415 }
416 node = PlanNode::Project {
417 input: Box::new(node),
418 fields,
419 };
420 }
421
422 if q.distinct {
423 node = PlanNode::Distinct {
424 input: Box::new(node),
425 };
426 }
427
428 if let Some(agg) = q.aggregation {
429 node = PlanNode::Aggregate {
430 input: Box::new(node),
431 function: agg.function,
432 field: agg.field,
433 };
434 }
435
436 Ok(node)
437}
438
439fn plan_insert(ins: InsertExpr) -> Result<PlanNode, PlanError> {
440 Ok(PlanNode::Insert {
441 table: ins.target,
442 assignments: ins.assignments,
443 })
444}
445
446fn plan_update(upd: UpdateExpr) -> Result<PlanNode, PlanError> {
447 let source = match upd.filter {
452 Some(pred) => match try_extract_eq_index_key(&upd.source, &pred) {
453 Some(index_scan) => index_scan,
454 None => match try_extract_range_index_keys(&upd.source, &pred) {
455 Some(range_scan) => range_scan,
456 None => PlanNode::Filter {
457 input: Box::new(PlanNode::SeqScan {
458 table: upd.source.clone(),
459 }),
460 predicate: pred,
461 },
462 },
463 },
464 None => PlanNode::SeqScan {
465 table: upd.source.clone(),
466 },
467 };
468 Ok(PlanNode::Update {
469 input: Box::new(source),
470 table: upd.source,
471 assignments: upd.assignments,
472 })
473}
474
475fn plan_delete(del: DeleteExpr) -> Result<PlanNode, PlanError> {
476 let source = match del.filter {
477 Some(pred) => match try_extract_eq_index_key(&del.source, &pred) {
478 Some(index_scan) => index_scan,
479 None => match try_extract_range_index_keys(&del.source, &pred) {
480 Some(range_scan) => range_scan,
481 None => PlanNode::Filter {
482 input: Box::new(PlanNode::SeqScan {
483 table: del.source.clone(),
484 }),
485 predicate: pred,
486 },
487 },
488 },
489 None => PlanNode::SeqScan {
490 table: del.source.clone(),
491 },
492 };
493 Ok(PlanNode::Delete {
494 input: Box::new(source),
495 table: del.source,
496 })
497}
498
499fn plan_upsert(ups: UpsertExpr) -> Result<PlanNode, PlanError> {
500 Ok(PlanNode::Upsert {
501 table: ups.target,
502 key_column: ups.key_column,
503 assignments: ups.assignments,
504 on_conflict: ups.on_conflict,
505 })
506}
507
508fn plan_create_type(ct: CreateTypeExpr) -> Result<PlanNode, PlanError> {
509 let fields = ct
510 .fields
511 .into_iter()
512 .map(|f| (f.name, f.type_name, f.required))
513 .collect();
514 Ok(PlanNode::CreateTable {
515 name: ct.name,
516 fields,
517 })
518}
519
520fn try_extract_eq_index_key(table: &str, pred: &Expr) -> Option<PlanNode> {
530 let (lhs, op, rhs) = match pred {
531 Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
532 _ => return None,
533 };
534 if op != BinOp::Eq {
535 return None;
536 }
537 let (column, key) = match (lhs, rhs) {
538 (Expr::Field(name), Expr::Literal(_)) => (name.clone(), rhs.clone()),
539 (Expr::Literal(_), Expr::Field(name)) => (name.clone(), lhs.clone()),
540 _ => return None,
541 };
542 Some(PlanNode::IndexScan {
543 table: table.to_string(),
544 column,
545 key,
546 })
547}
548
549fn extract_single_bound(pred: &Expr) -> Option<RangeBound> {
552 let (lhs, op, rhs) = match pred {
553 Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
554 _ => return None,
555 };
556 match op {
557 BinOp::Gt => match (lhs, rhs) {
559 (Expr::Field(name), Expr::Literal(_)) => {
560 Some((name.clone(), Some((rhs.clone(), false)), None))
561 }
562 (Expr::Literal(_), Expr::Field(name)) => {
563 Some((name.clone(), None, Some((lhs.clone(), false))))
565 }
566 _ => None,
567 },
568 BinOp::Gte => match (lhs, rhs) {
570 (Expr::Field(name), Expr::Literal(_)) => {
571 Some((name.clone(), Some((rhs.clone(), true)), None))
572 }
573 (Expr::Literal(_), Expr::Field(name)) => {
574 Some((name.clone(), None, Some((lhs.clone(), true))))
575 }
576 _ => None,
577 },
578 BinOp::Lt => match (lhs, rhs) {
580 (Expr::Field(name), Expr::Literal(_)) => {
581 Some((name.clone(), None, Some((rhs.clone(), false))))
582 }
583 (Expr::Literal(_), Expr::Field(name)) => {
584 Some((name.clone(), Some((lhs.clone(), false)), None))
585 }
586 _ => None,
587 },
588 BinOp::Lte => match (lhs, rhs) {
590 (Expr::Field(name), Expr::Literal(_)) => {
591 Some((name.clone(), None, Some((rhs.clone(), true))))
592 }
593 (Expr::Literal(_), Expr::Field(name)) => {
594 Some((name.clone(), Some((lhs.clone(), true)), None))
595 }
596 _ => None,
597 },
598 _ => None,
599 }
600}
601
602fn try_extract_range_index_keys(table: &str, pred: &Expr) -> Option<PlanNode> {
607 if let Expr::BinaryOp(lhs, BinOp::And, rhs) = pred {
609 if let (Some((col1, s1, e1)), Some((col2, s2, e2))) =
610 (extract_single_bound(lhs), extract_single_bound(rhs))
611 {
612 if col1 == col2 {
613 let start = s1.or(s2);
614 let end = e1.or(e2);
615 if start.is_some() || end.is_some() {
616 return Some(PlanNode::RangeScan {
617 table: table.to_string(),
618 column: col1,
619 start,
620 end,
621 });
622 }
623 }
624 }
625 }
626
627 if let Some((col, start, end)) = extract_single_bound(pred) {
629 return Some(PlanNode::RangeScan {
630 table: table.to_string(),
631 column: col,
632 start,
633 end,
634 });
635 }
636
637 None
638}
639
640fn extract_windows(proj_fields: &mut [ProjectField]) -> Vec<WindowDef> {
645 let mut defs = Vec::new();
646 let mut counter = 0usize;
647 for f in proj_fields.iter_mut() {
648 if let Expr::Window {
649 function,
650 args,
651 partition_by,
652 order_by,
653 } = &f.expr
654 {
655 let output_name = format!("__win_{counter}");
656 defs.push(WindowDef {
657 function: *function,
658 args: args.clone(),
659 partition_by: partition_by.clone(),
660 order_by: order_by
661 .iter()
662 .map(|k| SortKey {
663 field: k.field.clone(),
664 descending: k.descending,
665 })
666 .collect(),
667 output_name: output_name.clone(),
668 });
669 f.expr = Expr::Field(output_name);
670 counter += 1;
671 }
672 }
673 defs
674}
675
676fn extract_aggregates(
682 proj_fields: &mut [ProjectField],
683 having: &mut Option<Expr>,
684) -> Vec<GroupAgg> {
685 let mut aggs: Vec<GroupAgg> = Vec::new();
686 let mut counter = 0usize;
687 for f in proj_fields.iter_mut() {
688 rewrite_agg_expr(&mut f.expr, &mut aggs, &mut counter);
689 }
690 if let Some(h) = having {
691 rewrite_agg_expr(h, &mut aggs, &mut counter);
692 }
693 aggs
694}
695
696fn rewrite_agg_expr(expr: &mut Expr, aggs: &mut Vec<GroupAgg>, counter: &mut usize) {
697 match expr {
698 Expr::FunctionCall(func, inner) => {
699 if let Expr::Field(name) = inner.as_ref() {
700 let output = find_or_insert_agg(aggs, *func, name, counter);
701 *expr = Expr::Field(output);
702 }
703 }
704 Expr::BinaryOp(l, _, r) => {
705 rewrite_agg_expr(l, aggs, counter);
706 rewrite_agg_expr(r, aggs, counter);
707 }
708 Expr::UnaryOp(_, inner) => rewrite_agg_expr(inner, aggs, counter),
709 Expr::Coalesce(l, r) => {
710 rewrite_agg_expr(l, aggs, counter);
711 rewrite_agg_expr(r, aggs, counter);
712 }
713 Expr::InList { expr: e, list, .. } => {
714 rewrite_agg_expr(e, aggs, counter);
715 for item in list {
716 rewrite_agg_expr(item, aggs, counter);
717 }
718 }
719 Expr::InSubquery { expr: e, .. } => {
720 rewrite_agg_expr(e, aggs, counter);
721 }
722 _ => {}
723 }
724}
725
726fn find_or_insert_agg(
727 aggs: &mut Vec<GroupAgg>,
728 func: AggFunc,
729 field: &str,
730 counter: &mut usize,
731) -> String {
732 for existing in aggs.iter() {
733 if existing.function == func && existing.field == field {
734 return existing.output_name.clone();
735 }
736 }
737 let output_name = format!("__agg_{counter}");
738 aggs.push(GroupAgg {
739 function: func,
740 field: field.to_string(),
741 output_name: output_name.clone(),
742 });
743 *counter += 1;
744 output_name
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750 use crate::plan::PlanNode;
751
752 #[test]
753 fn test_plan_simple_scan() {
754 let plan = plan("User").unwrap();
755 assert!(matches!(plan, PlanNode::SeqScan { table } if table == "User"));
756 }
757
758 #[test]
759 fn test_plan_filter() {
760 let plan = plan("User filter .age > 30").unwrap();
761 assert!(matches!(plan, PlanNode::RangeScan { .. }));
762 }
763
764 #[test]
765 fn test_plan_filter_with_projection() {
766 let plan = plan("User filter .age > 30 { name, email }").unwrap();
767 assert!(matches!(plan, PlanNode::Project { .. }));
768 }
769
770 #[test]
771 fn test_plan_insert() {
772 let plan = plan(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
773 assert!(matches!(plan, PlanNode::Insert { .. }));
774 }
775
776 #[test]
777 fn test_plan_order_limit() {
778 let plan = plan("User order .name limit 10").unwrap();
779 match plan {
780 PlanNode::Limit { input, .. } => {
781 assert!(matches!(*input, PlanNode::Sort { .. }));
782 }
783 _ => panic!("expected Limit(Sort(SeqScan))"),
784 }
785 }
786
787 #[test]
788 fn test_plan_count() {
789 let plan = plan("count(User)").unwrap();
790 assert!(matches!(plan, PlanNode::Aggregate { .. }));
791 }
792
793 #[test]
794 fn test_plan_eq_becomes_index_scan() {
795 let plan = plan("User filter .id = 42").unwrap();
798 match plan {
799 PlanNode::IndexScan { table, column, key } => {
800 assert_eq!(table, "User");
801 assert_eq!(column, "id");
802 assert!(matches!(key, Expr::Literal(Literal::Int(42))));
803 }
804 other => panic!("expected IndexScan, got {other:?}"),
805 }
806 }
807
808 #[test]
809 fn test_plan_eq_reversed_becomes_index_scan() {
810 let plan = plan(r#"User filter "NYC" = .city"#).unwrap();
812 assert!(matches!(plan, PlanNode::IndexScan { .. }));
813 }
814
815 #[test]
816 fn test_plan_non_eq_stays_filter() {
817 let plan = plan("User filter .age > 30").unwrap();
819 match plan {
820 PlanNode::RangeScan {
821 column, start, end, ..
822 } => {
823 assert_eq!(column, "age");
824 assert!(start.is_some(), "expected lower bound");
825 assert!(end.is_none(), "expected no upper bound");
826 let (_, inclusive) = start.unwrap();
827 assert!(!inclusive, "expected exclusive lower bound for >");
828 }
829 other => panic!("expected RangeScan, got {other:?}"),
830 }
831 }
832
833 #[test]
834 fn test_plan_index_scan_with_projection() {
835 let plan = plan("User filter .id = 1 { .name }").unwrap();
837 match plan {
838 PlanNode::Project { input, .. } => {
839 assert!(matches!(*input, PlanNode::IndexScan { .. }));
840 }
841 other => panic!("expected Project(IndexScan), got {other:?}"),
842 }
843 }
844
845 #[test]
846 fn test_plan_update_by_pk_becomes_index_scan() {
847 let plan = plan("User filter .id = 42 update { age := 31 }").unwrap();
850 match plan {
851 PlanNode::Update { input, .. } => {
852 assert!(
853 matches!(*input, PlanNode::IndexScan { .. }),
854 "expected Update(IndexScan), got {input:?}"
855 );
856 }
857 other => panic!("expected Update, got {other:?}"),
858 }
859 }
860
861 #[test]
862 fn test_plan_update_range_stays_range_scan() {
863 let plan = plan("User filter .age > 30 update { age := 31 }").unwrap();
864 match plan {
865 PlanNode::Update { input, .. } => {
866 assert!(
867 matches!(*input, PlanNode::RangeScan { .. }),
868 "expected Update(RangeScan), got {input:?}"
869 );
870 }
871 other => panic!("expected Update, got {other:?}"),
872 }
873 }
874
875 #[test]
876 fn test_plan_delete_by_pk_becomes_index_scan() {
877 let plan = plan("User filter .id = 7 delete").unwrap();
878 match plan {
879 PlanNode::Delete { input, .. } => {
880 assert!(matches!(*input, PlanNode::IndexScan { .. }));
881 }
882 other => panic!("expected Delete, got {other:?}"),
883 }
884 }
885
886 #[test]
887 fn test_plan_inner_join_builds_nested_loop() {
888 let plan = plan("User as u join Order as o on u.id = o.user_id").unwrap();
891 match plan {
892 PlanNode::NestedLoopJoin {
893 left,
894 right,
895 on,
896 kind,
897 } => {
898 assert_eq!(kind, JoinKind::Inner);
899 assert!(on.is_some());
900 assert!(matches!(*left, PlanNode::AliasScan { .. }));
901 assert!(matches!(*right, PlanNode::AliasScan { .. }));
902 }
903 other => panic!("expected NestedLoopJoin, got {other:?}"),
904 }
905 }
906
907 #[test]
908 fn test_plan_right_join_rewritten_as_left_with_swapped_inputs() {
909 let plan = plan("User as u right join Order as o on u.id = o.user_id").unwrap();
910 match plan {
911 PlanNode::NestedLoopJoin {
912 left, right, kind, ..
913 } => {
914 assert_eq!(kind, JoinKind::LeftOuter);
915 match *left {
917 PlanNode::AliasScan { table, .. } => assert_eq!(table, "Order"),
918 other => panic!("expected AliasScan(Order), got {other:?}"),
919 }
920 match *right {
921 PlanNode::AliasScan { table, .. } => assert_eq!(table, "User"),
922 other => panic!("expected AliasScan(User), got {other:?}"),
923 }
924 }
925 other => panic!("expected NestedLoopJoin, got {other:?}"),
926 }
927 }
928
929 #[test]
930 fn test_plan_multi_join_is_left_deep() {
931 let plan = plan(
933 "User as u join Order as o on u.id = o.user_id \
934 join Product as p on o.product_id = p.id",
935 )
936 .unwrap();
937 match plan {
938 PlanNode::NestedLoopJoin { left, right, .. } => {
939 match *right {
941 PlanNode::AliasScan { table, .. } => assert_eq!(table, "Product"),
942 other => panic!("expected AliasScan(Product), got {other:?}"),
943 }
944 assert!(matches!(*left, PlanNode::NestedLoopJoin { .. }));
946 }
947 other => panic!("expected NestedLoopJoin, got {other:?}"),
948 }
949 }
950
951 #[test]
952 fn test_plan_join_with_filter_tail_wraps_filter_on_top() {
953 let plan =
954 plan("User as u join Order as o on u.id = o.user_id filter o.total > 100").unwrap();
955 match plan {
956 PlanNode::Filter { input, .. } => {
957 assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
958 }
959 other => panic!("expected Filter(NestedLoopJoin), got {other:?}"),
960 }
961 }
962
963 #[test]
964 fn test_plan_group_by_builds_groupby_node() {
965 let plan = plan("User group .status { .status, n: count(.name) }").unwrap();
966 match plan {
968 PlanNode::Project { input, fields } => {
969 assert_eq!(fields.len(), 2);
970 match *input {
971 PlanNode::GroupBy {
972 input: inner,
973 keys,
974 aggregates,
975 having,
976 } => {
977 assert!(matches!(*inner, PlanNode::SeqScan { .. }));
978 assert_eq!(keys, vec!["status"]);
979 assert_eq!(aggregates.len(), 1);
980 assert_eq!(aggregates[0].function, AggFunc::Count);
981 assert_eq!(aggregates[0].field, "name");
982 assert!(having.is_none());
983 }
984 other => panic!("expected GroupBy, got {other:?}"),
985 }
986 }
987 other => panic!("expected Project, got {other:?}"),
988 }
989 }
990
991 #[test]
992 fn test_plan_group_by_having_rewrites_agg_in_having() {
993 let plan = plan("User group .status having count(.name) > 1 { .status }").unwrap();
994 match plan {
995 PlanNode::Project { input, .. } => {
996 match *input {
997 PlanNode::GroupBy {
998 having, aggregates, ..
999 } => {
1000 assert_eq!(aggregates.len(), 1);
1003 assert_eq!(aggregates[0].output_name, "__agg_0");
1004 let h = having.expect("having should be Some");
1005 match h {
1006 Expr::BinaryOp(l, BinOp::Gt, _) => {
1007 assert!(
1008 matches!(*l, Expr::Field(ref name) if name == "__agg_0"),
1009 "expected Field(__agg_0), got {l:?}"
1010 );
1011 }
1012 other => panic!("expected BinaryOp, got {other:?}"),
1013 }
1014 }
1015 other => panic!("expected GroupBy, got {other:?}"),
1016 }
1017 }
1018 other => panic!("expected Project, got {other:?}"),
1019 }
1020 }
1021
1022 #[test]
1023 fn test_plan_window_inserts_window_node_before_project() {
1024 let plan = plan("User { .name, rn: row_number() over (order .age) }").unwrap();
1025 match plan {
1027 PlanNode::Project { input, fields } => {
1028 assert_eq!(fields.len(), 2);
1029 assert!(
1031 matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"),
1032 "expected Field(__win_0), got {:?}",
1033 fields[1].expr
1034 );
1035 match *input {
1036 PlanNode::Window {
1037 input: inner,
1038 windows,
1039 } => {
1040 assert_eq!(windows.len(), 1);
1041 assert_eq!(windows[0].output_name, "__win_0");
1042 assert!(matches!(*inner, PlanNode::SeqScan { .. }));
1043 }
1044 other => panic!("expected Window, got {other:?}"),
1045 }
1046 }
1047 other => panic!("expected Project, got {other:?}"),
1048 }
1049 }
1050
1051 #[test]
1052 fn test_plan_multiple_windows() {
1053 let plan = plan(
1054 "User { .name, rn: row_number() over (order .age), s: sum(.salary) over (partition .dept order .salary) }"
1055 ).unwrap();
1056 match plan {
1057 PlanNode::Project { input, fields } => {
1058 assert_eq!(fields.len(), 3);
1059 assert!(matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"));
1060 assert!(matches!(&fields[2].expr, Expr::Field(name) if name == "__win_1"));
1061 match *input {
1062 PlanNode::Window { windows, .. } => {
1063 assert_eq!(windows.len(), 2);
1064 assert_eq!(windows[0].output_name, "__win_0");
1065 assert_eq!(windows[1].output_name, "__win_1");
1066 }
1067 other => panic!("expected Window, got {other:?}"),
1068 }
1069 }
1070 other => panic!("expected Project, got {other:?}"),
1071 }
1072 }
1073
1074 #[test]
1075 fn test_plan_no_window_without_over() {
1076 let plan = plan("User group .dept { .dept, total: sum(.salary) }").unwrap();
1078 match plan {
1079 PlanNode::Project { input, .. } => {
1080 assert!(
1082 matches!(*input, PlanNode::GroupBy { .. }),
1083 "expected GroupBy under Project, got {:?}",
1084 input
1085 );
1086 }
1087 other => panic!("expected Project, got {other:?}"),
1088 }
1089 }
1090
1091 #[test]
1092 fn test_plan_explain_wraps_inner() {
1093 let plan = plan("explain User filter .age > 30").unwrap();
1094 match plan {
1095 PlanNode::Explain { input } => {
1096 assert!(
1097 matches!(*input, PlanNode::RangeScan { .. }),
1098 "expected Explain(RangeScan), got {:?}",
1099 input
1100 );
1101 }
1102 other => panic!("expected Explain, got {other:?}"),
1103 }
1104 }
1105
1106 #[test]
1107 fn test_plan_explain_simple_scan() {
1108 let plan = plan("explain User").unwrap();
1109 match plan {
1110 PlanNode::Explain { input } => {
1111 assert!(matches!(*input, PlanNode::SeqScan { .. }));
1112 }
1113 other => panic!("expected Explain(SeqScan), got {other:?}"),
1114 }
1115 }
1116
1117 #[test]
1118 fn test_plan_explain_join() {
1119 let plan = plan("explain User as u join Order as o on u.id = o.user_id").unwrap();
1120 match plan {
1121 PlanNode::Explain { input } => {
1122 assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
1123 }
1124 other => panic!("expected Explain(NestedLoopJoin), got {other:?}"),
1125 }
1126 }
1127}