1use crate::expressions::{Expression, JoinKind};
8use std::collections::{HashMap, HashSet};
9
10#[derive(Debug)]
12pub struct Plan {
13 pub root: Step,
15 dag: Option<HashMap<usize, HashSet<usize>>>,
17}
18
19impl Plan {
20 pub fn from_expression(expression: &Expression) -> Option<Self> {
22 let root = Step::from_expression(expression, &HashMap::new())?;
23 Some(Self { root, dag: None })
24 }
25
26 pub fn dag(&mut self) -> &HashMap<usize, HashSet<usize>> {
28 if self.dag.is_none() {
29 let mut dag = HashMap::new();
30 self.build_dag(&self.root, &mut dag, 0);
31 self.dag = Some(dag);
32 }
33 self.dag.as_ref().unwrap()
34 }
35
36 fn build_dag(&self, step: &Step, dag: &mut HashMap<usize, HashSet<usize>>, id: usize) {
37 let deps: HashSet<usize> = step.dependencies
38 .iter()
39 .enumerate()
40 .map(|(i, _)| id + i + 1)
41 .collect();
42 dag.insert(id, deps);
43
44 for (i, dep) in step.dependencies.iter().enumerate() {
45 self.build_dag(dep, dag, id + i + 1);
46 }
47 }
48
49 pub fn leaves(&self) -> Vec<&Step> {
51 let mut leaves = Vec::new();
52 self.collect_leaves(&self.root, &mut leaves);
53 leaves
54 }
55
56 fn collect_leaves<'a>(&'a self, step: &'a Step, leaves: &mut Vec<&'a Step>) {
57 if step.dependencies.is_empty() {
58 leaves.push(step);
59 } else {
60 for dep in &step.dependencies {
61 self.collect_leaves(dep, leaves);
62 }
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct Step {
70 pub name: String,
72 pub kind: StepKind,
74 pub projections: Vec<Expression>,
76 pub dependencies: Vec<Step>,
78 pub aggregations: Vec<Expression>,
80 pub group_by: Vec<Expression>,
82 pub condition: Option<Expression>,
84 pub order_by: Vec<Expression>,
86 pub limit: Option<Expression>,
88}
89
90#[derive(Debug, Clone, PartialEq)]
92pub enum StepKind {
93 Scan,
95 Join(JoinType),
97 Aggregate,
99 Sort,
101 SetOperation(SetOperationType),
103}
104
105#[derive(Debug, Clone, PartialEq)]
107pub enum JoinType {
108 Inner,
109 Left,
110 Right,
111 Full,
112 Cross,
113}
114
115#[derive(Debug, Clone, PartialEq)]
117pub enum SetOperationType {
118 Union,
119 UnionAll,
120 Intersect,
121 Except,
122}
123
124impl Step {
125 pub fn new(name: impl Into<String>, kind: StepKind) -> Self {
127 Self {
128 name: name.into(),
129 kind,
130 projections: Vec::new(),
131 dependencies: Vec::new(),
132 aggregations: Vec::new(),
133 group_by: Vec::new(),
134 condition: None,
135 order_by: Vec::new(),
136 limit: None,
137 }
138 }
139
140 pub fn from_expression(
142 expression: &Expression,
143 ctes: &HashMap<String, Step>,
144 ) -> Option<Self> {
145 match expression {
146 Expression::Select(select) => {
147 let mut step = Self::from_select(select, ctes)?;
148
149 if let Some(ref order_by) = select.order_by {
151 let sort_step = Step {
152 name: step.name.clone(),
153 kind: StepKind::Sort,
154 projections: Vec::new(),
155 dependencies: vec![step],
156 aggregations: Vec::new(),
157 group_by: Vec::new(),
158 condition: None,
159 order_by: order_by.expressions.iter().map(|o| o.this.clone()).collect(),
160 limit: None,
161 };
162 step = sort_step;
163 }
164
165 if let Some(ref limit) = select.limit {
167 step.limit = Some(limit.this.clone());
168 }
169
170 Some(step)
171 }
172 Expression::Union(union) => {
173 let left = Self::from_expression(&union.left, ctes)?;
174 let right = Self::from_expression(&union.right, ctes)?;
175
176 let op_type = if union.all {
177 SetOperationType::UnionAll
178 } else {
179 SetOperationType::Union
180 };
181
182 Some(Step {
183 name: "UNION".to_string(),
184 kind: StepKind::SetOperation(op_type),
185 projections: Vec::new(),
186 dependencies: vec![left, right],
187 aggregations: Vec::new(),
188 group_by: Vec::new(),
189 condition: None,
190 order_by: Vec::new(),
191 limit: None,
192 })
193 }
194 Expression::Intersect(intersect) => {
195 let left = Self::from_expression(&intersect.left, ctes)?;
196 let right = Self::from_expression(&intersect.right, ctes)?;
197
198 Some(Step {
199 name: "INTERSECT".to_string(),
200 kind: StepKind::SetOperation(SetOperationType::Intersect),
201 projections: Vec::new(),
202 dependencies: vec![left, right],
203 aggregations: Vec::new(),
204 group_by: Vec::new(),
205 condition: None,
206 order_by: Vec::new(),
207 limit: None,
208 })
209 }
210 Expression::Except(except) => {
211 let left = Self::from_expression(&except.left, ctes)?;
212 let right = Self::from_expression(&except.right, ctes)?;
213
214 Some(Step {
215 name: "EXCEPT".to_string(),
216 kind: StepKind::SetOperation(SetOperationType::Except),
217 projections: Vec::new(),
218 dependencies: vec![left, right],
219 aggregations: Vec::new(),
220 group_by: Vec::new(),
221 condition: None,
222 order_by: Vec::new(),
223 limit: None,
224 })
225 }
226 _ => None,
227 }
228 }
229
230 fn from_select(
231 select: &crate::expressions::Select,
232 ctes: &HashMap<String, Step>,
233 ) -> Option<Self> {
234 let mut ctes = ctes.clone();
236 if let Some(ref with) = select.with {
237 for cte in &with.ctes {
238 if let Some(step) = Self::from_expression(&cte.this, &ctes) {
239 ctes.insert(cte.alias.name.clone(), step);
240 }
241 }
242 }
243
244 let mut step = if let Some(ref from) = select.from {
246 if let Some(table_expr) = from.expressions.first() {
247 Self::from_table_expression(table_expr, &ctes)?
248 } else {
249 return None;
250 }
251 } else {
252 Step::new("", StepKind::Scan)
254 };
255
256 for join in &select.joins {
258 let right = Self::from_table_expression(&join.this, &ctes)?;
259
260 let join_type = match join.kind {
261 JoinKind::Inner => JoinType::Inner,
262 JoinKind::Left | JoinKind::NaturalLeft => JoinType::Left,
263 JoinKind::Right | JoinKind::NaturalRight => JoinType::Right,
264 JoinKind::Full | JoinKind::NaturalFull => JoinType::Full,
265 JoinKind::Cross | JoinKind::Natural => JoinType::Cross,
266 _ => JoinType::Inner,
267 };
268
269 let join_step = Step {
270 name: step.name.clone(),
271 kind: StepKind::Join(join_type),
272 projections: Vec::new(),
273 dependencies: vec![step, right],
274 aggregations: Vec::new(),
275 group_by: Vec::new(),
276 condition: join.on.clone(),
277 order_by: Vec::new(),
278 limit: None,
279 };
280 step = join_step;
281 }
282
283 let has_aggregations = select.expressions.iter().any(|e| contains_aggregate(e));
285 let has_group_by = select.group_by.is_some();
286
287 if has_aggregations || has_group_by {
288 let agg_step = Step {
290 name: step.name.clone(),
291 kind: StepKind::Aggregate,
292 projections: select.expressions.clone(),
293 dependencies: vec![step],
294 aggregations: extract_aggregations(&select.expressions),
295 group_by: select.group_by.as_ref()
296 .map(|g| g.expressions.clone())
297 .unwrap_or_default(),
298 condition: None,
299 order_by: Vec::new(),
300 limit: None,
301 };
302 step = agg_step;
303 } else {
304 step.projections = select.expressions.clone();
305 }
306
307 Some(step)
308 }
309
310 fn from_table_expression(
311 expr: &Expression,
312 ctes: &HashMap<String, Step>,
313 ) -> Option<Self> {
314 match expr {
315 Expression::Table(table) => {
316 if let Some(cte_step) = ctes.get(&table.name.name) {
318 return Some(cte_step.clone());
319 }
320
321 Some(Step::new(&table.name.name, StepKind::Scan))
323 }
324 Expression::Alias(alias) => {
325 let mut step = Self::from_table_expression(&alias.this, ctes)?;
326 step.name = alias.alias.name.clone();
327 Some(step)
328 }
329 Expression::Subquery(sq) => {
330 let step = Self::from_expression(&sq.this, ctes)?;
331 Some(step)
332 }
333 _ => None,
334 }
335 }
336
337 pub fn add_dependency(&mut self, dep: Step) {
339 self.dependencies.push(dep);
340 }
341}
342
343fn contains_aggregate(expr: &Expression) -> bool {
345 match expr {
346 Expression::Sum(_) | Expression::Count(_) | Expression::Avg(_) |
348 Expression::Min(_) | Expression::Max(_) | Expression::ArrayAgg(_) |
349 Expression::StringAgg(_) | Expression::ListAgg(_) |
350 Expression::Stddev(_) | Expression::StddevPop(_) | Expression::StddevSamp(_) |
351 Expression::Variance(_) | Expression::VarPop(_) | Expression::VarSamp(_) |
352 Expression::Median(_) | Expression::Mode(_) | Expression::First(_) | Expression::Last(_) |
353 Expression::AnyValue(_) | Expression::ApproxDistinct(_) | Expression::ApproxCountDistinct(_) |
354 Expression::LogicalAnd(_) | Expression::LogicalOr(_) |
355 Expression::AggregateFunction(_) => true,
356
357 Expression::Alias(alias) => contains_aggregate(&alias.this),
358 Expression::Add(op) | Expression::Sub(op) |
359 Expression::Mul(op) | Expression::Div(op) => {
360 contains_aggregate(&op.left) || contains_aggregate(&op.right)
361 }
362 Expression::Function(func) => {
363 let name = func.name.to_uppercase();
365 matches!(name.as_str(), "SUM" | "COUNT" | "AVG" | "MIN" | "MAX" |
366 "ARRAY_AGG" | "STRING_AGG" | "GROUP_CONCAT")
367 }
368 _ => false,
369 }
370}
371
372fn extract_aggregations(expressions: &[Expression]) -> Vec<Expression> {
374 let mut aggs = Vec::new();
375 for expr in expressions {
376 collect_aggregations(expr, &mut aggs);
377 }
378 aggs
379}
380
381fn collect_aggregations(expr: &Expression, aggs: &mut Vec<Expression>) {
382 match expr {
383 Expression::Sum(_) | Expression::Count(_) | Expression::Avg(_) |
385 Expression::Min(_) | Expression::Max(_) | Expression::ArrayAgg(_) |
386 Expression::StringAgg(_) | Expression::ListAgg(_) |
387 Expression::Stddev(_) | Expression::StddevPop(_) | Expression::StddevSamp(_) |
388 Expression::Variance(_) | Expression::VarPop(_) | Expression::VarSamp(_) |
389 Expression::Median(_) | Expression::Mode(_) | Expression::First(_) | Expression::Last(_) |
390 Expression::AnyValue(_) | Expression::ApproxDistinct(_) | Expression::ApproxCountDistinct(_) |
391 Expression::LogicalAnd(_) | Expression::LogicalOr(_) |
392 Expression::AggregateFunction(_) => {
393 aggs.push(expr.clone());
394 }
395 Expression::Alias(alias) => {
396 collect_aggregations(&alias.this, aggs);
397 }
398 Expression::Add(op) | Expression::Sub(op) |
399 Expression::Mul(op) | Expression::Div(op) => {
400 collect_aggregations(&op.left, aggs);
401 collect_aggregations(&op.right, aggs);
402 }
403 Expression::Function(func) => {
404 let name = func.name.to_uppercase();
405 if matches!(name.as_str(), "SUM" | "COUNT" | "AVG" | "MIN" | "MAX" |
406 "ARRAY_AGG" | "STRING_AGG" | "GROUP_CONCAT") {
407 aggs.push(expr.clone());
408 } else {
409 for arg in &func.args {
410 collect_aggregations(arg, aggs);
411 }
412 }
413 }
414 _ => {}
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use crate::dialects::{Dialect, DialectType};
422
423 fn parse(sql: &str) -> Expression {
424 let dialect = Dialect::get(DialectType::Generic);
425 let ast = dialect.parse(sql).unwrap();
426 ast.into_iter().next().unwrap()
427 }
428
429 #[test]
430 fn test_simple_scan() {
431 let sql = "SELECT a, b FROM t";
432 let expr = parse(sql);
433 let plan = Plan::from_expression(&expr);
434
435 assert!(plan.is_some());
436 let plan = plan.unwrap();
437 assert_eq!(plan.root.kind, StepKind::Scan);
438 assert_eq!(plan.root.name, "t");
439 }
440
441 #[test]
442 fn test_join() {
443 let sql = "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.id";
444 let expr = parse(sql);
445 let plan = Plan::from_expression(&expr);
446
447 assert!(plan.is_some());
448 let plan = plan.unwrap();
449 assert!(matches!(plan.root.kind, StepKind::Join(_)));
450 assert_eq!(plan.root.dependencies.len(), 2);
451 }
452
453 #[test]
454 fn test_aggregate() {
455 let sql = "SELECT x, SUM(y) FROM t GROUP BY x";
456 let expr = parse(sql);
457 let plan = Plan::from_expression(&expr);
458
459 assert!(plan.is_some());
460 let plan = plan.unwrap();
461 assert_eq!(plan.root.kind, StepKind::Aggregate);
462 }
463
464 #[test]
465 fn test_union() {
466 let sql = "SELECT a FROM t1 UNION SELECT b FROM t2";
467 let expr = parse(sql);
468 let plan = Plan::from_expression(&expr);
469
470 assert!(plan.is_some());
471 let plan = plan.unwrap();
472 assert!(matches!(plan.root.kind, StepKind::SetOperation(SetOperationType::Union)));
473 }
474
475 #[test]
476 fn test_contains_aggregate() {
477 let select_with_agg = parse("SELECT SUM(x) FROM t");
479 if let Expression::Select(ref sel) = select_with_agg {
480 assert!(!sel.expressions.is_empty());
481 assert!(contains_aggregate(&sel.expressions[0]),
482 "Expected SUM to be detected as aggregate function");
483 } else {
484 panic!("Expected SELECT expression");
485 }
486
487 let select_without_agg = parse("SELECT x + 1 FROM t");
489 if let Expression::Select(ref sel) = select_without_agg {
490 assert!(!sel.expressions.is_empty());
491 assert!(!contains_aggregate(&sel.expressions[0]),
492 "Expected x + 1 to not be an aggregate function");
493 } else {
494 panic!("Expected SELECT expression");
495 }
496 }
497}