1use crate::expressions::{Expression, JoinKind};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug)]
13pub struct Plan {
14 pub root: Step,
16 dag: Option<HashMap<usize, HashSet<usize>>>,
18}
19
20impl Plan {
21 pub fn from_expression(expression: &Expression) -> Option<Self> {
23 let root = Step::from_expression(expression, &HashMap::new())?;
24 Some(Self { root, dag: None })
25 }
26
27 pub fn dag(&mut self) -> &HashMap<usize, HashSet<usize>> {
29 if self.dag.is_none() {
30 let mut dag = HashMap::new();
31 self.build_dag(&self.root, &mut dag, 0);
32 self.dag = Some(dag);
33 }
34 self.dag.as_ref().unwrap()
35 }
36
37 fn build_dag(&self, step: &Step, dag: &mut HashMap<usize, HashSet<usize>>, id: usize) {
38 let deps: HashSet<usize> = step.dependencies
39 .iter()
40 .enumerate()
41 .map(|(i, _)| id + i + 1)
42 .collect();
43 dag.insert(id, deps);
44
45 for (i, dep) in step.dependencies.iter().enumerate() {
46 self.build_dag(dep, dag, id + i + 1);
47 }
48 }
49
50 pub fn leaves(&self) -> Vec<&Step> {
52 let mut leaves = Vec::new();
53 self.collect_leaves(&self.root, &mut leaves);
54 leaves
55 }
56
57 fn collect_leaves<'a>(&'a self, step: &'a Step, leaves: &mut Vec<&'a Step>) {
58 if step.dependencies.is_empty() {
59 leaves.push(step);
60 } else {
61 for dep in &step.dependencies {
62 self.collect_leaves(dep, leaves);
63 }
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Step {
71 pub name: String,
73 pub kind: StepKind,
75 pub projections: Vec<Expression>,
77 pub dependencies: Vec<Step>,
79 pub aggregations: Vec<Expression>,
81 pub group_by: Vec<Expression>,
83 pub condition: Option<Expression>,
85 pub order_by: Vec<Expression>,
87 pub limit: Option<Expression>,
89}
90
91#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum StepKind {
95 Scan,
97 Join(JoinType),
99 Aggregate,
101 Sort,
103 SetOperation(SetOperationType),
105}
106
107#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109#[serde(rename_all = "snake_case")]
110pub enum JoinType {
111 Inner,
112 Left,
113 Right,
114 Full,
115 Cross,
116}
117
118#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum SetOperationType {
122 Union,
123 UnionAll,
124 Intersect,
125 Except,
126}
127
128impl Step {
129 pub fn new(name: impl Into<String>, kind: StepKind) -> Self {
131 Self {
132 name: name.into(),
133 kind,
134 projections: Vec::new(),
135 dependencies: Vec::new(),
136 aggregations: Vec::new(),
137 group_by: Vec::new(),
138 condition: None,
139 order_by: Vec::new(),
140 limit: None,
141 }
142 }
143
144 pub fn from_expression(
146 expression: &Expression,
147 ctes: &HashMap<String, Step>,
148 ) -> Option<Self> {
149 match expression {
150 Expression::Select(select) => {
151 let mut step = Self::from_select(select, ctes)?;
152
153 if let Some(ref order_by) = select.order_by {
155 let sort_step = Step {
156 name: step.name.clone(),
157 kind: StepKind::Sort,
158 projections: Vec::new(),
159 dependencies: vec![step],
160 aggregations: Vec::new(),
161 group_by: Vec::new(),
162 condition: None,
163 order_by: order_by.expressions.iter().map(|o| o.this.clone()).collect(),
164 limit: None,
165 };
166 step = sort_step;
167 }
168
169 if let Some(ref limit) = select.limit {
171 step.limit = Some(limit.this.clone());
172 }
173
174 Some(step)
175 }
176 Expression::Union(union) => {
177 let left = Self::from_expression(&union.left, ctes)?;
178 let right = Self::from_expression(&union.right, ctes)?;
179
180 let op_type = if union.all {
181 SetOperationType::UnionAll
182 } else {
183 SetOperationType::Union
184 };
185
186 Some(Step {
187 name: "UNION".to_string(),
188 kind: StepKind::SetOperation(op_type),
189 projections: Vec::new(),
190 dependencies: vec![left, right],
191 aggregations: Vec::new(),
192 group_by: Vec::new(),
193 condition: None,
194 order_by: Vec::new(),
195 limit: None,
196 })
197 }
198 Expression::Intersect(intersect) => {
199 let left = Self::from_expression(&intersect.left, ctes)?;
200 let right = Self::from_expression(&intersect.right, ctes)?;
201
202 Some(Step {
203 name: "INTERSECT".to_string(),
204 kind: StepKind::SetOperation(SetOperationType::Intersect),
205 projections: Vec::new(),
206 dependencies: vec![left, right],
207 aggregations: Vec::new(),
208 group_by: Vec::new(),
209 condition: None,
210 order_by: Vec::new(),
211 limit: None,
212 })
213 }
214 Expression::Except(except) => {
215 let left = Self::from_expression(&except.left, ctes)?;
216 let right = Self::from_expression(&except.right, ctes)?;
217
218 Some(Step {
219 name: "EXCEPT".to_string(),
220 kind: StepKind::SetOperation(SetOperationType::Except),
221 projections: Vec::new(),
222 dependencies: vec![left, right],
223 aggregations: Vec::new(),
224 group_by: Vec::new(),
225 condition: None,
226 order_by: Vec::new(),
227 limit: None,
228 })
229 }
230 _ => None,
231 }
232 }
233
234 fn from_select(
235 select: &crate::expressions::Select,
236 ctes: &HashMap<String, Step>,
237 ) -> Option<Self> {
238 let mut ctes = ctes.clone();
240 if let Some(ref with) = select.with {
241 for cte in &with.ctes {
242 if let Some(step) = Self::from_expression(&cte.this, &ctes) {
243 ctes.insert(cte.alias.name.clone(), step);
244 }
245 }
246 }
247
248 let mut step = if let Some(ref from) = select.from {
250 if let Some(table_expr) = from.expressions.first() {
251 Self::from_table_expression(table_expr, &ctes)?
252 } else {
253 return None;
254 }
255 } else {
256 Step::new("", StepKind::Scan)
258 };
259
260 for join in &select.joins {
262 let right = Self::from_table_expression(&join.this, &ctes)?;
263
264 let join_type = match join.kind {
265 JoinKind::Inner => JoinType::Inner,
266 JoinKind::Left | JoinKind::NaturalLeft => JoinType::Left,
267 JoinKind::Right | JoinKind::NaturalRight => JoinType::Right,
268 JoinKind::Full | JoinKind::NaturalFull => JoinType::Full,
269 JoinKind::Cross | JoinKind::Natural => JoinType::Cross,
270 _ => JoinType::Inner,
271 };
272
273 let join_step = Step {
274 name: step.name.clone(),
275 kind: StepKind::Join(join_type),
276 projections: Vec::new(),
277 dependencies: vec![step, right],
278 aggregations: Vec::new(),
279 group_by: Vec::new(),
280 condition: join.on.clone(),
281 order_by: Vec::new(),
282 limit: None,
283 };
284 step = join_step;
285 }
286
287 let has_aggregations = select.expressions.iter().any(|e| contains_aggregate(e));
289 let has_group_by = select.group_by.is_some();
290
291 if has_aggregations || has_group_by {
292 let agg_step = Step {
294 name: step.name.clone(),
295 kind: StepKind::Aggregate,
296 projections: select.expressions.clone(),
297 dependencies: vec![step],
298 aggregations: extract_aggregations(&select.expressions),
299 group_by: select.group_by.as_ref()
300 .map(|g| g.expressions.clone())
301 .unwrap_or_default(),
302 condition: None,
303 order_by: Vec::new(),
304 limit: None,
305 };
306 step = agg_step;
307 } else {
308 step.projections = select.expressions.clone();
309 }
310
311 Some(step)
312 }
313
314 fn from_table_expression(
315 expr: &Expression,
316 ctes: &HashMap<String, Step>,
317 ) -> Option<Self> {
318 match expr {
319 Expression::Table(table) => {
320 if let Some(cte_step) = ctes.get(&table.name.name) {
322 return Some(cte_step.clone());
323 }
324
325 Some(Step::new(&table.name.name, StepKind::Scan))
327 }
328 Expression::Alias(alias) => {
329 let mut step = Self::from_table_expression(&alias.this, ctes)?;
330 step.name = alias.alias.name.clone();
331 Some(step)
332 }
333 Expression::Subquery(sq) => {
334 let step = Self::from_expression(&sq.this, ctes)?;
335 Some(step)
336 }
337 _ => None,
338 }
339 }
340
341 pub fn add_dependency(&mut self, dep: Step) {
343 self.dependencies.push(dep);
344 }
345}
346
347fn contains_aggregate(expr: &Expression) -> bool {
349 match expr {
350 Expression::Sum(_) | Expression::Count(_) | Expression::Avg(_) |
352 Expression::Min(_) | Expression::Max(_) | Expression::ArrayAgg(_) |
353 Expression::StringAgg(_) | Expression::ListAgg(_) |
354 Expression::Stddev(_) | Expression::StddevPop(_) | Expression::StddevSamp(_) |
355 Expression::Variance(_) | Expression::VarPop(_) | Expression::VarSamp(_) |
356 Expression::Median(_) | Expression::Mode(_) | Expression::First(_) | Expression::Last(_) |
357 Expression::AnyValue(_) | Expression::ApproxDistinct(_) | Expression::ApproxCountDistinct(_) |
358 Expression::LogicalAnd(_) | Expression::LogicalOr(_) |
359 Expression::AggregateFunction(_) => true,
360
361 Expression::Alias(alias) => contains_aggregate(&alias.this),
362 Expression::Add(op) | Expression::Sub(op) |
363 Expression::Mul(op) | Expression::Div(op) => {
364 contains_aggregate(&op.left) || contains_aggregate(&op.right)
365 }
366 Expression::Function(func) => {
367 let name = func.name.to_uppercase();
369 matches!(name.as_str(), "SUM" | "COUNT" | "AVG" | "MIN" | "MAX" |
370 "ARRAY_AGG" | "STRING_AGG" | "GROUP_CONCAT")
371 }
372 _ => false,
373 }
374}
375
376fn extract_aggregations(expressions: &[Expression]) -> Vec<Expression> {
378 let mut aggs = Vec::new();
379 for expr in expressions {
380 collect_aggregations(expr, &mut aggs);
381 }
382 aggs
383}
384
385fn collect_aggregations(expr: &Expression, aggs: &mut Vec<Expression>) {
386 match expr {
387 Expression::Sum(_) | Expression::Count(_) | Expression::Avg(_) |
389 Expression::Min(_) | Expression::Max(_) | Expression::ArrayAgg(_) |
390 Expression::StringAgg(_) | Expression::ListAgg(_) |
391 Expression::Stddev(_) | Expression::StddevPop(_) | Expression::StddevSamp(_) |
392 Expression::Variance(_) | Expression::VarPop(_) | Expression::VarSamp(_) |
393 Expression::Median(_) | Expression::Mode(_) | Expression::First(_) | Expression::Last(_) |
394 Expression::AnyValue(_) | Expression::ApproxDistinct(_) | Expression::ApproxCountDistinct(_) |
395 Expression::LogicalAnd(_) | Expression::LogicalOr(_) |
396 Expression::AggregateFunction(_) => {
397 aggs.push(expr.clone());
398 }
399 Expression::Alias(alias) => {
400 collect_aggregations(&alias.this, aggs);
401 }
402 Expression::Add(op) | Expression::Sub(op) |
403 Expression::Mul(op) | Expression::Div(op) => {
404 collect_aggregations(&op.left, aggs);
405 collect_aggregations(&op.right, aggs);
406 }
407 Expression::Function(func) => {
408 let name = func.name.to_uppercase();
409 if matches!(name.as_str(), "SUM" | "COUNT" | "AVG" | "MIN" | "MAX" |
410 "ARRAY_AGG" | "STRING_AGG" | "GROUP_CONCAT") {
411 aggs.push(expr.clone());
412 } else {
413 for arg in &func.args {
414 collect_aggregations(arg, aggs);
415 }
416 }
417 }
418 _ => {}
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::dialects::{Dialect, DialectType};
426
427 fn parse(sql: &str) -> Expression {
428 let dialect = Dialect::get(DialectType::Generic);
429 let ast = dialect.parse(sql).unwrap();
430 ast.into_iter().next().unwrap()
431 }
432
433 #[test]
434 fn test_simple_scan() {
435 let sql = "SELECT a, b FROM t";
436 let expr = parse(sql);
437 let plan = Plan::from_expression(&expr);
438
439 assert!(plan.is_some());
440 let plan = plan.unwrap();
441 assert_eq!(plan.root.kind, StepKind::Scan);
442 assert_eq!(plan.root.name, "t");
443 }
444
445 #[test]
446 fn test_join() {
447 let sql = "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.id";
448 let expr = parse(sql);
449 let plan = Plan::from_expression(&expr);
450
451 assert!(plan.is_some());
452 let plan = plan.unwrap();
453 assert!(matches!(plan.root.kind, StepKind::Join(_)));
454 assert_eq!(plan.root.dependencies.len(), 2);
455 }
456
457 #[test]
458 fn test_aggregate() {
459 let sql = "SELECT x, SUM(y) FROM t GROUP BY x";
460 let expr = parse(sql);
461 let plan = Plan::from_expression(&expr);
462
463 assert!(plan.is_some());
464 let plan = plan.unwrap();
465 assert_eq!(plan.root.kind, StepKind::Aggregate);
466 }
467
468 #[test]
469 fn test_union() {
470 let sql = "SELECT a FROM t1 UNION SELECT b FROM t2";
471 let expr = parse(sql);
472 let plan = Plan::from_expression(&expr);
473
474 assert!(plan.is_some());
475 let plan = plan.unwrap();
476 assert!(matches!(plan.root.kind, StepKind::SetOperation(SetOperationType::Union)));
477 }
478
479 #[test]
480 fn test_contains_aggregate() {
481 let select_with_agg = parse("SELECT SUM(x) FROM t");
483 if let Expression::Select(ref sel) = select_with_agg {
484 assert!(!sel.expressions.is_empty());
485 assert!(contains_aggregate(&sel.expressions[0]),
486 "Expected SUM to be detected as aggregate function");
487 } else {
488 panic!("Expected SELECT expression");
489 }
490
491 let select_without_agg = parse("SELECT x + 1 FROM t");
493 if let Expression::Select(ref sel) = select_without_agg {
494 assert!(!sel.expressions.is_empty());
495 assert!(!contains_aggregate(&sel.expressions[0]),
496 "Expected x + 1 to not be an aggregate function");
497 } else {
498 panic!("Expected SELECT expression");
499 }
500 }
501}