1pub mod cse;
4pub mod join_reordering;
5pub mod projection_pushdown;
6
7use crate::error::{QueryError, Result};
8use crate::parser::ast::*;
9use oxigdal_core::error::OxiGdalError;
10use std::collections::HashSet;
11
12pub use cse::CommonSubexpressionElimination;
14pub use join_reordering::JoinReordering;
15pub use projection_pushdown::ProjectionPushdown;
16
17pub trait OptimizationRule {
19 fn apply(&self, stmt: SelectStatement) -> Result<SelectStatement>;
21}
22
23pub struct PredicatePushdown;
29
30impl OptimizationRule for PredicatePushdown {
31 fn apply(&self, mut stmt: SelectStatement) -> Result<SelectStatement> {
32 if stmt.selection.is_none() || stmt.from.is_none() {
34 return Ok(stmt);
35 }
36
37 let selection = stmt
39 .selection
40 .take()
41 .ok_or_else(|| QueryError::optimization("Internal error: selection disappeared"))?;
42 let from = stmt
43 .from
44 .take()
45 .ok_or_else(|| QueryError::optimization("Internal error: from disappeared"))?;
46
47 let mut predicates = Vec::new();
49 extract_predicates(&selection, &mut predicates);
50
51 let table_aliases = collect_table_aliases(&from);
53
54 if table_aliases.is_empty() {
56 return Err(QueryError::optimization(
57 OxiGdalError::invalid_state_builder(
58 "Cannot apply predicate pushdown without table references",
59 )
60 .with_operation("predicate_pushdown")
61 .with_parameter("predicate_count", predicates.len().to_string())
62 .with_suggestion("Ensure the FROM clause contains valid table references")
63 .build()
64 .to_string(),
65 ));
66 }
67
68 let mut pushed_predicates: Vec<Expr> = Vec::new();
70 let mut remaining_predicates: Vec<Expr> = Vec::new();
71
72 for predicate in predicates {
73 let predicate_tables = get_predicate_tables(&predicate);
74
75 if !predicate_tables.is_empty()
77 && !predicate_tables.iter().any(|t| table_aliases.contains(t))
78 {
79 return Err(QueryError::optimization(
80 OxiGdalError::invalid_operation_builder("Predicate references unknown table")
81 .with_operation("predicate_pushdown")
82 .with_parameter(
83 "unknown_tables",
84 predicate_tables
85 .iter()
86 .filter(|t| !table_aliases.contains(*t))
87 .cloned()
88 .collect::<Vec<_>>()
89 .join(", "),
90 )
91 .with_parameter(
92 "available_tables",
93 table_aliases.iter().cloned().collect::<Vec<_>>().join(", "),
94 )
95 .with_suggestion("Check table names and aliases in the FROM clause")
96 .build()
97 .to_string(),
98 ));
99 }
100
101 if predicate_tables.len() == 1 {
103 if let Some(table_name) = predicate_tables.iter().next() {
104 if table_aliases.contains(table_name) {
105 pushed_predicates.push(predicate);
106 continue;
107 }
108 }
109 }
110 remaining_predicates.push(predicate);
111 }
112
113 let optimized_from = push_predicates_to_joins(from, &mut pushed_predicates);
115
116 let new_selection = if remaining_predicates.is_empty() && pushed_predicates.is_empty() {
118 None
119 } else {
120 let all_remaining: Vec<Expr> = remaining_predicates
121 .into_iter()
122 .chain(pushed_predicates)
123 .collect();
124 combine_predicates_with_and(all_remaining)
125 };
126
127 stmt.from = Some(optimized_from);
128 stmt.selection = new_selection;
129
130 Ok(stmt)
131 }
132}
133
134pub(crate) fn extract_predicates(expr: &Expr, predicates: &mut Vec<Expr>) {
140 match expr {
141 Expr::BinaryOp {
142 left,
143 op: BinaryOperator::And,
144 right,
145 } => {
146 extract_predicates(left, predicates);
147 extract_predicates(right, predicates);
148 }
149 _ => {
150 predicates.push(expr.clone());
151 }
152 }
153}
154
155pub(crate) fn collect_table_aliases(table_ref: &TableReference) -> HashSet<String> {
157 let mut aliases = HashSet::new();
158 collect_table_aliases_recursive(table_ref, &mut aliases);
159 aliases
160}
161
162fn collect_table_aliases_recursive(table_ref: &TableReference, aliases: &mut HashSet<String>) {
163 match table_ref {
164 TableReference::Table { name, alias } => {
165 aliases.insert(alias.clone().unwrap_or_else(|| name.clone()));
166 aliases.insert(name.clone());
167 }
168 TableReference::Join { left, right, .. } => {
169 collect_table_aliases_recursive(left, aliases);
170 collect_table_aliases_recursive(right, aliases);
171 }
172 TableReference::Subquery { alias, .. } => {
173 aliases.insert(alias.clone());
174 }
175 }
176}
177
178pub(crate) fn get_predicate_tables(expr: &Expr) -> HashSet<String> {
180 let mut tables = HashSet::new();
181 collect_predicate_tables(expr, &mut tables);
182 tables
183}
184
185fn collect_predicate_tables(expr: &Expr, tables: &mut HashSet<String>) {
186 match expr {
187 Expr::Column { table, .. } => {
188 if let Some(t) = table {
189 tables.insert(t.clone());
190 }
191 }
192 Expr::BinaryOp { left, right, .. } => {
193 collect_predicate_tables(left, tables);
194 collect_predicate_tables(right, tables);
195 }
196 Expr::UnaryOp { expr, .. } => {
197 collect_predicate_tables(expr, tables);
198 }
199 Expr::Function { args, .. } => {
200 for arg in args {
201 collect_predicate_tables(arg, tables);
202 }
203 }
204 Expr::Case {
205 operand,
206 when_then,
207 else_result,
208 } => {
209 if let Some(op) = operand {
210 collect_predicate_tables(op, tables);
211 }
212 for (when, then) in when_then {
213 collect_predicate_tables(when, tables);
214 collect_predicate_tables(then, tables);
215 }
216 if let Some(else_expr) = else_result {
217 collect_predicate_tables(else_expr, tables);
218 }
219 }
220 Expr::Cast { expr, .. } => {
221 collect_predicate_tables(expr, tables);
222 }
223 Expr::IsNull(expr) | Expr::IsNotNull(expr) => {
224 collect_predicate_tables(expr, tables);
225 }
226 Expr::InList { expr, list, .. } => {
227 collect_predicate_tables(expr, tables);
228 for item in list {
229 collect_predicate_tables(item, tables);
230 }
231 }
232 Expr::Between {
233 expr, low, high, ..
234 } => {
235 collect_predicate_tables(expr, tables);
236 collect_predicate_tables(low, tables);
237 collect_predicate_tables(high, tables);
238 }
239 Expr::Subquery(subquery) => {
240 if let Some(ref from) = subquery.from {
242 for alias in collect_table_aliases(from) {
243 tables.insert(alias);
244 }
245 }
246 }
247 Expr::Literal(_) | Expr::Wildcard => {}
248 }
249}
250
251fn push_predicates_to_joins(
253 table_ref: TableReference,
254 predicates: &mut Vec<Expr>,
255) -> TableReference {
256 match table_ref {
257 TableReference::Join {
258 left,
259 right,
260 join_type,
261 on,
262 } => {
263 let optimized_left = push_predicates_to_joins(*left, predicates);
265 let optimized_right = push_predicates_to_joins(*right, predicates);
266
267 let left_tables = collect_table_aliases(&optimized_left);
269 let right_tables = collect_table_aliases(&optimized_right);
270 let all_tables: HashSet<String> = left_tables
271 .iter()
272 .chain(right_tables.iter())
273 .cloned()
274 .collect();
275
276 let mut join_predicates = Vec::new();
278 let mut remaining = Vec::new();
279
280 for predicate in predicates.drain(..) {
281 let pred_tables = get_predicate_tables(&predicate);
282
283 let can_push =
285 !pred_tables.is_empty() && pred_tables.iter().all(|t| all_tables.contains(t));
286
287 if can_push && join_type == JoinType::Inner {
290 join_predicates.push(predicate);
291 } else if can_push && join_type == JoinType::Cross {
292 join_predicates.push(predicate);
295 } else {
296 remaining.push(predicate);
297 }
298 }
299
300 *predicates = remaining;
301
302 let new_on = match (on, combine_predicates_with_and(join_predicates)) {
304 (Some(existing), Some(new_pred)) => Some(Expr::BinaryOp {
305 left: Box::new(existing),
306 op: BinaryOperator::And,
307 right: Box::new(new_pred),
308 }),
309 (Some(existing), None) => Some(existing),
310 (None, Some(new_pred)) => Some(new_pred),
311 (None, None) => None,
312 };
313
314 TableReference::Join {
315 left: Box::new(optimized_left),
316 right: Box::new(optimized_right),
317 join_type,
318 on: new_on,
319 }
320 }
321 TableReference::Subquery { query, alias } => {
322 let mut subquery_predicates = Vec::new();
324 let mut remaining = Vec::new();
325
326 for predicate in predicates.drain(..) {
327 let pred_tables = get_predicate_tables(&predicate);
328 if pred_tables.len() == 1 && pred_tables.contains(&alias) {
329 subquery_predicates.push(predicate);
330 } else {
331 remaining.push(predicate);
332 }
333 }
334
335 *predicates = remaining;
336
337 let mut optimized_query = *query;
339 if !subquery_predicates.is_empty() {
340 let combined = combine_predicates_with_and(subquery_predicates);
341 optimized_query.selection = match (optimized_query.selection, combined) {
342 (Some(existing), Some(new_pred)) => Some(Expr::BinaryOp {
343 left: Box::new(existing),
344 op: BinaryOperator::And,
345 right: Box::new(new_pred),
346 }),
347 (Some(existing), None) => Some(existing),
348 (None, Some(new_pred)) => Some(new_pred),
349 (None, None) => None,
350 };
351 }
352
353 TableReference::Subquery {
354 query: Box::new(optimized_query),
355 alias,
356 }
357 }
358 other => other,
359 }
360}
361
362pub(crate) fn combine_predicates_with_and(predicates: Vec<Expr>) -> Option<Expr> {
364 if predicates.is_empty() {
365 return None;
366 }
367
368 let mut iter = predicates.into_iter();
369 let first = iter.next()?;
370
371 Some(iter.fold(first, |acc, pred| Expr::BinaryOp {
372 left: Box::new(acc),
373 op: BinaryOperator::And,
374 right: Box::new(pred),
375 }))
376}
377
378pub(crate) fn collect_column_refs(expr: &Expr, columns: &mut HashSet<String>) {
380 match expr {
381 Expr::Column { table, name } => {
382 let full_name = if let Some(t) = table {
383 format!("{}.{}", t, name)
384 } else {
385 name.clone()
386 };
387 columns.insert(full_name);
388 }
389 Expr::BinaryOp { left, right, .. } => {
390 collect_column_refs(left, columns);
391 collect_column_refs(right, columns);
392 }
393 Expr::UnaryOp { expr, .. } => {
394 collect_column_refs(expr, columns);
395 }
396 Expr::Function { args, .. } => {
397 for arg in args {
398 collect_column_refs(arg, columns);
399 }
400 }
401 Expr::Case {
402 operand,
403 when_then,
404 else_result,
405 } => {
406 if let Some(op) = operand {
407 collect_column_refs(op, columns);
408 }
409 for (when, then) in when_then {
410 collect_column_refs(when, columns);
411 collect_column_refs(then, columns);
412 }
413 if let Some(else_expr) = else_result {
414 collect_column_refs(else_expr, columns);
415 }
416 }
417 Expr::Cast { expr, .. } => {
418 collect_column_refs(expr, columns);
419 }
420 Expr::IsNull(expr) | Expr::IsNotNull(expr) => {
421 collect_column_refs(expr, columns);
422 }
423 Expr::InList { expr, list, .. } => {
424 collect_column_refs(expr, columns);
425 for item in list {
426 collect_column_refs(item, columns);
427 }
428 }
429 Expr::Between {
430 expr, low, high, ..
431 } => {
432 collect_column_refs(expr, columns);
433 collect_column_refs(low, columns);
434 collect_column_refs(high, columns);
435 }
436 _ => {}
437 }
438}
439
440pub struct ConstantFolding;
448
449impl OptimizationRule for ConstantFolding {
450 fn apply(&self, mut stmt: SelectStatement) -> Result<SelectStatement> {
451 stmt.projection = stmt.projection.into_iter().map(fold_select_item).collect();
453
454 if let Some(selection) = stmt.selection {
456 stmt.selection = Some(fold_expr(selection));
457 }
458
459 if let Some(having) = stmt.having {
461 stmt.having = Some(fold_expr(having));
462 }
463
464 stmt.order_by = stmt
466 .order_by
467 .into_iter()
468 .map(|order| OrderByExpr {
469 expr: fold_expr(order.expr),
470 asc: order.asc,
471 nulls_first: order.nulls_first,
472 })
473 .collect();
474
475 Ok(stmt)
476 }
477}
478
479fn fold_select_item(item: SelectItem) -> SelectItem {
480 match item {
481 SelectItem::Expr { expr, alias } => SelectItem::Expr {
482 expr: fold_expr(expr),
483 alias,
484 },
485 other => other,
486 }
487}
488
489fn fold_expr(expr: Expr) -> Expr {
490 match expr {
491 Expr::BinaryOp { left, op, right } => {
492 let left = fold_expr(*left);
493 let right = fold_expr(*right);
494
495 if let (Expr::Literal(l), Expr::Literal(r)) = (&left, &right) {
497 if let Some(result) = try_fold_binary(l, op, r) {
498 return Expr::Literal(result);
499 }
500 }
501
502 Expr::BinaryOp {
503 left: Box::new(left),
504 op,
505 right: Box::new(right),
506 }
507 }
508 Expr::UnaryOp { op, expr } => {
509 let expr = fold_expr(*expr);
510 if let Expr::Literal(lit) = &expr {
511 if let Some(result) = try_fold_unary(op, lit) {
512 return Expr::Literal(result);
513 }
514 }
515 Expr::UnaryOp {
516 op,
517 expr: Box::new(expr),
518 }
519 }
520 Expr::Function { name, args } => {
521 let args = args.into_iter().map(fold_expr).collect();
522 Expr::Function { name, args }
523 }
524 Expr::Case {
525 operand,
526 when_then,
527 else_result,
528 } => {
529 let operand = operand.map(|e| Box::new(fold_expr(*e)));
530 let when_then = when_then
531 .into_iter()
532 .map(|(w, t)| (fold_expr(w), fold_expr(t)))
533 .collect();
534 let else_result = else_result.map(|e| Box::new(fold_expr(*e)));
535 Expr::Case {
536 operand,
537 when_then,
538 else_result,
539 }
540 }
541 other => other,
542 }
543}
544
545fn try_fold_binary(left: &Literal, op: BinaryOperator, right: &Literal) -> Option<Literal> {
546 match (left, right) {
547 (Literal::Integer(l), Literal::Integer(r)) => match op {
548 BinaryOperator::Plus => Some(Literal::Integer(l + r)),
549 BinaryOperator::Minus => Some(Literal::Integer(l - r)),
550 BinaryOperator::Multiply => Some(Literal::Integer(l * r)),
551 BinaryOperator::Divide if *r != 0 => Some(Literal::Integer(l / r)),
552 BinaryOperator::Modulo if *r != 0 => Some(Literal::Integer(l % r)),
553 BinaryOperator::Eq => Some(Literal::Boolean(l == r)),
554 BinaryOperator::NotEq => Some(Literal::Boolean(l != r)),
555 BinaryOperator::Lt => Some(Literal::Boolean(l < r)),
556 BinaryOperator::LtEq => Some(Literal::Boolean(l <= r)),
557 BinaryOperator::Gt => Some(Literal::Boolean(l > r)),
558 BinaryOperator::GtEq => Some(Literal::Boolean(l >= r)),
559 _ => None,
560 },
561 (Literal::Float(l), Literal::Float(r)) => match op {
562 BinaryOperator::Plus => Some(Literal::Float(l + r)),
563 BinaryOperator::Minus => Some(Literal::Float(l - r)),
564 BinaryOperator::Multiply => Some(Literal::Float(l * r)),
565 BinaryOperator::Divide if *r != 0.0 => Some(Literal::Float(l / r)),
566 BinaryOperator::Eq => Some(Literal::Boolean((l - r).abs() < f64::EPSILON)),
567 BinaryOperator::NotEq => Some(Literal::Boolean((l - r).abs() >= f64::EPSILON)),
568 BinaryOperator::Lt => Some(Literal::Boolean(l < r)),
569 BinaryOperator::LtEq => Some(Literal::Boolean(l <= r)),
570 BinaryOperator::Gt => Some(Literal::Boolean(l > r)),
571 BinaryOperator::GtEq => Some(Literal::Boolean(l >= r)),
572 _ => None,
573 },
574 (Literal::Boolean(l), Literal::Boolean(r)) => match op {
575 BinaryOperator::And => Some(Literal::Boolean(*l && *r)),
576 BinaryOperator::Or => Some(Literal::Boolean(*l || *r)),
577 BinaryOperator::Eq => Some(Literal::Boolean(l == r)),
578 BinaryOperator::NotEq => Some(Literal::Boolean(l != r)),
579 _ => None,
580 },
581 (Literal::String(l), Literal::String(r)) => match op {
582 BinaryOperator::Concat => Some(Literal::String(format!("{}{}", l, r))),
583 BinaryOperator::Eq => Some(Literal::Boolean(l == r)),
584 BinaryOperator::NotEq => Some(Literal::Boolean(l != r)),
585 _ => None,
586 },
587 _ => None,
588 }
589}
590
591fn try_fold_unary(op: UnaryOperator, lit: &Literal) -> Option<Literal> {
592 match (op, lit) {
593 (UnaryOperator::Minus, Literal::Integer(i)) => Some(Literal::Integer(-i)),
594 (UnaryOperator::Minus, Literal::Float(f)) => Some(Literal::Float(-f)),
595 (UnaryOperator::Not, Literal::Boolean(b)) => Some(Literal::Boolean(!b)),
596 _ => None,
597 }
598}
599
600pub struct FilterFusion;
608
609impl OptimizationRule for FilterFusion {
610 fn apply(&self, mut stmt: SelectStatement) -> Result<SelectStatement> {
611 if let Some(selection) = stmt.selection {
613 stmt.selection = Some(fuse_filters(selection));
614 }
615 Ok(stmt)
616 }
617}
618
619fn fuse_filters(expr: Expr) -> Expr {
620 match expr {
621 Expr::BinaryOp {
622 left,
623 op: BinaryOperator::And,
624 right,
625 } => {
626 let left = fuse_filters(*left);
627 let right = fuse_filters(*right);
628
629 let mut conditions = Vec::new();
631 collect_and_conditions(&left, &mut conditions);
632 collect_and_conditions(&right, &mut conditions);
633
634 if conditions.len() > 1 {
635 let mut result = conditions[0].clone();
637 for cond in &conditions[1..] {
638 result = Expr::BinaryOp {
639 left: Box::new(result),
640 op: BinaryOperator::And,
641 right: Box::new(cond.clone()),
642 };
643 }
644 result
645 } else {
646 Expr::BinaryOp {
647 left: Box::new(left),
648 op: BinaryOperator::And,
649 right: Box::new(right),
650 }
651 }
652 }
653 other => other,
654 }
655}
656
657fn collect_and_conditions(expr: &Expr, conditions: &mut Vec<Expr>) {
658 if let Expr::BinaryOp {
659 left,
660 op: BinaryOperator::And,
661 right,
662 } = expr
663 {
664 collect_and_conditions(left, conditions);
665 collect_and_conditions(right, conditions);
666 } else {
667 conditions.push(expr.clone());
668 }
669}
670
671pub fn optimize_with_rules(stmt: SelectStatement) -> Result<SelectStatement> {
677 let rules: Vec<Box<dyn OptimizationRule>> = vec![
678 Box::new(ConstantFolding),
679 Box::new(FilterFusion),
680 Box::new(ProjectionPushdown),
681 Box::new(PredicatePushdown),
682 Box::new(JoinReordering),
683 Box::new(CommonSubexpressionElimination),
684 ];
685
686 let mut current = stmt;
687 for rule in rules {
688 current = rule.apply(current)?;
689 }
690
691 Ok(current)
692}
693
694#[cfg(test)]
695#[allow(clippy::expect_used)]
696#[allow(clippy::unwrap_used)]
697mod tests {
698 use super::*;
699
700 #[test]
701 fn test_constant_folding_arithmetic() {
702 let expr = Expr::BinaryOp {
703 left: Box::new(Expr::Literal(Literal::Integer(10))),
704 op: BinaryOperator::Plus,
705 right: Box::new(Expr::Literal(Literal::Integer(20))),
706 };
707 let folded = fold_expr(expr);
708 assert_eq!(folded, Expr::Literal(Literal::Integer(30)));
709 }
710
711 #[test]
712 fn test_constant_folding_boolean() {
713 let expr = Expr::BinaryOp {
714 left: Box::new(Expr::Literal(Literal::Boolean(true))),
715 op: BinaryOperator::And,
716 right: Box::new(Expr::Literal(Literal::Boolean(false))),
717 };
718 let folded = fold_expr(expr);
719 assert_eq!(folded, Expr::Literal(Literal::Boolean(false)));
720 }
721
722 #[test]
723 fn test_unary_folding() {
724 let expr = Expr::UnaryOp {
725 op: UnaryOperator::Minus,
726 expr: Box::new(Expr::Literal(Literal::Integer(42))),
727 };
728 let folded = fold_expr(expr);
729 assert_eq!(folded, Expr::Literal(Literal::Integer(-42)));
730 }
731
732 #[test]
733 fn test_full_optimization_pipeline() {
734 let a_plus_b = Expr::BinaryOp {
735 left: Box::new(Expr::Column {
736 table: None,
737 name: "a".to_string(),
738 }),
739 op: BinaryOperator::Plus,
740 right: Box::new(Expr::Column {
741 table: None,
742 name: "b".to_string(),
743 }),
744 };
745
746 let stmt = SelectStatement {
747 projection: vec![
748 SelectItem::Expr {
749 expr: a_plus_b.clone(),
750 alias: None,
751 },
752 SelectItem::Expr {
753 expr: Expr::BinaryOp {
754 left: Box::new(Expr::Literal(Literal::Integer(1))),
755 op: BinaryOperator::Plus,
756 right: Box::new(Expr::Literal(Literal::Integer(2))),
757 },
758 alias: Some("constant".to_string()),
759 },
760 ],
761 from: Some(TableReference::Table {
762 name: "t".to_string(),
763 alias: None,
764 }),
765 selection: Some(Expr::BinaryOp {
766 left: Box::new(a_plus_b),
767 op: BinaryOperator::Gt,
768 right: Box::new(Expr::Literal(Literal::Integer(10))),
769 }),
770 group_by: Vec::new(),
771 having: None,
772 order_by: Vec::new(),
773 limit: None,
774 offset: None,
775 };
776
777 let result = optimize_with_rules(stmt).expect("Full optimization should succeed");
778
779 if let SelectItem::Expr { expr, .. } = &result.projection[1] {
781 assert_eq!(*expr, Expr::Literal(Literal::Integer(3)));
782 }
783 }
784}