1use crate::ast::*;
2use crate::parser::{parse, ParseError};
3use crate::plan::*;
4
5type RangeBound = (String, Option<(Expr, bool)>, Option<(Expr, bool)>);
7
8#[derive(Debug)]
10pub enum PlanError {
11 Parse(ParseError),
13}
14
15impl PlanError {
16 pub fn message(&self) -> String {
18 self.to_string()
19 }
20}
21
22impl std::fmt::Display for PlanError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::Parse(e) => write!(f, "{e}"),
26 }
27 }
28}
29
30impl std::error::Error for PlanError {
31 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32 match self {
33 Self::Parse(e) => Some(e),
34 }
35 }
36}
37
38impl From<ParseError> for PlanError {
39 fn from(e: ParseError) -> Self {
40 PlanError::Parse(e)
41 }
42}
43
44pub fn plan(input: &str) -> Result<PlanNode, PlanError> {
45 let stmt = parse(input)?;
46 plan_statement(stmt)
47}
48
49pub fn plan_statement(stmt: Statement) -> Result<PlanNode, PlanError> {
50 match stmt {
51 Statement::Query(q) => plan_query(q),
52 Statement::Insert(ins) => plan_insert(ins),
53 Statement::UpdateQuery(upd) => plan_update(upd),
54 Statement::DeleteQuery(del) => plan_delete(del),
55 Statement::CreateType(ct) => plan_create_type(ct),
56 Statement::AlterTable(at) => Ok(PlanNode::AlterTable {
57 table: at.table,
58 action: at.action,
59 }),
60 Statement::DropTable(dt) => Ok(PlanNode::DropTable { name: dt.table }),
61 Statement::CreateView(cv) => Ok(PlanNode::CreateView {
62 name: cv.name,
63 query_text: cv.query_text,
64 }),
65 Statement::RefreshView(rv) => Ok(PlanNode::RefreshView { name: rv.name }),
66 Statement::DropView(dv) => Ok(PlanNode::DropView { name: dv.name }),
67 Statement::Union(u) => {
68 let left = plan_statement(*u.left)?;
69 let right = plan_statement(*u.right)?;
70 Ok(PlanNode::Union {
71 left: Box::new(left),
72 right: Box::new(right),
73 all: u.all,
74 })
75 }
76 Statement::Upsert(ups) => plan_upsert(ups),
77 Statement::Explain(inner) => {
78 let inner_plan = plan_statement(*inner)?;
79 Ok(PlanNode::Explain {
80 input: Box::new(inner_plan),
81 })
82 }
83 }
84}
85
86fn plan_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
87 if !q.joins.is_empty() {
93 return plan_joined_query(q);
94 }
95 let (source, filter) = match q.filter {
104 Some(pred) => match try_extract_eq_index_key(&q.source, &pred) {
105 Some(index_scan) => (index_scan, None),
106 None => match try_extract_range_index_keys(&q.source, &pred) {
107 Some(range_scan) => (range_scan, None),
108 None => (
109 PlanNode::SeqScan {
110 table: q.source.clone(),
111 },
112 Some(pred),
113 ),
114 },
115 },
116 None => (
117 PlanNode::SeqScan {
118 table: q.source.clone(),
119 },
120 None,
121 ),
122 };
123 let mut node = source;
124
125 if let Some(pred) = filter {
126 node = PlanNode::Filter {
127 input: Box::new(node),
128 predicate: pred,
129 };
130 }
131
132 if let Some(group) = q.group_by {
135 let mut proj_fields: Vec<ProjectField> = q
136 .projection
137 .map(|proj| {
138 proj.into_iter()
139 .map(|pf| ProjectField {
140 alias: pf.alias,
141 expr: pf.expr,
142 })
143 .collect()
144 })
145 .unwrap_or_default();
146 let mut having = group.having;
147 let aggregates = extract_aggregates(&mut proj_fields, &mut having);
148
149 node = PlanNode::GroupBy {
150 input: Box::new(node),
151 keys: group.keys,
152 aggregates,
153 having,
154 };
155
156 if !proj_fields.is_empty() {
157 node = PlanNode::Project {
158 input: Box::new(node),
159 fields: proj_fields,
160 };
161 }
162
163 if let Some(order) = q.order {
164 node = PlanNode::Sort {
165 input: Box::new(node),
166 keys: order
167 .keys
168 .into_iter()
169 .map(|k| SortKey {
170 field: k.field,
171 descending: k.descending,
172 })
173 .collect(),
174 };
175 }
176 if let Some(off) = q.offset {
180 node = PlanNode::Offset {
181 input: Box::new(node),
182 count: off,
183 };
184 }
185 if let Some(lim) = q.limit {
186 node = PlanNode::Limit {
187 input: Box::new(node),
188 count: lim,
189 };
190 }
191 if q.distinct {
192 node = PlanNode::Distinct {
193 input: Box::new(node),
194 };
195 }
196 return Ok(node);
197 }
198
199 if let Some(order) = q.order {
200 node = PlanNode::Sort {
201 input: Box::new(node),
202 keys: order
203 .keys
204 .into_iter()
205 .map(|k| SortKey {
206 field: k.field,
207 descending: k.descending,
208 })
209 .collect(),
210 };
211 }
212
213 if let Some(off) = q.offset {
217 node = PlanNode::Offset {
218 input: Box::new(node),
219 count: off,
220 };
221 }
222
223 if let Some(lim) = q.limit {
224 node = PlanNode::Limit {
225 input: Box::new(node),
226 count: lim,
227 };
228 }
229
230 if let Some(proj) = q.projection {
231 let mut fields: Vec<ProjectField> = proj
232 .into_iter()
233 .map(|pf| ProjectField {
234 alias: pf.alias,
235 expr: pf.expr,
236 })
237 .collect();
238 let windows = extract_windows(&mut fields);
239 if !windows.is_empty() {
240 node = PlanNode::Window {
241 input: Box::new(node),
242 windows,
243 };
244 }
245 node = PlanNode::Project {
246 input: Box::new(node),
247 fields,
248 };
249 }
250
251 if q.distinct {
252 node = PlanNode::Distinct {
253 input: Box::new(node),
254 };
255 }
256
257 if let Some(agg) = q.aggregation {
258 node = PlanNode::Aggregate {
259 input: Box::new(node),
260 function: agg.function,
261 field: agg.field,
262 };
263 }
264
265 Ok(node)
266}
267
268fn plan_joined_query(q: QueryExpr) -> Result<PlanNode, PlanError> {
290 let primary_alias = q.alias.clone().unwrap_or_else(|| q.source.clone());
291 let mut node = PlanNode::AliasScan {
292 table: q.source.clone(),
293 alias: primary_alias,
294 };
295
296 for join in q.joins {
297 let right_alias = join.alias.unwrap_or_else(|| join.source.clone());
298 let right = PlanNode::AliasScan {
299 table: join.source,
300 alias: right_alias,
301 };
302 match join.kind {
303 JoinKind::Inner | JoinKind::LeftOuter | JoinKind::Cross => {
304 node = PlanNode::NestedLoopJoin {
305 left: Box::new(node),
306 right: Box::new(right),
307 on: join.on,
308 kind: join.kind,
309 };
310 }
311 JoinKind::RightOuter => {
312 node = PlanNode::NestedLoopJoin {
314 left: Box::new(right),
315 right: Box::new(node),
316 on: join.on,
317 kind: JoinKind::LeftOuter,
318 };
319 }
320 }
321 }
322
323 if let Some(pred) = q.filter {
324 node = PlanNode::Filter {
325 input: Box::new(node),
326 predicate: pred,
327 };
328 }
329
330 if let Some(order) = q.order {
331 node = PlanNode::Sort {
332 input: Box::new(node),
333 keys: order
334 .keys
335 .into_iter()
336 .map(|k| SortKey {
337 field: k.field,
338 descending: k.descending,
339 })
340 .collect(),
341 };
342 }
343
344 if let Some(off) = q.offset {
348 node = PlanNode::Offset {
349 input: Box::new(node),
350 count: off,
351 };
352 }
353
354 if let Some(lim) = q.limit {
355 node = PlanNode::Limit {
356 input: Box::new(node),
357 count: lim,
358 };
359 }
360
361 if let Some(group) = q.group_by {
363 let mut proj_fields: Vec<ProjectField> = q
364 .projection
365 .map(|proj| {
366 proj.into_iter()
367 .map(|pf| ProjectField {
368 alias: pf.alias,
369 expr: pf.expr,
370 })
371 .collect()
372 })
373 .unwrap_or_default();
374 let mut having = group.having;
375 let aggregates = extract_aggregates(&mut proj_fields, &mut having);
376
377 node = PlanNode::GroupBy {
378 input: Box::new(node),
379 keys: group.keys,
380 aggregates,
381 having,
382 };
383
384 if !proj_fields.is_empty() {
385 node = PlanNode::Project {
386 input: Box::new(node),
387 fields: proj_fields,
388 };
389 }
390 if q.distinct {
391 node = PlanNode::Distinct {
392 input: Box::new(node),
393 };
394 }
395 return Ok(node);
396 }
397
398 if let Some(proj) = q.projection {
399 let mut fields: Vec<ProjectField> = proj
400 .into_iter()
401 .map(|pf| ProjectField {
402 alias: pf.alias,
403 expr: pf.expr,
404 })
405 .collect();
406 let windows = extract_windows(&mut fields);
407 if !windows.is_empty() {
408 node = PlanNode::Window {
409 input: Box::new(node),
410 windows,
411 };
412 }
413 node = PlanNode::Project {
414 input: Box::new(node),
415 fields,
416 };
417 }
418
419 if q.distinct {
420 node = PlanNode::Distinct {
421 input: Box::new(node),
422 };
423 }
424
425 if let Some(agg) = q.aggregation {
426 node = PlanNode::Aggregate {
427 input: Box::new(node),
428 function: agg.function,
429 field: agg.field,
430 };
431 }
432
433 Ok(node)
434}
435
436fn plan_insert(ins: InsertExpr) -> Result<PlanNode, PlanError> {
437 Ok(PlanNode::Insert {
438 table: ins.target,
439 assignments: ins.assignments,
440 })
441}
442
443fn plan_update(upd: UpdateExpr) -> Result<PlanNode, PlanError> {
444 let source = match upd.filter {
449 Some(pred) => match try_extract_eq_index_key(&upd.source, &pred) {
450 Some(index_scan) => index_scan,
451 None => match try_extract_range_index_keys(&upd.source, &pred) {
452 Some(range_scan) => range_scan,
453 None => PlanNode::Filter {
454 input: Box::new(PlanNode::SeqScan {
455 table: upd.source.clone(),
456 }),
457 predicate: pred,
458 },
459 },
460 },
461 None => PlanNode::SeqScan {
462 table: upd.source.clone(),
463 },
464 };
465 Ok(PlanNode::Update {
466 input: Box::new(source),
467 table: upd.source,
468 assignments: upd.assignments,
469 })
470}
471
472fn plan_delete(del: DeleteExpr) -> Result<PlanNode, PlanError> {
473 let source = match del.filter {
474 Some(pred) => match try_extract_eq_index_key(&del.source, &pred) {
475 Some(index_scan) => index_scan,
476 None => match try_extract_range_index_keys(&del.source, &pred) {
477 Some(range_scan) => range_scan,
478 None => PlanNode::Filter {
479 input: Box::new(PlanNode::SeqScan {
480 table: del.source.clone(),
481 }),
482 predicate: pred,
483 },
484 },
485 },
486 None => PlanNode::SeqScan {
487 table: del.source.clone(),
488 },
489 };
490 Ok(PlanNode::Delete {
491 input: Box::new(source),
492 table: del.source,
493 })
494}
495
496fn plan_upsert(ups: UpsertExpr) -> Result<PlanNode, PlanError> {
497 Ok(PlanNode::Upsert {
498 table: ups.target,
499 key_column: ups.key_column,
500 assignments: ups.assignments,
501 on_conflict: ups.on_conflict,
502 })
503}
504
505fn plan_create_type(ct: CreateTypeExpr) -> Result<PlanNode, PlanError> {
506 let fields = ct
507 .fields
508 .into_iter()
509 .map(|f| (f.name, f.type_name, f.required))
510 .collect();
511 Ok(PlanNode::CreateTable {
512 name: ct.name,
513 fields,
514 })
515}
516
517fn try_extract_eq_index_key(table: &str, pred: &Expr) -> Option<PlanNode> {
527 let (lhs, op, rhs) = match pred {
528 Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
529 _ => return None,
530 };
531 if op != BinOp::Eq {
532 return None;
533 }
534 let (column, key) = match (lhs, rhs) {
535 (Expr::Field(name), Expr::Literal(_)) => (name.clone(), rhs.clone()),
536 (Expr::Literal(_), Expr::Field(name)) => (name.clone(), lhs.clone()),
537 _ => return None,
538 };
539 Some(PlanNode::IndexScan {
540 table: table.to_string(),
541 column,
542 key,
543 })
544}
545
546fn extract_single_bound(pred: &Expr) -> Option<RangeBound> {
549 let (lhs, op, rhs) = match pred {
550 Expr::BinaryOp(lhs, op, rhs) => (lhs.as_ref(), *op, rhs.as_ref()),
551 _ => return None,
552 };
553 match op {
554 BinOp::Gt => match (lhs, rhs) {
556 (Expr::Field(name), Expr::Literal(_)) => {
557 Some((name.clone(), Some((rhs.clone(), false)), None))
558 }
559 (Expr::Literal(_), Expr::Field(name)) => {
560 Some((name.clone(), None, Some((lhs.clone(), false))))
562 }
563 _ => None,
564 },
565 BinOp::Gte => match (lhs, rhs) {
567 (Expr::Field(name), Expr::Literal(_)) => {
568 Some((name.clone(), Some((rhs.clone(), true)), None))
569 }
570 (Expr::Literal(_), Expr::Field(name)) => {
571 Some((name.clone(), None, Some((lhs.clone(), true))))
572 }
573 _ => None,
574 },
575 BinOp::Lt => match (lhs, rhs) {
577 (Expr::Field(name), Expr::Literal(_)) => {
578 Some((name.clone(), None, Some((rhs.clone(), false))))
579 }
580 (Expr::Literal(_), Expr::Field(name)) => {
581 Some((name.clone(), Some((lhs.clone(), false)), None))
582 }
583 _ => None,
584 },
585 BinOp::Lte => match (lhs, rhs) {
587 (Expr::Field(name), Expr::Literal(_)) => {
588 Some((name.clone(), None, Some((rhs.clone(), true))))
589 }
590 (Expr::Literal(_), Expr::Field(name)) => {
591 Some((name.clone(), Some((lhs.clone(), true)), None))
592 }
593 _ => None,
594 },
595 _ => None,
596 }
597}
598
599fn try_extract_range_index_keys(table: &str, pred: &Expr) -> Option<PlanNode> {
604 if let Expr::BinaryOp(lhs, BinOp::And, rhs) = pred {
606 if let (Some((col1, s1, e1)), Some((col2, s2, e2))) =
607 (extract_single_bound(lhs), extract_single_bound(rhs))
608 {
609 if col1 == col2 {
610 let start = s1.or(s2);
611 let end = e1.or(e2);
612 if start.is_some() || end.is_some() {
613 return Some(PlanNode::RangeScan {
614 table: table.to_string(),
615 column: col1,
616 start,
617 end,
618 });
619 }
620 }
621 }
622 }
623
624 if let Some((col, start, end)) = extract_single_bound(pred) {
626 return Some(PlanNode::RangeScan {
627 table: table.to_string(),
628 column: col,
629 start,
630 end,
631 });
632 }
633
634 None
635}
636
637fn extract_windows(proj_fields: &mut [ProjectField]) -> Vec<WindowDef> {
642 let mut defs = Vec::new();
643 let mut counter = 0usize;
644 for f in proj_fields.iter_mut() {
645 if let Expr::Window {
646 function,
647 args,
648 partition_by,
649 order_by,
650 } = &f.expr
651 {
652 let output_name = format!("__win_{counter}");
653 defs.push(WindowDef {
654 function: *function,
655 args: args.clone(),
656 partition_by: partition_by.clone(),
657 order_by: order_by
658 .iter()
659 .map(|k| SortKey {
660 field: k.field.clone(),
661 descending: k.descending,
662 })
663 .collect(),
664 output_name: output_name.clone(),
665 });
666 f.expr = Expr::Field(output_name);
667 counter += 1;
668 }
669 }
670 defs
671}
672
673fn extract_aggregates(
679 proj_fields: &mut [ProjectField],
680 having: &mut Option<Expr>,
681) -> Vec<GroupAgg> {
682 let mut aggs: Vec<GroupAgg> = Vec::new();
683 let mut counter = 0usize;
684 for f in proj_fields.iter_mut() {
685 rewrite_agg_expr(&mut f.expr, &mut aggs, &mut counter);
686 }
687 if let Some(h) = having {
688 rewrite_agg_expr(h, &mut aggs, &mut counter);
689 }
690 aggs
691}
692
693fn rewrite_agg_expr(expr: &mut Expr, aggs: &mut Vec<GroupAgg>, counter: &mut usize) {
694 match expr {
695 Expr::FunctionCall(func, inner) => {
696 if let Expr::Field(name) = inner.as_ref() {
697 let output = find_or_insert_agg(aggs, *func, name, counter);
698 *expr = Expr::Field(output);
699 }
700 }
701 Expr::BinaryOp(l, _, r) => {
702 rewrite_agg_expr(l, aggs, counter);
703 rewrite_agg_expr(r, aggs, counter);
704 }
705 Expr::UnaryOp(_, inner) => rewrite_agg_expr(inner, aggs, counter),
706 Expr::Coalesce(l, r) => {
707 rewrite_agg_expr(l, aggs, counter);
708 rewrite_agg_expr(r, aggs, counter);
709 }
710 Expr::InList { expr: e, list, .. } => {
711 rewrite_agg_expr(e, aggs, counter);
712 for item in list {
713 rewrite_agg_expr(item, aggs, counter);
714 }
715 }
716 Expr::InSubquery { expr: e, .. } => {
717 rewrite_agg_expr(e, aggs, counter);
718 }
719 _ => {}
720 }
721}
722
723fn find_or_insert_agg(
724 aggs: &mut Vec<GroupAgg>,
725 func: AggFunc,
726 field: &str,
727 counter: &mut usize,
728) -> String {
729 for existing in aggs.iter() {
730 if existing.function == func && existing.field == field {
731 return existing.output_name.clone();
732 }
733 }
734 let output_name = format!("__agg_{counter}");
735 aggs.push(GroupAgg {
736 function: func,
737 field: field.to_string(),
738 output_name: output_name.clone(),
739 });
740 *counter += 1;
741 output_name
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747 use crate::plan::PlanNode;
748
749 #[test]
750 fn test_plan_simple_scan() {
751 let plan = plan("User").unwrap();
752 assert!(matches!(plan, PlanNode::SeqScan { table } if table == "User"));
753 }
754
755 #[test]
756 fn test_plan_filter() {
757 let plan = plan("User filter .age > 30").unwrap();
758 assert!(matches!(plan, PlanNode::RangeScan { .. }));
759 }
760
761 #[test]
762 fn test_plan_filter_with_projection() {
763 let plan = plan("User filter .age > 30 { name, email }").unwrap();
764 assert!(matches!(plan, PlanNode::Project { .. }));
765 }
766
767 #[test]
768 fn test_plan_insert() {
769 let plan = plan(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
770 assert!(matches!(plan, PlanNode::Insert { .. }));
771 }
772
773 #[test]
774 fn test_plan_order_limit() {
775 let plan = plan("User order .name limit 10").unwrap();
776 match plan {
777 PlanNode::Limit { input, .. } => {
778 assert!(matches!(*input, PlanNode::Sort { .. }));
779 }
780 _ => panic!("expected Limit(Sort(SeqScan))"),
781 }
782 }
783
784 #[test]
785 fn test_plan_count() {
786 let plan = plan("count(User)").unwrap();
787 assert!(matches!(plan, PlanNode::Aggregate { .. }));
788 }
789
790 #[test]
791 fn test_plan_eq_becomes_index_scan() {
792 let plan = plan("User filter .id = 42").unwrap();
795 match plan {
796 PlanNode::IndexScan { table, column, key } => {
797 assert_eq!(table, "User");
798 assert_eq!(column, "id");
799 assert!(matches!(key, Expr::Literal(Literal::Int(42))));
800 }
801 other => panic!("expected IndexScan, got {other:?}"),
802 }
803 }
804
805 #[test]
806 fn test_plan_eq_reversed_becomes_index_scan() {
807 let plan = plan(r#"User filter "NYC" = .city"#).unwrap();
809 assert!(matches!(plan, PlanNode::IndexScan { .. }));
810 }
811
812 #[test]
813 fn test_plan_non_eq_stays_filter() {
814 let plan = plan("User filter .age > 30").unwrap();
816 match plan {
817 PlanNode::RangeScan {
818 column, start, end, ..
819 } => {
820 assert_eq!(column, "age");
821 assert!(start.is_some(), "expected lower bound");
822 assert!(end.is_none(), "expected no upper bound");
823 let (_, inclusive) = start.unwrap();
824 assert!(!inclusive, "expected exclusive lower bound for >");
825 }
826 other => panic!("expected RangeScan, got {other:?}"),
827 }
828 }
829
830 #[test]
831 fn test_plan_index_scan_with_projection() {
832 let plan = plan("User filter .id = 1 { .name }").unwrap();
834 match plan {
835 PlanNode::Project { input, .. } => {
836 assert!(matches!(*input, PlanNode::IndexScan { .. }));
837 }
838 other => panic!("expected Project(IndexScan), got {other:?}"),
839 }
840 }
841
842 #[test]
843 fn test_plan_update_by_pk_becomes_index_scan() {
844 let plan = plan("User filter .id = 42 update { age := 31 }").unwrap();
847 match plan {
848 PlanNode::Update { input, .. } => {
849 assert!(
850 matches!(*input, PlanNode::IndexScan { .. }),
851 "expected Update(IndexScan), got {input:?}"
852 );
853 }
854 other => panic!("expected Update, got {other:?}"),
855 }
856 }
857
858 #[test]
859 fn test_plan_update_range_stays_range_scan() {
860 let plan = plan("User filter .age > 30 update { age := 31 }").unwrap();
861 match plan {
862 PlanNode::Update { input, .. } => {
863 assert!(
864 matches!(*input, PlanNode::RangeScan { .. }),
865 "expected Update(RangeScan), got {input:?}"
866 );
867 }
868 other => panic!("expected Update, got {other:?}"),
869 }
870 }
871
872 #[test]
873 fn test_plan_delete_by_pk_becomes_index_scan() {
874 let plan = plan("User filter .id = 7 delete").unwrap();
875 match plan {
876 PlanNode::Delete { input, .. } => {
877 assert!(matches!(*input, PlanNode::IndexScan { .. }));
878 }
879 other => panic!("expected Delete, got {other:?}"),
880 }
881 }
882
883 #[test]
884 fn test_plan_inner_join_builds_nested_loop() {
885 let plan = plan("User as u join Order as o on u.id = o.user_id").unwrap();
888 match plan {
889 PlanNode::NestedLoopJoin {
890 left,
891 right,
892 on,
893 kind,
894 } => {
895 assert_eq!(kind, JoinKind::Inner);
896 assert!(on.is_some());
897 assert!(matches!(*left, PlanNode::AliasScan { .. }));
898 assert!(matches!(*right, PlanNode::AliasScan { .. }));
899 }
900 other => panic!("expected NestedLoopJoin, got {other:?}"),
901 }
902 }
903
904 #[test]
905 fn test_plan_right_join_rewritten_as_left_with_swapped_inputs() {
906 let plan = plan("User as u right join Order as o on u.id = o.user_id").unwrap();
907 match plan {
908 PlanNode::NestedLoopJoin {
909 left, right, kind, ..
910 } => {
911 assert_eq!(kind, JoinKind::LeftOuter);
912 match *left {
914 PlanNode::AliasScan { table, .. } => assert_eq!(table, "Order"),
915 other => panic!("expected AliasScan(Order), got {other:?}"),
916 }
917 match *right {
918 PlanNode::AliasScan { table, .. } => assert_eq!(table, "User"),
919 other => panic!("expected AliasScan(User), got {other:?}"),
920 }
921 }
922 other => panic!("expected NestedLoopJoin, got {other:?}"),
923 }
924 }
925
926 #[test]
927 fn test_plan_multi_join_is_left_deep() {
928 let plan = plan(
930 "User as u join Order as o on u.id = o.user_id \
931 join Product as p on o.product_id = p.id",
932 )
933 .unwrap();
934 match plan {
935 PlanNode::NestedLoopJoin { left, right, .. } => {
936 match *right {
938 PlanNode::AliasScan { table, .. } => assert_eq!(table, "Product"),
939 other => panic!("expected AliasScan(Product), got {other:?}"),
940 }
941 assert!(matches!(*left, PlanNode::NestedLoopJoin { .. }));
943 }
944 other => panic!("expected NestedLoopJoin, got {other:?}"),
945 }
946 }
947
948 #[test]
949 fn test_plan_join_with_filter_tail_wraps_filter_on_top() {
950 let plan =
951 plan("User as u join Order as o on u.id = o.user_id filter o.total > 100").unwrap();
952 match plan {
953 PlanNode::Filter { input, .. } => {
954 assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
955 }
956 other => panic!("expected Filter(NestedLoopJoin), got {other:?}"),
957 }
958 }
959
960 #[test]
961 fn test_plan_group_by_builds_groupby_node() {
962 let plan = plan("User group .status { .status, n: count(.name) }").unwrap();
963 match plan {
965 PlanNode::Project { input, fields } => {
966 assert_eq!(fields.len(), 2);
967 match *input {
968 PlanNode::GroupBy {
969 input: inner,
970 keys,
971 aggregates,
972 having,
973 } => {
974 assert!(matches!(*inner, PlanNode::SeqScan { .. }));
975 assert_eq!(keys, vec!["status"]);
976 assert_eq!(aggregates.len(), 1);
977 assert_eq!(aggregates[0].function, AggFunc::Count);
978 assert_eq!(aggregates[0].field, "name");
979 assert!(having.is_none());
980 }
981 other => panic!("expected GroupBy, got {other:?}"),
982 }
983 }
984 other => panic!("expected Project, got {other:?}"),
985 }
986 }
987
988 #[test]
989 fn test_plan_group_by_having_rewrites_agg_in_having() {
990 let plan = plan("User group .status having count(.name) > 1 { .status }").unwrap();
991 match plan {
992 PlanNode::Project { input, .. } => {
993 match *input {
994 PlanNode::GroupBy {
995 having, aggregates, ..
996 } => {
997 assert_eq!(aggregates.len(), 1);
1000 assert_eq!(aggregates[0].output_name, "__agg_0");
1001 let h = having.expect("having should be Some");
1002 match h {
1003 Expr::BinaryOp(l, BinOp::Gt, _) => {
1004 assert!(
1005 matches!(*l, Expr::Field(ref name) if name == "__agg_0"),
1006 "expected Field(__agg_0), got {l:?}"
1007 );
1008 }
1009 other => panic!("expected BinaryOp, got {other:?}"),
1010 }
1011 }
1012 other => panic!("expected GroupBy, got {other:?}"),
1013 }
1014 }
1015 other => panic!("expected Project, got {other:?}"),
1016 }
1017 }
1018
1019 #[test]
1020 fn test_plan_window_inserts_window_node_before_project() {
1021 let plan = plan("User { .name, rn: row_number() over (order .age) }").unwrap();
1022 match plan {
1024 PlanNode::Project { input, fields } => {
1025 assert_eq!(fields.len(), 2);
1026 assert!(
1028 matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"),
1029 "expected Field(__win_0), got {:?}",
1030 fields[1].expr
1031 );
1032 match *input {
1033 PlanNode::Window {
1034 input: inner,
1035 windows,
1036 } => {
1037 assert_eq!(windows.len(), 1);
1038 assert_eq!(windows[0].output_name, "__win_0");
1039 assert!(matches!(*inner, PlanNode::SeqScan { .. }));
1040 }
1041 other => panic!("expected Window, got {other:?}"),
1042 }
1043 }
1044 other => panic!("expected Project, got {other:?}"),
1045 }
1046 }
1047
1048 #[test]
1049 fn test_plan_multiple_windows() {
1050 let plan = plan(
1051 "User { .name, rn: row_number() over (order .age), s: sum(.salary) over (partition .dept order .salary) }"
1052 ).unwrap();
1053 match plan {
1054 PlanNode::Project { input, fields } => {
1055 assert_eq!(fields.len(), 3);
1056 assert!(matches!(&fields[1].expr, Expr::Field(name) if name == "__win_0"));
1057 assert!(matches!(&fields[2].expr, Expr::Field(name) if name == "__win_1"));
1058 match *input {
1059 PlanNode::Window { windows, .. } => {
1060 assert_eq!(windows.len(), 2);
1061 assert_eq!(windows[0].output_name, "__win_0");
1062 assert_eq!(windows[1].output_name, "__win_1");
1063 }
1064 other => panic!("expected Window, got {other:?}"),
1065 }
1066 }
1067 other => panic!("expected Project, got {other:?}"),
1068 }
1069 }
1070
1071 #[test]
1072 fn test_plan_no_window_without_over() {
1073 let plan = plan("User group .dept { .dept, total: sum(.salary) }").unwrap();
1075 match plan {
1076 PlanNode::Project { input, .. } => {
1077 assert!(
1079 matches!(*input, PlanNode::GroupBy { .. }),
1080 "expected GroupBy under Project, got {:?}",
1081 input
1082 );
1083 }
1084 other => panic!("expected Project, got {other:?}"),
1085 }
1086 }
1087
1088 #[test]
1089 fn test_plan_explain_wraps_inner() {
1090 let plan = plan("explain User filter .age > 30").unwrap();
1091 match plan {
1092 PlanNode::Explain { input } => {
1093 assert!(
1094 matches!(*input, PlanNode::RangeScan { .. }),
1095 "expected Explain(RangeScan), got {:?}",
1096 input
1097 );
1098 }
1099 other => panic!("expected Explain, got {other:?}"),
1100 }
1101 }
1102
1103 #[test]
1104 fn test_plan_explain_simple_scan() {
1105 let plan = plan("explain User").unwrap();
1106 match plan {
1107 PlanNode::Explain { input } => {
1108 assert!(matches!(*input, PlanNode::SeqScan { .. }));
1109 }
1110 other => panic!("expected Explain(SeqScan), got {other:?}"),
1111 }
1112 }
1113
1114 #[test]
1115 fn test_plan_explain_join() {
1116 let plan = plan("explain User as u join Order as o on u.id = o.user_id").unwrap();
1117 match plan {
1118 PlanNode::Explain { input } => {
1119 assert!(matches!(*input, PlanNode::NestedLoopJoin { .. }));
1120 }
1121 other => panic!("expected Explain(NestedLoopJoin), got {other:?}"),
1122 }
1123 }
1124}