Skip to main content

polyglot_sql/
planner.rs

1//! Query Execution Planner
2//!
3//! This module provides functionality to convert SQL AST into an execution plan
4//! represented as a DAG (Directed Acyclic Graph) of steps.
5//!
6
7use crate::expressions::{Expression, JoinKind};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// A query execution plan
12#[derive(Debug)]
13pub struct Plan {
14    /// The root step of the plan DAG
15    pub root: Step,
16    /// Cached DAG representation
17    dag: Option<HashMap<usize, HashSet<usize>>>,
18}
19
20impl Plan {
21    /// Create a new plan from an expression
22    pub fn from_expression(expression: &Expression) -> Option<Self> {
23        let root = Step::from_expression(expression, &HashMap::new())?;
24        Some(Self { root, dag: None })
25    }
26
27    /// Get the DAG representation of the plan
28    pub fn dag(&mut self) -> &HashMap<usize, HashSet<usize>> {
29        if self.dag.is_none() {
30            let mut dag = HashMap::new();
31            self.build_dag(&self.root, &mut dag, 0);
32            self.dag = Some(dag);
33        }
34        self.dag.as_ref().unwrap()
35    }
36
37    fn build_dag(&self, step: &Step, dag: &mut HashMap<usize, HashSet<usize>>, id: usize) {
38        let deps: HashSet<usize> = step.dependencies
39            .iter()
40            .enumerate()
41            .map(|(i, _)| id + i + 1)
42            .collect();
43        dag.insert(id, deps);
44
45        for (i, dep) in step.dependencies.iter().enumerate() {
46            self.build_dag(dep, dag, id + i + 1);
47        }
48    }
49
50    /// Get all leaf steps (steps with no dependencies)
51    pub fn leaves(&self) -> Vec<&Step> {
52        let mut leaves = Vec::new();
53        self.collect_leaves(&self.root, &mut leaves);
54        leaves
55    }
56
57    fn collect_leaves<'a>(&'a self, step: &'a Step, leaves: &mut Vec<&'a Step>) {
58        if step.dependencies.is_empty() {
59            leaves.push(step);
60        } else {
61            for dep in &step.dependencies {
62                self.collect_leaves(dep, leaves);
63            }
64        }
65    }
66}
67
68/// A step in the execution plan
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Step {
71    /// Name of this step
72    pub name: String,
73    /// Type of step
74    pub kind: StepKind,
75    /// Projections to output
76    pub projections: Vec<Expression>,
77    /// Dependencies (other steps that must complete first)
78    pub dependencies: Vec<Step>,
79    /// Aggregation expressions (for Aggregate steps)
80    pub aggregations: Vec<Expression>,
81    /// Group by expressions (for Aggregate steps)
82    pub group_by: Vec<Expression>,
83    /// Join condition (for Join steps)
84    pub condition: Option<Expression>,
85    /// Sort expressions (for Sort steps)
86    pub order_by: Vec<Expression>,
87    /// Limit value (for Scan/other steps)
88    pub limit: Option<Expression>,
89}
90
91/// Types of execution steps
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum StepKind {
95    /// Scan a table
96    Scan,
97    /// Join multiple inputs
98    Join(JoinType),
99    /// Aggregate rows
100    Aggregate,
101    /// Sort rows
102    Sort,
103    /// Set operation (UNION, INTERSECT, EXCEPT)
104    SetOperation(SetOperationType),
105}
106
107/// Types of joins in execution plans
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109#[serde(rename_all = "snake_case")]
110pub enum JoinType {
111    Inner,
112    Left,
113    Right,
114    Full,
115    Cross,
116}
117
118/// Types of set operations
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum SetOperationType {
122    Union,
123    UnionAll,
124    Intersect,
125    Except,
126}
127
128impl Step {
129    /// Create a new step
130    pub fn new(name: impl Into<String>, kind: StepKind) -> Self {
131        Self {
132            name: name.into(),
133            kind,
134            projections: Vec::new(),
135            dependencies: Vec::new(),
136            aggregations: Vec::new(),
137            group_by: Vec::new(),
138            condition: None,
139            order_by: Vec::new(),
140            limit: None,
141        }
142    }
143
144    /// Build a step from an expression
145    pub fn from_expression(
146        expression: &Expression,
147        ctes: &HashMap<String, Step>,
148    ) -> Option<Self> {
149        match expression {
150            Expression::Select(select) => {
151                let mut step = Self::from_select(select, ctes)?;
152
153                // Handle ORDER BY
154                if let Some(ref order_by) = select.order_by {
155                    let sort_step = Step {
156                        name: step.name.clone(),
157                        kind: StepKind::Sort,
158                        projections: Vec::new(),
159                        dependencies: vec![step],
160                        aggregations: Vec::new(),
161                        group_by: Vec::new(),
162                        condition: None,
163                        order_by: order_by.expressions.iter().map(|o| o.this.clone()).collect(),
164                        limit: None,
165                    };
166                    step = sort_step;
167                }
168
169                // Handle LIMIT
170                if let Some(ref limit) = select.limit {
171                    step.limit = Some(limit.this.clone());
172                }
173
174                Some(step)
175            }
176            Expression::Union(union) => {
177                let left = Self::from_expression(&union.left, ctes)?;
178                let right = Self::from_expression(&union.right, ctes)?;
179
180                let op_type = if union.all {
181                    SetOperationType::UnionAll
182                } else {
183                    SetOperationType::Union
184                };
185
186                Some(Step {
187                    name: "UNION".to_string(),
188                    kind: StepKind::SetOperation(op_type),
189                    projections: Vec::new(),
190                    dependencies: vec![left, right],
191                    aggregations: Vec::new(),
192                    group_by: Vec::new(),
193                    condition: None,
194                    order_by: Vec::new(),
195                    limit: None,
196                })
197            }
198            Expression::Intersect(intersect) => {
199                let left = Self::from_expression(&intersect.left, ctes)?;
200                let right = Self::from_expression(&intersect.right, ctes)?;
201
202                Some(Step {
203                    name: "INTERSECT".to_string(),
204                    kind: StepKind::SetOperation(SetOperationType::Intersect),
205                    projections: Vec::new(),
206                    dependencies: vec![left, right],
207                    aggregations: Vec::new(),
208                    group_by: Vec::new(),
209                    condition: None,
210                    order_by: Vec::new(),
211                    limit: None,
212                })
213            }
214            Expression::Except(except) => {
215                let left = Self::from_expression(&except.left, ctes)?;
216                let right = Self::from_expression(&except.right, ctes)?;
217
218                Some(Step {
219                    name: "EXCEPT".to_string(),
220                    kind: StepKind::SetOperation(SetOperationType::Except),
221                    projections: Vec::new(),
222                    dependencies: vec![left, right],
223                    aggregations: Vec::new(),
224                    group_by: Vec::new(),
225                    condition: None,
226                    order_by: Vec::new(),
227                    limit: None,
228                })
229            }
230            _ => None,
231        }
232    }
233
234    fn from_select(
235        select: &crate::expressions::Select,
236        ctes: &HashMap<String, Step>,
237    ) -> Option<Self> {
238        // Process CTEs first
239        let mut ctes = ctes.clone();
240        if let Some(ref with) = select.with {
241            for cte in &with.ctes {
242                if let Some(step) = Self::from_expression(&cte.this, &ctes) {
243                    ctes.insert(cte.alias.name.clone(), step);
244                }
245            }
246        }
247
248        // Start with the FROM clause
249        let mut step = if let Some(ref from) = select.from {
250            if let Some(table_expr) = from.expressions.first() {
251                Self::from_table_expression(table_expr, &ctes)?
252            } else {
253                return None;
254            }
255        } else {
256            // SELECT without FROM (e.g., SELECT 1)
257            Step::new("", StepKind::Scan)
258        };
259
260        // Process JOINs
261        for join in &select.joins {
262            let right = Self::from_table_expression(&join.this, &ctes)?;
263
264            let join_type = match join.kind {
265                JoinKind::Inner => JoinType::Inner,
266                JoinKind::Left | JoinKind::NaturalLeft => JoinType::Left,
267                JoinKind::Right | JoinKind::NaturalRight => JoinType::Right,
268                JoinKind::Full | JoinKind::NaturalFull => JoinType::Full,
269                JoinKind::Cross | JoinKind::Natural => JoinType::Cross,
270                _ => JoinType::Inner,
271            };
272
273            let join_step = Step {
274                name: step.name.clone(),
275                kind: StepKind::Join(join_type),
276                projections: Vec::new(),
277                dependencies: vec![step, right],
278                aggregations: Vec::new(),
279                group_by: Vec::new(),
280                condition: join.on.clone(),
281                order_by: Vec::new(),
282                limit: None,
283            };
284            step = join_step;
285        }
286
287        // Check for aggregations
288        let has_aggregations = select.expressions.iter().any(|e| contains_aggregate(e));
289        let has_group_by = select.group_by.is_some();
290
291        if has_aggregations || has_group_by {
292            // Create aggregate step
293            let agg_step = Step {
294                name: step.name.clone(),
295                kind: StepKind::Aggregate,
296                projections: select.expressions.clone(),
297                dependencies: vec![step],
298                aggregations: extract_aggregations(&select.expressions),
299                group_by: select.group_by.as_ref()
300                    .map(|g| g.expressions.clone())
301                    .unwrap_or_default(),
302                condition: None,
303                order_by: Vec::new(),
304                limit: None,
305            };
306            step = agg_step;
307        } else {
308            step.projections = select.expressions.clone();
309        }
310
311        Some(step)
312    }
313
314    fn from_table_expression(
315        expr: &Expression,
316        ctes: &HashMap<String, Step>,
317    ) -> Option<Self> {
318        match expr {
319            Expression::Table(table) => {
320                // Check if this references a CTE
321                if let Some(cte_step) = ctes.get(&table.name.name) {
322                    return Some(cte_step.clone());
323                }
324
325                // Regular table scan
326                Some(Step::new(&table.name.name, StepKind::Scan))
327            }
328            Expression::Alias(alias) => {
329                let mut step = Self::from_table_expression(&alias.this, ctes)?;
330                step.name = alias.alias.name.clone();
331                Some(step)
332            }
333            Expression::Subquery(sq) => {
334                let step = Self::from_expression(&sq.this, ctes)?;
335                Some(step)
336            }
337            _ => None,
338        }
339    }
340
341    /// Add a dependency to this step
342    pub fn add_dependency(&mut self, dep: Step) {
343        self.dependencies.push(dep);
344    }
345}
346
347/// Check if an expression contains an aggregate function
348fn contains_aggregate(expr: &Expression) -> bool {
349    match expr {
350        // Specific aggregate function variants
351        Expression::Sum(_) | Expression::Count(_) | Expression::Avg(_) |
352        Expression::Min(_) | Expression::Max(_) | Expression::ArrayAgg(_) |
353        Expression::StringAgg(_) | Expression::ListAgg(_) |
354        Expression::Stddev(_) | Expression::StddevPop(_) | Expression::StddevSamp(_) |
355        Expression::Variance(_) | Expression::VarPop(_) | Expression::VarSamp(_) |
356        Expression::Median(_) | Expression::Mode(_) | Expression::First(_) | Expression::Last(_) |
357        Expression::AnyValue(_) | Expression::ApproxDistinct(_) | Expression::ApproxCountDistinct(_) |
358        Expression::LogicalAnd(_) | Expression::LogicalOr(_) |
359        Expression::AggregateFunction(_) => true,
360
361        Expression::Alias(alias) => contains_aggregate(&alias.this),
362        Expression::Add(op) | Expression::Sub(op) |
363        Expression::Mul(op) | Expression::Div(op) => {
364            contains_aggregate(&op.left) || contains_aggregate(&op.right)
365        }
366        Expression::Function(func) => {
367            // Check for aggregate function names (fallback)
368            let name = func.name.to_uppercase();
369            matches!(name.as_str(), "SUM" | "COUNT" | "AVG" | "MIN" | "MAX" |
370                     "ARRAY_AGG" | "STRING_AGG" | "GROUP_CONCAT")
371        }
372        _ => false,
373    }
374}
375
376/// Extract aggregate expressions from a list
377fn extract_aggregations(expressions: &[Expression]) -> Vec<Expression> {
378    let mut aggs = Vec::new();
379    for expr in expressions {
380        collect_aggregations(expr, &mut aggs);
381    }
382    aggs
383}
384
385fn collect_aggregations(expr: &Expression, aggs: &mut Vec<Expression>) {
386    match expr {
387        // Specific aggregate function variants
388        Expression::Sum(_) | Expression::Count(_) | Expression::Avg(_) |
389        Expression::Min(_) | Expression::Max(_) | Expression::ArrayAgg(_) |
390        Expression::StringAgg(_) | Expression::ListAgg(_) |
391        Expression::Stddev(_) | Expression::StddevPop(_) | Expression::StddevSamp(_) |
392        Expression::Variance(_) | Expression::VarPop(_) | Expression::VarSamp(_) |
393        Expression::Median(_) | Expression::Mode(_) | Expression::First(_) | Expression::Last(_) |
394        Expression::AnyValue(_) | Expression::ApproxDistinct(_) | Expression::ApproxCountDistinct(_) |
395        Expression::LogicalAnd(_) | Expression::LogicalOr(_) |
396        Expression::AggregateFunction(_) => {
397            aggs.push(expr.clone());
398        }
399        Expression::Alias(alias) => {
400            collect_aggregations(&alias.this, aggs);
401        }
402        Expression::Add(op) | Expression::Sub(op) |
403        Expression::Mul(op) | Expression::Div(op) => {
404            collect_aggregations(&op.left, aggs);
405            collect_aggregations(&op.right, aggs);
406        }
407        Expression::Function(func) => {
408            let name = func.name.to_uppercase();
409            if matches!(name.as_str(), "SUM" | "COUNT" | "AVG" | "MIN" | "MAX" |
410                        "ARRAY_AGG" | "STRING_AGG" | "GROUP_CONCAT") {
411                aggs.push(expr.clone());
412            } else {
413                for arg in &func.args {
414                    collect_aggregations(arg, aggs);
415                }
416            }
417        }
418        _ => {}
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::dialects::{Dialect, DialectType};
426
427    fn parse(sql: &str) -> Expression {
428        let dialect = Dialect::get(DialectType::Generic);
429        let ast = dialect.parse(sql).unwrap();
430        ast.into_iter().next().unwrap()
431    }
432
433    #[test]
434    fn test_simple_scan() {
435        let sql = "SELECT a, b FROM t";
436        let expr = parse(sql);
437        let plan = Plan::from_expression(&expr);
438
439        assert!(plan.is_some());
440        let plan = plan.unwrap();
441        assert_eq!(plan.root.kind, StepKind::Scan);
442        assert_eq!(plan.root.name, "t");
443    }
444
445    #[test]
446    fn test_join() {
447        let sql = "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.id";
448        let expr = parse(sql);
449        let plan = Plan::from_expression(&expr);
450
451        assert!(plan.is_some());
452        let plan = plan.unwrap();
453        assert!(matches!(plan.root.kind, StepKind::Join(_)));
454        assert_eq!(plan.root.dependencies.len(), 2);
455    }
456
457    #[test]
458    fn test_aggregate() {
459        let sql = "SELECT x, SUM(y) FROM t GROUP BY x";
460        let expr = parse(sql);
461        let plan = Plan::from_expression(&expr);
462
463        assert!(plan.is_some());
464        let plan = plan.unwrap();
465        assert_eq!(plan.root.kind, StepKind::Aggregate);
466    }
467
468    #[test]
469    fn test_union() {
470        let sql = "SELECT a FROM t1 UNION SELECT b FROM t2";
471        let expr = parse(sql);
472        let plan = Plan::from_expression(&expr);
473
474        assert!(plan.is_some());
475        let plan = plan.unwrap();
476        assert!(matches!(plan.root.kind, StepKind::SetOperation(SetOperationType::Union)));
477    }
478
479    #[test]
480    fn test_contains_aggregate() {
481        // Parse a SELECT with an aggregate function and check the expression
482        let select_with_agg = parse("SELECT SUM(x) FROM t");
483        if let Expression::Select(ref sel) = select_with_agg {
484            assert!(!sel.expressions.is_empty());
485            assert!(contains_aggregate(&sel.expressions[0]),
486                "Expected SUM to be detected as aggregate function");
487        } else {
488            panic!("Expected SELECT expression");
489        }
490
491        // Parse a SELECT with a non-aggregate expression
492        let select_without_agg = parse("SELECT x + 1 FROM t");
493        if let Expression::Select(ref sel) = select_without_agg {
494            assert!(!sel.expressions.is_empty());
495            assert!(!contains_aggregate(&sel.expressions[0]),
496                "Expected x + 1 to not be an aggregate function");
497        } else {
498            panic!("Expected SELECT expression");
499        }
500    }
501}