1use std::collections::{HashMap, HashSet};
12
13use crate::dialects::DialectType;
14use crate::expressions::{BooleanLiteral, Expression};
15use crate::optimizer::normalize::normalized;
16use crate::optimizer::simplify::simplify;
17use crate::scope::{build_scope, Scope, SourceInfo};
18
19pub fn pushdown_predicates(expression: Expression, dialect: Option<DialectType>) -> Expression {
37 let root = build_scope(&expression);
38 let scope_ref_count = compute_ref_count(&root);
39
40 let unnest_requires_cross_join = matches!(
42 dialect,
43 Some(DialectType::Presto) | Some(DialectType::Trino) | Some(DialectType::Athena)
44 );
45
46 let mut result = expression.clone();
48 let scopes = collect_scopes(&root);
49
50 for scope in scopes.iter().rev() {
51 result = process_scope(
52 &result,
53 scope,
54 &scope_ref_count,
55 dialect,
56 unnest_requires_cross_join,
57 );
58 }
59
60 result
61}
62
63fn collect_scopes(root: &Scope) -> Vec<Scope> {
65 let mut result = vec![root.clone()];
66 for child in &root.subquery_scopes {
68 result.extend(collect_scopes(child));
69 }
70 for child in &root.derived_table_scopes {
72 result.extend(collect_scopes(child));
73 }
74 for child in &root.cte_scopes {
76 result.extend(collect_scopes(child));
77 }
78 for child in &root.union_scopes {
80 result.extend(collect_scopes(child));
81 }
82 result
83}
84
85fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
87 let mut counts = HashMap::new();
88 compute_ref_count_recursive(root, &mut counts);
89 counts
90}
91
92fn compute_ref_count_recursive(scope: &Scope, counts: &mut HashMap<u64, usize>) {
93 let id = scope as *const Scope as u64;
95 *counts.entry(id).or_insert(0) += 1;
96
97 for child in &scope.subquery_scopes {
98 compute_ref_count_recursive(child, counts);
99 }
100 for child in &scope.derived_table_scopes {
101 compute_ref_count_recursive(child, counts);
102 }
103 for child in &scope.cte_scopes {
104 compute_ref_count_recursive(child, counts);
105 }
106 for child in &scope.union_scopes {
107 compute_ref_count_recursive(child, counts);
108 }
109}
110
111fn process_scope(
113 expression: &Expression,
114 scope: &Scope,
115 _scope_ref_count: &HashMap<u64, usize>,
116 dialect: Option<DialectType>,
117 _unnest_requires_cross_join: bool,
118) -> Expression {
119 let result = expression.clone();
120
121 let (where_condition, join_conditions, join_index) = if let Expression::Select(select) = &result
123 {
124 let where_cond = select.where_clause.as_ref().map(|w| w.this.clone());
125
126 let mut idx: HashMap<String, usize> = HashMap::new();
127 for (i, join) in select.joins.iter().enumerate() {
128 if let Some(name) = get_table_alias_or_name(&join.this) {
129 idx.insert(name, i);
130 }
131 }
132
133 let join_conds: Vec<Expression> =
134 select.joins.iter().filter_map(|j| j.on.clone()).collect();
135
136 (where_cond, join_conds, idx)
137 } else {
138 (None, vec![], HashMap::new())
139 };
140
141 let mut result = result;
142
143 if let Some(where_cond) = where_condition {
145 let simplified = simplify(where_cond, dialect);
146 result = pushdown_impl(
147 result,
148 &simplified,
149 &scope.sources,
150 dialect,
151 Some(&join_index),
152 );
153 }
154
155 for join_cond in join_conditions {
157 let simplified = simplify(join_cond, dialect);
158 result = pushdown_impl(result, &simplified, &scope.sources, dialect, None);
159 }
160
161 result
162}
163
164fn pushdown_impl(
166 expression: Expression,
167 condition: &Expression,
168 sources: &HashMap<String, SourceInfo>,
169 _dialect: Option<DialectType>,
170 join_index: Option<&HashMap<String, usize>>,
171) -> Expression {
172 let is_cnf = normalized(condition, false); let is_dnf = normalized(condition, true); let cnf_like = is_cnf || !is_dnf;
176
177 let predicates = flatten_predicates(condition, cnf_like);
179
180 if cnf_like {
181 pushdown_cnf(expression, &predicates, sources, join_index)
182 } else {
183 pushdown_dnf(expression, &predicates, sources)
184 }
185}
186
187fn flatten_predicates(expr: &Expression, cnf_like: bool) -> Vec<Expression> {
189 if cnf_like {
190 flatten_and(expr)
192 } else {
193 flatten_or(expr)
195 }
196}
197
198fn flatten_and(expr: &Expression) -> Vec<Expression> {
199 match expr {
200 Expression::And(bin) => {
201 let mut result = flatten_and(&bin.left);
202 result.extend(flatten_and(&bin.right));
203 result
204 }
205 Expression::Paren(p) => flatten_and(&p.this),
206 other => vec![other.clone()],
207 }
208}
209
210fn flatten_or(expr: &Expression) -> Vec<Expression> {
211 match expr {
212 Expression::Or(bin) => {
213 let mut result = flatten_or(&bin.left);
214 result.extend(flatten_or(&bin.right));
215 result
216 }
217 Expression::Paren(p) => flatten_or(&p.this),
218 other => vec![other.clone()],
219 }
220}
221
222fn pushdown_cnf(
224 expression: Expression,
225 predicates: &[Expression],
226 sources: &HashMap<String, SourceInfo>,
227 join_index: Option<&HashMap<String, usize>>,
228) -> Expression {
229 let mut result = expression;
230
231 for predicate in predicates {
232 let nodes = nodes_for_predicate(predicate, sources);
233
234 for (table_name, node_expr) in nodes {
235 if let Some(join_idx) = join_index {
237 if let Some(&this_index) = join_idx.get(&table_name) {
238 let predicate_tables = get_column_table_names(predicate);
239
240 let can_push = predicate_tables
242 .iter()
243 .all(|t| join_idx.get(t).map_or(true, |&idx| idx <= this_index));
244
245 if can_push {
246 result = push_predicate_to_node(&result, predicate, &node_expr);
247 }
248 }
249 } else {
250 result = push_predicate_to_node(&result, predicate, &node_expr);
251 }
252 }
253 }
254
255 result
256}
257
258fn pushdown_dnf(
260 expression: Expression,
261 predicates: &[Expression],
262 sources: &HashMap<String, SourceInfo>,
263) -> Expression {
264 let mut pushdown_tables: HashSet<String> = HashSet::new();
267
268 for a in predicates {
269 let a_tables: HashSet<String> = get_column_table_names(a).into_iter().collect();
270
271 let common: HashSet<String> = predicates.iter().fold(a_tables, |acc, b| {
272 let b_tables: HashSet<String> = get_column_table_names(b).into_iter().collect();
273 acc.intersection(&b_tables).cloned().collect()
274 });
275
276 pushdown_tables.extend(common);
277 }
278
279 let mut result = expression;
280
281 let mut conditions: HashMap<String, Expression> = HashMap::new();
283
284 for table in &pushdown_tables {
285 for predicate in predicates {
286 let nodes = nodes_for_predicate(predicate, sources);
287
288 if nodes.contains_key(table) {
289 let existing = conditions.remove(table);
290 conditions.insert(
291 table.clone(),
292 if let Some(existing) = existing {
293 make_or(existing, predicate.clone())
294 } else {
295 predicate.clone()
296 },
297 );
298 }
299 }
300 }
301
302 for (table, condition) in conditions {
304 if let Some(source_info) = sources.get(&table) {
305 result = push_predicate_to_node(&result, &condition, &source_info.expression);
306 }
307 }
308
309 result
310}
311
312fn nodes_for_predicate(
314 predicate: &Expression,
315 sources: &HashMap<String, SourceInfo>,
316) -> HashMap<String, Expression> {
317 let mut nodes = HashMap::new();
318 let tables = get_column_table_names(predicate);
319
320 for table in tables {
321 if let Some(source_info) = sources.get(&table) {
322 nodes.insert(table, source_info.expression.clone());
329 }
330 }
331
332 nodes
333}
334
335fn push_predicate_to_node(
337 expression: &Expression,
338 _predicate: &Expression,
339 _target_node: &Expression,
340) -> Expression {
341 expression.clone()
348}
349
350fn get_column_table_names(expr: &Expression) -> Vec<String> {
352 let mut tables = Vec::new();
353 collect_column_tables(expr, &mut tables);
354 tables
355}
356
357fn collect_column_tables(expr: &Expression, tables: &mut Vec<String>) {
358 match expr {
359 Expression::Column(col) => {
360 if let Some(ref table) = col.table {
361 tables.push(table.name.clone());
362 }
363 }
364 Expression::And(bin) | Expression::Or(bin) => {
365 collect_column_tables(&bin.left, tables);
366 collect_column_tables(&bin.right, tables);
367 }
368 Expression::Eq(bin)
369 | Expression::Neq(bin)
370 | Expression::Lt(bin)
371 | Expression::Lte(bin)
372 | Expression::Gt(bin)
373 | Expression::Gte(bin) => {
374 collect_column_tables(&bin.left, tables);
375 collect_column_tables(&bin.right, tables);
376 }
377 Expression::Not(un) => {
378 collect_column_tables(&un.this, tables);
379 }
380 Expression::Paren(p) => {
381 collect_column_tables(&p.this, tables);
382 }
383 Expression::In(in_expr) => {
384 collect_column_tables(&in_expr.this, tables);
385 for e in &in_expr.expressions {
386 collect_column_tables(e, tables);
387 }
388 }
389 Expression::Between(between) => {
390 collect_column_tables(&between.this, tables);
391 collect_column_tables(&between.low, tables);
392 collect_column_tables(&between.high, tables);
393 }
394 Expression::IsNull(is_null) => {
395 collect_column_tables(&is_null.this, tables);
396 }
397 Expression::Like(like) => {
398 collect_column_tables(&like.left, tables);
399 collect_column_tables(&like.right, tables);
400 }
401 Expression::Function(func) => {
402 for arg in &func.args {
403 collect_column_tables(arg, tables);
404 }
405 }
406 Expression::AggregateFunction(agg) => {
407 for arg in &agg.args {
408 collect_column_tables(arg, tables);
409 }
410 }
411 _ => {}
412 }
413}
414
415fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
417 match expr {
418 Expression::Table(table) => {
419 if let Some(ref alias) = table.alias {
420 Some(alias.name.clone())
421 } else {
422 Some(table.name.name.clone())
423 }
424 }
425 Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
426 Expression::Alias(alias) => Some(alias.alias.name.clone()),
427 _ => None,
428 }
429}
430
431fn make_or(left: Expression, right: Expression) -> Expression {
433 Expression::Or(Box::new(crate::expressions::BinaryOp {
434 left,
435 right,
436 left_comments: vec![],
437 operator_comments: vec![],
438 trailing_comments: vec![],
439 inferred_type: None,
440 }))
441}
442
443pub fn replace_aliases(source: &Expression, predicate: Expression) -> Expression {
445 let mut aliases: HashMap<String, Expression> = HashMap::new();
447
448 if let Expression::Select(select) = source {
449 for select_expr in &select.expressions {
450 match select_expr {
451 Expression::Alias(alias) => {
452 aliases.insert(alias.alias.name.clone(), alias.this.clone());
453 }
454 Expression::Column(col) => {
455 aliases.insert(col.name.name.clone(), select_expr.clone());
456 }
457 _ => {}
458 }
459 }
460 }
461
462 replace_aliases_recursive(predicate, &aliases)
464}
465
466fn replace_aliases_recursive(
467 expr: Expression,
468 aliases: &HashMap<String, Expression>,
469) -> Expression {
470 match expr {
471 Expression::Column(col) => {
472 if let Some(replacement) = aliases.get(&col.name.name) {
473 replacement.clone()
474 } else {
475 Expression::Column(col)
476 }
477 }
478 Expression::And(bin) => {
479 let left = replace_aliases_recursive(bin.left, aliases);
480 let right = replace_aliases_recursive(bin.right, aliases);
481 Expression::And(Box::new(crate::expressions::BinaryOp {
482 left,
483 right,
484 left_comments: bin.left_comments,
485 operator_comments: bin.operator_comments,
486 trailing_comments: bin.trailing_comments,
487 inferred_type: None,
488 }))
489 }
490 Expression::Or(bin) => {
491 let left = replace_aliases_recursive(bin.left, aliases);
492 let right = replace_aliases_recursive(bin.right, aliases);
493 Expression::Or(Box::new(crate::expressions::BinaryOp {
494 left,
495 right,
496 left_comments: bin.left_comments,
497 operator_comments: bin.operator_comments,
498 trailing_comments: bin.trailing_comments,
499 inferred_type: None,
500 }))
501 }
502 Expression::Eq(bin) => {
503 let left = replace_aliases_recursive(bin.left, aliases);
504 let right = replace_aliases_recursive(bin.right, aliases);
505 Expression::Eq(Box::new(crate::expressions::BinaryOp {
506 left,
507 right,
508 left_comments: bin.left_comments,
509 operator_comments: bin.operator_comments,
510 trailing_comments: bin.trailing_comments,
511 inferred_type: None,
512 }))
513 }
514 Expression::Neq(bin) => {
515 let left = replace_aliases_recursive(bin.left, aliases);
516 let right = replace_aliases_recursive(bin.right, aliases);
517 Expression::Neq(Box::new(crate::expressions::BinaryOp {
518 left,
519 right,
520 left_comments: bin.left_comments,
521 operator_comments: bin.operator_comments,
522 trailing_comments: bin.trailing_comments,
523 inferred_type: None,
524 }))
525 }
526 Expression::Lt(bin) => {
527 let left = replace_aliases_recursive(bin.left, aliases);
528 let right = replace_aliases_recursive(bin.right, aliases);
529 Expression::Lt(Box::new(crate::expressions::BinaryOp {
530 left,
531 right,
532 left_comments: bin.left_comments,
533 operator_comments: bin.operator_comments,
534 trailing_comments: bin.trailing_comments,
535 inferred_type: None,
536 }))
537 }
538 Expression::Gt(bin) => {
539 let left = replace_aliases_recursive(bin.left, aliases);
540 let right = replace_aliases_recursive(bin.right, aliases);
541 Expression::Gt(Box::new(crate::expressions::BinaryOp {
542 left,
543 right,
544 left_comments: bin.left_comments,
545 operator_comments: bin.operator_comments,
546 trailing_comments: bin.trailing_comments,
547 inferred_type: None,
548 }))
549 }
550 Expression::Lte(bin) => {
551 let left = replace_aliases_recursive(bin.left, aliases);
552 let right = replace_aliases_recursive(bin.right, aliases);
553 Expression::Lte(Box::new(crate::expressions::BinaryOp {
554 left,
555 right,
556 left_comments: bin.left_comments,
557 operator_comments: bin.operator_comments,
558 trailing_comments: bin.trailing_comments,
559 inferred_type: None,
560 }))
561 }
562 Expression::Gte(bin) => {
563 let left = replace_aliases_recursive(bin.left, aliases);
564 let right = replace_aliases_recursive(bin.right, aliases);
565 Expression::Gte(Box::new(crate::expressions::BinaryOp {
566 left,
567 right,
568 left_comments: bin.left_comments,
569 operator_comments: bin.operator_comments,
570 trailing_comments: bin.trailing_comments,
571 inferred_type: None,
572 }))
573 }
574 Expression::Not(un) => {
575 let inner = replace_aliases_recursive(un.this, aliases);
576 Expression::Not(Box::new(crate::expressions::UnaryOp {
577 this: inner,
578 inferred_type: None,
579 }))
580 }
581 Expression::Paren(paren) => {
582 let inner = replace_aliases_recursive(paren.this, aliases);
583 Expression::Paren(Box::new(crate::expressions::Paren {
584 this: inner,
585 trailing_comments: paren.trailing_comments,
586 }))
587 }
588 other => other,
589 }
590}
591
592pub fn make_true() -> Expression {
594 Expression::Boolean(BooleanLiteral { value: true })
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use crate::generator::Generator;
601 use crate::parser::Parser;
602
603 fn gen(expr: &Expression) -> String {
604 Generator::new().generate(expr).unwrap()
605 }
606
607 fn parse(sql: &str) -> Expression {
608 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
609 }
610
611 #[test]
612 fn test_pushdown_simple() {
613 let expr = parse("SELECT a FROM t WHERE a = 1");
614 let result = pushdown_predicates(expr, None);
615 let sql = gen(&result);
616 assert!(sql.contains("WHERE"));
617 }
618
619 #[test]
620 fn test_pushdown_preserves_structure() {
621 let expr = parse("SELECT y.a FROM (SELECT x.a FROM x) AS y WHERE y.a = 1");
622 let result = pushdown_predicates(expr, None);
623 let sql = gen(&result);
624 assert!(sql.contains("SELECT"));
625 }
626
627 #[test]
628 fn test_get_column_table_names() {
629 let expr = parse("SELECT 1 WHERE t.a = 1 AND s.b = 2");
630 if let Expression::Select(select) = &expr {
631 if let Some(where_clause) = &select.where_clause {
632 let tables = get_column_table_names(&where_clause.this);
633 assert!(tables.contains(&"t".to_string()));
634 assert!(tables.contains(&"s".to_string()));
635 }
636 }
637 }
638
639 #[test]
640 fn test_flatten_and() {
641 let expr = parse("SELECT 1 WHERE a = 1 AND b = 2 AND c = 3");
642 if let Expression::Select(select) = &expr {
643 if let Some(where_clause) = &select.where_clause {
644 let predicates = flatten_and(&where_clause.this);
645 assert_eq!(predicates.len(), 3);
646 }
647 }
648 }
649
650 #[test]
651 fn test_flatten_or() {
652 let expr = parse("SELECT 1 WHERE a = 1 OR b = 2 OR c = 3");
653 if let Expression::Select(select) = &expr {
654 if let Some(where_clause) = &select.where_clause {
655 let predicates = flatten_or(&where_clause.this);
656 assert_eq!(predicates.len(), 3);
657 }
658 }
659 }
660
661 #[test]
662 fn test_replace_aliases() {
663 let source = parse("SELECT x.a AS col_a FROM x");
664 let predicate = parse("SELECT 1 WHERE col_a = 1");
665
666 if let Expression::Select(select) = &predicate {
667 if let Some(where_clause) = &select.where_clause {
668 let replaced = replace_aliases(&source, where_clause.this.clone());
669 let sql = gen(&replaced);
671 assert!(sql.contains("="));
672 }
673 }
674 }
675
676 #[test]
677 fn test_pushdown_with_join() {
678 let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id WHERE t.a = 1");
679 let result = pushdown_predicates(expr, None);
680 let sql = gen(&result);
681 assert!(sql.contains("JOIN"));
682 }
683
684 #[test]
685 fn test_pushdown_complex_and() {
686 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2 AND c < 3");
687 let result = pushdown_predicates(expr, None);
688 let sql = gen(&result);
689 assert!(sql.contains("AND"));
690 }
691
692 #[test]
693 fn test_pushdown_complex_or() {
694 let expr = parse("SELECT 1 WHERE a = 1 OR b = 2");
695 let result = pushdown_predicates(expr, None);
696 let sql = gen(&result);
697 assert!(sql.contains("OR"));
698 }
699
700 #[test]
701 fn test_normalized_dnf_simple() {
702 let expr = parse("SELECT 1 WHERE a = 1");
704 if let Expression::Select(select) = &expr {
705 if let Some(where_clause) = &select.where_clause {
706 assert!(normalized(&where_clause.this, true));
708 }
709 }
710 }
711
712 #[test]
713 fn test_make_true() {
714 let t = make_true();
715 let sql = gen(&t);
716 assert_eq!(sql, "TRUE");
717 }
718}