1use std::collections::{HashMap, HashSet};
12
13use crate::dialects::DialectType;
14use crate::expressions::{BooleanLiteral, Expression};
15use crate::optimizer::normalize::normalized;
16use crate::optimizer::simplify::simplify;
17use crate::scope::{build_scope, Scope, SourceInfo};
18
19pub fn pushdown_predicates(
37 expression: Expression,
38 dialect: Option<DialectType>,
39) -> Expression {
40 let root = build_scope(&expression);
41 let scope_ref_count = compute_ref_count(&root);
42
43 let unnest_requires_cross_join = matches!(
45 dialect,
46 Some(DialectType::Presto) | Some(DialectType::Trino) | Some(DialectType::Athena)
47 );
48
49 let mut result = expression.clone();
51 let scopes = collect_scopes(&root);
52
53 for scope in scopes.iter().rev() {
54 result = process_scope(&result, scope, &scope_ref_count, dialect, unnest_requires_cross_join);
55 }
56
57 result
58}
59
60fn collect_scopes(root: &Scope) -> Vec<Scope> {
62 let mut result = vec![root.clone()];
63 for child in &root.subquery_scopes {
65 result.extend(collect_scopes(child));
66 }
67 for child in &root.derived_table_scopes {
69 result.extend(collect_scopes(child));
70 }
71 for child in &root.cte_scopes {
73 result.extend(collect_scopes(child));
74 }
75 for child in &root.union_scopes {
77 result.extend(collect_scopes(child));
78 }
79 result
80}
81
82fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
84 let mut counts = HashMap::new();
85 compute_ref_count_recursive(root, &mut counts);
86 counts
87}
88
89fn compute_ref_count_recursive(scope: &Scope, counts: &mut HashMap<u64, usize>) {
90 let id = scope as *const Scope as u64;
92 *counts.entry(id).or_insert(0) += 1;
93
94 for child in &scope.subquery_scopes {
95 compute_ref_count_recursive(child, counts);
96 }
97 for child in &scope.derived_table_scopes {
98 compute_ref_count_recursive(child, counts);
99 }
100 for child in &scope.cte_scopes {
101 compute_ref_count_recursive(child, counts);
102 }
103 for child in &scope.union_scopes {
104 compute_ref_count_recursive(child, counts);
105 }
106}
107
108fn process_scope(
110 expression: &Expression,
111 scope: &Scope,
112 _scope_ref_count: &HashMap<u64, usize>,
113 dialect: Option<DialectType>,
114 _unnest_requires_cross_join: bool,
115) -> Expression {
116 let result = expression.clone();
117
118 let (where_condition, join_conditions, join_index) = if let Expression::Select(select) = &result {
120 let where_cond = select.where_clause.as_ref().map(|w| w.this.clone());
121
122 let mut idx: HashMap<String, usize> = HashMap::new();
123 for (i, join) in select.joins.iter().enumerate() {
124 if let Some(name) = get_table_alias_or_name(&join.this) {
125 idx.insert(name, i);
126 }
127 }
128
129 let join_conds: Vec<Expression> = select.joins.iter()
130 .filter_map(|j| j.on.clone())
131 .collect();
132
133 (where_cond, join_conds, idx)
134 } else {
135 (None, vec![], HashMap::new())
136 };
137
138 let mut result = result;
139
140 if let Some(where_cond) = where_condition {
142 let simplified = simplify(where_cond, dialect);
143 result = pushdown_impl(
144 result,
145 &simplified,
146 &scope.sources,
147 dialect,
148 Some(&join_index),
149 );
150 }
151
152 for join_cond in join_conditions {
154 let simplified = simplify(join_cond, dialect);
155 result = pushdown_impl(
156 result,
157 &simplified,
158 &scope.sources,
159 dialect,
160 None,
161 );
162 }
163
164 result
165}
166
167fn pushdown_impl(
169 expression: Expression,
170 condition: &Expression,
171 sources: &HashMap<String, SourceInfo>,
172 _dialect: Option<DialectType>,
173 join_index: Option<&HashMap<String, usize>>,
174) -> Expression {
175 let is_cnf = normalized(condition, false); let is_dnf = normalized(condition, true); let cnf_like = is_cnf || !is_dnf;
179
180 let predicates = flatten_predicates(condition, cnf_like);
182
183 if cnf_like {
184 pushdown_cnf(expression, &predicates, sources, join_index)
185 } else {
186 pushdown_dnf(expression, &predicates, sources)
187 }
188}
189
190fn flatten_predicates(expr: &Expression, cnf_like: bool) -> Vec<Expression> {
192 if cnf_like {
193 flatten_and(expr)
195 } else {
196 flatten_or(expr)
198 }
199}
200
201fn flatten_and(expr: &Expression) -> Vec<Expression> {
202 match expr {
203 Expression::And(bin) => {
204 let mut result = flatten_and(&bin.left);
205 result.extend(flatten_and(&bin.right));
206 result
207 }
208 Expression::Paren(p) => flatten_and(&p.this),
209 other => vec![other.clone()],
210 }
211}
212
213fn flatten_or(expr: &Expression) -> Vec<Expression> {
214 match expr {
215 Expression::Or(bin) => {
216 let mut result = flatten_or(&bin.left);
217 result.extend(flatten_or(&bin.right));
218 result
219 }
220 Expression::Paren(p) => flatten_or(&p.this),
221 other => vec![other.clone()],
222 }
223}
224
225fn pushdown_cnf(
227 expression: Expression,
228 predicates: &[Expression],
229 sources: &HashMap<String, SourceInfo>,
230 join_index: Option<&HashMap<String, usize>>,
231) -> Expression {
232 let mut result = expression;
233
234 for predicate in predicates {
235 let nodes = nodes_for_predicate(predicate, sources);
236
237 for (table_name, node_expr) in nodes {
238 if let Some(join_idx) = join_index {
240 if let Some(&this_index) = join_idx.get(&table_name) {
241 let predicate_tables = get_column_table_names(predicate);
242
243 let can_push = predicate_tables.iter().all(|t| {
245 join_idx.get(t).map_or(true, |&idx| idx <= this_index)
246 });
247
248 if can_push {
249 result = push_predicate_to_node(&result, predicate, &node_expr);
250 }
251 }
252 } else {
253 result = push_predicate_to_node(&result, predicate, &node_expr);
254 }
255 }
256 }
257
258 result
259}
260
261fn pushdown_dnf(
263 expression: Expression,
264 predicates: &[Expression],
265 sources: &HashMap<String, SourceInfo>,
266) -> Expression {
267 let mut pushdown_tables: HashSet<String> = HashSet::new();
270
271 for a in predicates {
272 let a_tables: HashSet<String> = get_column_table_names(a).into_iter().collect();
273
274 let common: HashSet<String> = predicates
275 .iter()
276 .fold(a_tables, |acc, b| {
277 let b_tables: HashSet<String> = get_column_table_names(b).into_iter().collect();
278 acc.intersection(&b_tables).cloned().collect()
279 });
280
281 pushdown_tables.extend(common);
282 }
283
284 let mut result = expression;
285
286 let mut conditions: HashMap<String, Expression> = HashMap::new();
288
289 for table in &pushdown_tables {
290 for predicate in predicates {
291 let nodes = nodes_for_predicate(predicate, sources);
292
293 if nodes.contains_key(table) {
294 let existing = conditions.remove(table);
295 conditions.insert(
296 table.clone(),
297 if let Some(existing) = existing {
298 make_or(existing, predicate.clone())
299 } else {
300 predicate.clone()
301 },
302 );
303 }
304 }
305 }
306
307 for (table, condition) in conditions {
309 if let Some(source_info) = sources.get(&table) {
310 result = push_predicate_to_node(&result, &condition, &source_info.expression);
311 }
312 }
313
314 result
315}
316
317fn nodes_for_predicate(
319 predicate: &Expression,
320 sources: &HashMap<String, SourceInfo>,
321) -> HashMap<String, Expression> {
322 let mut nodes = HashMap::new();
323 let tables = get_column_table_names(predicate);
324
325 for table in tables {
326 if let Some(source_info) = sources.get(&table) {
327 nodes.insert(table, source_info.expression.clone());
334 }
335 }
336
337 nodes
338}
339
340fn push_predicate_to_node(
342 expression: &Expression,
343 _predicate: &Expression,
344 _target_node: &Expression,
345) -> Expression {
346 expression.clone()
353}
354
355fn get_column_table_names(expr: &Expression) -> Vec<String> {
357 let mut tables = Vec::new();
358 collect_column_tables(expr, &mut tables);
359 tables
360}
361
362fn collect_column_tables(expr: &Expression, tables: &mut Vec<String>) {
363 match expr {
364 Expression::Column(col) => {
365 if let Some(ref table) = col.table {
366 tables.push(table.name.clone());
367 }
368 }
369 Expression::And(bin) | Expression::Or(bin) => {
370 collect_column_tables(&bin.left, tables);
371 collect_column_tables(&bin.right, tables);
372 }
373 Expression::Eq(bin) | Expression::Neq(bin) | Expression::Lt(bin) |
374 Expression::Lte(bin) | Expression::Gt(bin) | Expression::Gte(bin) => {
375 collect_column_tables(&bin.left, tables);
376 collect_column_tables(&bin.right, tables);
377 }
378 Expression::Not(un) => {
379 collect_column_tables(&un.this, tables);
380 }
381 Expression::Paren(p) => {
382 collect_column_tables(&p.this, tables);
383 }
384 Expression::In(in_expr) => {
385 collect_column_tables(&in_expr.this, tables);
386 for e in &in_expr.expressions {
387 collect_column_tables(e, tables);
388 }
389 }
390 Expression::Between(between) => {
391 collect_column_tables(&between.this, tables);
392 collect_column_tables(&between.low, tables);
393 collect_column_tables(&between.high, tables);
394 }
395 Expression::IsNull(is_null) => {
396 collect_column_tables(&is_null.this, tables);
397 }
398 Expression::Like(like) => {
399 collect_column_tables(&like.left, tables);
400 collect_column_tables(&like.right, tables);
401 }
402 Expression::Function(func) => {
403 for arg in &func.args {
404 collect_column_tables(arg, tables);
405 }
406 }
407 Expression::AggregateFunction(agg) => {
408 for arg in &agg.args {
409 collect_column_tables(arg, tables);
410 }
411 }
412 _ => {}
413 }
414}
415
416fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
418 match expr {
419 Expression::Table(table) => {
420 if let Some(ref alias) = table.alias {
421 Some(alias.name.clone())
422 } else {
423 Some(table.name.name.clone())
424 }
425 }
426 Expression::Subquery(subquery) => {
427 subquery.alias.as_ref().map(|a| a.name.clone())
428 }
429 Expression::Alias(alias) => {
430 Some(alias.alias.name.clone())
431 }
432 _ => None,
433 }
434}
435
436fn make_or(left: Expression, right: Expression) -> Expression {
438 Expression::Or(Box::new(crate::expressions::BinaryOp {
439 left,
440 right,
441 left_comments: vec![],
442 operator_comments: vec![],
443 trailing_comments: vec![],
444 }))
445}
446
447pub fn replace_aliases(
449 source: &Expression,
450 predicate: Expression,
451) -> Expression {
452 let mut aliases: HashMap<String, Expression> = HashMap::new();
454
455 if let Expression::Select(select) = source {
456 for select_expr in &select.expressions {
457 match select_expr {
458 Expression::Alias(alias) => {
459 aliases.insert(alias.alias.name.clone(), alias.this.clone());
460 }
461 Expression::Column(col) => {
462 aliases.insert(col.name.name.clone(), select_expr.clone());
463 }
464 _ => {}
465 }
466 }
467 }
468
469 replace_aliases_recursive(predicate, &aliases)
471}
472
473fn replace_aliases_recursive(
474 expr: Expression,
475 aliases: &HashMap<String, Expression>,
476) -> Expression {
477 match expr {
478 Expression::Column(col) => {
479 if let Some(replacement) = aliases.get(&col.name.name) {
480 replacement.clone()
481 } else {
482 Expression::Column(col)
483 }
484 }
485 Expression::And(bin) => {
486 let left = replace_aliases_recursive(bin.left, aliases);
487 let right = replace_aliases_recursive(bin.right, aliases);
488 Expression::And(Box::new(crate::expressions::BinaryOp {
489 left,
490 right,
491 left_comments: bin.left_comments,
492 operator_comments: bin.operator_comments,
493 trailing_comments: bin.trailing_comments,
494 }))
495 }
496 Expression::Or(bin) => {
497 let left = replace_aliases_recursive(bin.left, aliases);
498 let right = replace_aliases_recursive(bin.right, aliases);
499 Expression::Or(Box::new(crate::expressions::BinaryOp {
500 left,
501 right,
502 left_comments: bin.left_comments,
503 operator_comments: bin.operator_comments,
504 trailing_comments: bin.trailing_comments,
505 }))
506 }
507 Expression::Eq(bin) => {
508 let left = replace_aliases_recursive(bin.left, aliases);
509 let right = replace_aliases_recursive(bin.right, aliases);
510 Expression::Eq(Box::new(crate::expressions::BinaryOp {
511 left,
512 right,
513 left_comments: bin.left_comments,
514 operator_comments: bin.operator_comments,
515 trailing_comments: bin.trailing_comments,
516 }))
517 }
518 Expression::Neq(bin) => {
519 let left = replace_aliases_recursive(bin.left, aliases);
520 let right = replace_aliases_recursive(bin.right, aliases);
521 Expression::Neq(Box::new(crate::expressions::BinaryOp {
522 left,
523 right,
524 left_comments: bin.left_comments,
525 operator_comments: bin.operator_comments,
526 trailing_comments: bin.trailing_comments,
527 }))
528 }
529 Expression::Lt(bin) => {
530 let left = replace_aliases_recursive(bin.left, aliases);
531 let right = replace_aliases_recursive(bin.right, aliases);
532 Expression::Lt(Box::new(crate::expressions::BinaryOp {
533 left,
534 right,
535 left_comments: bin.left_comments,
536 operator_comments: bin.operator_comments,
537 trailing_comments: bin.trailing_comments,
538 }))
539 }
540 Expression::Gt(bin) => {
541 let left = replace_aliases_recursive(bin.left, aliases);
542 let right = replace_aliases_recursive(bin.right, aliases);
543 Expression::Gt(Box::new(crate::expressions::BinaryOp {
544 left,
545 right,
546 left_comments: bin.left_comments,
547 operator_comments: bin.operator_comments,
548 trailing_comments: bin.trailing_comments,
549 }))
550 }
551 Expression::Lte(bin) => {
552 let left = replace_aliases_recursive(bin.left, aliases);
553 let right = replace_aliases_recursive(bin.right, aliases);
554 Expression::Lte(Box::new(crate::expressions::BinaryOp {
555 left,
556 right,
557 left_comments: bin.left_comments,
558 operator_comments: bin.operator_comments,
559 trailing_comments: bin.trailing_comments,
560 }))
561 }
562 Expression::Gte(bin) => {
563 let left = replace_aliases_recursive(bin.left, aliases);
564 let right = replace_aliases_recursive(bin.right, aliases);
565 Expression::Gte(Box::new(crate::expressions::BinaryOp {
566 left,
567 right,
568 left_comments: bin.left_comments,
569 operator_comments: bin.operator_comments,
570 trailing_comments: bin.trailing_comments,
571 }))
572 }
573 Expression::Not(un) => {
574 let inner = replace_aliases_recursive(un.this, aliases);
575 Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
576 }
577 Expression::Paren(paren) => {
578 let inner = replace_aliases_recursive(paren.this, aliases);
579 Expression::Paren(Box::new(crate::expressions::Paren {
580 this: inner,
581 trailing_comments: paren.trailing_comments,
582 }))
583 }
584 other => other,
585 }
586}
587
588pub fn make_true() -> Expression {
590 Expression::Boolean(BooleanLiteral { value: true })
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596 use crate::generator::Generator;
597 use crate::parser::Parser;
598
599 fn gen(expr: &Expression) -> String {
600 Generator::new().generate(expr).unwrap()
601 }
602
603 fn parse(sql: &str) -> Expression {
604 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
605 }
606
607 #[test]
608 fn test_pushdown_simple() {
609 let expr = parse("SELECT a FROM t WHERE a = 1");
610 let result = pushdown_predicates(expr, None);
611 let sql = gen(&result);
612 assert!(sql.contains("WHERE"));
613 }
614
615 #[test]
616 fn test_pushdown_preserves_structure() {
617 let expr = parse("SELECT y.a FROM (SELECT x.a FROM x) AS y WHERE y.a = 1");
618 let result = pushdown_predicates(expr, None);
619 let sql = gen(&result);
620 assert!(sql.contains("SELECT"));
621 }
622
623 #[test]
624 fn test_get_column_table_names() {
625 let expr = parse("SELECT 1 WHERE t.a = 1 AND s.b = 2");
626 if let Expression::Select(select) = &expr {
627 if let Some(where_clause) = &select.where_clause {
628 let tables = get_column_table_names(&where_clause.this);
629 assert!(tables.contains(&"t".to_string()));
630 assert!(tables.contains(&"s".to_string()));
631 }
632 }
633 }
634
635 #[test]
636 fn test_flatten_and() {
637 let expr = parse("SELECT 1 WHERE a = 1 AND b = 2 AND c = 3");
638 if let Expression::Select(select) = &expr {
639 if let Some(where_clause) = &select.where_clause {
640 let predicates = flatten_and(&where_clause.this);
641 assert_eq!(predicates.len(), 3);
642 }
643 }
644 }
645
646 #[test]
647 fn test_flatten_or() {
648 let expr = parse("SELECT 1 WHERE a = 1 OR b = 2 OR c = 3");
649 if let Expression::Select(select) = &expr {
650 if let Some(where_clause) = &select.where_clause {
651 let predicates = flatten_or(&where_clause.this);
652 assert_eq!(predicates.len(), 3);
653 }
654 }
655 }
656
657 #[test]
658 fn test_replace_aliases() {
659 let source = parse("SELECT x.a AS col_a FROM x");
660 let predicate = parse("SELECT 1 WHERE col_a = 1");
661
662 if let Expression::Select(select) = &predicate {
663 if let Some(where_clause) = &select.where_clause {
664 let replaced = replace_aliases(&source, where_clause.this.clone());
665 let sql = gen(&replaced);
667 assert!(sql.contains("="));
668 }
669 }
670 }
671
672 #[test]
673 fn test_pushdown_with_join() {
674 let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id WHERE t.a = 1");
675 let result = pushdown_predicates(expr, None);
676 let sql = gen(&result);
677 assert!(sql.contains("JOIN"));
678 }
679
680 #[test]
681 fn test_pushdown_complex_and() {
682 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2 AND c < 3");
683 let result = pushdown_predicates(expr, None);
684 let sql = gen(&result);
685 assert!(sql.contains("AND"));
686 }
687
688 #[test]
689 fn test_pushdown_complex_or() {
690 let expr = parse("SELECT 1 WHERE a = 1 OR b = 2");
691 let result = pushdown_predicates(expr, None);
692 let sql = gen(&result);
693 assert!(sql.contains("OR"));
694 }
695
696 #[test]
697 fn test_normalized_dnf_simple() {
698 let expr = parse("SELECT 1 WHERE a = 1");
700 if let Expression::Select(select) = &expr {
701 if let Some(where_clause) = &select.where_clause {
702 assert!(normalized(&where_clause.this, true));
704 }
705 }
706 }
707
708 #[test]
709 fn test_make_true() {
710 let t = make_true();
711 let sql = gen(&t);
712 assert_eq!(sql, "TRUE");
713 }
714}