Skip to main content

polyglot_sql/optimizer/
pushdown_predicates.rs

1//! Predicate Pushdown Module
2//!
3//! This module provides functionality for pushing WHERE predicates down
4//! into subqueries and JOINs for better query performance.
5//!
6//! When a predicate in the outer query only references columns from a subquery,
7//! it can be pushed down into that subquery's WHERE clause to filter data earlier.
8//!
9//! Ported from sqlglot's optimizer/pushdown_predicates.py
10
11use std::collections::{HashMap, HashSet};
12
13use crate::dialects::DialectType;
14use crate::expressions::{BooleanLiteral, Expression};
15use crate::optimizer::normalize::normalized;
16use crate::optimizer::simplify::simplify;
17use crate::scope::{build_scope, Scope, SourceInfo};
18
19/// Rewrite SQL AST to pushdown predicates in FROMs and JOINs.
20///
21/// # Example
22///
23/// ```sql
24/// -- Before:
25/// SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1
26/// -- After:
27/// SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE
28/// ```
29///
30/// # Arguments
31/// * `expression` - The expression to optimize
32/// * `dialect` - Optional dialect for dialect-specific behavior
33///
34/// # Returns
35/// The optimized expression with predicates pushed down
36pub fn pushdown_predicates(expression: Expression, dialect: Option<DialectType>) -> Expression {
37    let root = build_scope(&expression);
38    let scope_ref_count = compute_ref_count(&root);
39
40    // Check if dialect requires special handling for UNNEST
41    let unnest_requires_cross_join = matches!(
42        dialect,
43        Some(DialectType::Presto) | Some(DialectType::Trino) | Some(DialectType::Athena)
44    );
45
46    // Process scopes in reverse order (bottom-up)
47    let mut result = expression.clone();
48    let scopes = collect_scopes(&root);
49
50    for scope in scopes.iter().rev() {
51        result = process_scope(
52            &result,
53            scope,
54            &scope_ref_count,
55            dialect,
56            unnest_requires_cross_join,
57        );
58    }
59
60    result
61}
62
63/// Collect all scopes from the tree
64fn collect_scopes(root: &Scope) -> Vec<Scope> {
65    let mut result = vec![root.clone()];
66    // Collect from subquery scopes
67    for child in &root.subquery_scopes {
68        result.extend(collect_scopes(child));
69    }
70    // Collect from derived table scopes
71    for child in &root.derived_table_scopes {
72        result.extend(collect_scopes(child));
73    }
74    // Collect from CTE scopes
75    for child in &root.cte_scopes {
76        result.extend(collect_scopes(child));
77    }
78    // Collect from union scopes
79    for child in &root.union_scopes {
80        result.extend(collect_scopes(child));
81    }
82    result
83}
84
85/// Compute reference counts for each scope
86fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
87    let mut counts = HashMap::new();
88    compute_ref_count_recursive(root, &mut counts);
89    counts
90}
91
92fn compute_ref_count_recursive(scope: &Scope, counts: &mut HashMap<u64, usize>) {
93    // Use the pointer address as a pseudo-ID
94    let id = scope as *const Scope as u64;
95    *counts.entry(id).or_insert(0) += 1;
96
97    for child in &scope.subquery_scopes {
98        compute_ref_count_recursive(child, counts);
99    }
100    for child in &scope.derived_table_scopes {
101        compute_ref_count_recursive(child, counts);
102    }
103    for child in &scope.cte_scopes {
104        compute_ref_count_recursive(child, counts);
105    }
106    for child in &scope.union_scopes {
107        compute_ref_count_recursive(child, counts);
108    }
109}
110
111/// Process a single scope for predicate pushdown
112fn process_scope(
113    expression: &Expression,
114    scope: &Scope,
115    _scope_ref_count: &HashMap<u64, usize>,
116    dialect: Option<DialectType>,
117    _unnest_requires_cross_join: bool,
118) -> Expression {
119    let result = expression.clone();
120
121    // Extract data we need before processing
122    let (where_condition, join_conditions, join_index) = if let Expression::Select(select) = &result
123    {
124        let where_cond = select.where_clause.as_ref().map(|w| w.this.clone());
125
126        let mut idx: HashMap<String, usize> = HashMap::new();
127        for (i, join) in select.joins.iter().enumerate() {
128            if let Some(name) = get_table_alias_or_name(&join.this) {
129                idx.insert(name, i);
130            }
131        }
132
133        let join_conds: Vec<Expression> =
134            select.joins.iter().filter_map(|j| j.on.clone()).collect();
135
136        (where_cond, join_conds, idx)
137    } else {
138        (None, vec![], HashMap::new())
139    };
140
141    let mut result = result;
142
143    // Process WHERE clause
144    if let Some(where_cond) = where_condition {
145        let simplified = simplify(where_cond, dialect);
146        result = pushdown_impl(
147            result,
148            &simplified,
149            &scope.sources,
150            dialect,
151            Some(&join_index),
152        );
153    }
154
155    // Process JOIN ON conditions
156    for join_cond in join_conditions {
157        let simplified = simplify(join_cond, dialect);
158        result = pushdown_impl(result, &simplified, &scope.sources, dialect, None);
159    }
160
161    result
162}
163
164/// Push down a condition into sources
165fn pushdown_impl(
166    expression: Expression,
167    condition: &Expression,
168    sources: &HashMap<String, SourceInfo>,
169    _dialect: Option<DialectType>,
170    join_index: Option<&HashMap<String, usize>>,
171) -> Expression {
172    // Check if condition is in CNF or DNF form
173    let is_cnf = normalized(condition, false); // CNF check
174    let is_dnf = normalized(condition, true); // DNF check
175    let cnf_like = is_cnf || !is_dnf;
176
177    // Flatten the condition into predicates
178    let predicates = flatten_predicates(condition, cnf_like);
179
180    if cnf_like {
181        pushdown_cnf(expression, &predicates, sources, join_index)
182    } else {
183        pushdown_dnf(expression, &predicates, sources)
184    }
185}
186
187/// Flatten predicates from AND/OR expressions
188fn flatten_predicates(expr: &Expression, cnf_like: bool) -> Vec<Expression> {
189    if cnf_like {
190        // For CNF, flatten AND
191        flatten_and(expr)
192    } else {
193        // For DNF, flatten OR
194        flatten_or(expr)
195    }
196}
197
198fn flatten_and(expr: &Expression) -> Vec<Expression> {
199    match expr {
200        Expression::And(bin) => {
201            let mut result = flatten_and(&bin.left);
202            result.extend(flatten_and(&bin.right));
203            result
204        }
205        Expression::Paren(p) => flatten_and(&p.this),
206        other => vec![other.clone()],
207    }
208}
209
210fn flatten_or(expr: &Expression) -> Vec<Expression> {
211    match expr {
212        Expression::Or(bin) => {
213            let mut result = flatten_or(&bin.left);
214            result.extend(flatten_or(&bin.right));
215            result
216        }
217        Expression::Paren(p) => flatten_or(&p.this),
218        other => vec![other.clone()],
219    }
220}
221
222/// Pushdown predicates in CNF form
223fn pushdown_cnf(
224    expression: Expression,
225    predicates: &[Expression],
226    sources: &HashMap<String, SourceInfo>,
227    join_index: Option<&HashMap<String, usize>>,
228) -> Expression {
229    let mut result = expression;
230
231    for predicate in predicates {
232        let nodes = nodes_for_predicate(predicate, sources);
233
234        for (table_name, node_expr) in nodes {
235            // Check if this is a JOIN node
236            if let Some(join_idx) = join_index {
237                if let Some(&this_index) = join_idx.get(&table_name) {
238                    let predicate_tables = get_column_table_names(predicate);
239
240                    // Don't push if predicate references tables from later joins
241                    let can_push = predicate_tables
242                        .iter()
243                        .all(|t| join_idx.get(t).map_or(true, |&idx| idx <= this_index));
244
245                    if can_push {
246                        result = push_predicate_to_node(&result, predicate, &node_expr);
247                    }
248                }
249            } else {
250                result = push_predicate_to_node(&result, predicate, &node_expr);
251            }
252        }
253    }
254
255    result
256}
257
258/// Pushdown predicates in DNF form
259fn pushdown_dnf(
260    expression: Expression,
261    predicates: &[Expression],
262    sources: &HashMap<String, SourceInfo>,
263) -> Expression {
264    // Find tables that can be pushed down to
265    // These are tables referenced in ALL blocks of the DNF
266    let mut pushdown_tables: HashSet<String> = HashSet::new();
267
268    for a in predicates {
269        let a_tables: HashSet<String> = get_column_table_names(a).into_iter().collect();
270
271        let common: HashSet<String> = predicates.iter().fold(a_tables, |acc, b| {
272            let b_tables: HashSet<String> = get_column_table_names(b).into_iter().collect();
273            acc.intersection(&b_tables).cloned().collect()
274        });
275
276        pushdown_tables.extend(common);
277    }
278
279    let mut result = expression;
280
281    // Build conditions for each table
282    let mut conditions: HashMap<String, Expression> = HashMap::new();
283
284    for table in &pushdown_tables {
285        for predicate in predicates {
286            let nodes = nodes_for_predicate(predicate, sources);
287
288            if nodes.contains_key(table) {
289                let existing = conditions.remove(table);
290                conditions.insert(
291                    table.clone(),
292                    if let Some(existing) = existing {
293                        make_or(existing, predicate.clone())
294                    } else {
295                        predicate.clone()
296                    },
297                );
298            }
299        }
300    }
301
302    // Push conditions to nodes
303    for (table, condition) in conditions {
304        if let Some(source_info) = sources.get(&table) {
305            result = push_predicate_to_node(&result, &condition, &source_info.expression);
306        }
307    }
308
309    result
310}
311
312/// Get nodes that a predicate can be pushed down to
313fn nodes_for_predicate(
314    predicate: &Expression,
315    sources: &HashMap<String, SourceInfo>,
316) -> HashMap<String, Expression> {
317    let mut nodes = HashMap::new();
318    let tables = get_column_table_names(predicate);
319
320    for table in tables {
321        if let Some(source_info) = sources.get(&table) {
322            // For now, add the node if it's a valid pushdown target
323            // In a full implementation, we'd check for:
324            // - RIGHT joins (can only push to itself)
325            // - GROUP BY (push to HAVING instead)
326            // - Window functions (can't push)
327            // - Multiple references (can't push)
328            nodes.insert(table, source_info.expression.clone());
329        }
330    }
331
332    nodes
333}
334
335/// Push a predicate to a node (JOIN or subquery)
336fn push_predicate_to_node(
337    expression: &Expression,
338    _predicate: &Expression,
339    _target_node: &Expression,
340) -> Expression {
341    // In a full implementation, this would:
342    // 1. Find the target node in the expression tree
343    // 2. Add the predicate to its WHERE/ON clause
344    // 3. Replace the original predicate with TRUE
345
346    // For now, return unchanged - the structure is complex
347    expression.clone()
348}
349
350/// Extract table names from column references in an expression
351fn get_column_table_names(expr: &Expression) -> Vec<String> {
352    let mut tables = Vec::new();
353    collect_column_tables(expr, &mut tables);
354    tables
355}
356
357fn collect_column_tables(expr: &Expression, tables: &mut Vec<String>) {
358    match expr {
359        Expression::Column(col) => {
360            if let Some(ref table) = col.table {
361                tables.push(table.name.clone());
362            }
363        }
364        Expression::And(bin) | Expression::Or(bin) => {
365            collect_column_tables(&bin.left, tables);
366            collect_column_tables(&bin.right, tables);
367        }
368        Expression::Eq(bin)
369        | Expression::Neq(bin)
370        | Expression::Lt(bin)
371        | Expression::Lte(bin)
372        | Expression::Gt(bin)
373        | Expression::Gte(bin) => {
374            collect_column_tables(&bin.left, tables);
375            collect_column_tables(&bin.right, tables);
376        }
377        Expression::Not(un) => {
378            collect_column_tables(&un.this, tables);
379        }
380        Expression::Paren(p) => {
381            collect_column_tables(&p.this, tables);
382        }
383        Expression::In(in_expr) => {
384            collect_column_tables(&in_expr.this, tables);
385            for e in &in_expr.expressions {
386                collect_column_tables(e, tables);
387            }
388        }
389        Expression::Between(between) => {
390            collect_column_tables(&between.this, tables);
391            collect_column_tables(&between.low, tables);
392            collect_column_tables(&between.high, tables);
393        }
394        Expression::IsNull(is_null) => {
395            collect_column_tables(&is_null.this, tables);
396        }
397        Expression::Like(like) => {
398            collect_column_tables(&like.left, tables);
399            collect_column_tables(&like.right, tables);
400        }
401        Expression::Function(func) => {
402            for arg in &func.args {
403                collect_column_tables(arg, tables);
404            }
405        }
406        Expression::AggregateFunction(agg) => {
407            for arg in &agg.args {
408                collect_column_tables(arg, tables);
409            }
410        }
411        _ => {}
412    }
413}
414
415/// Get table name or alias from an expression
416fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
417    match expr {
418        Expression::Table(table) => {
419            if let Some(ref alias) = table.alias {
420                Some(alias.name.clone())
421            } else {
422                Some(table.name.name.clone())
423            }
424        }
425        Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
426        Expression::Alias(alias) => Some(alias.alias.name.clone()),
427        _ => None,
428    }
429}
430
431/// Create an OR expression from two expressions
432fn make_or(left: Expression, right: Expression) -> Expression {
433    Expression::Or(Box::new(crate::expressions::BinaryOp {
434        left,
435        right,
436        left_comments: vec![],
437        operator_comments: vec![],
438        trailing_comments: vec![],
439        inferred_type: None,
440    }))
441}
442
443/// Replace aliases in a predicate with the original expressions
444pub fn replace_aliases(source: &Expression, predicate: Expression) -> Expression {
445    // Build alias map from source SELECT expressions
446    let mut aliases: HashMap<String, Expression> = HashMap::new();
447
448    if let Expression::Select(select) = source {
449        for select_expr in &select.expressions {
450            match select_expr {
451                Expression::Alias(alias) => {
452                    aliases.insert(alias.alias.name.clone(), alias.this.clone());
453                }
454                Expression::Column(col) => {
455                    aliases.insert(col.name.name.clone(), select_expr.clone());
456                }
457                _ => {}
458            }
459        }
460    }
461
462    // Transform predicate, replacing column references with aliases
463    replace_aliases_recursive(predicate, &aliases)
464}
465
466fn replace_aliases_recursive(
467    expr: Expression,
468    aliases: &HashMap<String, Expression>,
469) -> Expression {
470    match expr {
471        Expression::Column(col) => {
472            if let Some(replacement) = aliases.get(&col.name.name) {
473                replacement.clone()
474            } else {
475                Expression::Column(col)
476            }
477        }
478        Expression::And(bin) => {
479            let left = replace_aliases_recursive(bin.left, aliases);
480            let right = replace_aliases_recursive(bin.right, aliases);
481            Expression::And(Box::new(crate::expressions::BinaryOp {
482                left,
483                right,
484                left_comments: bin.left_comments,
485                operator_comments: bin.operator_comments,
486                trailing_comments: bin.trailing_comments,
487                inferred_type: None,
488            }))
489        }
490        Expression::Or(bin) => {
491            let left = replace_aliases_recursive(bin.left, aliases);
492            let right = replace_aliases_recursive(bin.right, aliases);
493            Expression::Or(Box::new(crate::expressions::BinaryOp {
494                left,
495                right,
496                left_comments: bin.left_comments,
497                operator_comments: bin.operator_comments,
498                trailing_comments: bin.trailing_comments,
499                inferred_type: None,
500            }))
501        }
502        Expression::Eq(bin) => {
503            let left = replace_aliases_recursive(bin.left, aliases);
504            let right = replace_aliases_recursive(bin.right, aliases);
505            Expression::Eq(Box::new(crate::expressions::BinaryOp {
506                left,
507                right,
508                left_comments: bin.left_comments,
509                operator_comments: bin.operator_comments,
510                trailing_comments: bin.trailing_comments,
511                inferred_type: None,
512            }))
513        }
514        Expression::Neq(bin) => {
515            let left = replace_aliases_recursive(bin.left, aliases);
516            let right = replace_aliases_recursive(bin.right, aliases);
517            Expression::Neq(Box::new(crate::expressions::BinaryOp {
518                left,
519                right,
520                left_comments: bin.left_comments,
521                operator_comments: bin.operator_comments,
522                trailing_comments: bin.trailing_comments,
523                inferred_type: None,
524            }))
525        }
526        Expression::Lt(bin) => {
527            let left = replace_aliases_recursive(bin.left, aliases);
528            let right = replace_aliases_recursive(bin.right, aliases);
529            Expression::Lt(Box::new(crate::expressions::BinaryOp {
530                left,
531                right,
532                left_comments: bin.left_comments,
533                operator_comments: bin.operator_comments,
534                trailing_comments: bin.trailing_comments,
535                inferred_type: None,
536            }))
537        }
538        Expression::Gt(bin) => {
539            let left = replace_aliases_recursive(bin.left, aliases);
540            let right = replace_aliases_recursive(bin.right, aliases);
541            Expression::Gt(Box::new(crate::expressions::BinaryOp {
542                left,
543                right,
544                left_comments: bin.left_comments,
545                operator_comments: bin.operator_comments,
546                trailing_comments: bin.trailing_comments,
547                inferred_type: None,
548            }))
549        }
550        Expression::Lte(bin) => {
551            let left = replace_aliases_recursive(bin.left, aliases);
552            let right = replace_aliases_recursive(bin.right, aliases);
553            Expression::Lte(Box::new(crate::expressions::BinaryOp {
554                left,
555                right,
556                left_comments: bin.left_comments,
557                operator_comments: bin.operator_comments,
558                trailing_comments: bin.trailing_comments,
559                inferred_type: None,
560            }))
561        }
562        Expression::Gte(bin) => {
563            let left = replace_aliases_recursive(bin.left, aliases);
564            let right = replace_aliases_recursive(bin.right, aliases);
565            Expression::Gte(Box::new(crate::expressions::BinaryOp {
566                left,
567                right,
568                left_comments: bin.left_comments,
569                operator_comments: bin.operator_comments,
570                trailing_comments: bin.trailing_comments,
571                inferred_type: None,
572            }))
573        }
574        Expression::Not(un) => {
575            let inner = replace_aliases_recursive(un.this, aliases);
576            Expression::Not(Box::new(crate::expressions::UnaryOp {
577                this: inner,
578                inferred_type: None,
579            }))
580        }
581        Expression::Paren(paren) => {
582            let inner = replace_aliases_recursive(paren.this, aliases);
583            Expression::Paren(Box::new(crate::expressions::Paren {
584                this: inner,
585                trailing_comments: paren.trailing_comments,
586            }))
587        }
588        other => other,
589    }
590}
591
592/// Create a TRUE literal expression
593pub fn make_true() -> Expression {
594    Expression::Boolean(BooleanLiteral { value: true })
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600    use crate::generator::Generator;
601    use crate::parser::Parser;
602
603    fn gen(expr: &Expression) -> String {
604        Generator::new().generate(expr).unwrap()
605    }
606
607    fn parse(sql: &str) -> Expression {
608        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
609    }
610
611    #[test]
612    fn test_pushdown_simple() {
613        let expr = parse("SELECT a FROM t WHERE a = 1");
614        let result = pushdown_predicates(expr, None);
615        let sql = gen(&result);
616        assert!(sql.contains("WHERE"));
617    }
618
619    #[test]
620    fn test_pushdown_preserves_structure() {
621        let expr = parse("SELECT y.a FROM (SELECT x.a FROM x) AS y WHERE y.a = 1");
622        let result = pushdown_predicates(expr, None);
623        let sql = gen(&result);
624        assert!(sql.contains("SELECT"));
625    }
626
627    #[test]
628    fn test_get_column_table_names() {
629        let expr = parse("SELECT 1 WHERE t.a = 1 AND s.b = 2");
630        if let Expression::Select(select) = &expr {
631            if let Some(where_clause) = &select.where_clause {
632                let tables = get_column_table_names(&where_clause.this);
633                assert!(tables.contains(&"t".to_string()));
634                assert!(tables.contains(&"s".to_string()));
635            }
636        }
637    }
638
639    #[test]
640    fn test_flatten_and() {
641        let expr = parse("SELECT 1 WHERE a = 1 AND b = 2 AND c = 3");
642        if let Expression::Select(select) = &expr {
643            if let Some(where_clause) = &select.where_clause {
644                let predicates = flatten_and(&where_clause.this);
645                assert_eq!(predicates.len(), 3);
646            }
647        }
648    }
649
650    #[test]
651    fn test_flatten_or() {
652        let expr = parse("SELECT 1 WHERE a = 1 OR b = 2 OR c = 3");
653        if let Expression::Select(select) = &expr {
654            if let Some(where_clause) = &select.where_clause {
655                let predicates = flatten_or(&where_clause.this);
656                assert_eq!(predicates.len(), 3);
657            }
658        }
659    }
660
661    #[test]
662    fn test_replace_aliases() {
663        let source = parse("SELECT x.a AS col_a FROM x");
664        let predicate = parse("SELECT 1 WHERE col_a = 1");
665
666        if let Expression::Select(select) = &predicate {
667            if let Some(where_clause) = &select.where_clause {
668                let replaced = replace_aliases(&source, where_clause.this.clone());
669                // The alias should be replaced
670                let sql = gen(&replaced);
671                assert!(sql.contains("="));
672            }
673        }
674    }
675
676    #[test]
677    fn test_pushdown_with_join() {
678        let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id WHERE t.a = 1");
679        let result = pushdown_predicates(expr, None);
680        let sql = gen(&result);
681        assert!(sql.contains("JOIN"));
682    }
683
684    #[test]
685    fn test_pushdown_complex_and() {
686        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2 AND c < 3");
687        let result = pushdown_predicates(expr, None);
688        let sql = gen(&result);
689        assert!(sql.contains("AND"));
690    }
691
692    #[test]
693    fn test_pushdown_complex_or() {
694        let expr = parse("SELECT 1 WHERE a = 1 OR b = 2");
695        let result = pushdown_predicates(expr, None);
696        let sql = gen(&result);
697        assert!(sql.contains("OR"));
698    }
699
700    #[test]
701    fn test_normalized_dnf_simple() {
702        // a = 1 is in both CNF and DNF form
703        let expr = parse("SELECT 1 WHERE a = 1");
704        if let Expression::Select(select) = &expr {
705            if let Some(where_clause) = &select.where_clause {
706                // Check DNF: pass true for dnf flag
707                assert!(normalized(&where_clause.this, true));
708            }
709        }
710    }
711
712    #[test]
713    fn test_make_true() {
714        let t = make_true();
715        let sql = gen(&t);
716        assert_eq!(sql, "TRUE");
717    }
718}