Skip to main content

polyglot_sql/
planner.rs

1//! Query Execution Planner
2//!
3//! This module provides functionality to convert SQL AST into an execution plan
4//! represented as a DAG (Directed Acyclic Graph) of steps.
5//!
6
7use crate::expressions::{Expression, JoinKind};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// A query execution plan
12#[derive(Debug)]
13pub struct Plan {
14    /// The root step of the plan DAG
15    pub root: Step,
16    /// Cached DAG representation
17    dag: Option<HashMap<usize, HashSet<usize>>>,
18}
19
20impl Plan {
21    /// Create a new plan from an expression
22    pub fn from_expression(expression: &Expression) -> Option<Self> {
23        let root = Step::from_expression(expression, &HashMap::new())?;
24        Some(Self { root, dag: None })
25    }
26
27    /// Get the DAG representation of the plan
28    pub fn dag(&mut self) -> &HashMap<usize, HashSet<usize>> {
29        if self.dag.is_none() {
30            let mut dag = HashMap::new();
31            self.build_dag(&self.root, &mut dag, 0);
32            self.dag = Some(dag);
33        }
34        self.dag.as_ref().unwrap()
35    }
36
37    fn build_dag(&self, step: &Step, dag: &mut HashMap<usize, HashSet<usize>>, id: usize) {
38        let deps: HashSet<usize> = step
39            .dependencies
40            .iter()
41            .enumerate()
42            .map(|(i, _)| id + i + 1)
43            .collect();
44        dag.insert(id, deps);
45
46        for (i, dep) in step.dependencies.iter().enumerate() {
47            self.build_dag(dep, dag, id + i + 1);
48        }
49    }
50
51    /// Get all leaf steps (steps with no dependencies)
52    pub fn leaves(&self) -> Vec<&Step> {
53        let mut leaves = Vec::new();
54        self.collect_leaves(&self.root, &mut leaves);
55        leaves
56    }
57
58    fn collect_leaves<'a>(&'a self, step: &'a Step, leaves: &mut Vec<&'a Step>) {
59        if step.dependencies.is_empty() {
60            leaves.push(step);
61        } else {
62            for dep in &step.dependencies {
63                self.collect_leaves(dep, leaves);
64            }
65        }
66    }
67}
68
69/// A step in the execution plan
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct Step {
72    /// Name of this step
73    pub name: String,
74    /// Type of step
75    pub kind: StepKind,
76    /// Projections to output
77    pub projections: Vec<Expression>,
78    /// Dependencies (other steps that must complete first)
79    pub dependencies: Vec<Step>,
80    /// Aggregation expressions (for Aggregate steps)
81    pub aggregations: Vec<Expression>,
82    /// Group by expressions (for Aggregate steps)
83    pub group_by: Vec<Expression>,
84    /// Join condition (for Join steps)
85    pub condition: Option<Expression>,
86    /// Sort expressions (for Sort steps)
87    pub order_by: Vec<Expression>,
88    /// Limit value (for Scan/other steps)
89    pub limit: Option<Expression>,
90}
91
92/// Types of execution steps
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum StepKind {
96    /// Scan a table
97    Scan,
98    /// Join multiple inputs
99    Join(JoinType),
100    /// Aggregate rows
101    Aggregate,
102    /// Sort rows
103    Sort,
104    /// Set operation (UNION, INTERSECT, EXCEPT)
105    SetOperation(SetOperationType),
106}
107
108/// Types of joins in execution plans
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
110#[serde(rename_all = "snake_case")]
111pub enum JoinType {
112    Inner,
113    Left,
114    Right,
115    Full,
116    Cross,
117}
118
119/// Types of set operations
120#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
121#[serde(rename_all = "snake_case")]
122pub enum SetOperationType {
123    Union,
124    UnionAll,
125    Intersect,
126    Except,
127}
128
129impl Step {
130    /// Create a new step
131    pub fn new(name: impl Into<String>, kind: StepKind) -> Self {
132        Self {
133            name: name.into(),
134            kind,
135            projections: Vec::new(),
136            dependencies: Vec::new(),
137            aggregations: Vec::new(),
138            group_by: Vec::new(),
139            condition: None,
140            order_by: Vec::new(),
141            limit: None,
142        }
143    }
144
145    /// Build a step from an expression
146    pub fn from_expression(expression: &Expression, ctes: &HashMap<String, Step>) -> Option<Self> {
147        match expression {
148            Expression::Select(select) => {
149                let mut step = Self::from_select(select, ctes)?;
150
151                // Handle ORDER BY
152                if let Some(ref order_by) = select.order_by {
153                    let sort_step = Step {
154                        name: step.name.clone(),
155                        kind: StepKind::Sort,
156                        projections: Vec::new(),
157                        dependencies: vec![step],
158                        aggregations: Vec::new(),
159                        group_by: Vec::new(),
160                        condition: None,
161                        order_by: order_by
162                            .expressions
163                            .iter()
164                            .map(|o| o.this.clone())
165                            .collect(),
166                        limit: None,
167                    };
168                    step = sort_step;
169                }
170
171                // Handle LIMIT
172                if let Some(ref limit) = select.limit {
173                    step.limit = Some(limit.this.clone());
174                }
175
176                Some(step)
177            }
178            Expression::Union(union) => {
179                let left = Self::from_expression(&union.left, ctes)?;
180                let right = Self::from_expression(&union.right, ctes)?;
181
182                let op_type = if union.all {
183                    SetOperationType::UnionAll
184                } else {
185                    SetOperationType::Union
186                };
187
188                Some(Step {
189                    name: "UNION".to_string(),
190                    kind: StepKind::SetOperation(op_type),
191                    projections: Vec::new(),
192                    dependencies: vec![left, right],
193                    aggregations: Vec::new(),
194                    group_by: Vec::new(),
195                    condition: None,
196                    order_by: Vec::new(),
197                    limit: None,
198                })
199            }
200            Expression::Intersect(intersect) => {
201                let left = Self::from_expression(&intersect.left, ctes)?;
202                let right = Self::from_expression(&intersect.right, ctes)?;
203
204                Some(Step {
205                    name: "INTERSECT".to_string(),
206                    kind: StepKind::SetOperation(SetOperationType::Intersect),
207                    projections: Vec::new(),
208                    dependencies: vec![left, right],
209                    aggregations: Vec::new(),
210                    group_by: Vec::new(),
211                    condition: None,
212                    order_by: Vec::new(),
213                    limit: None,
214                })
215            }
216            Expression::Except(except) => {
217                let left = Self::from_expression(&except.left, ctes)?;
218                let right = Self::from_expression(&except.right, ctes)?;
219
220                Some(Step {
221                    name: "EXCEPT".to_string(),
222                    kind: StepKind::SetOperation(SetOperationType::Except),
223                    projections: Vec::new(),
224                    dependencies: vec![left, right],
225                    aggregations: Vec::new(),
226                    group_by: Vec::new(),
227                    condition: None,
228                    order_by: Vec::new(),
229                    limit: None,
230                })
231            }
232            _ => None,
233        }
234    }
235
236    fn from_select(
237        select: &crate::expressions::Select,
238        ctes: &HashMap<String, Step>,
239    ) -> Option<Self> {
240        // Process CTEs first
241        let mut ctes = ctes.clone();
242        if let Some(ref with) = select.with {
243            for cte in &with.ctes {
244                if let Some(step) = Self::from_expression(&cte.this, &ctes) {
245                    ctes.insert(cte.alias.name.clone(), step);
246                }
247            }
248        }
249
250        // Start with the FROM clause
251        let mut step = if let Some(ref from) = select.from {
252            if let Some(table_expr) = from.expressions.first() {
253                Self::from_table_expression(table_expr, &ctes)?
254            } else {
255                return None;
256            }
257        } else {
258            // SELECT without FROM (e.g., SELECT 1)
259            Step::new("", StepKind::Scan)
260        };
261
262        // Process JOINs
263        for join in &select.joins {
264            let right = Self::from_table_expression(&join.this, &ctes)?;
265
266            let join_type = match join.kind {
267                JoinKind::Inner => JoinType::Inner,
268                JoinKind::Left | JoinKind::NaturalLeft => JoinType::Left,
269                JoinKind::Right | JoinKind::NaturalRight => JoinType::Right,
270                JoinKind::Full | JoinKind::NaturalFull => JoinType::Full,
271                JoinKind::Cross | JoinKind::Natural => JoinType::Cross,
272                _ => JoinType::Inner,
273            };
274
275            let join_step = Step {
276                name: step.name.clone(),
277                kind: StepKind::Join(join_type),
278                projections: Vec::new(),
279                dependencies: vec![step, right],
280                aggregations: Vec::new(),
281                group_by: Vec::new(),
282                condition: join.on.clone(),
283                order_by: Vec::new(),
284                limit: None,
285            };
286            step = join_step;
287        }
288
289        // Check for aggregations
290        let has_aggregations = select.expressions.iter().any(|e| contains_aggregate(e));
291        let has_group_by = select.group_by.is_some();
292
293        if has_aggregations || has_group_by {
294            // Create aggregate step
295            let agg_step = Step {
296                name: step.name.clone(),
297                kind: StepKind::Aggregate,
298                projections: select.expressions.clone(),
299                dependencies: vec![step],
300                aggregations: extract_aggregations(&select.expressions),
301                group_by: select
302                    .group_by
303                    .as_ref()
304                    .map(|g| g.expressions.clone())
305                    .unwrap_or_default(),
306                condition: None,
307                order_by: Vec::new(),
308                limit: None,
309            };
310            step = agg_step;
311        } else {
312            step.projections = select.expressions.clone();
313        }
314
315        Some(step)
316    }
317
318    fn from_table_expression(expr: &Expression, ctes: &HashMap<String, Step>) -> Option<Self> {
319        match expr {
320            Expression::Table(table) => {
321                // Check if this references a CTE
322                if let Some(cte_step) = ctes.get(&table.name.name) {
323                    return Some(cte_step.clone());
324                }
325
326                // Regular table scan
327                Some(Step::new(&table.name.name, StepKind::Scan))
328            }
329            Expression::Alias(alias) => {
330                let mut step = Self::from_table_expression(&alias.this, ctes)?;
331                step.name = alias.alias.name.clone();
332                Some(step)
333            }
334            Expression::Subquery(sq) => {
335                let step = Self::from_expression(&sq.this, ctes)?;
336                Some(step)
337            }
338            _ => None,
339        }
340    }
341
342    /// Add a dependency to this step
343    pub fn add_dependency(&mut self, dep: Step) {
344        self.dependencies.push(dep);
345    }
346}
347
348/// Check if an expression contains an aggregate function
349fn contains_aggregate(expr: &Expression) -> bool {
350    match expr {
351        // Specific aggregate function variants
352        Expression::Sum(_)
353        | Expression::Count(_)
354        | Expression::Avg(_)
355        | Expression::Min(_)
356        | Expression::Max(_)
357        | Expression::ArrayAgg(_)
358        | Expression::StringAgg(_)
359        | Expression::ListAgg(_)
360        | Expression::Stddev(_)
361        | Expression::StddevPop(_)
362        | Expression::StddevSamp(_)
363        | Expression::Variance(_)
364        | Expression::VarPop(_)
365        | Expression::VarSamp(_)
366        | Expression::Median(_)
367        | Expression::Mode(_)
368        | Expression::First(_)
369        | Expression::Last(_)
370        | Expression::AnyValue(_)
371        | Expression::ApproxDistinct(_)
372        | Expression::ApproxCountDistinct(_)
373        | Expression::LogicalAnd(_)
374        | Expression::LogicalOr(_)
375        | Expression::AggregateFunction(_) => true,
376
377        Expression::Alias(alias) => contains_aggregate(&alias.this),
378        Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) | Expression::Div(op) => {
379            contains_aggregate(&op.left) || contains_aggregate(&op.right)
380        }
381        Expression::Function(func) => {
382            // Check for aggregate function names (fallback)
383            let name = func.name.to_uppercase();
384            matches!(
385                name.as_str(),
386                "SUM"
387                    | "COUNT"
388                    | "AVG"
389                    | "MIN"
390                    | "MAX"
391                    | "ARRAY_AGG"
392                    | "STRING_AGG"
393                    | "GROUP_CONCAT"
394            )
395        }
396        _ => false,
397    }
398}
399
400/// Extract aggregate expressions from a list
401fn extract_aggregations(expressions: &[Expression]) -> Vec<Expression> {
402    let mut aggs = Vec::new();
403    for expr in expressions {
404        collect_aggregations(expr, &mut aggs);
405    }
406    aggs
407}
408
409fn collect_aggregations(expr: &Expression, aggs: &mut Vec<Expression>) {
410    match expr {
411        // Specific aggregate function variants
412        Expression::Sum(_)
413        | Expression::Count(_)
414        | Expression::Avg(_)
415        | Expression::Min(_)
416        | Expression::Max(_)
417        | Expression::ArrayAgg(_)
418        | Expression::StringAgg(_)
419        | Expression::ListAgg(_)
420        | Expression::Stddev(_)
421        | Expression::StddevPop(_)
422        | Expression::StddevSamp(_)
423        | Expression::Variance(_)
424        | Expression::VarPop(_)
425        | Expression::VarSamp(_)
426        | Expression::Median(_)
427        | Expression::Mode(_)
428        | Expression::First(_)
429        | Expression::Last(_)
430        | Expression::AnyValue(_)
431        | Expression::ApproxDistinct(_)
432        | Expression::ApproxCountDistinct(_)
433        | Expression::LogicalAnd(_)
434        | Expression::LogicalOr(_)
435        | Expression::AggregateFunction(_) => {
436            aggs.push(expr.clone());
437        }
438        Expression::Alias(alias) => {
439            collect_aggregations(&alias.this, aggs);
440        }
441        Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) | Expression::Div(op) => {
442            collect_aggregations(&op.left, aggs);
443            collect_aggregations(&op.right, aggs);
444        }
445        Expression::Function(func) => {
446            let name = func.name.to_uppercase();
447            if matches!(
448                name.as_str(),
449                "SUM"
450                    | "COUNT"
451                    | "AVG"
452                    | "MIN"
453                    | "MAX"
454                    | "ARRAY_AGG"
455                    | "STRING_AGG"
456                    | "GROUP_CONCAT"
457            ) {
458                aggs.push(expr.clone());
459            } else {
460                for arg in &func.args {
461                    collect_aggregations(arg, aggs);
462                }
463            }
464        }
465        _ => {}
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::dialects::{Dialect, DialectType};
473
474    fn parse(sql: &str) -> Expression {
475        let dialect = Dialect::get(DialectType::Generic);
476        let ast = dialect.parse(sql).unwrap();
477        ast.into_iter().next().unwrap()
478    }
479
480    #[test]
481    fn test_simple_scan() {
482        let sql = "SELECT a, b FROM t";
483        let expr = parse(sql);
484        let plan = Plan::from_expression(&expr);
485
486        assert!(plan.is_some());
487        let plan = plan.unwrap();
488        assert_eq!(plan.root.kind, StepKind::Scan);
489        assert_eq!(plan.root.name, "t");
490    }
491
492    #[test]
493    fn test_join() {
494        let sql = "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.id";
495        let expr = parse(sql);
496        let plan = Plan::from_expression(&expr);
497
498        assert!(plan.is_some());
499        let plan = plan.unwrap();
500        assert!(matches!(plan.root.kind, StepKind::Join(_)));
501        assert_eq!(plan.root.dependencies.len(), 2);
502    }
503
504    #[test]
505    fn test_aggregate() {
506        let sql = "SELECT x, SUM(y) FROM t GROUP BY x";
507        let expr = parse(sql);
508        let plan = Plan::from_expression(&expr);
509
510        assert!(plan.is_some());
511        let plan = plan.unwrap();
512        assert_eq!(plan.root.kind, StepKind::Aggregate);
513    }
514
515    #[test]
516    fn test_union() {
517        let sql = "SELECT a FROM t1 UNION SELECT b FROM t2";
518        let expr = parse(sql);
519        let plan = Plan::from_expression(&expr);
520
521        assert!(plan.is_some());
522        let plan = plan.unwrap();
523        assert!(matches!(
524            plan.root.kind,
525            StepKind::SetOperation(SetOperationType::Union)
526        ));
527    }
528
529    #[test]
530    fn test_contains_aggregate() {
531        // Parse a SELECT with an aggregate function and check the expression
532        let select_with_agg = parse("SELECT SUM(x) FROM t");
533        if let Expression::Select(ref sel) = select_with_agg {
534            assert!(!sel.expressions.is_empty());
535            assert!(
536                contains_aggregate(&sel.expressions[0]),
537                "Expected SUM to be detected as aggregate function"
538            );
539        } else {
540            panic!("Expected SELECT expression");
541        }
542
543        // Parse a SELECT with a non-aggregate expression
544        let select_without_agg = parse("SELECT x + 1 FROM t");
545        if let Expression::Select(ref sel) = select_without_agg {
546            assert!(!sel.expressions.is_empty());
547            assert!(
548                !contains_aggregate(&sel.expressions[0]),
549                "Expected x + 1 to not be an aggregate function"
550            );
551        } else {
552            panic!("Expected SELECT expression");
553        }
554    }
555}