Skip to main content

palimpsest_sql/
lower.rs

1// Copyright 2026 Thousand Birds Inc.
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Lowering from sqlparser AST into [`MirGraph`].
5
6use std::collections::HashMap;
7
8use sqlparser::ast::{
9    BinaryOperator, Distinct, DuplicateTreatment, Expr, Function, FunctionArgExpr,
10    FunctionArguments, GroupByExpr, Join, JoinConstraint, JoinOperator, Query, Select, SelectItem,
11    SetExpr, SetOperator, SetQuantifier, Statement, TableFactor, TableWithJoins, Value,
12};
13
14use crate::{
15    limits::{enforce_graph_size, QueryLimits},
16    mir::{AggExpr, ColumnRef, JoinKind, MirGraph, MirNodeKind, OrderKey, SetQuantifierKind},
17    SqlError,
18};
19
20/// Parses `sql` and lowers it into an [`MirGraph`] under the default
21/// [`QueryLimits`].
22///
23/// # Errors
24/// Surfaces parse, validation, and size-bound errors.
25pub fn parse_and_lower(sql: &str) -> Result<MirGraph, SqlError> {
26    parse_and_lower_with_limits(sql, QueryLimits::DEFAULT)
27}
28
29/// Parse + lower while enforcing both `max_input_bytes` and
30/// `max_mir_nodes` from `limits`.
31///
32/// # Errors
33/// Surfaces [`SqlError::QueryTooLarge`] / [`SqlError::QueryTooComplex`]
34/// in addition to the usual parse/lower errors.
35pub fn parse_and_lower_with_limits(sql: &str, limits: QueryLimits) -> Result<MirGraph, SqlError> {
36    let statement = crate::parser::parse_select_with_limits(sql, limits)?;
37    let graph = lower_select_statement(&statement)?;
38    enforce_graph_size(graph.node_count(), limits)?;
39    Ok(graph)
40}
41
42/// Lowers an already-parsed `SELECT` [`Statement`] into an [`MirGraph`].
43///
44/// Skips the byte-budget check (the input is no longer textual at this
45/// point) but still produces graphs that should be size-checked by the
46/// caller via [`enforce_graph_size`].
47///
48/// # Errors
49/// [`SqlError::UnsupportedStatement`] on non-`SELECT` input, plus any
50/// downstream lowering error.
51pub fn lower_select_statement(statement: &Statement) -> Result<MirGraph, SqlError> {
52    let Statement::Query(query) = statement else {
53        return Err(SqlError::UnsupportedStatement);
54    };
55
56    lower_query(query)
57}
58
59fn lower_query(query: &Query) -> Result<MirGraph, SqlError> {
60    let mut context = LowerContext::default();
61
62    if let Some(with) = &query.with {
63        for cte in &with.cte_tables {
64            let graph = lower_query_with_context(&cte.query, &context)?;
65            context.ctes.insert(cte.alias.name.value.clone(), graph);
66        }
67    }
68
69    lower_query_with_context(query, &context)
70}
71
72fn lower_query_with_context(query: &Query, context: &LowerContext) -> Result<MirGraph, SqlError> {
73    let mut graph = lower_set_expr(&query.body, context)?;
74
75    if let Some(order_by) = &query.order_by {
76        // ORDER BY without LIMIT plans as a sort over the whole input —
77        // represented in the MIR as `TopK` with `usize::MAX` so we
78        // don't need a separate node kind. Downstream operators see
79        // "ordered, unbounded" and can pick the right physical plan.
80        let limit = query
81            .limit
82            .as_ref()
83            .map(literal_usize)
84            .transpose()?
85            .unwrap_or(usize::MAX);
86        let offset = query
87            .offset
88            .as_ref()
89            .map(|offset| literal_usize(&offset.value))
90            .transpose()?
91            .unwrap_or(0);
92        let order_by = order_by
93            .exprs
94            .iter()
95            .map(|expr| OrderKey {
96                expression: expr.expr.to_string(),
97                descending: expr.asc == Some(false),
98            })
99            .collect();
100        push_unary(
101            &mut graph,
102            MirNodeKind::TopK {
103                order_by,
104                limit,
105                offset,
106            },
107        );
108    }
109
110    Ok(graph)
111}
112
113#[derive(Debug, Default)]
114struct LowerContext {
115    ctes: HashMap<String, MirGraph>,
116}
117
118fn lower_set_expr(expr: &SetExpr, context: &LowerContext) -> Result<MirGraph, SqlError> {
119    match expr {
120        SetExpr::Select(select) => lower_select_body(select, context),
121        SetExpr::SetOperation {
122            op,
123            set_quantifier,
124            left,
125            right,
126        } => lower_set_operation(*op, *set_quantifier, left, right, context),
127        SetExpr::Query(query) => lower_query_with_context(query, context),
128        SetExpr::Values(_) => Err(SqlError::UnsupportedFeature("VALUES queries")),
129        SetExpr::Insert(_) => Err(SqlError::UnsupportedFeature("INSERT in query body")),
130        SetExpr::Update(_) => Err(SqlError::UnsupportedFeature("UPDATE in query body")),
131        SetExpr::Table(_) => Err(SqlError::UnsupportedFeature("TABLE queries")),
132    }
133}
134
135fn lower_set_operation(
136    op: SetOperator,
137    set_quantifier: SetQuantifier,
138    left: &SetExpr,
139    right: &SetExpr,
140    context: &LowerContext,
141) -> Result<MirGraph, SqlError> {
142    let quantifier = lower_set_quantifier(set_quantifier)?;
143    let mut graph = lower_set_expr(left, context)?;
144    let left_root = graph.root();
145    let right = lower_set_expr(right, context)?;
146    let right_root = graph.append_graph(&right);
147
148    let set_op = graph.add_node(match op {
149        SetOperator::Union => MirNodeKind::Union { quantifier },
150        SetOperator::Except => MirNodeKind::Except { quantifier },
151        SetOperator::Intersect => MirNodeKind::Intersect { quantifier },
152    });
153    graph.add_input(left_root, set_op);
154    graph.add_input(right_root, set_op);
155    graph.set_root(set_op);
156    Ok(graph)
157}
158
159const fn lower_set_quantifier(quantifier: SetQuantifier) -> Result<SetQuantifierKind, SqlError> {
160    match quantifier {
161        SetQuantifier::All => Ok(SetQuantifierKind::All),
162        SetQuantifier::None | SetQuantifier::Distinct => Ok(SetQuantifierKind::Distinct),
163        SetQuantifier::ByName | SetQuantifier::AllByName | SetQuantifier::DistinctByName => {
164            Err(SqlError::UnsupportedFeature("set operations BY NAME"))
165        }
166    }
167}
168
169fn lower_select_body(select: &Select, context: &LowerContext) -> Result<MirGraph, SqlError> {
170    reject_select_features_not_lowered(select)?;
171
172    let mut graph = lower_from(select, context)?;
173
174    if let Some(predicate) = &select.selection {
175        push_unary(
176            &mut graph,
177            MirNodeKind::Filter {
178                predicate: canonical_predicate(predicate),
179            },
180        );
181    }
182
183    let group_by = group_by_columns(&select.group_by)?;
184    let aggs = aggregate_exprs(&select.projection)?;
185    if !group_by.is_empty() || !aggs.is_empty() {
186        push_unary(&mut graph, MirNodeKind::Aggregate { group_by, aggs });
187    }
188
189    push_unary(
190        &mut graph,
191        MirNodeKind::Project {
192            columns: select.projection.iter().map(select_item_name).collect(),
193        },
194    );
195
196    if matches!(select.distinct, Some(Distinct::Distinct)) {
197        push_unary(&mut graph, MirNodeKind::Distinct);
198    }
199
200    Ok(graph)
201}
202
203fn reject_select_features_not_lowered(select: &Select) -> Result<(), SqlError> {
204    if select.having.is_some() {
205        return Err(SqlError::UnsupportedFeature("HAVING"));
206    }
207    if has_group_by_modifiers(&select.group_by) {
208        return Err(SqlError::UnsupportedFeature("GROUP BY modifiers"));
209    }
210    if select.distinct.is_some() && !matches!(select.distinct, Some(Distinct::Distinct)) {
211        return Err(SqlError::UnsupportedFeature("DISTINCT ON"));
212    }
213    if select.top.is_some() {
214        return Err(SqlError::UnsupportedFeature("TOP"));
215    }
216    if select.into.is_some() {
217        return Err(SqlError::UnsupportedFeature("SELECT INTO"));
218    }
219    if !select.lateral_views.is_empty()
220        || select.prewhere.is_some()
221        || !select.cluster_by.is_empty()
222        || !select.distribute_by.is_empty()
223        || !select.sort_by.is_empty()
224        || !select.named_window.is_empty()
225        || select.qualify.is_some()
226        || select.value_table_mode.is_some()
227        || select.connect_by.is_some()
228    {
229        return Err(SqlError::UnsupportedFeature("non-standard SELECT clauses"));
230    }
231
232    Ok(())
233}
234
235fn lower_from(select: &Select, context: &LowerContext) -> Result<MirGraph, SqlError> {
236    let [source] = select.from.as_slice() else {
237        return Err(SqlError::UnsupportedFeature(
238            "MIR lowering for zero or multiple FROM items",
239        ));
240    };
241
242    lower_table_with_joins(source, context)
243}
244
245fn lower_table_with_joins(
246    source: &TableWithJoins,
247    context: &LowerContext,
248) -> Result<MirGraph, SqlError> {
249    let mut graph = lower_table_factor(&source.relation, context)?;
250
251    for join in &source.joins {
252        lower_join(&mut graph, join, context)?;
253    }
254
255    Ok(graph)
256}
257
258fn lower_join(graph: &mut MirGraph, join: &Join, context: &LowerContext) -> Result<(), SqlError> {
259    let right_graph = lower_table_factor(&join.relation, context)?;
260    let right = graph.append_graph(&right_graph);
261
262    let (kind, on) = match &join.join_operator {
263        JoinOperator::Inner(JoinConstraint::On(predicate)) => {
264            (JoinKind::Inner, equi_join_columns(predicate)?)
265        }
266        JoinOperator::LeftOuter(JoinConstraint::On(predicate)) => {
267            (JoinKind::Left, equi_join_columns(predicate)?)
268        }
269        JoinOperator::Inner(
270            JoinConstraint::Using(_) | JoinConstraint::Natural | JoinConstraint::None,
271        )
272        | JoinOperator::LeftOuter(
273            JoinConstraint::Using(_) | JoinConstraint::Natural | JoinConstraint::None,
274        ) => {
275            return Err(SqlError::UnsupportedFeature(
276                "MIR lowering for non-ON joins",
277            ));
278        }
279        JoinOperator::CrossJoin => {
280            return Err(SqlError::UnsupportedFeature("MIR lowering for cross joins"));
281        }
282        _ => return Err(SqlError::UnsupportedFeature("non-standard joins")),
283    };
284
285    let left = graph.root();
286    let join = graph.add_node(MirNodeKind::Join { kind, on });
287    graph.add_input(left, join);
288    graph.add_input(right, join);
289    graph.set_root(join);
290    Ok(())
291}
292
293fn lower_table_factor(table: &TableFactor, context: &LowerContext) -> Result<MirGraph, SqlError> {
294    match table {
295        TableFactor::Table { name, .. } => {
296            let name = name.to_string();
297            if let Some(cte) = context.ctes.get(&name) {
298                let mut graph = MirGraph::new(MirNodeKind::CteRef { cte: name });
299                let cte_root = graph.append_graph(cte);
300                graph.add_cte_expansion(cte_root, graph.root());
301                Ok(graph)
302            } else {
303                Ok(MirGraph::new(MirNodeKind::BaseTable {
304                    table: name,
305                    project: Vec::new(),
306                }))
307            }
308        }
309        TableFactor::Derived {
310            lateral: false,
311            subquery,
312            ..
313        } => lower_query_with_context(subquery, context),
314        TableFactor::Derived { lateral: true, .. } => {
315            Err(SqlError::UnsupportedFeature("LATERAL derived tables"))
316        }
317        _ => Err(SqlError::UnsupportedFeature(
318            "table functions or special table factors",
319        )),
320    }
321}
322
323fn equi_join_columns(predicate: &Expr) -> Result<Vec<(ColumnRef, ColumnRef)>, SqlError> {
324    match predicate {
325        Expr::BinaryOp { left, op, right } if *op == BinaryOperator::Eq => {
326            Ok(vec![(column_ref(left)?, column_ref(right)?)])
327        }
328        Expr::BinaryOp {
329            left,
330            op: BinaryOperator::And,
331            right,
332        } => {
333            let mut pairs = equi_join_columns(left)?;
334            pairs.extend(equi_join_columns(right)?);
335            Ok(pairs)
336        }
337        _ => Err(SqlError::UnsupportedFeature("theta joins")),
338    }
339}
340
341fn column_ref(expr: &Expr) -> Result<ColumnRef, SqlError> {
342    match expr {
343        Expr::Identifier(ident) => Ok(ColumnRef {
344            relation: None,
345            name: ident.value.clone(),
346        }),
347        Expr::CompoundIdentifier(parts) => {
348            let [relation, name] = parts.as_slice() else {
349                return Err(SqlError::UnsupportedFeature(
350                    "multi-part column references beyond relation.column",
351                ));
352            };
353
354            Ok(ColumnRef {
355                relation: Some(relation.value.clone()),
356                name: name.value.clone(),
357            })
358        }
359        _ => Err(SqlError::UnsupportedFeature("non-column join keys")),
360    }
361}
362
363fn canonical_predicate(expr: &Expr) -> String {
364    match expr {
365        Expr::BinaryOp {
366            left,
367            op: BinaryOperator::And,
368            right,
369        } => {
370            let mut parts = flatten_and(left);
371            parts.extend(flatten_and(right));
372            parts.sort();
373            parts.join(" AND ")
374        }
375        Expr::BinaryOp { left, op, right } if *op == BinaryOperator::Eq => {
376            let mut operands = [
377                (operand_sort_key(left), canonical_expr(left)),
378                (operand_sort_key(right), canonical_expr(right)),
379            ];
380            operands.sort_by(|left, right| left.0.cmp(&right.0).then_with(|| left.1.cmp(&right.1)));
381            format!("{} = {}", operands[0].1, operands[1].1)
382        }
383        Expr::BinaryOp { left, op, right } => {
384            format!("{} {op} {}", canonical_expr(left), canonical_expr(right))
385        }
386        _ => canonical_expr(expr),
387    }
388}
389
390fn operand_sort_key(expr: &Expr) -> String {
391    match expr {
392        Expr::Identifier(_) | Expr::CompoundIdentifier(_) => format!("0:{expr}"),
393        _ => format!("1:{}", canonical_expr(expr)),
394    }
395}
396
397fn canonical_expr(expr: &Expr) -> String {
398    match expr {
399        Expr::Value(value) => canonical_value(value),
400        Expr::UnaryOp { op, expr } => format!("{op} {}", canonical_expr(expr)),
401        Expr::Nested(expr) => canonical_expr(expr),
402        Expr::BinaryOp { left, op, right } => {
403            format!("{} {op} {}", canonical_expr(left), canonical_expr(right))
404        }
405        _ => expr.to_string(),
406    }
407}
408
409fn canonical_value(value: &Value) -> String {
410    match value {
411        Value::Number(value, false) => canonical_number(value),
412        Value::SingleQuotedString(value)
413        | Value::EscapedStringLiteral(value)
414        | Value::UnicodeStringLiteral(value)
415        | Value::NationalStringLiteral(value) => format!("'{}'", value.replace('\'', "''")),
416        Value::Boolean(value) => value.to_string(),
417        Value::Null => "NULL".to_owned(),
418        _ => value.to_string(),
419    }
420}
421
422fn canonical_number(value: &str) -> String {
423    let value = value.trim_start_matches('+');
424    if value.contains(['.', 'e', 'E']) {
425        return value.to_ascii_lowercase();
426    }
427
428    let negative = value.starts_with('-');
429    let digits = if negative { &value[1..] } else { value };
430    let digits = digits.trim_start_matches('0');
431    let digits = if digits.is_empty() { "0" } else { digits };
432    if negative && digits != "0" {
433        format!("-{digits}")
434    } else {
435        digits.to_owned()
436    }
437}
438
439fn flatten_and(expr: &Expr) -> Vec<String> {
440    match expr {
441        Expr::BinaryOp {
442            left,
443            op: BinaryOperator::And,
444            right,
445        } => {
446            let mut parts = flatten_and(left);
447            parts.extend(flatten_and(right));
448            parts
449        }
450        _ => vec![canonical_predicate(expr)],
451    }
452}
453
454fn group_by_columns(group_by: &GroupByExpr) -> Result<Vec<ColumnRef>, SqlError> {
455    match group_by {
456        GroupByExpr::Expressions(expressions, modifiers) if modifiers.is_empty() => {
457            expressions.iter().map(column_ref).collect()
458        }
459        GroupByExpr::Expressions(_, _) => Err(SqlError::UnsupportedFeature("GROUP BY modifiers")),
460        GroupByExpr::All(_) => Err(SqlError::UnsupportedFeature("GROUP BY ALL")),
461    }
462}
463
464fn aggregate_exprs(projection: &[SelectItem]) -> Result<Vec<AggExpr>, SqlError> {
465    projection.iter().try_fold(Vec::new(), |mut aggs, item| {
466        match item {
467            SelectItem::UnnamedExpr(Expr::Function(function)) => {
468                if let Some(agg) = aggregate_expr(function, None)? {
469                    aggs.push(agg);
470                }
471            }
472            SelectItem::ExprWithAlias {
473                expr: Expr::Function(function),
474                alias,
475            } => {
476                if let Some(agg) = aggregate_expr(function, Some(alias.value.clone()))? {
477                    aggs.push(agg);
478                }
479            }
480            SelectItem::UnnamedExpr(_)
481            | SelectItem::ExprWithAlias { .. }
482            | SelectItem::QualifiedWildcard(_, _)
483            | SelectItem::Wildcard(_) => {}
484        }
485
486        Ok(aggs)
487    })
488}
489
490fn aggregate_expr(function: &Function, alias: Option<String>) -> Result<Option<AggExpr>, SqlError> {
491    let name = function.name.to_string().to_ascii_lowercase();
492    if !matches!(name.as_str(), "count" | "sum" | "min" | "max" | "avg") {
493        return Ok(None);
494    }
495
496    let mut args = function_args(&function.args)?;
497    if matches!(
498        function.args,
499        FunctionArguments::List(ref args)
500            if args.duplicate_treatment == Some(DuplicateTreatment::Distinct)
501    ) {
502        args.insert(0, "DISTINCT".to_owned());
503    }
504
505    Ok(Some(AggExpr {
506        function: name,
507        args,
508        alias,
509    }))
510}
511
512fn function_args(args: &FunctionArguments) -> Result<Vec<String>, SqlError> {
513    match args {
514        FunctionArguments::None => Ok(Vec::new()),
515        FunctionArguments::Subquery(_) => Err(SqlError::UnsupportedFeature(
516            "subqueries in aggregate arguments",
517        )),
518        FunctionArguments::List(args) => args
519            .args
520            .iter()
521            .map(|arg| match arg {
522                sqlparser::ast::FunctionArg::Named { .. } => {
523                    Err(SqlError::UnsupportedFeature("named aggregate arguments"))
524                }
525                sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => {
526                    Ok(expr.to_string())
527                }
528                sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(name)) => {
529                    Ok(format!("{name}.*"))
530                }
531                sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
532                    Ok("*".to_owned())
533                }
534            })
535            .collect(),
536    }
537}
538
539fn select_item_name(item: &SelectItem) -> String {
540    match item {
541        SelectItem::UnnamedExpr(expr) => expr.to_string(),
542        SelectItem::ExprWithAlias { alias, .. } => alias.to_string(),
543        SelectItem::QualifiedWildcard(name, _) => format!("{name}.*"),
544        SelectItem::Wildcard(_) => "*".to_owned(),
545    }
546}
547
548fn literal_usize(expr: &Expr) -> Result<usize, SqlError> {
549    match expr {
550        Expr::Value(Value::Number(value, false)) => value
551            .parse()
552            .map_err(|_| SqlError::UnsupportedFeature("non-integer LIMIT/OFFSET")),
553        _ => Err(SqlError::UnsupportedFeature("non-literal LIMIT/OFFSET")),
554    }
555}
556
557fn push_unary(graph: &mut MirGraph, node: MirNodeKind) {
558    let previous_root = graph.root();
559    let next_root = graph.add_node(node);
560    graph.add_input(previous_root, next_root);
561    graph.set_root(next_root);
562}
563
564fn has_group_by_modifiers(group_by: &GroupByExpr) -> bool {
565    match group_by {
566        GroupByExpr::Expressions(_, modifiers) | GroupByExpr::All(modifiers) => {
567            !modifiers.is_empty()
568        }
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use crate::{
575        lower::parse_and_lower,
576        mir::{
577            AggExpr, ColumnRef, JoinKind, MirEdgeKind, MirNodeKind, OrderKey, SetQuantifierKind,
578        },
579    };
580
581    #[test]
582    fn lowers_filter_project_distinct_topk_chain() {
583        let graph = parse_and_lower(
584            "SELECT DISTINCT id, title AS post_title
585             FROM posts
586             WHERE author_id = 42
587             ORDER BY created_at DESC
588             LIMIT 5 OFFSET 10",
589        )
590        .expect("supported query should lower");
591
592        assert_eq!(graph.node_count(), 5);
593        assert!(matches!(
594            graph.root_kind(),
595            MirNodeKind::TopK {
596                order_by,
597                limit: 5,
598                offset: 10,
599            } if order_by == &vec![OrderKey {
600                expression: "created_at".to_owned(),
601                descending: true,
602            }]
603        ));
604        assert!(graph.node_kinds().any(|node| matches!(
605            node,
606            MirNodeKind::BaseTable { table, .. } if table == "posts"
607        )));
608        assert!(graph.node_kinds().any(|node| matches!(
609            node,
610            MirNodeKind::Filter { predicate } if predicate == "author_id = 42"
611        )));
612        assert!(graph.node_kinds().any(|node| matches!(
613            node,
614            MirNodeKind::Project { columns } if columns == &vec!["id".to_owned(), "post_title".to_owned()]
615        )));
616        assert!(graph
617            .node_kinds()
618            .any(|node| matches!(node, MirNodeKind::Distinct)));
619    }
620
621    #[test]
622    fn lowers_equi_join() {
623        let graph = parse_and_lower(
624            "SELECT posts.id
625             FROM posts JOIN authors ON posts.author_id = authors.id",
626        )
627        .expect("validated equi-join should lower");
628
629        assert_eq!(graph.node_count(), 4);
630        assert!(graph.node_kinds().any(|node| matches!(
631            node,
632            MirNodeKind::Join {
633                kind: JoinKind::Inner,
634                on,
635            } if on == &vec![(
636                ColumnRef {
637                    relation: Some("posts".to_owned()),
638                    name: "author_id".to_owned(),
639                },
640                ColumnRef {
641                    relation: Some("authors".to_owned()),
642                    name: "id".to_owned(),
643                },
644            )]
645        )));
646    }
647
648    #[test]
649    fn lowers_left_equi_join_with_conjunction() {
650        let graph = parse_and_lower(
651            "SELECT posts.id
652             FROM posts LEFT JOIN comments
653               ON posts.id = comments.post_id AND posts.author_id = comments.author_id",
654        )
655        .expect("validated left equi-join should lower");
656
657        assert!(graph.node_kinds().any(|node| matches!(
658            node,
659            MirNodeKind::Join {
660                kind: JoinKind::Left,
661                on,
662            } if on.len() == 2
663        )));
664    }
665
666    #[test]
667    fn lowers_group_by_aggregate() {
668        let graph = parse_and_lower(
669            "SELECT author_id, count(*) AS post_count, max(created_at)
670             FROM posts
671             WHERE author_id = 42
672             GROUP BY author_id",
673        )
674        .expect("basic aggregate query should lower");
675
676        assert_eq!(graph.node_count(), 4);
677        assert!(graph.node_kinds().any(|node| matches!(
678            node,
679            MirNodeKind::Aggregate { group_by, aggs }
680                if group_by == &vec![ColumnRef {
681                    relation: None,
682                    name: "author_id".to_owned(),
683                }]
684                    && aggs == &vec![
685                        AggExpr {
686                            function: "count".to_owned(),
687                            args: vec!["*".to_owned()],
688                            alias: Some("post_count".to_owned()),
689                        },
690                        AggExpr {
691                            function: "max".to_owned(),
692                            args: vec!["created_at".to_owned()],
693                            alias: None,
694                        },
695                    ]
696        )));
697    }
698
699    #[test]
700    fn lowers_scalar_aggregate() {
701        let graph = parse_and_lower("SELECT count(*) FROM posts")
702            .expect("scalar aggregate query should lower");
703
704        assert!(graph.node_kinds().any(|node| matches!(
705            node,
706            MirNodeKind::Aggregate { group_by, aggs }
707                if group_by.is_empty() && aggs.len() == 1
708        )));
709    }
710
711    #[test]
712    fn lowers_union_all() {
713        let graph = parse_and_lower(
714            "SELECT id FROM posts
715             UNION ALL
716             SELECT id FROM archived_posts",
717        )
718        .expect("UNION ALL should lower");
719
720        assert_eq!(graph.node_count(), 5);
721        assert!(matches!(
722            graph.root_kind(),
723            MirNodeKind::Union {
724                quantifier: SetQuantifierKind::All,
725            }
726        ));
727        assert_eq!(
728            graph
729                .node_kinds()
730                .filter(|node| matches!(node, MirNodeKind::BaseTable { .. }))
731                .count(),
732            2
733        );
734    }
735
736    #[test]
737    fn lowers_distinct_union() {
738        let graph = parse_and_lower(
739            "SELECT id FROM posts
740             UNION
741             SELECT id FROM archived_posts",
742        )
743        .expect("UNION DISTINCT should lower");
744
745        assert!(matches!(
746            graph.root_kind(),
747            MirNodeKind::Union {
748                quantifier: SetQuantifierKind::Distinct,
749            }
750        ));
751    }
752
753    #[test]
754    fn lowers_except_and_intersect() {
755        let except = parse_and_lower(
756            "SELECT id FROM posts
757             EXCEPT
758             SELECT id FROM archived_posts",
759        )
760        .expect("EXCEPT should lower");
761        let intersect = parse_and_lower(
762            "SELECT id FROM posts
763             INTERSECT ALL
764             SELECT id FROM archived_posts",
765        )
766        .expect("INTERSECT ALL should lower");
767
768        assert!(matches!(
769            except.root_kind(),
770            MirNodeKind::Except {
771                quantifier: SetQuantifierKind::Distinct,
772            }
773        ));
774        assert!(matches!(
775            intersect.root_kind(),
776            MirNodeKind::Intersect {
777                quantifier: SetQuantifierKind::All,
778            }
779        ));
780    }
781
782    #[test]
783    fn lowers_cte_reference() {
784        let graph = parse_and_lower(
785            "WITH recent_posts AS (
786                SELECT id, author_id FROM posts WHERE author_id = 42
787             )
788             SELECT id FROM recent_posts",
789        )
790        .expect("non-recursive CTE should lower");
791
792        assert_eq!(graph.node_count(), 5);
793        assert!(graph.node_kinds().any(|node| matches!(
794            node,
795            MirNodeKind::CteRef { cte } if cte == "recent_posts"
796        )));
797        assert!(graph
798            .graph()
799            .edge_weights()
800            .any(|edge| *edge == MirEdgeKind::CteExpansion));
801    }
802
803    #[test]
804    fn lowers_derived_table() {
805        let graph = parse_and_lower(
806            "SELECT id
807             FROM (
808                SELECT id FROM posts WHERE author_id = 42
809             ) AS recent_posts",
810        )
811        .expect("derived table should lower through nested query path");
812
813        assert_eq!(graph.node_count(), 4);
814        assert!(graph.node_kinds().any(|node| matches!(
815            node,
816            MirNodeKind::Filter { predicate } if predicate == "author_id = 42"
817        )));
818    }
819}