Skip to main content

polyglot_sql/
planner.rs

1//! Query Execution Planner
2//!
3//! This module provides functionality to convert SQL AST into an execution plan
4//! represented as a DAG (Directed Acyclic Graph) of steps.
5//!
6
7use crate::expressions::{Expression, JoinKind};
8use std::collections::{HashMap, HashSet};
9
10/// A query execution plan
11#[derive(Debug)]
12pub struct Plan {
13    /// The root step of the plan DAG
14    pub root: Step,
15    /// Cached DAG representation
16    dag: Option<HashMap<usize, HashSet<usize>>>,
17}
18
19impl Plan {
20    /// Create a new plan from an expression
21    pub fn from_expression(expression: &Expression) -> Option<Self> {
22        let root = Step::from_expression(expression, &HashMap::new())?;
23        Some(Self { root, dag: None })
24    }
25
26    /// Get the DAG representation of the plan
27    pub fn dag(&mut self) -> &HashMap<usize, HashSet<usize>> {
28        if self.dag.is_none() {
29            let mut dag = HashMap::new();
30            self.build_dag(&self.root, &mut dag, 0);
31            self.dag = Some(dag);
32        }
33        self.dag.as_ref().unwrap()
34    }
35
36    fn build_dag(&self, step: &Step, dag: &mut HashMap<usize, HashSet<usize>>, id: usize) {
37        let deps: HashSet<usize> = step.dependencies
38            .iter()
39            .enumerate()
40            .map(|(i, _)| id + i + 1)
41            .collect();
42        dag.insert(id, deps);
43
44        for (i, dep) in step.dependencies.iter().enumerate() {
45            self.build_dag(dep, dag, id + i + 1);
46        }
47    }
48
49    /// Get all leaf steps (steps with no dependencies)
50    pub fn leaves(&self) -> Vec<&Step> {
51        let mut leaves = Vec::new();
52        self.collect_leaves(&self.root, &mut leaves);
53        leaves
54    }
55
56    fn collect_leaves<'a>(&'a self, step: &'a Step, leaves: &mut Vec<&'a Step>) {
57        if step.dependencies.is_empty() {
58            leaves.push(step);
59        } else {
60            for dep in &step.dependencies {
61                self.collect_leaves(dep, leaves);
62            }
63        }
64    }
65}
66
67/// A step in the execution plan
68#[derive(Debug, Clone)]
69pub struct Step {
70    /// Name of this step
71    pub name: String,
72    /// Type of step
73    pub kind: StepKind,
74    /// Projections to output
75    pub projections: Vec<Expression>,
76    /// Dependencies (other steps that must complete first)
77    pub dependencies: Vec<Step>,
78    /// Aggregation expressions (for Aggregate steps)
79    pub aggregations: Vec<Expression>,
80    /// Group by expressions (for Aggregate steps)
81    pub group_by: Vec<Expression>,
82    /// Join condition (for Join steps)
83    pub condition: Option<Expression>,
84    /// Sort expressions (for Sort steps)
85    pub order_by: Vec<Expression>,
86    /// Limit value (for Scan/other steps)
87    pub limit: Option<Expression>,
88}
89
90/// Types of execution steps
91#[derive(Debug, Clone, PartialEq)]
92pub enum StepKind {
93    /// Scan a table
94    Scan,
95    /// Join multiple inputs
96    Join(JoinType),
97    /// Aggregate rows
98    Aggregate,
99    /// Sort rows
100    Sort,
101    /// Set operation (UNION, INTERSECT, EXCEPT)
102    SetOperation(SetOperationType),
103}
104
105/// Types of joins in execution plans
106#[derive(Debug, Clone, PartialEq)]
107pub enum JoinType {
108    Inner,
109    Left,
110    Right,
111    Full,
112    Cross,
113}
114
115/// Types of set operations
116#[derive(Debug, Clone, PartialEq)]
117pub enum SetOperationType {
118    Union,
119    UnionAll,
120    Intersect,
121    Except,
122}
123
124impl Step {
125    /// Create a new step
126    pub fn new(name: impl Into<String>, kind: StepKind) -> Self {
127        Self {
128            name: name.into(),
129            kind,
130            projections: Vec::new(),
131            dependencies: Vec::new(),
132            aggregations: Vec::new(),
133            group_by: Vec::new(),
134            condition: None,
135            order_by: Vec::new(),
136            limit: None,
137        }
138    }
139
140    /// Build a step from an expression
141    pub fn from_expression(
142        expression: &Expression,
143        ctes: &HashMap<String, Step>,
144    ) -> Option<Self> {
145        match expression {
146            Expression::Select(select) => {
147                let mut step = Self::from_select(select, ctes)?;
148
149                // Handle ORDER BY
150                if let Some(ref order_by) = select.order_by {
151                    let sort_step = Step {
152                        name: step.name.clone(),
153                        kind: StepKind::Sort,
154                        projections: Vec::new(),
155                        dependencies: vec![step],
156                        aggregations: Vec::new(),
157                        group_by: Vec::new(),
158                        condition: None,
159                        order_by: order_by.expressions.iter().map(|o| o.this.clone()).collect(),
160                        limit: None,
161                    };
162                    step = sort_step;
163                }
164
165                // Handle LIMIT
166                if let Some(ref limit) = select.limit {
167                    step.limit = Some(limit.this.clone());
168                }
169
170                Some(step)
171            }
172            Expression::Union(union) => {
173                let left = Self::from_expression(&union.left, ctes)?;
174                let right = Self::from_expression(&union.right, ctes)?;
175
176                let op_type = if union.all {
177                    SetOperationType::UnionAll
178                } else {
179                    SetOperationType::Union
180                };
181
182                Some(Step {
183                    name: "UNION".to_string(),
184                    kind: StepKind::SetOperation(op_type),
185                    projections: Vec::new(),
186                    dependencies: vec![left, right],
187                    aggregations: Vec::new(),
188                    group_by: Vec::new(),
189                    condition: None,
190                    order_by: Vec::new(),
191                    limit: None,
192                })
193            }
194            Expression::Intersect(intersect) => {
195                let left = Self::from_expression(&intersect.left, ctes)?;
196                let right = Self::from_expression(&intersect.right, ctes)?;
197
198                Some(Step {
199                    name: "INTERSECT".to_string(),
200                    kind: StepKind::SetOperation(SetOperationType::Intersect),
201                    projections: Vec::new(),
202                    dependencies: vec![left, right],
203                    aggregations: Vec::new(),
204                    group_by: Vec::new(),
205                    condition: None,
206                    order_by: Vec::new(),
207                    limit: None,
208                })
209            }
210            Expression::Except(except) => {
211                let left = Self::from_expression(&except.left, ctes)?;
212                let right = Self::from_expression(&except.right, ctes)?;
213
214                Some(Step {
215                    name: "EXCEPT".to_string(),
216                    kind: StepKind::SetOperation(SetOperationType::Except),
217                    projections: Vec::new(),
218                    dependencies: vec![left, right],
219                    aggregations: Vec::new(),
220                    group_by: Vec::new(),
221                    condition: None,
222                    order_by: Vec::new(),
223                    limit: None,
224                })
225            }
226            _ => None,
227        }
228    }
229
230    fn from_select(
231        select: &crate::expressions::Select,
232        ctes: &HashMap<String, Step>,
233    ) -> Option<Self> {
234        // Process CTEs first
235        let mut ctes = ctes.clone();
236        if let Some(ref with) = select.with {
237            for cte in &with.ctes {
238                if let Some(step) = Self::from_expression(&cte.this, &ctes) {
239                    ctes.insert(cte.alias.name.clone(), step);
240                }
241            }
242        }
243
244        // Start with the FROM clause
245        let mut step = if let Some(ref from) = select.from {
246            if let Some(table_expr) = from.expressions.first() {
247                Self::from_table_expression(table_expr, &ctes)?
248            } else {
249                return None;
250            }
251        } else {
252            // SELECT without FROM (e.g., SELECT 1)
253            Step::new("", StepKind::Scan)
254        };
255
256        // Process JOINs
257        for join in &select.joins {
258            let right = Self::from_table_expression(&join.this, &ctes)?;
259
260            let join_type = match join.kind {
261                JoinKind::Inner => JoinType::Inner,
262                JoinKind::Left | JoinKind::NaturalLeft => JoinType::Left,
263                JoinKind::Right | JoinKind::NaturalRight => JoinType::Right,
264                JoinKind::Full | JoinKind::NaturalFull => JoinType::Full,
265                JoinKind::Cross | JoinKind::Natural => JoinType::Cross,
266                _ => JoinType::Inner,
267            };
268
269            let join_step = Step {
270                name: step.name.clone(),
271                kind: StepKind::Join(join_type),
272                projections: Vec::new(),
273                dependencies: vec![step, right],
274                aggregations: Vec::new(),
275                group_by: Vec::new(),
276                condition: join.on.clone(),
277                order_by: Vec::new(),
278                limit: None,
279            };
280            step = join_step;
281        }
282
283        // Check for aggregations
284        let has_aggregations = select.expressions.iter().any(|e| contains_aggregate(e));
285        let has_group_by = select.group_by.is_some();
286
287        if has_aggregations || has_group_by {
288            // Create aggregate step
289            let agg_step = Step {
290                name: step.name.clone(),
291                kind: StepKind::Aggregate,
292                projections: select.expressions.clone(),
293                dependencies: vec![step],
294                aggregations: extract_aggregations(&select.expressions),
295                group_by: select.group_by.as_ref()
296                    .map(|g| g.expressions.clone())
297                    .unwrap_or_default(),
298                condition: None,
299                order_by: Vec::new(),
300                limit: None,
301            };
302            step = agg_step;
303        } else {
304            step.projections = select.expressions.clone();
305        }
306
307        Some(step)
308    }
309
310    fn from_table_expression(
311        expr: &Expression,
312        ctes: &HashMap<String, Step>,
313    ) -> Option<Self> {
314        match expr {
315            Expression::Table(table) => {
316                // Check if this references a CTE
317                if let Some(cte_step) = ctes.get(&table.name.name) {
318                    return Some(cte_step.clone());
319                }
320
321                // Regular table scan
322                Some(Step::new(&table.name.name, StepKind::Scan))
323            }
324            Expression::Alias(alias) => {
325                let mut step = Self::from_table_expression(&alias.this, ctes)?;
326                step.name = alias.alias.name.clone();
327                Some(step)
328            }
329            Expression::Subquery(sq) => {
330                let step = Self::from_expression(&sq.this, ctes)?;
331                Some(step)
332            }
333            _ => None,
334        }
335    }
336
337    /// Add a dependency to this step
338    pub fn add_dependency(&mut self, dep: Step) {
339        self.dependencies.push(dep);
340    }
341}
342
343/// Check if an expression contains an aggregate function
344fn contains_aggregate(expr: &Expression) -> bool {
345    match expr {
346        // Specific aggregate function variants
347        Expression::Sum(_) | Expression::Count(_) | Expression::Avg(_) |
348        Expression::Min(_) | Expression::Max(_) | Expression::ArrayAgg(_) |
349        Expression::StringAgg(_) | Expression::ListAgg(_) |
350        Expression::Stddev(_) | Expression::StddevPop(_) | Expression::StddevSamp(_) |
351        Expression::Variance(_) | Expression::VarPop(_) | Expression::VarSamp(_) |
352        Expression::Median(_) | Expression::Mode(_) | Expression::First(_) | Expression::Last(_) |
353        Expression::AnyValue(_) | Expression::ApproxDistinct(_) | Expression::ApproxCountDistinct(_) |
354        Expression::LogicalAnd(_) | Expression::LogicalOr(_) |
355        Expression::AggregateFunction(_) => true,
356
357        Expression::Alias(alias) => contains_aggregate(&alias.this),
358        Expression::Add(op) | Expression::Sub(op) |
359        Expression::Mul(op) | Expression::Div(op) => {
360            contains_aggregate(&op.left) || contains_aggregate(&op.right)
361        }
362        Expression::Function(func) => {
363            // Check for aggregate function names (fallback)
364            let name = func.name.to_uppercase();
365            matches!(name.as_str(), "SUM" | "COUNT" | "AVG" | "MIN" | "MAX" |
366                     "ARRAY_AGG" | "STRING_AGG" | "GROUP_CONCAT")
367        }
368        _ => false,
369    }
370}
371
372/// Extract aggregate expressions from a list
373fn extract_aggregations(expressions: &[Expression]) -> Vec<Expression> {
374    let mut aggs = Vec::new();
375    for expr in expressions {
376        collect_aggregations(expr, &mut aggs);
377    }
378    aggs
379}
380
381fn collect_aggregations(expr: &Expression, aggs: &mut Vec<Expression>) {
382    match expr {
383        // Specific aggregate function variants
384        Expression::Sum(_) | Expression::Count(_) | Expression::Avg(_) |
385        Expression::Min(_) | Expression::Max(_) | Expression::ArrayAgg(_) |
386        Expression::StringAgg(_) | Expression::ListAgg(_) |
387        Expression::Stddev(_) | Expression::StddevPop(_) | Expression::StddevSamp(_) |
388        Expression::Variance(_) | Expression::VarPop(_) | Expression::VarSamp(_) |
389        Expression::Median(_) | Expression::Mode(_) | Expression::First(_) | Expression::Last(_) |
390        Expression::AnyValue(_) | Expression::ApproxDistinct(_) | Expression::ApproxCountDistinct(_) |
391        Expression::LogicalAnd(_) | Expression::LogicalOr(_) |
392        Expression::AggregateFunction(_) => {
393            aggs.push(expr.clone());
394        }
395        Expression::Alias(alias) => {
396            collect_aggregations(&alias.this, aggs);
397        }
398        Expression::Add(op) | Expression::Sub(op) |
399        Expression::Mul(op) | Expression::Div(op) => {
400            collect_aggregations(&op.left, aggs);
401            collect_aggregations(&op.right, aggs);
402        }
403        Expression::Function(func) => {
404            let name = func.name.to_uppercase();
405            if matches!(name.as_str(), "SUM" | "COUNT" | "AVG" | "MIN" | "MAX" |
406                        "ARRAY_AGG" | "STRING_AGG" | "GROUP_CONCAT") {
407                aggs.push(expr.clone());
408            } else {
409                for arg in &func.args {
410                    collect_aggregations(arg, aggs);
411                }
412            }
413        }
414        _ => {}
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use crate::dialects::{Dialect, DialectType};
422
423    fn parse(sql: &str) -> Expression {
424        let dialect = Dialect::get(DialectType::Generic);
425        let ast = dialect.parse(sql).unwrap();
426        ast.into_iter().next().unwrap()
427    }
428
429    #[test]
430    fn test_simple_scan() {
431        let sql = "SELECT a, b FROM t";
432        let expr = parse(sql);
433        let plan = Plan::from_expression(&expr);
434
435        assert!(plan.is_some());
436        let plan = plan.unwrap();
437        assert_eq!(plan.root.kind, StepKind::Scan);
438        assert_eq!(plan.root.name, "t");
439    }
440
441    #[test]
442    fn test_join() {
443        let sql = "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.id";
444        let expr = parse(sql);
445        let plan = Plan::from_expression(&expr);
446
447        assert!(plan.is_some());
448        let plan = plan.unwrap();
449        assert!(matches!(plan.root.kind, StepKind::Join(_)));
450        assert_eq!(plan.root.dependencies.len(), 2);
451    }
452
453    #[test]
454    fn test_aggregate() {
455        let sql = "SELECT x, SUM(y) FROM t GROUP BY x";
456        let expr = parse(sql);
457        let plan = Plan::from_expression(&expr);
458
459        assert!(plan.is_some());
460        let plan = plan.unwrap();
461        assert_eq!(plan.root.kind, StepKind::Aggregate);
462    }
463
464    #[test]
465    fn test_union() {
466        let sql = "SELECT a FROM t1 UNION SELECT b FROM t2";
467        let expr = parse(sql);
468        let plan = Plan::from_expression(&expr);
469
470        assert!(plan.is_some());
471        let plan = plan.unwrap();
472        assert!(matches!(plan.root.kind, StepKind::SetOperation(SetOperationType::Union)));
473    }
474
475    #[test]
476    fn test_contains_aggregate() {
477        // Parse a SELECT with an aggregate function and check the expression
478        let select_with_agg = parse("SELECT SUM(x) FROM t");
479        if let Expression::Select(ref sel) = select_with_agg {
480            assert!(!sel.expressions.is_empty());
481            assert!(contains_aggregate(&sel.expressions[0]),
482                "Expected SUM to be detected as aggregate function");
483        } else {
484            panic!("Expected SELECT expression");
485        }
486
487        // Parse a SELECT with a non-aggregate expression
488        let select_without_agg = parse("SELECT x + 1 FROM t");
489        if let Expression::Select(ref sel) = select_without_agg {
490            assert!(!sel.expressions.is_empty());
491            assert!(!contains_aggregate(&sel.expressions[0]),
492                "Expected x + 1 to not be an aggregate function");
493        } else {
494            panic!("Expected SELECT expression");
495        }
496    }
497}