1use crate::expressions::{Expression, JoinKind};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug)]
13pub struct Plan {
14 pub root: Step,
16 dag: Option<HashMap<usize, HashSet<usize>>>,
18}
19
20impl Plan {
21 pub fn from_expression(expression: &Expression) -> Option<Self> {
23 let root = Step::from_expression(expression, &HashMap::new())?;
24 Some(Self { root, dag: None })
25 }
26
27 pub fn dag(&mut self) -> &HashMap<usize, HashSet<usize>> {
29 if self.dag.is_none() {
30 let mut dag = HashMap::new();
31 self.build_dag(&self.root, &mut dag, 0);
32 self.dag = Some(dag);
33 }
34 self.dag.as_ref().unwrap()
35 }
36
37 fn build_dag(&self, step: &Step, dag: &mut HashMap<usize, HashSet<usize>>, id: usize) {
38 let deps: HashSet<usize> = step
39 .dependencies
40 .iter()
41 .enumerate()
42 .map(|(i, _)| id + i + 1)
43 .collect();
44 dag.insert(id, deps);
45
46 for (i, dep) in step.dependencies.iter().enumerate() {
47 self.build_dag(dep, dag, id + i + 1);
48 }
49 }
50
51 pub fn leaves(&self) -> Vec<&Step> {
53 let mut leaves = Vec::new();
54 self.collect_leaves(&self.root, &mut leaves);
55 leaves
56 }
57
58 fn collect_leaves<'a>(&'a self, step: &'a Step, leaves: &mut Vec<&'a Step>) {
59 if step.dependencies.is_empty() {
60 leaves.push(step);
61 } else {
62 for dep in &step.dependencies {
63 self.collect_leaves(dep, leaves);
64 }
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct Step {
72 pub name: String,
74 pub kind: StepKind,
76 pub projections: Vec<Expression>,
78 pub dependencies: Vec<Step>,
80 pub aggregations: Vec<Expression>,
82 pub group_by: Vec<Expression>,
84 pub condition: Option<Expression>,
86 pub order_by: Vec<Expression>,
88 pub limit: Option<Expression>,
90}
91
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum StepKind {
96 Scan,
98 Join(JoinType),
100 Aggregate,
102 Sort,
104 SetOperation(SetOperationType),
106}
107
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
110#[serde(rename_all = "snake_case")]
111pub enum JoinType {
112 Inner,
113 Left,
114 Right,
115 Full,
116 Cross,
117}
118
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
121#[serde(rename_all = "snake_case")]
122pub enum SetOperationType {
123 Union,
124 UnionAll,
125 Intersect,
126 Except,
127}
128
129impl Step {
130 pub fn new(name: impl Into<String>, kind: StepKind) -> Self {
132 Self {
133 name: name.into(),
134 kind,
135 projections: Vec::new(),
136 dependencies: Vec::new(),
137 aggregations: Vec::new(),
138 group_by: Vec::new(),
139 condition: None,
140 order_by: Vec::new(),
141 limit: None,
142 }
143 }
144
145 pub fn from_expression(expression: &Expression, ctes: &HashMap<String, Step>) -> Option<Self> {
147 match expression {
148 Expression::Select(select) => {
149 let mut step = Self::from_select(select, ctes)?;
150
151 if let Some(ref order_by) = select.order_by {
153 let sort_step = Step {
154 name: step.name.clone(),
155 kind: StepKind::Sort,
156 projections: Vec::new(),
157 dependencies: vec![step],
158 aggregations: Vec::new(),
159 group_by: Vec::new(),
160 condition: None,
161 order_by: order_by
162 .expressions
163 .iter()
164 .map(|o| o.this.clone())
165 .collect(),
166 limit: None,
167 };
168 step = sort_step;
169 }
170
171 if let Some(ref limit) = select.limit {
173 step.limit = Some(limit.this.clone());
174 }
175
176 Some(step)
177 }
178 Expression::Union(union) => {
179 let left = Self::from_expression(&union.left, ctes)?;
180 let right = Self::from_expression(&union.right, ctes)?;
181
182 let op_type = if union.all {
183 SetOperationType::UnionAll
184 } else {
185 SetOperationType::Union
186 };
187
188 Some(Step {
189 name: "UNION".to_string(),
190 kind: StepKind::SetOperation(op_type),
191 projections: Vec::new(),
192 dependencies: vec![left, right],
193 aggregations: Vec::new(),
194 group_by: Vec::new(),
195 condition: None,
196 order_by: Vec::new(),
197 limit: None,
198 })
199 }
200 Expression::Intersect(intersect) => {
201 let left = Self::from_expression(&intersect.left, ctes)?;
202 let right = Self::from_expression(&intersect.right, ctes)?;
203
204 Some(Step {
205 name: "INTERSECT".to_string(),
206 kind: StepKind::SetOperation(SetOperationType::Intersect),
207 projections: Vec::new(),
208 dependencies: vec![left, right],
209 aggregations: Vec::new(),
210 group_by: Vec::new(),
211 condition: None,
212 order_by: Vec::new(),
213 limit: None,
214 })
215 }
216 Expression::Except(except) => {
217 let left = Self::from_expression(&except.left, ctes)?;
218 let right = Self::from_expression(&except.right, ctes)?;
219
220 Some(Step {
221 name: "EXCEPT".to_string(),
222 kind: StepKind::SetOperation(SetOperationType::Except),
223 projections: Vec::new(),
224 dependencies: vec![left, right],
225 aggregations: Vec::new(),
226 group_by: Vec::new(),
227 condition: None,
228 order_by: Vec::new(),
229 limit: None,
230 })
231 }
232 _ => None,
233 }
234 }
235
236 fn from_select(
237 select: &crate::expressions::Select,
238 ctes: &HashMap<String, Step>,
239 ) -> Option<Self> {
240 let mut ctes = ctes.clone();
242 if let Some(ref with) = select.with {
243 for cte in &with.ctes {
244 if let Some(step) = Self::from_expression(&cte.this, &ctes) {
245 ctes.insert(cte.alias.name.clone(), step);
246 }
247 }
248 }
249
250 let mut step = if let Some(ref from) = select.from {
252 if let Some(table_expr) = from.expressions.first() {
253 Self::from_table_expression(table_expr, &ctes)?
254 } else {
255 return None;
256 }
257 } else {
258 Step::new("", StepKind::Scan)
260 };
261
262 for join in &select.joins {
264 let right = Self::from_table_expression(&join.this, &ctes)?;
265
266 let join_type = match join.kind {
267 JoinKind::Inner => JoinType::Inner,
268 JoinKind::Left | JoinKind::NaturalLeft => JoinType::Left,
269 JoinKind::Right | JoinKind::NaturalRight => JoinType::Right,
270 JoinKind::Full | JoinKind::NaturalFull => JoinType::Full,
271 JoinKind::Cross | JoinKind::Natural => JoinType::Cross,
272 _ => JoinType::Inner,
273 };
274
275 let join_step = Step {
276 name: step.name.clone(),
277 kind: StepKind::Join(join_type),
278 projections: Vec::new(),
279 dependencies: vec![step, right],
280 aggregations: Vec::new(),
281 group_by: Vec::new(),
282 condition: join.on.clone(),
283 order_by: Vec::new(),
284 limit: None,
285 };
286 step = join_step;
287 }
288
289 let has_aggregations = select.expressions.iter().any(|e| contains_aggregate(e));
291 let has_group_by = select.group_by.is_some();
292
293 if has_aggregations || has_group_by {
294 let agg_step = Step {
296 name: step.name.clone(),
297 kind: StepKind::Aggregate,
298 projections: select.expressions.clone(),
299 dependencies: vec![step],
300 aggregations: extract_aggregations(&select.expressions),
301 group_by: select
302 .group_by
303 .as_ref()
304 .map(|g| g.expressions.clone())
305 .unwrap_or_default(),
306 condition: None,
307 order_by: Vec::new(),
308 limit: None,
309 };
310 step = agg_step;
311 } else {
312 step.projections = select.expressions.clone();
313 }
314
315 Some(step)
316 }
317
318 fn from_table_expression(expr: &Expression, ctes: &HashMap<String, Step>) -> Option<Self> {
319 match expr {
320 Expression::Table(table) => {
321 if let Some(cte_step) = ctes.get(&table.name.name) {
323 return Some(cte_step.clone());
324 }
325
326 Some(Step::new(&table.name.name, StepKind::Scan))
328 }
329 Expression::Alias(alias) => {
330 let mut step = Self::from_table_expression(&alias.this, ctes)?;
331 step.name = alias.alias.name.clone();
332 Some(step)
333 }
334 Expression::Subquery(sq) => {
335 let step = Self::from_expression(&sq.this, ctes)?;
336 Some(step)
337 }
338 _ => None,
339 }
340 }
341
342 pub fn add_dependency(&mut self, dep: Step) {
344 self.dependencies.push(dep);
345 }
346}
347
348fn contains_aggregate(expr: &Expression) -> bool {
350 match expr {
351 Expression::Sum(_)
353 | Expression::Count(_)
354 | Expression::Avg(_)
355 | Expression::Min(_)
356 | Expression::Max(_)
357 | Expression::ArrayAgg(_)
358 | Expression::StringAgg(_)
359 | Expression::ListAgg(_)
360 | Expression::Stddev(_)
361 | Expression::StddevPop(_)
362 | Expression::StddevSamp(_)
363 | Expression::Variance(_)
364 | Expression::VarPop(_)
365 | Expression::VarSamp(_)
366 | Expression::Median(_)
367 | Expression::Mode(_)
368 | Expression::First(_)
369 | Expression::Last(_)
370 | Expression::AnyValue(_)
371 | Expression::ApproxDistinct(_)
372 | Expression::ApproxCountDistinct(_)
373 | Expression::LogicalAnd(_)
374 | Expression::LogicalOr(_)
375 | Expression::AggregateFunction(_) => true,
376
377 Expression::Alias(alias) => contains_aggregate(&alias.this),
378 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) | Expression::Div(op) => {
379 contains_aggregate(&op.left) || contains_aggregate(&op.right)
380 }
381 Expression::Function(func) => {
382 let name = func.name.to_uppercase();
384 matches!(
385 name.as_str(),
386 "SUM"
387 | "COUNT"
388 | "AVG"
389 | "MIN"
390 | "MAX"
391 | "ARRAY_AGG"
392 | "STRING_AGG"
393 | "GROUP_CONCAT"
394 )
395 }
396 _ => false,
397 }
398}
399
400fn extract_aggregations(expressions: &[Expression]) -> Vec<Expression> {
402 let mut aggs = Vec::new();
403 for expr in expressions {
404 collect_aggregations(expr, &mut aggs);
405 }
406 aggs
407}
408
409fn collect_aggregations(expr: &Expression, aggs: &mut Vec<Expression>) {
410 match expr {
411 Expression::Sum(_)
413 | Expression::Count(_)
414 | Expression::Avg(_)
415 | Expression::Min(_)
416 | Expression::Max(_)
417 | Expression::ArrayAgg(_)
418 | Expression::StringAgg(_)
419 | Expression::ListAgg(_)
420 | Expression::Stddev(_)
421 | Expression::StddevPop(_)
422 | Expression::StddevSamp(_)
423 | Expression::Variance(_)
424 | Expression::VarPop(_)
425 | Expression::VarSamp(_)
426 | Expression::Median(_)
427 | Expression::Mode(_)
428 | Expression::First(_)
429 | Expression::Last(_)
430 | Expression::AnyValue(_)
431 | Expression::ApproxDistinct(_)
432 | Expression::ApproxCountDistinct(_)
433 | Expression::LogicalAnd(_)
434 | Expression::LogicalOr(_)
435 | Expression::AggregateFunction(_) => {
436 aggs.push(expr.clone());
437 }
438 Expression::Alias(alias) => {
439 collect_aggregations(&alias.this, aggs);
440 }
441 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) | Expression::Div(op) => {
442 collect_aggregations(&op.left, aggs);
443 collect_aggregations(&op.right, aggs);
444 }
445 Expression::Function(func) => {
446 let name = func.name.to_uppercase();
447 if matches!(
448 name.as_str(),
449 "SUM"
450 | "COUNT"
451 | "AVG"
452 | "MIN"
453 | "MAX"
454 | "ARRAY_AGG"
455 | "STRING_AGG"
456 | "GROUP_CONCAT"
457 ) {
458 aggs.push(expr.clone());
459 } else {
460 for arg in &func.args {
461 collect_aggregations(arg, aggs);
462 }
463 }
464 }
465 _ => {}
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::dialects::{Dialect, DialectType};
473
474 fn parse(sql: &str) -> Expression {
475 let dialect = Dialect::get(DialectType::Generic);
476 let ast = dialect.parse(sql).unwrap();
477 ast.into_iter().next().unwrap()
478 }
479
480 #[test]
481 fn test_simple_scan() {
482 let sql = "SELECT a, b FROM t";
483 let expr = parse(sql);
484 let plan = Plan::from_expression(&expr);
485
486 assert!(plan.is_some());
487 let plan = plan.unwrap();
488 assert_eq!(plan.root.kind, StepKind::Scan);
489 assert_eq!(plan.root.name, "t");
490 }
491
492 #[test]
493 fn test_join() {
494 let sql = "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.id";
495 let expr = parse(sql);
496 let plan = Plan::from_expression(&expr);
497
498 assert!(plan.is_some());
499 let plan = plan.unwrap();
500 assert!(matches!(plan.root.kind, StepKind::Join(_)));
501 assert_eq!(plan.root.dependencies.len(), 2);
502 }
503
504 #[test]
505 fn test_aggregate() {
506 let sql = "SELECT x, SUM(y) FROM t GROUP BY x";
507 let expr = parse(sql);
508 let plan = Plan::from_expression(&expr);
509
510 assert!(plan.is_some());
511 let plan = plan.unwrap();
512 assert_eq!(plan.root.kind, StepKind::Aggregate);
513 }
514
515 #[test]
516 fn test_union() {
517 let sql = "SELECT a FROM t1 UNION SELECT b FROM t2";
518 let expr = parse(sql);
519 let plan = Plan::from_expression(&expr);
520
521 assert!(plan.is_some());
522 let plan = plan.unwrap();
523 assert!(matches!(
524 plan.root.kind,
525 StepKind::SetOperation(SetOperationType::Union)
526 ));
527 }
528
529 #[test]
530 fn test_contains_aggregate() {
531 let select_with_agg = parse("SELECT SUM(x) FROM t");
533 if let Expression::Select(ref sel) = select_with_agg {
534 assert!(!sel.expressions.is_empty());
535 assert!(
536 contains_aggregate(&sel.expressions[0]),
537 "Expected SUM to be detected as aggregate function"
538 );
539 } else {
540 panic!("Expected SELECT expression");
541 }
542
543 let select_without_agg = parse("SELECT x + 1 FROM t");
545 if let Expression::Select(ref sel) = select_without_agg {
546 assert!(!sel.expressions.is_empty());
547 assert!(
548 !contains_aggregate(&sel.expressions[0]),
549 "Expected x + 1 to not be an aggregate function"
550 );
551 } else {
552 panic!("Expected SELECT expression");
553 }
554 }
555}