1use std::collections::{HashMap, HashSet};
12
13use crate::dialects::DialectType;
14use crate::expressions::{BooleanLiteral, Expression};
15use crate::optimizer::normalize::normalized;
16use crate::optimizer::simplify::simplify;
17use crate::scope::{build_scope, Scope, SourceInfo};
18
19pub fn pushdown_predicates(expression: Expression, dialect: Option<DialectType>) -> Expression {
37 let root = build_scope(&expression);
38 let scope_ref_count = compute_ref_count(&root);
39
40 let unnest_requires_cross_join = matches!(
42 dialect,
43 Some(DialectType::Presto) | Some(DialectType::Trino) | Some(DialectType::Athena)
44 );
45
46 let mut result = expression.clone();
48 let scopes = collect_scopes(&root);
49
50 for scope in scopes.iter().rev() {
51 result = process_scope(
52 &result,
53 scope,
54 &scope_ref_count,
55 dialect,
56 unnest_requires_cross_join,
57 );
58 }
59
60 result
61}
62
63fn collect_scopes(root: &Scope) -> Vec<Scope> {
65 let mut result = vec![root.clone()];
66 for child in &root.subquery_scopes {
68 result.extend(collect_scopes(child));
69 }
70 for child in &root.derived_table_scopes {
72 result.extend(collect_scopes(child));
73 }
74 for child in &root.cte_scopes {
76 result.extend(collect_scopes(child));
77 }
78 for child in &root.union_scopes {
80 result.extend(collect_scopes(child));
81 }
82 result
83}
84
85fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
87 let mut counts = HashMap::new();
88 compute_ref_count_recursive(root, &mut counts);
89 counts
90}
91
92fn compute_ref_count_recursive(scope: &Scope, counts: &mut HashMap<u64, usize>) {
93 let id = scope as *const Scope as u64;
95 *counts.entry(id).or_insert(0) += 1;
96
97 for child in &scope.subquery_scopes {
98 compute_ref_count_recursive(child, counts);
99 }
100 for child in &scope.derived_table_scopes {
101 compute_ref_count_recursive(child, counts);
102 }
103 for child in &scope.cte_scopes {
104 compute_ref_count_recursive(child, counts);
105 }
106 for child in &scope.union_scopes {
107 compute_ref_count_recursive(child, counts);
108 }
109}
110
111fn process_scope(
113 expression: &Expression,
114 scope: &Scope,
115 _scope_ref_count: &HashMap<u64, usize>,
116 dialect: Option<DialectType>,
117 _unnest_requires_cross_join: bool,
118) -> Expression {
119 let result = expression.clone();
120
121 let (where_condition, join_conditions, join_index) = if let Expression::Select(select) = &result
123 {
124 let where_cond = select.where_clause.as_ref().map(|w| w.this.clone());
125
126 let mut idx: HashMap<String, usize> = HashMap::new();
127 for (i, join) in select.joins.iter().enumerate() {
128 if let Some(name) = get_table_alias_or_name(&join.this) {
129 idx.insert(name, i);
130 }
131 }
132
133 let join_conds: Vec<Expression> =
134 select.joins.iter().filter_map(|j| j.on.clone()).collect();
135
136 (where_cond, join_conds, idx)
137 } else {
138 (None, vec![], HashMap::new())
139 };
140
141 let mut result = result;
142
143 if let Some(where_cond) = where_condition {
145 let simplified = simplify(where_cond, dialect);
146 result = pushdown_impl(
147 result,
148 &simplified,
149 &scope.sources,
150 dialect,
151 Some(&join_index),
152 );
153 }
154
155 for join_cond in join_conditions {
157 let simplified = simplify(join_cond, dialect);
158 result = pushdown_impl(result, &simplified, &scope.sources, dialect, None);
159 }
160
161 result
162}
163
164fn pushdown_impl(
166 expression: Expression,
167 condition: &Expression,
168 sources: &HashMap<String, SourceInfo>,
169 _dialect: Option<DialectType>,
170 join_index: Option<&HashMap<String, usize>>,
171) -> Expression {
172 let is_cnf = normalized(condition, false); let is_dnf = normalized(condition, true); let cnf_like = is_cnf || !is_dnf;
176
177 let predicates = flatten_predicates(condition, cnf_like);
179
180 if cnf_like {
181 pushdown_cnf(expression, &predicates, sources, join_index)
182 } else {
183 pushdown_dnf(expression, &predicates, sources)
184 }
185}
186
187fn flatten_predicates(expr: &Expression, cnf_like: bool) -> Vec<Expression> {
189 if cnf_like {
190 flatten_and(expr)
192 } else {
193 flatten_or(expr)
195 }
196}
197
198fn flatten_and(expr: &Expression) -> Vec<Expression> {
199 match expr {
200 Expression::And(bin) => {
201 let mut result = flatten_and(&bin.left);
202 result.extend(flatten_and(&bin.right));
203 result
204 }
205 Expression::Paren(p) => flatten_and(&p.this),
206 other => vec![other.clone()],
207 }
208}
209
210fn flatten_or(expr: &Expression) -> Vec<Expression> {
211 match expr {
212 Expression::Or(bin) => {
213 let mut result = flatten_or(&bin.left);
214 result.extend(flatten_or(&bin.right));
215 result
216 }
217 Expression::Paren(p) => flatten_or(&p.this),
218 other => vec![other.clone()],
219 }
220}
221
222fn pushdown_cnf(
224 expression: Expression,
225 predicates: &[Expression],
226 sources: &HashMap<String, SourceInfo>,
227 join_index: Option<&HashMap<String, usize>>,
228) -> Expression {
229 let mut result = expression;
230
231 for predicate in predicates {
232 let nodes = nodes_for_predicate(predicate, sources);
233
234 for (table_name, node_expr) in nodes {
235 if let Some(join_idx) = join_index {
237 if let Some(&this_index) = join_idx.get(&table_name) {
238 let predicate_tables = get_column_table_names(predicate);
239
240 let can_push = predicate_tables
242 .iter()
243 .all(|t| join_idx.get(t).map_or(true, |&idx| idx <= this_index));
244
245 if can_push {
246 result = push_predicate_to_node(&result, predicate, &node_expr);
247 }
248 }
249 } else {
250 result = push_predicate_to_node(&result, predicate, &node_expr);
251 }
252 }
253 }
254
255 result
256}
257
258fn pushdown_dnf(
260 expression: Expression,
261 predicates: &[Expression],
262 sources: &HashMap<String, SourceInfo>,
263) -> Expression {
264 let mut pushdown_tables: HashSet<String> = HashSet::new();
267
268 for a in predicates {
269 let a_tables: HashSet<String> = get_column_table_names(a).into_iter().collect();
270
271 let common: HashSet<String> = predicates.iter().fold(a_tables, |acc, b| {
272 let b_tables: HashSet<String> = get_column_table_names(b).into_iter().collect();
273 acc.intersection(&b_tables).cloned().collect()
274 });
275
276 pushdown_tables.extend(common);
277 }
278
279 let mut result = expression;
280
281 let mut conditions: HashMap<String, Expression> = HashMap::new();
283
284 for table in &pushdown_tables {
285 for predicate in predicates {
286 let nodes = nodes_for_predicate(predicate, sources);
287
288 if nodes.contains_key(table) {
289 let existing = conditions.remove(table);
290 conditions.insert(
291 table.clone(),
292 if let Some(existing) = existing {
293 make_or(existing, predicate.clone())
294 } else {
295 predicate.clone()
296 },
297 );
298 }
299 }
300 }
301
302 for (table, condition) in conditions {
304 if let Some(source_info) = sources.get(&table) {
305 result = push_predicate_to_node(&result, &condition, &source_info.expression);
306 }
307 }
308
309 result
310}
311
312fn nodes_for_predicate(
314 predicate: &Expression,
315 sources: &HashMap<String, SourceInfo>,
316) -> HashMap<String, Expression> {
317 let mut nodes = HashMap::new();
318 let tables = get_column_table_names(predicate);
319
320 for table in tables {
321 if let Some(source_info) = sources.get(&table) {
322 nodes.insert(table, source_info.expression.clone());
329 }
330 }
331
332 nodes
333}
334
335fn push_predicate_to_node(
337 expression: &Expression,
338 _predicate: &Expression,
339 _target_node: &Expression,
340) -> Expression {
341 expression.clone()
348}
349
350fn get_column_table_names(expr: &Expression) -> Vec<String> {
352 let mut tables = Vec::new();
353 collect_column_tables(expr, &mut tables);
354 tables
355}
356
357fn collect_column_tables(expr: &Expression, tables: &mut Vec<String>) {
358 match expr {
359 Expression::Column(col) => {
360 if let Some(ref table) = col.table {
361 tables.push(table.name.clone());
362 }
363 }
364 Expression::And(bin) | Expression::Or(bin) => {
365 collect_column_tables(&bin.left, tables);
366 collect_column_tables(&bin.right, tables);
367 }
368 Expression::Eq(bin)
369 | Expression::Neq(bin)
370 | Expression::Lt(bin)
371 | Expression::Lte(bin)
372 | Expression::Gt(bin)
373 | Expression::Gte(bin) => {
374 collect_column_tables(&bin.left, tables);
375 collect_column_tables(&bin.right, tables);
376 }
377 Expression::Not(un) => {
378 collect_column_tables(&un.this, tables);
379 }
380 Expression::Paren(p) => {
381 collect_column_tables(&p.this, tables);
382 }
383 Expression::In(in_expr) => {
384 collect_column_tables(&in_expr.this, tables);
385 for e in &in_expr.expressions {
386 collect_column_tables(e, tables);
387 }
388 }
389 Expression::Between(between) => {
390 collect_column_tables(&between.this, tables);
391 collect_column_tables(&between.low, tables);
392 collect_column_tables(&between.high, tables);
393 }
394 Expression::IsNull(is_null) => {
395 collect_column_tables(&is_null.this, tables);
396 }
397 Expression::Like(like) => {
398 collect_column_tables(&like.left, tables);
399 collect_column_tables(&like.right, tables);
400 }
401 Expression::Function(func) => {
402 for arg in &func.args {
403 collect_column_tables(arg, tables);
404 }
405 }
406 Expression::AggregateFunction(agg) => {
407 for arg in &agg.args {
408 collect_column_tables(arg, tables);
409 }
410 }
411 _ => {}
412 }
413}
414
415fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
417 match expr {
418 Expression::Table(table) => {
419 if let Some(ref alias) = table.alias {
420 Some(alias.name.clone())
421 } else {
422 Some(table.name.name.clone())
423 }
424 }
425 Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
426 Expression::Alias(alias) => Some(alias.alias.name.clone()),
427 _ => None,
428 }
429}
430
431fn make_or(left: Expression, right: Expression) -> Expression {
433 Expression::Or(Box::new(crate::expressions::BinaryOp {
434 left,
435 right,
436 left_comments: vec![],
437 operator_comments: vec![],
438 trailing_comments: vec![],
439 }))
440}
441
442pub fn replace_aliases(source: &Expression, predicate: Expression) -> Expression {
444 let mut aliases: HashMap<String, Expression> = HashMap::new();
446
447 if let Expression::Select(select) = source {
448 for select_expr in &select.expressions {
449 match select_expr {
450 Expression::Alias(alias) => {
451 aliases.insert(alias.alias.name.clone(), alias.this.clone());
452 }
453 Expression::Column(col) => {
454 aliases.insert(col.name.name.clone(), select_expr.clone());
455 }
456 _ => {}
457 }
458 }
459 }
460
461 replace_aliases_recursive(predicate, &aliases)
463}
464
465fn replace_aliases_recursive(
466 expr: Expression,
467 aliases: &HashMap<String, Expression>,
468) -> Expression {
469 match expr {
470 Expression::Column(col) => {
471 if let Some(replacement) = aliases.get(&col.name.name) {
472 replacement.clone()
473 } else {
474 Expression::Column(col)
475 }
476 }
477 Expression::And(bin) => {
478 let left = replace_aliases_recursive(bin.left, aliases);
479 let right = replace_aliases_recursive(bin.right, aliases);
480 Expression::And(Box::new(crate::expressions::BinaryOp {
481 left,
482 right,
483 left_comments: bin.left_comments,
484 operator_comments: bin.operator_comments,
485 trailing_comments: bin.trailing_comments,
486 }))
487 }
488 Expression::Or(bin) => {
489 let left = replace_aliases_recursive(bin.left, aliases);
490 let right = replace_aliases_recursive(bin.right, aliases);
491 Expression::Or(Box::new(crate::expressions::BinaryOp {
492 left,
493 right,
494 left_comments: bin.left_comments,
495 operator_comments: bin.operator_comments,
496 trailing_comments: bin.trailing_comments,
497 }))
498 }
499 Expression::Eq(bin) => {
500 let left = replace_aliases_recursive(bin.left, aliases);
501 let right = replace_aliases_recursive(bin.right, aliases);
502 Expression::Eq(Box::new(crate::expressions::BinaryOp {
503 left,
504 right,
505 left_comments: bin.left_comments,
506 operator_comments: bin.operator_comments,
507 trailing_comments: bin.trailing_comments,
508 }))
509 }
510 Expression::Neq(bin) => {
511 let left = replace_aliases_recursive(bin.left, aliases);
512 let right = replace_aliases_recursive(bin.right, aliases);
513 Expression::Neq(Box::new(crate::expressions::BinaryOp {
514 left,
515 right,
516 left_comments: bin.left_comments,
517 operator_comments: bin.operator_comments,
518 trailing_comments: bin.trailing_comments,
519 }))
520 }
521 Expression::Lt(bin) => {
522 let left = replace_aliases_recursive(bin.left, aliases);
523 let right = replace_aliases_recursive(bin.right, aliases);
524 Expression::Lt(Box::new(crate::expressions::BinaryOp {
525 left,
526 right,
527 left_comments: bin.left_comments,
528 operator_comments: bin.operator_comments,
529 trailing_comments: bin.trailing_comments,
530 }))
531 }
532 Expression::Gt(bin) => {
533 let left = replace_aliases_recursive(bin.left, aliases);
534 let right = replace_aliases_recursive(bin.right, aliases);
535 Expression::Gt(Box::new(crate::expressions::BinaryOp {
536 left,
537 right,
538 left_comments: bin.left_comments,
539 operator_comments: bin.operator_comments,
540 trailing_comments: bin.trailing_comments,
541 }))
542 }
543 Expression::Lte(bin) => {
544 let left = replace_aliases_recursive(bin.left, aliases);
545 let right = replace_aliases_recursive(bin.right, aliases);
546 Expression::Lte(Box::new(crate::expressions::BinaryOp {
547 left,
548 right,
549 left_comments: bin.left_comments,
550 operator_comments: bin.operator_comments,
551 trailing_comments: bin.trailing_comments,
552 }))
553 }
554 Expression::Gte(bin) => {
555 let left = replace_aliases_recursive(bin.left, aliases);
556 let right = replace_aliases_recursive(bin.right, aliases);
557 Expression::Gte(Box::new(crate::expressions::BinaryOp {
558 left,
559 right,
560 left_comments: bin.left_comments,
561 operator_comments: bin.operator_comments,
562 trailing_comments: bin.trailing_comments,
563 }))
564 }
565 Expression::Not(un) => {
566 let inner = replace_aliases_recursive(un.this, aliases);
567 Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
568 }
569 Expression::Paren(paren) => {
570 let inner = replace_aliases_recursive(paren.this, aliases);
571 Expression::Paren(Box::new(crate::expressions::Paren {
572 this: inner,
573 trailing_comments: paren.trailing_comments,
574 }))
575 }
576 other => other,
577 }
578}
579
580pub fn make_true() -> Expression {
582 Expression::Boolean(BooleanLiteral { value: true })
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use crate::generator::Generator;
589 use crate::parser::Parser;
590
591 fn gen(expr: &Expression) -> String {
592 Generator::new().generate(expr).unwrap()
593 }
594
595 fn parse(sql: &str) -> Expression {
596 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
597 }
598
599 #[test]
600 fn test_pushdown_simple() {
601 let expr = parse("SELECT a FROM t WHERE a = 1");
602 let result = pushdown_predicates(expr, None);
603 let sql = gen(&result);
604 assert!(sql.contains("WHERE"));
605 }
606
607 #[test]
608 fn test_pushdown_preserves_structure() {
609 let expr = parse("SELECT y.a FROM (SELECT x.a FROM x) AS y WHERE y.a = 1");
610 let result = pushdown_predicates(expr, None);
611 let sql = gen(&result);
612 assert!(sql.contains("SELECT"));
613 }
614
615 #[test]
616 fn test_get_column_table_names() {
617 let expr = parse("SELECT 1 WHERE t.a = 1 AND s.b = 2");
618 if let Expression::Select(select) = &expr {
619 if let Some(where_clause) = &select.where_clause {
620 let tables = get_column_table_names(&where_clause.this);
621 assert!(tables.contains(&"t".to_string()));
622 assert!(tables.contains(&"s".to_string()));
623 }
624 }
625 }
626
627 #[test]
628 fn test_flatten_and() {
629 let expr = parse("SELECT 1 WHERE a = 1 AND b = 2 AND c = 3");
630 if let Expression::Select(select) = &expr {
631 if let Some(where_clause) = &select.where_clause {
632 let predicates = flatten_and(&where_clause.this);
633 assert_eq!(predicates.len(), 3);
634 }
635 }
636 }
637
638 #[test]
639 fn test_flatten_or() {
640 let expr = parse("SELECT 1 WHERE a = 1 OR b = 2 OR c = 3");
641 if let Expression::Select(select) = &expr {
642 if let Some(where_clause) = &select.where_clause {
643 let predicates = flatten_or(&where_clause.this);
644 assert_eq!(predicates.len(), 3);
645 }
646 }
647 }
648
649 #[test]
650 fn test_replace_aliases() {
651 let source = parse("SELECT x.a AS col_a FROM x");
652 let predicate = parse("SELECT 1 WHERE col_a = 1");
653
654 if let Expression::Select(select) = &predicate {
655 if let Some(where_clause) = &select.where_clause {
656 let replaced = replace_aliases(&source, where_clause.this.clone());
657 let sql = gen(&replaced);
659 assert!(sql.contains("="));
660 }
661 }
662 }
663
664 #[test]
665 fn test_pushdown_with_join() {
666 let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id WHERE t.a = 1");
667 let result = pushdown_predicates(expr, None);
668 let sql = gen(&result);
669 assert!(sql.contains("JOIN"));
670 }
671
672 #[test]
673 fn test_pushdown_complex_and() {
674 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2 AND c < 3");
675 let result = pushdown_predicates(expr, None);
676 let sql = gen(&result);
677 assert!(sql.contains("AND"));
678 }
679
680 #[test]
681 fn test_pushdown_complex_or() {
682 let expr = parse("SELECT 1 WHERE a = 1 OR b = 2");
683 let result = pushdown_predicates(expr, None);
684 let sql = gen(&result);
685 assert!(sql.contains("OR"));
686 }
687
688 #[test]
689 fn test_normalized_dnf_simple() {
690 let expr = parse("SELECT 1 WHERE a = 1");
692 if let Expression::Select(select) = &expr {
693 if let Some(where_clause) = &select.where_clause {
694 assert!(normalized(&where_clause.this, true));
696 }
697 }
698 }
699
700 #[test]
701 fn test_make_true() {
702 let t = make_true();
703 let sql = gen(&t);
704 assert_eq!(sql, "TRUE");
705 }
706}