Skip to main content

polyglot_sql/optimizer/
pushdown_predicates.rs

1//! Predicate Pushdown Module
2//!
3//! This module provides functionality for pushing WHERE predicates down
4//! into subqueries and JOINs for better query performance.
5//!
6//! When a predicate in the outer query only references columns from a subquery,
7//! it can be pushed down into that subquery's WHERE clause to filter data earlier.
8//!
9//! Ported from sqlglot's optimizer/pushdown_predicates.py
10
11use std::collections::{HashMap, HashSet};
12
13use crate::dialects::DialectType;
14use crate::expressions::{BooleanLiteral, Expression};
15use crate::optimizer::normalize::normalized;
16use crate::optimizer::simplify::simplify;
17use crate::scope::{build_scope, Scope, SourceInfo};
18
19/// Rewrite SQL AST to pushdown predicates in FROMs and JOINs.
20///
21/// # Example
22///
23/// ```sql
24/// -- Before:
25/// SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1
26/// -- After:
27/// SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE
28/// ```
29///
30/// # Arguments
31/// * `expression` - The expression to optimize
32/// * `dialect` - Optional dialect for dialect-specific behavior
33///
34/// # Returns
35/// The optimized expression with predicates pushed down
36pub fn pushdown_predicates(expression: Expression, dialect: Option<DialectType>) -> Expression {
37    let root = build_scope(&expression);
38    let scope_ref_count = compute_ref_count(&root);
39
40    // Check if dialect requires special handling for UNNEST
41    let unnest_requires_cross_join = matches!(
42        dialect,
43        Some(DialectType::Presto) | Some(DialectType::Trino) | Some(DialectType::Athena)
44    );
45
46    // Process scopes in reverse order (bottom-up)
47    let mut result = expression.clone();
48    let scopes = collect_scopes(&root);
49
50    for scope in scopes.iter().rev() {
51        result = process_scope(
52            &result,
53            scope,
54            &scope_ref_count,
55            dialect,
56            unnest_requires_cross_join,
57        );
58    }
59
60    result
61}
62
63/// Collect all scopes from the tree
64fn collect_scopes(root: &Scope) -> Vec<Scope> {
65    let mut result = vec![root.clone()];
66    // Collect from subquery scopes
67    for child in &root.subquery_scopes {
68        result.extend(collect_scopes(child));
69    }
70    // Collect from derived table scopes
71    for child in &root.derived_table_scopes {
72        result.extend(collect_scopes(child));
73    }
74    // Collect from CTE scopes
75    for child in &root.cte_scopes {
76        result.extend(collect_scopes(child));
77    }
78    // Collect from union scopes
79    for child in &root.union_scopes {
80        result.extend(collect_scopes(child));
81    }
82    result
83}
84
85/// Compute reference counts for each scope
86fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
87    let mut counts = HashMap::new();
88    compute_ref_count_recursive(root, &mut counts);
89    counts
90}
91
92fn compute_ref_count_recursive(scope: &Scope, counts: &mut HashMap<u64, usize>) {
93    // Use the pointer address as a pseudo-ID
94    let id = scope as *const Scope as u64;
95    *counts.entry(id).or_insert(0) += 1;
96
97    for child in &scope.subquery_scopes {
98        compute_ref_count_recursive(child, counts);
99    }
100    for child in &scope.derived_table_scopes {
101        compute_ref_count_recursive(child, counts);
102    }
103    for child in &scope.cte_scopes {
104        compute_ref_count_recursive(child, counts);
105    }
106    for child in &scope.union_scopes {
107        compute_ref_count_recursive(child, counts);
108    }
109}
110
111/// Process a single scope for predicate pushdown
112fn process_scope(
113    expression: &Expression,
114    scope: &Scope,
115    _scope_ref_count: &HashMap<u64, usize>,
116    dialect: Option<DialectType>,
117    _unnest_requires_cross_join: bool,
118) -> Expression {
119    let result = expression.clone();
120
121    // Extract data we need before processing
122    let (where_condition, join_conditions, join_index) = if let Expression::Select(select) = &result
123    {
124        let where_cond = select.where_clause.as_ref().map(|w| w.this.clone());
125
126        let mut idx: HashMap<String, usize> = HashMap::new();
127        for (i, join) in select.joins.iter().enumerate() {
128            if let Some(name) = get_table_alias_or_name(&join.this) {
129                idx.insert(name, i);
130            }
131        }
132
133        let join_conds: Vec<Expression> =
134            select.joins.iter().filter_map(|j| j.on.clone()).collect();
135
136        (where_cond, join_conds, idx)
137    } else {
138        (None, vec![], HashMap::new())
139    };
140
141    let mut result = result;
142
143    // Process WHERE clause
144    if let Some(where_cond) = where_condition {
145        let simplified = simplify(where_cond, dialect);
146        result = pushdown_impl(
147            result,
148            &simplified,
149            &scope.sources,
150            dialect,
151            Some(&join_index),
152        );
153    }
154
155    // Process JOIN ON conditions
156    for join_cond in join_conditions {
157        let simplified = simplify(join_cond, dialect);
158        result = pushdown_impl(result, &simplified, &scope.sources, dialect, None);
159    }
160
161    result
162}
163
164/// Push down a condition into sources
165fn pushdown_impl(
166    expression: Expression,
167    condition: &Expression,
168    sources: &HashMap<String, SourceInfo>,
169    _dialect: Option<DialectType>,
170    join_index: Option<&HashMap<String, usize>>,
171) -> Expression {
172    // Check if condition is in CNF or DNF form
173    let is_cnf = normalized(condition, false); // CNF check
174    let is_dnf = normalized(condition, true); // DNF check
175    let cnf_like = is_cnf || !is_dnf;
176
177    // Flatten the condition into predicates
178    let predicates = flatten_predicates(condition, cnf_like);
179
180    if cnf_like {
181        pushdown_cnf(expression, &predicates, sources, join_index)
182    } else {
183        pushdown_dnf(expression, &predicates, sources)
184    }
185}
186
187/// Flatten predicates from AND/OR expressions
188fn flatten_predicates(expr: &Expression, cnf_like: bool) -> Vec<Expression> {
189    if cnf_like {
190        // For CNF, flatten AND
191        flatten_and(expr)
192    } else {
193        // For DNF, flatten OR
194        flatten_or(expr)
195    }
196}
197
198fn flatten_and(expr: &Expression) -> Vec<Expression> {
199    match expr {
200        Expression::And(bin) => {
201            let mut result = flatten_and(&bin.left);
202            result.extend(flatten_and(&bin.right));
203            result
204        }
205        Expression::Paren(p) => flatten_and(&p.this),
206        other => vec![other.clone()],
207    }
208}
209
210fn flatten_or(expr: &Expression) -> Vec<Expression> {
211    match expr {
212        Expression::Or(bin) => {
213            let mut result = flatten_or(&bin.left);
214            result.extend(flatten_or(&bin.right));
215            result
216        }
217        Expression::Paren(p) => flatten_or(&p.this),
218        other => vec![other.clone()],
219    }
220}
221
222/// Pushdown predicates in CNF form
223fn pushdown_cnf(
224    expression: Expression,
225    predicates: &[Expression],
226    sources: &HashMap<String, SourceInfo>,
227    join_index: Option<&HashMap<String, usize>>,
228) -> Expression {
229    let mut result = expression;
230
231    for predicate in predicates {
232        let nodes = nodes_for_predicate(predicate, sources);
233
234        for (table_name, node_expr) in nodes {
235            // Check if this is a JOIN node
236            if let Some(join_idx) = join_index {
237                if let Some(&this_index) = join_idx.get(&table_name) {
238                    let predicate_tables = get_column_table_names(predicate);
239
240                    // Don't push if predicate references tables from later joins
241                    let can_push = predicate_tables
242                        .iter()
243                        .all(|t| join_idx.get(t).map_or(true, |&idx| idx <= this_index));
244
245                    if can_push {
246                        result = push_predicate_to_node(&result, predicate, &node_expr);
247                    }
248                }
249            } else {
250                result = push_predicate_to_node(&result, predicate, &node_expr);
251            }
252        }
253    }
254
255    result
256}
257
258/// Pushdown predicates in DNF form
259fn pushdown_dnf(
260    expression: Expression,
261    predicates: &[Expression],
262    sources: &HashMap<String, SourceInfo>,
263) -> Expression {
264    // Find tables that can be pushed down to
265    // These are tables referenced in ALL blocks of the DNF
266    let mut pushdown_tables: HashSet<String> = HashSet::new();
267
268    for a in predicates {
269        let a_tables: HashSet<String> = get_column_table_names(a).into_iter().collect();
270
271        let common: HashSet<String> = predicates.iter().fold(a_tables, |acc, b| {
272            let b_tables: HashSet<String> = get_column_table_names(b).into_iter().collect();
273            acc.intersection(&b_tables).cloned().collect()
274        });
275
276        pushdown_tables.extend(common);
277    }
278
279    let mut result = expression;
280
281    // Build conditions for each table
282    let mut conditions: HashMap<String, Expression> = HashMap::new();
283
284    for table in &pushdown_tables {
285        for predicate in predicates {
286            let nodes = nodes_for_predicate(predicate, sources);
287
288            if nodes.contains_key(table) {
289                let existing = conditions.remove(table);
290                conditions.insert(
291                    table.clone(),
292                    if let Some(existing) = existing {
293                        make_or(existing, predicate.clone())
294                    } else {
295                        predicate.clone()
296                    },
297                );
298            }
299        }
300    }
301
302    // Push conditions to nodes
303    for (table, condition) in conditions {
304        if let Some(source_info) = sources.get(&table) {
305            result = push_predicate_to_node(&result, &condition, &source_info.expression);
306        }
307    }
308
309    result
310}
311
312/// Get nodes that a predicate can be pushed down to
313fn nodes_for_predicate(
314    predicate: &Expression,
315    sources: &HashMap<String, SourceInfo>,
316) -> HashMap<String, Expression> {
317    let mut nodes = HashMap::new();
318    let tables = get_column_table_names(predicate);
319
320    for table in tables {
321        if let Some(source_info) = sources.get(&table) {
322            // For now, add the node if it's a valid pushdown target
323            // In a full implementation, we'd check for:
324            // - RIGHT joins (can only push to itself)
325            // - GROUP BY (push to HAVING instead)
326            // - Window functions (can't push)
327            // - Multiple references (can't push)
328            nodes.insert(table, source_info.expression.clone());
329        }
330    }
331
332    nodes
333}
334
335/// Push a predicate to a node (JOIN or subquery)
336fn push_predicate_to_node(
337    expression: &Expression,
338    _predicate: &Expression,
339    _target_node: &Expression,
340) -> Expression {
341    // In a full implementation, this would:
342    // 1. Find the target node in the expression tree
343    // 2. Add the predicate to its WHERE/ON clause
344    // 3. Replace the original predicate with TRUE
345
346    // For now, return unchanged - the structure is complex
347    expression.clone()
348}
349
350/// Extract table names from column references in an expression
351fn get_column_table_names(expr: &Expression) -> Vec<String> {
352    let mut tables = Vec::new();
353    collect_column_tables(expr, &mut tables);
354    tables
355}
356
357fn collect_column_tables(expr: &Expression, tables: &mut Vec<String>) {
358    match expr {
359        Expression::Column(col) => {
360            if let Some(ref table) = col.table {
361                tables.push(table.name.clone());
362            }
363        }
364        Expression::And(bin) | Expression::Or(bin) => {
365            collect_column_tables(&bin.left, tables);
366            collect_column_tables(&bin.right, tables);
367        }
368        Expression::Eq(bin)
369        | Expression::Neq(bin)
370        | Expression::Lt(bin)
371        | Expression::Lte(bin)
372        | Expression::Gt(bin)
373        | Expression::Gte(bin) => {
374            collect_column_tables(&bin.left, tables);
375            collect_column_tables(&bin.right, tables);
376        }
377        Expression::Not(un) => {
378            collect_column_tables(&un.this, tables);
379        }
380        Expression::Paren(p) => {
381            collect_column_tables(&p.this, tables);
382        }
383        Expression::In(in_expr) => {
384            collect_column_tables(&in_expr.this, tables);
385            for e in &in_expr.expressions {
386                collect_column_tables(e, tables);
387            }
388        }
389        Expression::Between(between) => {
390            collect_column_tables(&between.this, tables);
391            collect_column_tables(&between.low, tables);
392            collect_column_tables(&between.high, tables);
393        }
394        Expression::IsNull(is_null) => {
395            collect_column_tables(&is_null.this, tables);
396        }
397        Expression::Like(like) => {
398            collect_column_tables(&like.left, tables);
399            collect_column_tables(&like.right, tables);
400        }
401        Expression::Function(func) => {
402            for arg in &func.args {
403                collect_column_tables(arg, tables);
404            }
405        }
406        Expression::AggregateFunction(agg) => {
407            for arg in &agg.args {
408                collect_column_tables(arg, tables);
409            }
410        }
411        _ => {}
412    }
413}
414
415/// Get table name or alias from an expression
416fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
417    match expr {
418        Expression::Table(table) => {
419            if let Some(ref alias) = table.alias {
420                Some(alias.name.clone())
421            } else {
422                Some(table.name.name.clone())
423            }
424        }
425        Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
426        Expression::Alias(alias) => Some(alias.alias.name.clone()),
427        _ => None,
428    }
429}
430
431/// Create an OR expression from two expressions
432fn make_or(left: Expression, right: Expression) -> Expression {
433    Expression::Or(Box::new(crate::expressions::BinaryOp {
434        left,
435        right,
436        left_comments: vec![],
437        operator_comments: vec![],
438        trailing_comments: vec![],
439    }))
440}
441
442/// Replace aliases in a predicate with the original expressions
443pub fn replace_aliases(source: &Expression, predicate: Expression) -> Expression {
444    // Build alias map from source SELECT expressions
445    let mut aliases: HashMap<String, Expression> = HashMap::new();
446
447    if let Expression::Select(select) = source {
448        for select_expr in &select.expressions {
449            match select_expr {
450                Expression::Alias(alias) => {
451                    aliases.insert(alias.alias.name.clone(), alias.this.clone());
452                }
453                Expression::Column(col) => {
454                    aliases.insert(col.name.name.clone(), select_expr.clone());
455                }
456                _ => {}
457            }
458        }
459    }
460
461    // Transform predicate, replacing column references with aliases
462    replace_aliases_recursive(predicate, &aliases)
463}
464
465fn replace_aliases_recursive(
466    expr: Expression,
467    aliases: &HashMap<String, Expression>,
468) -> Expression {
469    match expr {
470        Expression::Column(col) => {
471            if let Some(replacement) = aliases.get(&col.name.name) {
472                replacement.clone()
473            } else {
474                Expression::Column(col)
475            }
476        }
477        Expression::And(bin) => {
478            let left = replace_aliases_recursive(bin.left, aliases);
479            let right = replace_aliases_recursive(bin.right, aliases);
480            Expression::And(Box::new(crate::expressions::BinaryOp {
481                left,
482                right,
483                left_comments: bin.left_comments,
484                operator_comments: bin.operator_comments,
485                trailing_comments: bin.trailing_comments,
486            }))
487        }
488        Expression::Or(bin) => {
489            let left = replace_aliases_recursive(bin.left, aliases);
490            let right = replace_aliases_recursive(bin.right, aliases);
491            Expression::Or(Box::new(crate::expressions::BinaryOp {
492                left,
493                right,
494                left_comments: bin.left_comments,
495                operator_comments: bin.operator_comments,
496                trailing_comments: bin.trailing_comments,
497            }))
498        }
499        Expression::Eq(bin) => {
500            let left = replace_aliases_recursive(bin.left, aliases);
501            let right = replace_aliases_recursive(bin.right, aliases);
502            Expression::Eq(Box::new(crate::expressions::BinaryOp {
503                left,
504                right,
505                left_comments: bin.left_comments,
506                operator_comments: bin.operator_comments,
507                trailing_comments: bin.trailing_comments,
508            }))
509        }
510        Expression::Neq(bin) => {
511            let left = replace_aliases_recursive(bin.left, aliases);
512            let right = replace_aliases_recursive(bin.right, aliases);
513            Expression::Neq(Box::new(crate::expressions::BinaryOp {
514                left,
515                right,
516                left_comments: bin.left_comments,
517                operator_comments: bin.operator_comments,
518                trailing_comments: bin.trailing_comments,
519            }))
520        }
521        Expression::Lt(bin) => {
522            let left = replace_aliases_recursive(bin.left, aliases);
523            let right = replace_aliases_recursive(bin.right, aliases);
524            Expression::Lt(Box::new(crate::expressions::BinaryOp {
525                left,
526                right,
527                left_comments: bin.left_comments,
528                operator_comments: bin.operator_comments,
529                trailing_comments: bin.trailing_comments,
530            }))
531        }
532        Expression::Gt(bin) => {
533            let left = replace_aliases_recursive(bin.left, aliases);
534            let right = replace_aliases_recursive(bin.right, aliases);
535            Expression::Gt(Box::new(crate::expressions::BinaryOp {
536                left,
537                right,
538                left_comments: bin.left_comments,
539                operator_comments: bin.operator_comments,
540                trailing_comments: bin.trailing_comments,
541            }))
542        }
543        Expression::Lte(bin) => {
544            let left = replace_aliases_recursive(bin.left, aliases);
545            let right = replace_aliases_recursive(bin.right, aliases);
546            Expression::Lte(Box::new(crate::expressions::BinaryOp {
547                left,
548                right,
549                left_comments: bin.left_comments,
550                operator_comments: bin.operator_comments,
551                trailing_comments: bin.trailing_comments,
552            }))
553        }
554        Expression::Gte(bin) => {
555            let left = replace_aliases_recursive(bin.left, aliases);
556            let right = replace_aliases_recursive(bin.right, aliases);
557            Expression::Gte(Box::new(crate::expressions::BinaryOp {
558                left,
559                right,
560                left_comments: bin.left_comments,
561                operator_comments: bin.operator_comments,
562                trailing_comments: bin.trailing_comments,
563            }))
564        }
565        Expression::Not(un) => {
566            let inner = replace_aliases_recursive(un.this, aliases);
567            Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
568        }
569        Expression::Paren(paren) => {
570            let inner = replace_aliases_recursive(paren.this, aliases);
571            Expression::Paren(Box::new(crate::expressions::Paren {
572                this: inner,
573                trailing_comments: paren.trailing_comments,
574            }))
575        }
576        other => other,
577    }
578}
579
580/// Create a TRUE literal expression
581pub fn make_true() -> Expression {
582    Expression::Boolean(BooleanLiteral { value: true })
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use crate::generator::Generator;
589    use crate::parser::Parser;
590
591    fn gen(expr: &Expression) -> String {
592        Generator::new().generate(expr).unwrap()
593    }
594
595    fn parse(sql: &str) -> Expression {
596        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
597    }
598
599    #[test]
600    fn test_pushdown_simple() {
601        let expr = parse("SELECT a FROM t WHERE a = 1");
602        let result = pushdown_predicates(expr, None);
603        let sql = gen(&result);
604        assert!(sql.contains("WHERE"));
605    }
606
607    #[test]
608    fn test_pushdown_preserves_structure() {
609        let expr = parse("SELECT y.a FROM (SELECT x.a FROM x) AS y WHERE y.a = 1");
610        let result = pushdown_predicates(expr, None);
611        let sql = gen(&result);
612        assert!(sql.contains("SELECT"));
613    }
614
615    #[test]
616    fn test_get_column_table_names() {
617        let expr = parse("SELECT 1 WHERE t.a = 1 AND s.b = 2");
618        if let Expression::Select(select) = &expr {
619            if let Some(where_clause) = &select.where_clause {
620                let tables = get_column_table_names(&where_clause.this);
621                assert!(tables.contains(&"t".to_string()));
622                assert!(tables.contains(&"s".to_string()));
623            }
624        }
625    }
626
627    #[test]
628    fn test_flatten_and() {
629        let expr = parse("SELECT 1 WHERE a = 1 AND b = 2 AND c = 3");
630        if let Expression::Select(select) = &expr {
631            if let Some(where_clause) = &select.where_clause {
632                let predicates = flatten_and(&where_clause.this);
633                assert_eq!(predicates.len(), 3);
634            }
635        }
636    }
637
638    #[test]
639    fn test_flatten_or() {
640        let expr = parse("SELECT 1 WHERE a = 1 OR b = 2 OR c = 3");
641        if let Expression::Select(select) = &expr {
642            if let Some(where_clause) = &select.where_clause {
643                let predicates = flatten_or(&where_clause.this);
644                assert_eq!(predicates.len(), 3);
645            }
646        }
647    }
648
649    #[test]
650    fn test_replace_aliases() {
651        let source = parse("SELECT x.a AS col_a FROM x");
652        let predicate = parse("SELECT 1 WHERE col_a = 1");
653
654        if let Expression::Select(select) = &predicate {
655            if let Some(where_clause) = &select.where_clause {
656                let replaced = replace_aliases(&source, where_clause.this.clone());
657                // The alias should be replaced
658                let sql = gen(&replaced);
659                assert!(sql.contains("="));
660            }
661        }
662    }
663
664    #[test]
665    fn test_pushdown_with_join() {
666        let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id WHERE t.a = 1");
667        let result = pushdown_predicates(expr, None);
668        let sql = gen(&result);
669        assert!(sql.contains("JOIN"));
670    }
671
672    #[test]
673    fn test_pushdown_complex_and() {
674        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2 AND c < 3");
675        let result = pushdown_predicates(expr, None);
676        let sql = gen(&result);
677        assert!(sql.contains("AND"));
678    }
679
680    #[test]
681    fn test_pushdown_complex_or() {
682        let expr = parse("SELECT 1 WHERE a = 1 OR b = 2");
683        let result = pushdown_predicates(expr, None);
684        let sql = gen(&result);
685        assert!(sql.contains("OR"));
686    }
687
688    #[test]
689    fn test_normalized_dnf_simple() {
690        // a = 1 is in both CNF and DNF form
691        let expr = parse("SELECT 1 WHERE a = 1");
692        if let Expression::Select(select) = &expr {
693            if let Some(where_clause) = &select.where_clause {
694                // Check DNF: pass true for dnf flag
695                assert!(normalized(&where_clause.this, true));
696            }
697        }
698    }
699
700    #[test]
701    fn test_make_true() {
702        let t = make_true();
703        let sql = gen(&t);
704        assert_eq!(sql, "TRUE");
705    }
706}