Skip to main content

sqlglot_rust/planner/
mod.rs

1//! Logical query planner.
2//!
3//! Generates a logical execution plan (a DAG of [`Step`]s) from an
4//! optimized SQL AST. Inspired by Python sqlglot's `planner.py`.
5//!
6//! The planner sits between the optimizer and the executor: the optimizer
7//! rewrites the AST, then the planner produces a plan that an execution
8//! engine can follow.
9//!
10//! # Example
11//!
12//! ```rust
13//! use sqlglot_rust::parser::parse;
14//! use sqlglot_rust::dialects::Dialect;
15//! use sqlglot_rust::planner::{plan, Plan};
16//!
17//! let ast = parse("SELECT a, b FROM t WHERE a > 1 ORDER BY b", Dialect::Ansi).unwrap();
18//! let p = plan(&ast).unwrap();
19//! println!("{}", p.to_mermaid());
20//! ```
21
22use std::fmt;
23
24use crate::ast::*;
25use crate::errors::{Result, SqlglotError};
26
27// ═══════════════════════════════════════════════════════════════════════
28// Step ID
29// ═══════════════════════════════════════════════════════════════════════
30
31/// Opaque identifier for a step within a plan.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct StepId(usize);
34
35impl fmt::Display for StepId {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        write!(f, "step_{}", self.0)
38    }
39}
40
41// ═══════════════════════════════════════════════════════════════════════
42// Column projection
43// ═══════════════════════════════════════════════════════════════════════
44
45/// A projected column in a plan step.
46#[derive(Debug, Clone, PartialEq)]
47pub struct Projection {
48    /// The expression being projected.
49    pub expr: Expr,
50    /// Output alias (if any).
51    pub alias: Option<String>,
52}
53
54// ═══════════════════════════════════════════════════════════════════════
55// Plan step types
56// ═══════════════════════════════════════════════════════════════════════
57
58/// A single step in the logical execution plan.
59#[derive(Debug, Clone, PartialEq)]
60pub enum Step {
61    /// Full table scan with optional filter pushdown.
62    Scan {
63        /// Fully-qualified table name.
64        table: String,
65        /// Alias for the table (if any).
66        alias: Option<String>,
67        /// Projected columns.
68        projections: Vec<Projection>,
69        /// Predicate pushed down to the scan.
70        predicate: Option<Expr>,
71        /// IDs of steps this step depends on (always empty for a scan).
72        dependencies: Vec<StepId>,
73    },
74    /// Filter (WHERE / HAVING) applied to its input.
75    Filter {
76        /// The filter predicate.
77        predicate: Expr,
78        /// Projected columns.
79        projections: Vec<Projection>,
80        /// The single input step.
81        dependencies: Vec<StepId>,
82    },
83    /// Projection (SELECT list evaluation).
84    Project {
85        /// Output projections.
86        projections: Vec<Projection>,
87        /// The single input step.
88        dependencies: Vec<StepId>,
89    },
90    /// Aggregation (GROUP BY + aggregate functions).
91    Aggregate {
92        /// GROUP BY keys.
93        group_by: Vec<Expr>,
94        /// Aggregate expressions (COUNT, SUM, etc.).
95        aggregations: Vec<Projection>,
96        /// Projected output columns.
97        projections: Vec<Projection>,
98        /// The single input step.
99        dependencies: Vec<StepId>,
100    },
101    /// Sort (ORDER BY).
102    Sort {
103        /// Order-by items.
104        order_by: Vec<OrderByItem>,
105        /// Projected columns (pass-through).
106        projections: Vec<Projection>,
107        /// The single input step.
108        dependencies: Vec<StepId>,
109    },
110    /// Join two inputs.
111    Join {
112        /// Type of join.
113        join_type: JoinType,
114        /// Join condition (ON clause).
115        condition: Option<Expr>,
116        /// USING columns (if specified instead of ON).
117        using_columns: Vec<String>,
118        /// Projected columns.
119        projections: Vec<Projection>,
120        /// Two input steps: [left, right].
121        dependencies: Vec<StepId>,
122    },
123    /// LIMIT / OFFSET.
124    Limit {
125        /// Row limit.
126        limit: Option<Expr>,
127        /// Row offset.
128        offset: Option<Expr>,
129        /// Projected columns (pass-through).
130        projections: Vec<Projection>,
131        /// The single input step.
132        dependencies: Vec<StepId>,
133    },
134    /// UNION / INTERSECT / EXCEPT.
135    SetOperation {
136        /// The kind of set operation.
137        op: SetOperationType,
138        /// Whether ALL (no deduplication).
139        all: bool,
140        /// Projected columns from the combined result.
141        projections: Vec<Projection>,
142        /// Two input steps: [left, right].
143        dependencies: Vec<StepId>,
144    },
145    /// DISTINCT elimination.
146    Distinct {
147        /// Projected columns.
148        projections: Vec<Projection>,
149        /// The single input step.
150        dependencies: Vec<StepId>,
151    },
152}
153
154impl Step {
155    /// Returns the list of step IDs this step depends on.
156    #[must_use]
157    pub fn dependencies(&self) -> &[StepId] {
158        match self {
159            Step::Scan { dependencies, .. }
160            | Step::Filter { dependencies, .. }
161            | Step::Project { dependencies, .. }
162            | Step::Aggregate { dependencies, .. }
163            | Step::Sort { dependencies, .. }
164            | Step::Join { dependencies, .. }
165            | Step::Limit { dependencies, .. }
166            | Step::SetOperation { dependencies, .. }
167            | Step::Distinct { dependencies, .. } => dependencies,
168        }
169    }
170
171    /// Returns the projected columns of this step.
172    #[must_use]
173    pub fn projections(&self) -> &[Projection] {
174        match self {
175            Step::Scan { projections, .. }
176            | Step::Filter { projections, .. }
177            | Step::Project { projections, .. }
178            | Step::Aggregate { projections, .. }
179            | Step::Sort { projections, .. }
180            | Step::Join { projections, .. }
181            | Step::Limit { projections, .. }
182            | Step::SetOperation { projections, .. }
183            | Step::Distinct { projections, .. } => projections,
184        }
185    }
186
187    /// A short human-readable label for the step type.
188    #[must_use]
189    pub fn kind(&self) -> &'static str {
190        match self {
191            Step::Scan { .. } => "Scan",
192            Step::Filter { .. } => "Filter",
193            Step::Project { .. } => "Project",
194            Step::Aggregate { .. } => "Aggregate",
195            Step::Sort { .. } => "Sort",
196            Step::Join { .. } => "Join",
197            Step::Limit { .. } => "Limit",
198            Step::SetOperation { .. } => "SetOperation",
199            Step::Distinct { .. } => "Distinct",
200        }
201    }
202}
203
204// ═══════════════════════════════════════════════════════════════════════
205// Plan
206// ═══════════════════════════════════════════════════════════════════════
207
208/// A logical execution plan — a directed acyclic graph (DAG) of steps.
209///
210/// Steps are stored in topological order: a step's dependencies always
211/// have a smaller [`StepId`] than the step itself.
212#[derive(Debug, Clone)]
213pub struct Plan {
214    /// All steps in topological order.
215    steps: Vec<Step>,
216    /// The "root" step that produces the final result.
217    root: StepId,
218}
219
220impl Plan {
221    /// Returns the root step ID.
222    #[must_use]
223    pub fn root(&self) -> StepId {
224        self.root
225    }
226
227    /// Returns a reference to all steps.
228    #[must_use]
229    pub fn steps(&self) -> &[Step] {
230        &self.steps
231    }
232
233    /// Looks up a step by its ID.
234    #[must_use]
235    pub fn get(&self, id: StepId) -> Option<&Step> {
236        self.steps.get(id.0)
237    }
238
239    /// Number of steps in the plan.
240    #[must_use]
241    pub fn len(&self) -> usize {
242        self.steps.len()
243    }
244
245    /// Whether the plan has zero steps.
246    #[must_use]
247    pub fn is_empty(&self) -> bool {
248        self.steps.is_empty()
249    }
250
251    /// Render the plan as a Mermaid flowchart.
252    #[must_use]
253    pub fn to_mermaid(&self) -> String {
254        let mut out = String::from("graph TD\n");
255        for (i, step) in self.steps.iter().enumerate() {
256            let id = StepId(i);
257            let label = step_label(step);
258            out.push_str(&format!("    {id}[\"{label}\"]\n"));
259            for dep in step.dependencies() {
260                out.push_str(&format!("    {dep} --> {id}\n"));
261            }
262        }
263        out
264    }
265
266    /// Render the plan as a DOT (Graphviz) digraph.
267    #[must_use]
268    pub fn to_dot(&self) -> String {
269        let mut out = String::from("digraph plan {\n    rankdir=BT;\n");
270        for (i, step) in self.steps.iter().enumerate() {
271            let id = StepId(i);
272            let label = step_label(step);
273            out.push_str(&format!("    {id} [label=\"{label}\"];\n"));
274            for dep in step.dependencies() {
275                out.push_str(&format!("    {dep} -> {id};\n"));
276            }
277        }
278        out.push_str("}\n");
279        out
280    }
281}
282
283impl fmt::Display for Plan {
284    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285        for (i, step) in self.steps.iter().enumerate() {
286            let id = StepId(i);
287            let root_marker = if id == self.root { " (root)" } else { "" };
288            writeln!(f, "{id}{root_marker}: {}", step_label(step))?;
289            for dep in step.dependencies() {
290                writeln!(f, "  <- {dep}")?;
291            }
292        }
293        Ok(())
294    }
295}
296
297/// Produce a concise label for visualization.
298fn step_label(step: &Step) -> String {
299    match step {
300        Step::Scan {
301            table,
302            alias,
303            predicate,
304            ..
305        } => {
306            let name = alias.as_deref().unwrap_or(table.as_str());
307            if predicate.is_some() {
308                format!("Scan({name} + filter)")
309            } else {
310                format!("Scan({name})")
311            }
312        }
313        Step::Filter { .. } => "Filter".to_string(),
314        Step::Project { projections, .. } => {
315            let cols: Vec<_> = projections
316                .iter()
317                .map(|p| {
318                    p.alias
319                        .as_deref()
320                        .unwrap_or_else(|| expr_short_name(&p.expr))
321                })
322                .collect();
323            if cols.len() <= 4 {
324                format!("Project({})", cols.join(", "))
325            } else {
326                format!("Project({} cols)", cols.len())
327            }
328        }
329        Step::Aggregate { group_by, .. } => {
330            if group_by.is_empty() {
331                "Aggregate(scalar)".to_string()
332            } else {
333                format!("Aggregate({} keys)", group_by.len())
334            }
335        }
336        Step::Sort { order_by, .. } => format!("Sort({} keys)", order_by.len()),
337        Step::Join { join_type, .. } => format!("Join({join_type:?})"),
338        Step::Limit { limit, offset, .. } => {
339            let mut parts = Vec::new();
340            if limit.is_some() {
341                parts.push("limit");
342            }
343            if offset.is_some() {
344                parts.push("offset");
345            }
346            format!("Limit({})", parts.join("+"))
347        }
348        Step::SetOperation { op, all, .. } => {
349            let all_str = if *all { " ALL" } else { "" };
350            format!("{op:?}{all_str}")
351        }
352        Step::Distinct { .. } => "Distinct".to_string(),
353    }
354}
355
356/// Short name for an expression (used in labels).
357fn expr_short_name(expr: &Expr) -> &str {
358    match expr {
359        Expr::Column { name, .. } => name.as_str(),
360        Expr::Wildcard => "*",
361        _ => "expr",
362    }
363}
364
365// ═══════════════════════════════════════════════════════════════════════
366// Plan builder
367// ═══════════════════════════════════════════════════════════════════════
368
369/// Build a logical execution plan from a parsed SQL statement.
370///
371/// The statement should ideally be optimized first (via
372/// [`crate::optimizer::optimize`]) for the best plan quality, but this
373/// is not required.
374///
375/// # Errors
376///
377/// Returns [`SqlglotError`] when the statement cannot be planned (e.g.,
378/// DDL statements, unsupported constructs).
379pub fn plan(statement: &Statement) -> Result<Plan> {
380    let mut builder = PlanBuilder::new();
381    let _root = builder.plan_statement(statement)?;
382    Ok(builder.build())
383}
384
385/// Internal builder that accumulates steps.
386struct PlanBuilder {
387    steps: Vec<Step>,
388}
389
390impl PlanBuilder {
391    fn new() -> Self {
392        Self { steps: Vec::new() }
393    }
394
395    fn add_step(&mut self, step: Step) -> StepId {
396        let id = StepId(self.steps.len());
397        self.steps.push(step);
398        id
399    }
400
401    fn build(self) -> Plan {
402        let root = if self.steps.is_empty() {
403            StepId(0)
404        } else {
405            StepId(self.steps.len() - 1)
406        };
407        Plan {
408            steps: self.steps,
409            root,
410        }
411    }
412
413    // ───────────────────────────────────────────────────────────────
414    // Statement dispatch
415    // ───────────────────────────────────────────────────────────────
416
417    fn plan_statement(&mut self, stmt: &Statement) -> Result<StepId> {
418        match stmt {
419            Statement::Select(sel) => self.plan_select(sel),
420            Statement::SetOperation(set_op) => self.plan_set_operation(set_op),
421            _ => Err(SqlglotError::Internal(format!(
422                "Planner does not support {:?} statements",
423                std::mem::discriminant(stmt)
424            ))),
425        }
426    }
427
428    // ───────────────────────────────────────────────────────────────
429    // SELECT
430    // ───────────────────────────────────────────────────────────────
431
432    fn plan_select(&mut self, sel: &SelectStatement) -> Result<StepId> {
433        // 1. Resolve FROM source(s)
434        let mut current = if let Some(from) = &sel.from {
435            self.plan_table_source(&from.source)?
436        } else {
437            // No FROM — single-row virtual scan (e.g., SELECT 1+2)
438            self.add_step(Step::Scan {
439                table: String::new(),
440                alias: None,
441                projections: vec![],
442                predicate: None,
443                dependencies: vec![],
444            })
445        };
446
447        // 2. JOINs
448        for join in &sel.joins {
449            let right = self.plan_table_source(&join.table)?;
450            let projections = vec![]; // pass-through
451            current = self.add_step(Step::Join {
452                join_type: join.join_type.clone(),
453                condition: join.on.clone(),
454                using_columns: join.using.clone(),
455                projections,
456                dependencies: vec![current, right],
457            });
458        }
459
460        // 3. WHERE
461        if let Some(pred) = &sel.where_clause {
462            current = self.add_step(Step::Filter {
463                predicate: pred.clone(),
464                projections: vec![],
465                dependencies: vec![current],
466            });
467        }
468
469        // 4. GROUP BY / Aggregation
470        if !sel.group_by.is_empty() || has_aggregates(&sel.columns) {
471            let aggregations = extract_aggregates(&sel.columns);
472            current = self.add_step(Step::Aggregate {
473                group_by: sel.group_by.clone(),
474                aggregations,
475                projections: vec![],
476                dependencies: vec![current],
477            });
478        }
479
480        // 5. HAVING
481        if let Some(having) = &sel.having {
482            current = self.add_step(Step::Filter {
483                predicate: having.clone(),
484                projections: vec![],
485                dependencies: vec![current],
486            });
487        }
488
489        // 6. DISTINCT
490        if sel.distinct {
491            current = self.add_step(Step::Distinct {
492                projections: vec![],
493                dependencies: vec![current],
494            });
495        }
496
497        // 7. ORDER BY
498        if !sel.order_by.is_empty() {
499            current = self.add_step(Step::Sort {
500                order_by: sel.order_by.clone(),
501                projections: vec![],
502                dependencies: vec![current],
503            });
504        }
505
506        // 8. LIMIT / OFFSET
507        if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() {
508            let limit = sel.limit.clone().or_else(|| sel.fetch_first.clone());
509            current = self.add_step(Step::Limit {
510                limit,
511                offset: sel.offset.clone(),
512                projections: vec![],
513                dependencies: vec![current],
514            });
515        }
516
517        // 9. Project (SELECT columns)
518        let projections = select_items_to_projections(&sel.columns);
519        if !projections.is_empty() {
520            current = self.add_step(Step::Project {
521                projections,
522                dependencies: vec![current],
523            });
524        }
525
526        Ok(current)
527    }
528
529    // ───────────────────────────────────────────────────────────────
530    // Table sources
531    // ───────────────────────────────────────────────────────────────
532
533    fn plan_table_source(&mut self, source: &TableSource) -> Result<StepId> {
534        match source {
535            TableSource::Table(tref) => {
536                let table = fully_qualified_name(tref);
537                Ok(self.add_step(Step::Scan {
538                    table,
539                    alias: tref.alias.clone(),
540                    projections: vec![],
541                    predicate: None,
542                    dependencies: vec![],
543                }))
544            }
545            TableSource::Subquery {
546                query, alias: _, ..
547            } => self.plan_statement(query),
548            TableSource::Lateral { source } => self.plan_table_source(source),
549            TableSource::TableFunction {
550                name, args, alias, ..
551            } => Ok(self.add_step(Step::Scan {
552                table: name.clone(),
553                alias: alias.clone(),
554                projections: args
555                    .iter()
556                    .map(|a| Projection {
557                        expr: a.clone(),
558                        alias: None,
559                    })
560                    .collect(),
561                predicate: None,
562                dependencies: vec![],
563            })),
564            TableSource::Unnest { expr, alias, .. } => Ok(self.add_step(Step::Scan {
565                table: "UNNEST".to_string(),
566                alias: alias.clone(),
567                projections: vec![Projection {
568                    expr: *expr.clone(),
569                    alias: None,
570                }],
571                predicate: None,
572                dependencies: vec![],
573            })),
574            TableSource::Pivot { source, alias, .. }
575            | TableSource::Unpivot { source, alias, .. } => {
576                // Plan the underlying source; the pivot/unpivot is treated
577                // as a virtual scan wrapping it.
578                let inner = self.plan_table_source(source)?;
579                // For simplicity, wrap pivot/unpivot into a project.
580                Ok(self.add_step(Step::Project {
581                    projections: vec![Projection {
582                        expr: Expr::Wildcard,
583                        alias: alias.clone(),
584                    }],
585                    dependencies: vec![inner],
586                }))
587            }
588        }
589    }
590
591    // ───────────────────────────────────────────────────────────────
592    // Set operations
593    // ───────────────────────────────────────────────────────────────
594
595    fn plan_set_operation(&mut self, set_op: &SetOperationStatement) -> Result<StepId> {
596        let left = self.plan_statement(&set_op.left)?;
597        let right = self.plan_statement(&set_op.right)?;
598
599        let mut current = self.add_step(Step::SetOperation {
600            op: set_op.op.clone(),
601            all: set_op.all,
602            projections: vec![],
603            dependencies: vec![left, right],
604        });
605
606        if !set_op.order_by.is_empty() {
607            current = self.add_step(Step::Sort {
608                order_by: set_op.order_by.clone(),
609                projections: vec![],
610                dependencies: vec![current],
611            });
612        }
613
614        if set_op.limit.is_some() || set_op.offset.is_some() {
615            current = self.add_step(Step::Limit {
616                limit: set_op.limit.clone(),
617                offset: set_op.offset.clone(),
618                projections: vec![],
619                dependencies: vec![current],
620            });
621        }
622
623        Ok(current)
624    }
625}
626
627// ═══════════════════════════════════════════════════════════════════════
628// Helpers
629// ═══════════════════════════════════════════════════════════════════════
630
631/// Build a fully qualified table name from a [`TableRef`].
632fn fully_qualified_name(tref: &TableRef) -> String {
633    let mut parts = Vec::new();
634    if let Some(catalog) = &tref.catalog {
635        parts.push(catalog.as_str());
636    }
637    if let Some(schema) = &tref.schema {
638        parts.push(schema.as_str());
639    }
640    parts.push(tref.name.as_str());
641    parts.join(".")
642}
643
644/// Convert SELECT items to projections.
645fn select_items_to_projections(items: &[SelectItem]) -> Vec<Projection> {
646    items
647        .iter()
648        .map(|item| match item {
649            SelectItem::Wildcard => Projection {
650                expr: Expr::Wildcard,
651                alias: None,
652            },
653            SelectItem::QualifiedWildcard { table } => Projection {
654                expr: Expr::QualifiedWildcard {
655                    table: table.clone(),
656                },
657                alias: None,
658            },
659            SelectItem::Expr { expr, alias, .. } => Projection {
660                expr: expr.clone(),
661                alias: alias.clone(),
662            },
663        })
664        .collect()
665}
666
667/// Check whether any SELECT items contain aggregate functions.
668fn has_aggregates(items: &[SelectItem]) -> bool {
669    items.iter().any(|item| match item {
670        SelectItem::Expr { expr, .. } => expr_has_aggregate(expr),
671        _ => false,
672    })
673}
674
675/// Recursively check whether an expression contains an aggregate function.
676fn expr_has_aggregate(expr: &Expr) -> bool {
677    match expr {
678        Expr::Function { name, .. } => is_aggregate_name(name),
679        Expr::TypedFunction { func, .. } => typed_function_is_aggregate(func),
680        Expr::BinaryOp { left, right, .. } => expr_has_aggregate(left) || expr_has_aggregate(right),
681        Expr::UnaryOp { expr, .. } => expr_has_aggregate(expr),
682        Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => expr_has_aggregate(expr),
683        Expr::Case {
684            operand,
685            when_clauses,
686            else_clause,
687        } => {
688            operand.as_ref().is_some_and(|e| expr_has_aggregate(e))
689                || when_clauses
690                    .iter()
691                    .any(|(cond, result)| expr_has_aggregate(cond) || expr_has_aggregate(result))
692                || else_clause.as_ref().is_some_and(|e| expr_has_aggregate(e))
693        }
694        Expr::Alias { expr, .. } => expr_has_aggregate(expr),
695        _ => false,
696    }
697}
698
699/// Well-known aggregate function names.
700fn is_aggregate_name(name: &str) -> bool {
701    matches!(
702        name.to_uppercase().as_str(),
703        "COUNT"
704            | "SUM"
705            | "AVG"
706            | "MIN"
707            | "MAX"
708            | "GROUP_CONCAT"
709            | "STRING_AGG"
710            | "ARRAY_AGG"
711            | "LISTAGG"
712            | "COLLECT_LIST"
713            | "COLLECT_SET"
714            | "ANY_VALUE"
715            | "APPROX_COUNT_DISTINCT"
716            | "PERCENTILE_CONT"
717            | "PERCENTILE_DISC"
718            | "STDDEV"
719            | "STDDEV_POP"
720            | "STDDEV_SAMP"
721            | "VARIANCE"
722            | "VAR_POP"
723            | "VAR_SAMP"
724            | "CORR"
725            | "COVAR_POP"
726            | "COVAR_SAMP"
727            | "FIRST_VALUE"
728            | "LAST_VALUE"
729            | "NTH_VALUE"
730            | "BIT_AND"
731            | "BIT_OR"
732            | "BIT_XOR"
733            | "BOOL_AND"
734            | "BOOL_OR"
735            | "EVERY"
736    )
737}
738
739/// Check whether a TypedFunction variant is an aggregate.
740fn typed_function_is_aggregate(func: &TypedFunction) -> bool {
741    matches!(
742        func,
743        TypedFunction::Count { .. }
744            | TypedFunction::Sum { .. }
745            | TypedFunction::Avg { .. }
746            | TypedFunction::Min { .. }
747            | TypedFunction::Max { .. }
748            | TypedFunction::ArrayAgg { .. }
749            | TypedFunction::ApproxDistinct { .. }
750            | TypedFunction::Variance { .. }
751            | TypedFunction::Stddev { .. }
752            | TypedFunction::GroupConcat { .. }
753    )
754}
755
756/// Extract aggregation projections from SELECT items.
757fn extract_aggregates(items: &[SelectItem]) -> Vec<Projection> {
758    let mut aggs = Vec::new();
759    for item in items {
760        if let SelectItem::Expr { expr, alias, .. } = item {
761            collect_aggregates(expr, alias, &mut aggs);
762        }
763    }
764    aggs
765}
766
767fn collect_aggregates(expr: &Expr, alias: &Option<String>, out: &mut Vec<Projection>) {
768    match expr {
769        Expr::Function { name, .. } if is_aggregate_name(name) => {
770            out.push(Projection {
771                expr: expr.clone(),
772                alias: alias.clone(),
773            });
774        }
775        Expr::TypedFunction { func, .. } if typed_function_is_aggregate(func) => {
776            out.push(Projection {
777                expr: expr.clone(),
778                alias: alias.clone(),
779            });
780        }
781        Expr::BinaryOp { left, right, .. } => {
782            collect_aggregates(left, &None, out);
783            collect_aggregates(right, &None, out);
784        }
785        Expr::Alias { expr: inner, name } => {
786            collect_aggregates(inner, &Some(name.clone()), out);
787        }
788        _ => {}
789    }
790}
791
792// ═══════════════════════════════════════════════════════════════════════
793// Tests
794// ═══════════════════════════════════════════════════════════════════════
795
796#[cfg(test)]
797mod tests {
798    use super::*;
799    use crate::dialects::Dialect;
800    use crate::parser::parse;
801
802    #[test]
803    fn test_simple_select() {
804        let ast = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
805        let p = plan(&ast).unwrap();
806        assert!(p.len() >= 2); // Scan + Project
807        assert_eq!(p.get(p.root()).unwrap().kind(), "Project");
808    }
809
810    #[test]
811    fn test_select_with_where() {
812        let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
813        let p = plan(&ast).unwrap();
814        // Scan -> Filter -> Project
815        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
816        assert!(kinds.contains(&"Scan"));
817        assert!(kinds.contains(&"Filter"));
818        assert!(kinds.contains(&"Project"));
819    }
820
821    #[test]
822    fn test_select_with_order_by() {
823        let ast = parse("SELECT a FROM t ORDER BY a", Dialect::Ansi).unwrap();
824        let p = plan(&ast).unwrap();
825        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
826        assert!(kinds.contains(&"Sort"));
827    }
828
829    #[test]
830    fn test_select_with_group_by() {
831        let ast = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
832        let p = plan(&ast).unwrap();
833        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
834        assert!(kinds.contains(&"Aggregate"));
835    }
836
837    #[test]
838    fn test_select_with_having() {
839        let ast = parse(
840            "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1",
841            Dialect::Ansi,
842        )
843        .unwrap();
844        let p = plan(&ast).unwrap();
845        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
846        // Should have Aggregate and a Filter for HAVING
847        assert!(kinds.contains(&"Aggregate"));
848        assert!(kinds.contains(&"Filter"));
849    }
850
851    #[test]
852    fn test_join() {
853        let ast = parse("SELECT a.x FROM a JOIN b ON a.id = b.id", Dialect::Ansi).unwrap();
854        let p = plan(&ast).unwrap();
855        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
856        assert!(kinds.contains(&"Join"));
857    }
858
859    #[test]
860    fn test_multiple_joins() {
861        let ast = parse(
862            "SELECT a.x FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id",
863            Dialect::Ansi,
864        )
865        .unwrap();
866        let p = plan(&ast).unwrap();
867        let join_count = p.steps().iter().filter(|s| s.kind() == "Join").count();
868        assert_eq!(join_count, 2);
869    }
870
871    #[test]
872    fn test_union() {
873        let ast = parse("SELECT a FROM t1 UNION ALL SELECT b FROM t2", Dialect::Ansi).unwrap();
874        let p = plan(&ast).unwrap();
875        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
876        assert!(kinds.contains(&"SetOperation"));
877    }
878
879    #[test]
880    fn test_limit_offset() {
881        let ast = parse("SELECT a FROM t LIMIT 10 OFFSET 5", Dialect::Ansi).unwrap();
882        let p = plan(&ast).unwrap();
883        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
884        assert!(kinds.contains(&"Limit"));
885    }
886
887    #[test]
888    fn test_distinct() {
889        let ast = parse("SELECT DISTINCT a FROM t", Dialect::Ansi).unwrap();
890        let p = plan(&ast).unwrap();
891        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
892        assert!(kinds.contains(&"Distinct"));
893    }
894
895    #[test]
896    fn test_subquery_in_from() {
897        let ast = parse("SELECT x FROM (SELECT a AS x FROM t) sub", Dialect::Ansi).unwrap();
898        let p = plan(&ast).unwrap();
899        // Inner scan + inner project + outer project
900        assert!(p.len() >= 3);
901    }
902
903    #[test]
904    fn test_complex_query() {
905        let ast = parse(
906            "SELECT a, SUM(b) AS total FROM t WHERE c > 0 GROUP BY a HAVING SUM(b) > 10 ORDER BY total DESC LIMIT 5",
907            Dialect::Ansi,
908        ).unwrap();
909        let p = plan(&ast).unwrap();
910        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
911        assert!(kinds.contains(&"Scan"));
912        assert!(kinds.contains(&"Filter")); // WHERE and HAVING
913        assert!(kinds.contains(&"Aggregate"));
914        assert!(kinds.contains(&"Sort"));
915        assert!(kinds.contains(&"Limit"));
916        assert!(kinds.contains(&"Project"));
917    }
918
919    #[test]
920    fn test_dag_dependencies() {
921        let ast = parse("SELECT a FROM t1 JOIN t2 ON t1.id = t2.id", Dialect::Ansi).unwrap();
922        let p = plan(&ast).unwrap();
923        // Every step's dependencies should reference valid earlier steps
924        for (i, step) in p.steps().iter().enumerate() {
925            for dep in step.dependencies() {
926                assert!(dep.0 < i, "step {i} depends on {dep} which is not earlier");
927            }
928        }
929    }
930
931    #[test]
932    fn test_mermaid_output() {
933        let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
934        let p = plan(&ast).unwrap();
935        let mermaid = p.to_mermaid();
936        assert!(mermaid.starts_with("graph TD"));
937        assert!(mermaid.contains("Scan"));
938    }
939
940    #[test]
941    fn test_dot_output() {
942        let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
943        let p = plan(&ast).unwrap();
944        let dot = p.to_dot();
945        assert!(dot.starts_with("digraph plan"));
946        assert!(dot.contains("Scan"));
947    }
948
949    #[test]
950    fn test_display() {
951        let ast = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
952        let p = plan(&ast).unwrap();
953        let display = format!("{p}");
954        assert!(display.contains("(root)"));
955    }
956
957    #[test]
958    fn test_ddl_rejected() {
959        let ast = parse("CREATE TABLE t (a INT)", Dialect::Ansi).unwrap();
960        assert!(plan(&ast).is_err());
961    }
962
963    #[test]
964    fn test_no_from_select() {
965        let ast = parse("SELECT 1 + 2", Dialect::Ansi).unwrap();
966        let p = plan(&ast).unwrap();
967        assert!(!p.is_empty());
968    }
969
970    #[test]
971    fn test_left_join() {
972        let ast = parse(
973            "SELECT a.x FROM a LEFT JOIN b ON a.id = b.id",
974            Dialect::Ansi,
975        )
976        .unwrap();
977        let p = plan(&ast).unwrap();
978        let join_step = p.steps().iter().find(|s| s.kind() == "Join").unwrap();
979        if let Step::Join { join_type, .. } = join_step {
980            assert_eq!(*join_type, JoinType::Left);
981        } else {
982            panic!("expected Join step");
983        }
984    }
985
986    #[test]
987    fn test_cross_join() {
988        let ast = parse("SELECT a.x FROM a CROSS JOIN b", Dialect::Ansi).unwrap();
989        let p = plan(&ast).unwrap();
990        let join_step = p.steps().iter().find(|s| s.kind() == "Join").unwrap();
991        if let Step::Join { join_type, .. } = join_step {
992            assert_eq!(*join_type, JoinType::Cross);
993        } else {
994            panic!("expected Join step");
995        }
996    }
997
998    #[test]
999    fn test_union_with_order_limit() {
1000        let ast = parse(
1001            "SELECT a FROM t1 UNION SELECT b FROM t2 ORDER BY 1 LIMIT 10",
1002            Dialect::Ansi,
1003        )
1004        .unwrap();
1005        let p = plan(&ast).unwrap();
1006        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
1007        assert!(kinds.contains(&"SetOperation"));
1008        assert!(kinds.contains(&"Sort"));
1009        assert!(kinds.contains(&"Limit"));
1010    }
1011}