1use std::collections::HashMap;
7
8use sqlparser::ast::{
9 BinaryOperator, Distinct, DuplicateTreatment, Expr, Function, FunctionArgExpr,
10 FunctionArguments, GroupByExpr, Join, JoinConstraint, JoinOperator, Query, Select, SelectItem,
11 SetExpr, SetOperator, SetQuantifier, Statement, TableFactor, TableWithJoins, Value,
12};
13
14use crate::{
15 limits::{enforce_graph_size, QueryLimits},
16 mir::{AggExpr, ColumnRef, JoinKind, MirGraph, MirNodeKind, OrderKey, SetQuantifierKind},
17 SqlError,
18};
19
20pub fn parse_and_lower(sql: &str) -> Result<MirGraph, SqlError> {
26 parse_and_lower_with_limits(sql, QueryLimits::DEFAULT)
27}
28
29pub fn parse_and_lower_with_limits(sql: &str, limits: QueryLimits) -> Result<MirGraph, SqlError> {
36 let statement = crate::parser::parse_select_with_limits(sql, limits)?;
37 let graph = lower_select_statement(&statement)?;
38 enforce_graph_size(graph.node_count(), limits)?;
39 Ok(graph)
40}
41
42pub fn lower_select_statement(statement: &Statement) -> Result<MirGraph, SqlError> {
52 let Statement::Query(query) = statement else {
53 return Err(SqlError::UnsupportedStatement);
54 };
55
56 lower_query(query)
57}
58
59fn lower_query(query: &Query) -> Result<MirGraph, SqlError> {
60 let mut context = LowerContext::default();
61
62 if let Some(with) = &query.with {
63 for cte in &with.cte_tables {
64 let graph = lower_query_with_context(&cte.query, &context)?;
65 context.ctes.insert(cte.alias.name.value.clone(), graph);
66 }
67 }
68
69 lower_query_with_context(query, &context)
70}
71
72fn lower_query_with_context(query: &Query, context: &LowerContext) -> Result<MirGraph, SqlError> {
73 let mut graph = lower_set_expr(&query.body, context)?;
74
75 if let Some(order_by) = &query.order_by {
76 let limit = query
81 .limit
82 .as_ref()
83 .map(literal_usize)
84 .transpose()?
85 .unwrap_or(usize::MAX);
86 let offset = query
87 .offset
88 .as_ref()
89 .map(|offset| literal_usize(&offset.value))
90 .transpose()?
91 .unwrap_or(0);
92 let order_by = order_by
93 .exprs
94 .iter()
95 .map(|expr| OrderKey {
96 expression: expr.expr.to_string(),
97 descending: expr.asc == Some(false),
98 })
99 .collect();
100 push_unary(
101 &mut graph,
102 MirNodeKind::TopK {
103 order_by,
104 limit,
105 offset,
106 },
107 );
108 }
109
110 Ok(graph)
111}
112
113#[derive(Debug, Default)]
114struct LowerContext {
115 ctes: HashMap<String, MirGraph>,
116}
117
118fn lower_set_expr(expr: &SetExpr, context: &LowerContext) -> Result<MirGraph, SqlError> {
119 match expr {
120 SetExpr::Select(select) => lower_select_body(select, context),
121 SetExpr::SetOperation {
122 op,
123 set_quantifier,
124 left,
125 right,
126 } => lower_set_operation(*op, *set_quantifier, left, right, context),
127 SetExpr::Query(query) => lower_query_with_context(query, context),
128 SetExpr::Values(_) => Err(SqlError::UnsupportedFeature("VALUES queries")),
129 SetExpr::Insert(_) => Err(SqlError::UnsupportedFeature("INSERT in query body")),
130 SetExpr::Update(_) => Err(SqlError::UnsupportedFeature("UPDATE in query body")),
131 SetExpr::Table(_) => Err(SqlError::UnsupportedFeature("TABLE queries")),
132 }
133}
134
135fn lower_set_operation(
136 op: SetOperator,
137 set_quantifier: SetQuantifier,
138 left: &SetExpr,
139 right: &SetExpr,
140 context: &LowerContext,
141) -> Result<MirGraph, SqlError> {
142 let quantifier = lower_set_quantifier(set_quantifier)?;
143 let mut graph = lower_set_expr(left, context)?;
144 let left_root = graph.root();
145 let right = lower_set_expr(right, context)?;
146 let right_root = graph.append_graph(&right);
147
148 let set_op = graph.add_node(match op {
149 SetOperator::Union => MirNodeKind::Union { quantifier },
150 SetOperator::Except => MirNodeKind::Except { quantifier },
151 SetOperator::Intersect => MirNodeKind::Intersect { quantifier },
152 });
153 graph.add_input(left_root, set_op);
154 graph.add_input(right_root, set_op);
155 graph.set_root(set_op);
156 Ok(graph)
157}
158
159const fn lower_set_quantifier(quantifier: SetQuantifier) -> Result<SetQuantifierKind, SqlError> {
160 match quantifier {
161 SetQuantifier::All => Ok(SetQuantifierKind::All),
162 SetQuantifier::None | SetQuantifier::Distinct => Ok(SetQuantifierKind::Distinct),
163 SetQuantifier::ByName | SetQuantifier::AllByName | SetQuantifier::DistinctByName => {
164 Err(SqlError::UnsupportedFeature("set operations BY NAME"))
165 }
166 }
167}
168
169fn lower_select_body(select: &Select, context: &LowerContext) -> Result<MirGraph, SqlError> {
170 reject_select_features_not_lowered(select)?;
171
172 let mut graph = lower_from(select, context)?;
173
174 if let Some(predicate) = &select.selection {
175 push_unary(
176 &mut graph,
177 MirNodeKind::Filter {
178 predicate: canonical_predicate(predicate),
179 },
180 );
181 }
182
183 let group_by = group_by_columns(&select.group_by)?;
184 let aggs = aggregate_exprs(&select.projection)?;
185 if !group_by.is_empty() || !aggs.is_empty() {
186 push_unary(&mut graph, MirNodeKind::Aggregate { group_by, aggs });
187 }
188
189 push_unary(
190 &mut graph,
191 MirNodeKind::Project {
192 columns: select.projection.iter().map(select_item_name).collect(),
193 },
194 );
195
196 if matches!(select.distinct, Some(Distinct::Distinct)) {
197 push_unary(&mut graph, MirNodeKind::Distinct);
198 }
199
200 Ok(graph)
201}
202
203fn reject_select_features_not_lowered(select: &Select) -> Result<(), SqlError> {
204 if select.having.is_some() {
205 return Err(SqlError::UnsupportedFeature("HAVING"));
206 }
207 if has_group_by_modifiers(&select.group_by) {
208 return Err(SqlError::UnsupportedFeature("GROUP BY modifiers"));
209 }
210 if select.distinct.is_some() && !matches!(select.distinct, Some(Distinct::Distinct)) {
211 return Err(SqlError::UnsupportedFeature("DISTINCT ON"));
212 }
213 if select.top.is_some() {
214 return Err(SqlError::UnsupportedFeature("TOP"));
215 }
216 if select.into.is_some() {
217 return Err(SqlError::UnsupportedFeature("SELECT INTO"));
218 }
219 if !select.lateral_views.is_empty()
220 || select.prewhere.is_some()
221 || !select.cluster_by.is_empty()
222 || !select.distribute_by.is_empty()
223 || !select.sort_by.is_empty()
224 || !select.named_window.is_empty()
225 || select.qualify.is_some()
226 || select.value_table_mode.is_some()
227 || select.connect_by.is_some()
228 {
229 return Err(SqlError::UnsupportedFeature("non-standard SELECT clauses"));
230 }
231
232 Ok(())
233}
234
235fn lower_from(select: &Select, context: &LowerContext) -> Result<MirGraph, SqlError> {
236 let [source] = select.from.as_slice() else {
237 return Err(SqlError::UnsupportedFeature(
238 "MIR lowering for zero or multiple FROM items",
239 ));
240 };
241
242 lower_table_with_joins(source, context)
243}
244
245fn lower_table_with_joins(
246 source: &TableWithJoins,
247 context: &LowerContext,
248) -> Result<MirGraph, SqlError> {
249 let mut graph = lower_table_factor(&source.relation, context)?;
250
251 for join in &source.joins {
252 lower_join(&mut graph, join, context)?;
253 }
254
255 Ok(graph)
256}
257
258fn lower_join(graph: &mut MirGraph, join: &Join, context: &LowerContext) -> Result<(), SqlError> {
259 let right_graph = lower_table_factor(&join.relation, context)?;
260 let right = graph.append_graph(&right_graph);
261
262 let (kind, on) = match &join.join_operator {
263 JoinOperator::Inner(JoinConstraint::On(predicate)) => {
264 (JoinKind::Inner, equi_join_columns(predicate)?)
265 }
266 JoinOperator::LeftOuter(JoinConstraint::On(predicate)) => {
267 (JoinKind::Left, equi_join_columns(predicate)?)
268 }
269 JoinOperator::Inner(
270 JoinConstraint::Using(_) | JoinConstraint::Natural | JoinConstraint::None,
271 )
272 | JoinOperator::LeftOuter(
273 JoinConstraint::Using(_) | JoinConstraint::Natural | JoinConstraint::None,
274 ) => {
275 return Err(SqlError::UnsupportedFeature(
276 "MIR lowering for non-ON joins",
277 ));
278 }
279 JoinOperator::CrossJoin => {
280 return Err(SqlError::UnsupportedFeature("MIR lowering for cross joins"));
281 }
282 _ => return Err(SqlError::UnsupportedFeature("non-standard joins")),
283 };
284
285 let left = graph.root();
286 let join = graph.add_node(MirNodeKind::Join { kind, on });
287 graph.add_input(left, join);
288 graph.add_input(right, join);
289 graph.set_root(join);
290 Ok(())
291}
292
293fn lower_table_factor(table: &TableFactor, context: &LowerContext) -> Result<MirGraph, SqlError> {
294 match table {
295 TableFactor::Table { name, .. } => {
296 let name = name.to_string();
297 if let Some(cte) = context.ctes.get(&name) {
298 let mut graph = MirGraph::new(MirNodeKind::CteRef { cte: name });
299 let cte_root = graph.append_graph(cte);
300 graph.add_cte_expansion(cte_root, graph.root());
301 Ok(graph)
302 } else {
303 Ok(MirGraph::new(MirNodeKind::BaseTable {
304 table: name,
305 project: Vec::new(),
306 }))
307 }
308 }
309 TableFactor::Derived {
310 lateral: false,
311 subquery,
312 ..
313 } => lower_query_with_context(subquery, context),
314 TableFactor::Derived { lateral: true, .. } => {
315 Err(SqlError::UnsupportedFeature("LATERAL derived tables"))
316 }
317 _ => Err(SqlError::UnsupportedFeature(
318 "table functions or special table factors",
319 )),
320 }
321}
322
323fn equi_join_columns(predicate: &Expr) -> Result<Vec<(ColumnRef, ColumnRef)>, SqlError> {
324 match predicate {
325 Expr::BinaryOp { left, op, right } if *op == BinaryOperator::Eq => {
326 Ok(vec![(column_ref(left)?, column_ref(right)?)])
327 }
328 Expr::BinaryOp {
329 left,
330 op: BinaryOperator::And,
331 right,
332 } => {
333 let mut pairs = equi_join_columns(left)?;
334 pairs.extend(equi_join_columns(right)?);
335 Ok(pairs)
336 }
337 _ => Err(SqlError::UnsupportedFeature("theta joins")),
338 }
339}
340
341fn column_ref(expr: &Expr) -> Result<ColumnRef, SqlError> {
342 match expr {
343 Expr::Identifier(ident) => Ok(ColumnRef {
344 relation: None,
345 name: ident.value.clone(),
346 }),
347 Expr::CompoundIdentifier(parts) => {
348 let [relation, name] = parts.as_slice() else {
349 return Err(SqlError::UnsupportedFeature(
350 "multi-part column references beyond relation.column",
351 ));
352 };
353
354 Ok(ColumnRef {
355 relation: Some(relation.value.clone()),
356 name: name.value.clone(),
357 })
358 }
359 _ => Err(SqlError::UnsupportedFeature("non-column join keys")),
360 }
361}
362
363fn canonical_predicate(expr: &Expr) -> String {
364 match expr {
365 Expr::BinaryOp {
366 left,
367 op: BinaryOperator::And,
368 right,
369 } => {
370 let mut parts = flatten_and(left);
371 parts.extend(flatten_and(right));
372 parts.sort();
373 parts.join(" AND ")
374 }
375 Expr::BinaryOp { left, op, right } if *op == BinaryOperator::Eq => {
376 let mut operands = [
377 (operand_sort_key(left), canonical_expr(left)),
378 (operand_sort_key(right), canonical_expr(right)),
379 ];
380 operands.sort_by(|left, right| left.0.cmp(&right.0).then_with(|| left.1.cmp(&right.1)));
381 format!("{} = {}", operands[0].1, operands[1].1)
382 }
383 Expr::BinaryOp { left, op, right } => {
384 format!("{} {op} {}", canonical_expr(left), canonical_expr(right))
385 }
386 _ => canonical_expr(expr),
387 }
388}
389
390fn operand_sort_key(expr: &Expr) -> String {
391 match expr {
392 Expr::Identifier(_) | Expr::CompoundIdentifier(_) => format!("0:{expr}"),
393 _ => format!("1:{}", canonical_expr(expr)),
394 }
395}
396
397fn canonical_expr(expr: &Expr) -> String {
398 match expr {
399 Expr::Value(value) => canonical_value(value),
400 Expr::UnaryOp { op, expr } => format!("{op} {}", canonical_expr(expr)),
401 Expr::Nested(expr) => canonical_expr(expr),
402 Expr::BinaryOp { left, op, right } => {
403 format!("{} {op} {}", canonical_expr(left), canonical_expr(right))
404 }
405 _ => expr.to_string(),
406 }
407}
408
409fn canonical_value(value: &Value) -> String {
410 match value {
411 Value::Number(value, false) => canonical_number(value),
412 Value::SingleQuotedString(value)
413 | Value::EscapedStringLiteral(value)
414 | Value::UnicodeStringLiteral(value)
415 | Value::NationalStringLiteral(value) => format!("'{}'", value.replace('\'', "''")),
416 Value::Boolean(value) => value.to_string(),
417 Value::Null => "NULL".to_owned(),
418 _ => value.to_string(),
419 }
420}
421
422fn canonical_number(value: &str) -> String {
423 let value = value.trim_start_matches('+');
424 if value.contains(['.', 'e', 'E']) {
425 return value.to_ascii_lowercase();
426 }
427
428 let negative = value.starts_with('-');
429 let digits = if negative { &value[1..] } else { value };
430 let digits = digits.trim_start_matches('0');
431 let digits = if digits.is_empty() { "0" } else { digits };
432 if negative && digits != "0" {
433 format!("-{digits}")
434 } else {
435 digits.to_owned()
436 }
437}
438
439fn flatten_and(expr: &Expr) -> Vec<String> {
440 match expr {
441 Expr::BinaryOp {
442 left,
443 op: BinaryOperator::And,
444 right,
445 } => {
446 let mut parts = flatten_and(left);
447 parts.extend(flatten_and(right));
448 parts
449 }
450 _ => vec![canonical_predicate(expr)],
451 }
452}
453
454fn group_by_columns(group_by: &GroupByExpr) -> Result<Vec<ColumnRef>, SqlError> {
455 match group_by {
456 GroupByExpr::Expressions(expressions, modifiers) if modifiers.is_empty() => {
457 expressions.iter().map(column_ref).collect()
458 }
459 GroupByExpr::Expressions(_, _) => Err(SqlError::UnsupportedFeature("GROUP BY modifiers")),
460 GroupByExpr::All(_) => Err(SqlError::UnsupportedFeature("GROUP BY ALL")),
461 }
462}
463
464fn aggregate_exprs(projection: &[SelectItem]) -> Result<Vec<AggExpr>, SqlError> {
465 projection.iter().try_fold(Vec::new(), |mut aggs, item| {
466 match item {
467 SelectItem::UnnamedExpr(Expr::Function(function)) => {
468 if let Some(agg) = aggregate_expr(function, None)? {
469 aggs.push(agg);
470 }
471 }
472 SelectItem::ExprWithAlias {
473 expr: Expr::Function(function),
474 alias,
475 } => {
476 if let Some(agg) = aggregate_expr(function, Some(alias.value.clone()))? {
477 aggs.push(agg);
478 }
479 }
480 SelectItem::UnnamedExpr(_)
481 | SelectItem::ExprWithAlias { .. }
482 | SelectItem::QualifiedWildcard(_, _)
483 | SelectItem::Wildcard(_) => {}
484 }
485
486 Ok(aggs)
487 })
488}
489
490fn aggregate_expr(function: &Function, alias: Option<String>) -> Result<Option<AggExpr>, SqlError> {
491 let name = function.name.to_string().to_ascii_lowercase();
492 if !matches!(name.as_str(), "count" | "sum" | "min" | "max" | "avg") {
493 return Ok(None);
494 }
495
496 let mut args = function_args(&function.args)?;
497 if matches!(
498 function.args,
499 FunctionArguments::List(ref args)
500 if args.duplicate_treatment == Some(DuplicateTreatment::Distinct)
501 ) {
502 args.insert(0, "DISTINCT".to_owned());
503 }
504
505 Ok(Some(AggExpr {
506 function: name,
507 args,
508 alias,
509 }))
510}
511
512fn function_args(args: &FunctionArguments) -> Result<Vec<String>, SqlError> {
513 match args {
514 FunctionArguments::None => Ok(Vec::new()),
515 FunctionArguments::Subquery(_) => Err(SqlError::UnsupportedFeature(
516 "subqueries in aggregate arguments",
517 )),
518 FunctionArguments::List(args) => args
519 .args
520 .iter()
521 .map(|arg| match arg {
522 sqlparser::ast::FunctionArg::Named { .. } => {
523 Err(SqlError::UnsupportedFeature("named aggregate arguments"))
524 }
525 sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => {
526 Ok(expr.to_string())
527 }
528 sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(name)) => {
529 Ok(format!("{name}.*"))
530 }
531 sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
532 Ok("*".to_owned())
533 }
534 })
535 .collect(),
536 }
537}
538
539fn select_item_name(item: &SelectItem) -> String {
540 match item {
541 SelectItem::UnnamedExpr(expr) => expr.to_string(),
542 SelectItem::ExprWithAlias { alias, .. } => alias.to_string(),
543 SelectItem::QualifiedWildcard(name, _) => format!("{name}.*"),
544 SelectItem::Wildcard(_) => "*".to_owned(),
545 }
546}
547
548fn literal_usize(expr: &Expr) -> Result<usize, SqlError> {
549 match expr {
550 Expr::Value(Value::Number(value, false)) => value
551 .parse()
552 .map_err(|_| SqlError::UnsupportedFeature("non-integer LIMIT/OFFSET")),
553 _ => Err(SqlError::UnsupportedFeature("non-literal LIMIT/OFFSET")),
554 }
555}
556
557fn push_unary(graph: &mut MirGraph, node: MirNodeKind) {
558 let previous_root = graph.root();
559 let next_root = graph.add_node(node);
560 graph.add_input(previous_root, next_root);
561 graph.set_root(next_root);
562}
563
564fn has_group_by_modifiers(group_by: &GroupByExpr) -> bool {
565 match group_by {
566 GroupByExpr::Expressions(_, modifiers) | GroupByExpr::All(modifiers) => {
567 !modifiers.is_empty()
568 }
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use crate::{
575 lower::parse_and_lower,
576 mir::{
577 AggExpr, ColumnRef, JoinKind, MirEdgeKind, MirNodeKind, OrderKey, SetQuantifierKind,
578 },
579 };
580
581 #[test]
582 fn lowers_filter_project_distinct_topk_chain() {
583 let graph = parse_and_lower(
584 "SELECT DISTINCT id, title AS post_title
585 FROM posts
586 WHERE author_id = 42
587 ORDER BY created_at DESC
588 LIMIT 5 OFFSET 10",
589 )
590 .expect("supported query should lower");
591
592 assert_eq!(graph.node_count(), 5);
593 assert!(matches!(
594 graph.root_kind(),
595 MirNodeKind::TopK {
596 order_by,
597 limit: 5,
598 offset: 10,
599 } if order_by == &vec![OrderKey {
600 expression: "created_at".to_owned(),
601 descending: true,
602 }]
603 ));
604 assert!(graph.node_kinds().any(|node| matches!(
605 node,
606 MirNodeKind::BaseTable { table, .. } if table == "posts"
607 )));
608 assert!(graph.node_kinds().any(|node| matches!(
609 node,
610 MirNodeKind::Filter { predicate } if predicate == "author_id = 42"
611 )));
612 assert!(graph.node_kinds().any(|node| matches!(
613 node,
614 MirNodeKind::Project { columns } if columns == &vec!["id".to_owned(), "post_title".to_owned()]
615 )));
616 assert!(graph
617 .node_kinds()
618 .any(|node| matches!(node, MirNodeKind::Distinct)));
619 }
620
621 #[test]
622 fn lowers_equi_join() {
623 let graph = parse_and_lower(
624 "SELECT posts.id
625 FROM posts JOIN authors ON posts.author_id = authors.id",
626 )
627 .expect("validated equi-join should lower");
628
629 assert_eq!(graph.node_count(), 4);
630 assert!(graph.node_kinds().any(|node| matches!(
631 node,
632 MirNodeKind::Join {
633 kind: JoinKind::Inner,
634 on,
635 } if on == &vec![(
636 ColumnRef {
637 relation: Some("posts".to_owned()),
638 name: "author_id".to_owned(),
639 },
640 ColumnRef {
641 relation: Some("authors".to_owned()),
642 name: "id".to_owned(),
643 },
644 )]
645 )));
646 }
647
648 #[test]
649 fn lowers_left_equi_join_with_conjunction() {
650 let graph = parse_and_lower(
651 "SELECT posts.id
652 FROM posts LEFT JOIN comments
653 ON posts.id = comments.post_id AND posts.author_id = comments.author_id",
654 )
655 .expect("validated left equi-join should lower");
656
657 assert!(graph.node_kinds().any(|node| matches!(
658 node,
659 MirNodeKind::Join {
660 kind: JoinKind::Left,
661 on,
662 } if on.len() == 2
663 )));
664 }
665
666 #[test]
667 fn lowers_group_by_aggregate() {
668 let graph = parse_and_lower(
669 "SELECT author_id, count(*) AS post_count, max(created_at)
670 FROM posts
671 WHERE author_id = 42
672 GROUP BY author_id",
673 )
674 .expect("basic aggregate query should lower");
675
676 assert_eq!(graph.node_count(), 4);
677 assert!(graph.node_kinds().any(|node| matches!(
678 node,
679 MirNodeKind::Aggregate { group_by, aggs }
680 if group_by == &vec![ColumnRef {
681 relation: None,
682 name: "author_id".to_owned(),
683 }]
684 && aggs == &vec![
685 AggExpr {
686 function: "count".to_owned(),
687 args: vec!["*".to_owned()],
688 alias: Some("post_count".to_owned()),
689 },
690 AggExpr {
691 function: "max".to_owned(),
692 args: vec!["created_at".to_owned()],
693 alias: None,
694 },
695 ]
696 )));
697 }
698
699 #[test]
700 fn lowers_scalar_aggregate() {
701 let graph = parse_and_lower("SELECT count(*) FROM posts")
702 .expect("scalar aggregate query should lower");
703
704 assert!(graph.node_kinds().any(|node| matches!(
705 node,
706 MirNodeKind::Aggregate { group_by, aggs }
707 if group_by.is_empty() && aggs.len() == 1
708 )));
709 }
710
711 #[test]
712 fn lowers_union_all() {
713 let graph = parse_and_lower(
714 "SELECT id FROM posts
715 UNION ALL
716 SELECT id FROM archived_posts",
717 )
718 .expect("UNION ALL should lower");
719
720 assert_eq!(graph.node_count(), 5);
721 assert!(matches!(
722 graph.root_kind(),
723 MirNodeKind::Union {
724 quantifier: SetQuantifierKind::All,
725 }
726 ));
727 assert_eq!(
728 graph
729 .node_kinds()
730 .filter(|node| matches!(node, MirNodeKind::BaseTable { .. }))
731 .count(),
732 2
733 );
734 }
735
736 #[test]
737 fn lowers_distinct_union() {
738 let graph = parse_and_lower(
739 "SELECT id FROM posts
740 UNION
741 SELECT id FROM archived_posts",
742 )
743 .expect("UNION DISTINCT should lower");
744
745 assert!(matches!(
746 graph.root_kind(),
747 MirNodeKind::Union {
748 quantifier: SetQuantifierKind::Distinct,
749 }
750 ));
751 }
752
753 #[test]
754 fn lowers_except_and_intersect() {
755 let except = parse_and_lower(
756 "SELECT id FROM posts
757 EXCEPT
758 SELECT id FROM archived_posts",
759 )
760 .expect("EXCEPT should lower");
761 let intersect = parse_and_lower(
762 "SELECT id FROM posts
763 INTERSECT ALL
764 SELECT id FROM archived_posts",
765 )
766 .expect("INTERSECT ALL should lower");
767
768 assert!(matches!(
769 except.root_kind(),
770 MirNodeKind::Except {
771 quantifier: SetQuantifierKind::Distinct,
772 }
773 ));
774 assert!(matches!(
775 intersect.root_kind(),
776 MirNodeKind::Intersect {
777 quantifier: SetQuantifierKind::All,
778 }
779 ));
780 }
781
782 #[test]
783 fn lowers_cte_reference() {
784 let graph = parse_and_lower(
785 "WITH recent_posts AS (
786 SELECT id, author_id FROM posts WHERE author_id = 42
787 )
788 SELECT id FROM recent_posts",
789 )
790 .expect("non-recursive CTE should lower");
791
792 assert_eq!(graph.node_count(), 5);
793 assert!(graph.node_kinds().any(|node| matches!(
794 node,
795 MirNodeKind::CteRef { cte } if cte == "recent_posts"
796 )));
797 assert!(graph
798 .graph()
799 .edge_weights()
800 .any(|edge| *edge == MirEdgeKind::CteExpansion));
801 }
802
803 #[test]
804 fn lowers_derived_table() {
805 let graph = parse_and_lower(
806 "SELECT id
807 FROM (
808 SELECT id FROM posts WHERE author_id = 42
809 ) AS recent_posts",
810 )
811 .expect("derived table should lower through nested query path");
812
813 assert_eq!(graph.node_count(), 4);
814 assert!(graph.node_kinds().any(|node| matches!(
815 node,
816 MirNodeKind::Filter { predicate } if predicate == "author_id = 42"
817 )));
818 }
819}