Skip to main content

polyglot_sql/optimizer/
pushdown_predicates.rs

1//! Predicate Pushdown Module
2//!
3//! This module provides functionality for pushing WHERE predicates down
4//! into subqueries and JOINs for better query performance.
5//!
6//! When a predicate in the outer query only references columns from a subquery,
7//! it can be pushed down into that subquery's WHERE clause to filter data earlier.
8//!
9//! Ported from sqlglot's optimizer/pushdown_predicates.py
10
11use std::collections::{HashMap, HashSet};
12
13use crate::dialects::DialectType;
14use crate::expressions::{BooleanLiteral, Expression};
15use crate::optimizer::normalize::normalized;
16use crate::optimizer::simplify::simplify;
17use crate::scope::{build_scope, Scope, SourceInfo};
18
19/// Rewrite SQL AST to pushdown predicates in FROMs and JOINs.
20///
21/// # Example
22///
23/// ```sql
24/// -- Before:
25/// SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1
26/// -- After:
27/// SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE
28/// ```
29///
30/// # Arguments
31/// * `expression` - The expression to optimize
32/// * `dialect` - Optional dialect for dialect-specific behavior
33///
34/// # Returns
35/// The optimized expression with predicates pushed down
36pub fn pushdown_predicates(
37    expression: Expression,
38    dialect: Option<DialectType>,
39) -> Expression {
40    let root = build_scope(&expression);
41    let scope_ref_count = compute_ref_count(&root);
42
43    // Check if dialect requires special handling for UNNEST
44    let unnest_requires_cross_join = matches!(
45        dialect,
46        Some(DialectType::Presto) | Some(DialectType::Trino) | Some(DialectType::Athena)
47    );
48
49    // Process scopes in reverse order (bottom-up)
50    let mut result = expression.clone();
51    let scopes = collect_scopes(&root);
52
53    for scope in scopes.iter().rev() {
54        result = process_scope(&result, scope, &scope_ref_count, dialect, unnest_requires_cross_join);
55    }
56
57    result
58}
59
60/// Collect all scopes from the tree
61fn collect_scopes(root: &Scope) -> Vec<Scope> {
62    let mut result = vec![root.clone()];
63    // Collect from subquery scopes
64    for child in &root.subquery_scopes {
65        result.extend(collect_scopes(child));
66    }
67    // Collect from derived table scopes
68    for child in &root.derived_table_scopes {
69        result.extend(collect_scopes(child));
70    }
71    // Collect from CTE scopes
72    for child in &root.cte_scopes {
73        result.extend(collect_scopes(child));
74    }
75    // Collect from union scopes
76    for child in &root.union_scopes {
77        result.extend(collect_scopes(child));
78    }
79    result
80}
81
82/// Compute reference counts for each scope
83fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
84    let mut counts = HashMap::new();
85    compute_ref_count_recursive(root, &mut counts);
86    counts
87}
88
89fn compute_ref_count_recursive(scope: &Scope, counts: &mut HashMap<u64, usize>) {
90    // Use the pointer address as a pseudo-ID
91    let id = scope as *const Scope as u64;
92    *counts.entry(id).or_insert(0) += 1;
93
94    for child in &scope.subquery_scopes {
95        compute_ref_count_recursive(child, counts);
96    }
97    for child in &scope.derived_table_scopes {
98        compute_ref_count_recursive(child, counts);
99    }
100    for child in &scope.cte_scopes {
101        compute_ref_count_recursive(child, counts);
102    }
103    for child in &scope.union_scopes {
104        compute_ref_count_recursive(child, counts);
105    }
106}
107
108/// Process a single scope for predicate pushdown
109fn process_scope(
110    expression: &Expression,
111    scope: &Scope,
112    _scope_ref_count: &HashMap<u64, usize>,
113    dialect: Option<DialectType>,
114    _unnest_requires_cross_join: bool,
115) -> Expression {
116    let result = expression.clone();
117
118    // Extract data we need before processing
119    let (where_condition, join_conditions, join_index) = if let Expression::Select(select) = &result {
120        let where_cond = select.where_clause.as_ref().map(|w| w.this.clone());
121
122        let mut idx: HashMap<String, usize> = HashMap::new();
123        for (i, join) in select.joins.iter().enumerate() {
124            if let Some(name) = get_table_alias_or_name(&join.this) {
125                idx.insert(name, i);
126            }
127        }
128
129        let join_conds: Vec<Expression> = select.joins.iter()
130            .filter_map(|j| j.on.clone())
131            .collect();
132
133        (where_cond, join_conds, idx)
134    } else {
135        (None, vec![], HashMap::new())
136    };
137
138    let mut result = result;
139
140    // Process WHERE clause
141    if let Some(where_cond) = where_condition {
142        let simplified = simplify(where_cond, dialect);
143        result = pushdown_impl(
144            result,
145            &simplified,
146            &scope.sources,
147            dialect,
148            Some(&join_index),
149        );
150    }
151
152    // Process JOIN ON conditions
153    for join_cond in join_conditions {
154        let simplified = simplify(join_cond, dialect);
155        result = pushdown_impl(
156            result,
157            &simplified,
158            &scope.sources,
159            dialect,
160            None,
161        );
162    }
163
164    result
165}
166
167/// Push down a condition into sources
168fn pushdown_impl(
169    expression: Expression,
170    condition: &Expression,
171    sources: &HashMap<String, SourceInfo>,
172    _dialect: Option<DialectType>,
173    join_index: Option<&HashMap<String, usize>>,
174) -> Expression {
175    // Check if condition is in CNF or DNF form
176    let is_cnf = normalized(condition, false); // CNF check
177    let is_dnf = normalized(condition, true);  // DNF check
178    let cnf_like = is_cnf || !is_dnf;
179
180    // Flatten the condition into predicates
181    let predicates = flatten_predicates(condition, cnf_like);
182
183    if cnf_like {
184        pushdown_cnf(expression, &predicates, sources, join_index)
185    } else {
186        pushdown_dnf(expression, &predicates, sources)
187    }
188}
189
190/// Flatten predicates from AND/OR expressions
191fn flatten_predicates(expr: &Expression, cnf_like: bool) -> Vec<Expression> {
192    if cnf_like {
193        // For CNF, flatten AND
194        flatten_and(expr)
195    } else {
196        // For DNF, flatten OR
197        flatten_or(expr)
198    }
199}
200
201fn flatten_and(expr: &Expression) -> Vec<Expression> {
202    match expr {
203        Expression::And(bin) => {
204            let mut result = flatten_and(&bin.left);
205            result.extend(flatten_and(&bin.right));
206            result
207        }
208        Expression::Paren(p) => flatten_and(&p.this),
209        other => vec![other.clone()],
210    }
211}
212
213fn flatten_or(expr: &Expression) -> Vec<Expression> {
214    match expr {
215        Expression::Or(bin) => {
216            let mut result = flatten_or(&bin.left);
217            result.extend(flatten_or(&bin.right));
218            result
219        }
220        Expression::Paren(p) => flatten_or(&p.this),
221        other => vec![other.clone()],
222    }
223}
224
225/// Pushdown predicates in CNF form
226fn pushdown_cnf(
227    expression: Expression,
228    predicates: &[Expression],
229    sources: &HashMap<String, SourceInfo>,
230    join_index: Option<&HashMap<String, usize>>,
231) -> Expression {
232    let mut result = expression;
233
234    for predicate in predicates {
235        let nodes = nodes_for_predicate(predicate, sources);
236
237        for (table_name, node_expr) in nodes {
238            // Check if this is a JOIN node
239            if let Some(join_idx) = join_index {
240                if let Some(&this_index) = join_idx.get(&table_name) {
241                    let predicate_tables = get_column_table_names(predicate);
242
243                    // Don't push if predicate references tables from later joins
244                    let can_push = predicate_tables.iter().all(|t| {
245                        join_idx.get(t).map_or(true, |&idx| idx <= this_index)
246                    });
247
248                    if can_push {
249                        result = push_predicate_to_node(&result, predicate, &node_expr);
250                    }
251                }
252            } else {
253                result = push_predicate_to_node(&result, predicate, &node_expr);
254            }
255        }
256    }
257
258    result
259}
260
261/// Pushdown predicates in DNF form
262fn pushdown_dnf(
263    expression: Expression,
264    predicates: &[Expression],
265    sources: &HashMap<String, SourceInfo>,
266) -> Expression {
267    // Find tables that can be pushed down to
268    // These are tables referenced in ALL blocks of the DNF
269    let mut pushdown_tables: HashSet<String> = HashSet::new();
270
271    for a in predicates {
272        let a_tables: HashSet<String> = get_column_table_names(a).into_iter().collect();
273
274        let common: HashSet<String> = predicates
275            .iter()
276            .fold(a_tables, |acc, b| {
277                let b_tables: HashSet<String> = get_column_table_names(b).into_iter().collect();
278                acc.intersection(&b_tables).cloned().collect()
279            });
280
281        pushdown_tables.extend(common);
282    }
283
284    let mut result = expression;
285
286    // Build conditions for each table
287    let mut conditions: HashMap<String, Expression> = HashMap::new();
288
289    for table in &pushdown_tables {
290        for predicate in predicates {
291            let nodes = nodes_for_predicate(predicate, sources);
292
293            if nodes.contains_key(table) {
294                let existing = conditions.remove(table);
295                conditions.insert(
296                    table.clone(),
297                    if let Some(existing) = existing {
298                        make_or(existing, predicate.clone())
299                    } else {
300                        predicate.clone()
301                    },
302                );
303            }
304        }
305    }
306
307    // Push conditions to nodes
308    for (table, condition) in conditions {
309        if let Some(source_info) = sources.get(&table) {
310            result = push_predicate_to_node(&result, &condition, &source_info.expression);
311        }
312    }
313
314    result
315}
316
317/// Get nodes that a predicate can be pushed down to
318fn nodes_for_predicate(
319    predicate: &Expression,
320    sources: &HashMap<String, SourceInfo>,
321) -> HashMap<String, Expression> {
322    let mut nodes = HashMap::new();
323    let tables = get_column_table_names(predicate);
324
325    for table in tables {
326        if let Some(source_info) = sources.get(&table) {
327            // For now, add the node if it's a valid pushdown target
328            // In a full implementation, we'd check for:
329            // - RIGHT joins (can only push to itself)
330            // - GROUP BY (push to HAVING instead)
331            // - Window functions (can't push)
332            // - Multiple references (can't push)
333            nodes.insert(table, source_info.expression.clone());
334        }
335    }
336
337    nodes
338}
339
340/// Push a predicate to a node (JOIN or subquery)
341fn push_predicate_to_node(
342    expression: &Expression,
343    _predicate: &Expression,
344    _target_node: &Expression,
345) -> Expression {
346    // In a full implementation, this would:
347    // 1. Find the target node in the expression tree
348    // 2. Add the predicate to its WHERE/ON clause
349    // 3. Replace the original predicate with TRUE
350
351    // For now, return unchanged - the structure is complex
352    expression.clone()
353}
354
355/// Extract table names from column references in an expression
356fn get_column_table_names(expr: &Expression) -> Vec<String> {
357    let mut tables = Vec::new();
358    collect_column_tables(expr, &mut tables);
359    tables
360}
361
362fn collect_column_tables(expr: &Expression, tables: &mut Vec<String>) {
363    match expr {
364        Expression::Column(col) => {
365            if let Some(ref table) = col.table {
366                tables.push(table.name.clone());
367            }
368        }
369        Expression::And(bin) | Expression::Or(bin) => {
370            collect_column_tables(&bin.left, tables);
371            collect_column_tables(&bin.right, tables);
372        }
373        Expression::Eq(bin) | Expression::Neq(bin) | Expression::Lt(bin) |
374        Expression::Lte(bin) | Expression::Gt(bin) | Expression::Gte(bin) => {
375            collect_column_tables(&bin.left, tables);
376            collect_column_tables(&bin.right, tables);
377        }
378        Expression::Not(un) => {
379            collect_column_tables(&un.this, tables);
380        }
381        Expression::Paren(p) => {
382            collect_column_tables(&p.this, tables);
383        }
384        Expression::In(in_expr) => {
385            collect_column_tables(&in_expr.this, tables);
386            for e in &in_expr.expressions {
387                collect_column_tables(e, tables);
388            }
389        }
390        Expression::Between(between) => {
391            collect_column_tables(&between.this, tables);
392            collect_column_tables(&between.low, tables);
393            collect_column_tables(&between.high, tables);
394        }
395        Expression::IsNull(is_null) => {
396            collect_column_tables(&is_null.this, tables);
397        }
398        Expression::Like(like) => {
399            collect_column_tables(&like.left, tables);
400            collect_column_tables(&like.right, tables);
401        }
402        Expression::Function(func) => {
403            for arg in &func.args {
404                collect_column_tables(arg, tables);
405            }
406        }
407        Expression::AggregateFunction(agg) => {
408            for arg in &agg.args {
409                collect_column_tables(arg, tables);
410            }
411        }
412        _ => {}
413    }
414}
415
416/// Get table name or alias from an expression
417fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
418    match expr {
419        Expression::Table(table) => {
420            if let Some(ref alias) = table.alias {
421                Some(alias.name.clone())
422            } else {
423                Some(table.name.name.clone())
424            }
425        }
426        Expression::Subquery(subquery) => {
427            subquery.alias.as_ref().map(|a| a.name.clone())
428        }
429        Expression::Alias(alias) => {
430            Some(alias.alias.name.clone())
431        }
432        _ => None,
433    }
434}
435
436/// Create an OR expression from two expressions
437fn make_or(left: Expression, right: Expression) -> Expression {
438    Expression::Or(Box::new(crate::expressions::BinaryOp {
439        left,
440        right,
441        left_comments: vec![],
442        operator_comments: vec![],
443        trailing_comments: vec![],
444    }))
445}
446
447/// Replace aliases in a predicate with the original expressions
448pub fn replace_aliases(
449    source: &Expression,
450    predicate: Expression,
451) -> Expression {
452    // Build alias map from source SELECT expressions
453    let mut aliases: HashMap<String, Expression> = HashMap::new();
454
455    if let Expression::Select(select) = source {
456        for select_expr in &select.expressions {
457            match select_expr {
458                Expression::Alias(alias) => {
459                    aliases.insert(alias.alias.name.clone(), alias.this.clone());
460                }
461                Expression::Column(col) => {
462                    aliases.insert(col.name.name.clone(), select_expr.clone());
463                }
464                _ => {}
465            }
466        }
467    }
468
469    // Transform predicate, replacing column references with aliases
470    replace_aliases_recursive(predicate, &aliases)
471}
472
473fn replace_aliases_recursive(
474    expr: Expression,
475    aliases: &HashMap<String, Expression>,
476) -> Expression {
477    match expr {
478        Expression::Column(col) => {
479            if let Some(replacement) = aliases.get(&col.name.name) {
480                replacement.clone()
481            } else {
482                Expression::Column(col)
483            }
484        }
485        Expression::And(bin) => {
486            let left = replace_aliases_recursive(bin.left, aliases);
487            let right = replace_aliases_recursive(bin.right, aliases);
488            Expression::And(Box::new(crate::expressions::BinaryOp {
489                left,
490                right,
491                left_comments: bin.left_comments,
492                operator_comments: bin.operator_comments,
493                trailing_comments: bin.trailing_comments,
494            }))
495        }
496        Expression::Or(bin) => {
497            let left = replace_aliases_recursive(bin.left, aliases);
498            let right = replace_aliases_recursive(bin.right, aliases);
499            Expression::Or(Box::new(crate::expressions::BinaryOp {
500                left,
501                right,
502                left_comments: bin.left_comments,
503                operator_comments: bin.operator_comments,
504                trailing_comments: bin.trailing_comments,
505            }))
506        }
507        Expression::Eq(bin) => {
508            let left = replace_aliases_recursive(bin.left, aliases);
509            let right = replace_aliases_recursive(bin.right, aliases);
510            Expression::Eq(Box::new(crate::expressions::BinaryOp {
511                left,
512                right,
513                left_comments: bin.left_comments,
514                operator_comments: bin.operator_comments,
515                trailing_comments: bin.trailing_comments,
516            }))
517        }
518        Expression::Neq(bin) => {
519            let left = replace_aliases_recursive(bin.left, aliases);
520            let right = replace_aliases_recursive(bin.right, aliases);
521            Expression::Neq(Box::new(crate::expressions::BinaryOp {
522                left,
523                right,
524                left_comments: bin.left_comments,
525                operator_comments: bin.operator_comments,
526                trailing_comments: bin.trailing_comments,
527            }))
528        }
529        Expression::Lt(bin) => {
530            let left = replace_aliases_recursive(bin.left, aliases);
531            let right = replace_aliases_recursive(bin.right, aliases);
532            Expression::Lt(Box::new(crate::expressions::BinaryOp {
533                left,
534                right,
535                left_comments: bin.left_comments,
536                operator_comments: bin.operator_comments,
537                trailing_comments: bin.trailing_comments,
538            }))
539        }
540        Expression::Gt(bin) => {
541            let left = replace_aliases_recursive(bin.left, aliases);
542            let right = replace_aliases_recursive(bin.right, aliases);
543            Expression::Gt(Box::new(crate::expressions::BinaryOp {
544                left,
545                right,
546                left_comments: bin.left_comments,
547                operator_comments: bin.operator_comments,
548                trailing_comments: bin.trailing_comments,
549            }))
550        }
551        Expression::Lte(bin) => {
552            let left = replace_aliases_recursive(bin.left, aliases);
553            let right = replace_aliases_recursive(bin.right, aliases);
554            Expression::Lte(Box::new(crate::expressions::BinaryOp {
555                left,
556                right,
557                left_comments: bin.left_comments,
558                operator_comments: bin.operator_comments,
559                trailing_comments: bin.trailing_comments,
560            }))
561        }
562        Expression::Gte(bin) => {
563            let left = replace_aliases_recursive(bin.left, aliases);
564            let right = replace_aliases_recursive(bin.right, aliases);
565            Expression::Gte(Box::new(crate::expressions::BinaryOp {
566                left,
567                right,
568                left_comments: bin.left_comments,
569                operator_comments: bin.operator_comments,
570                trailing_comments: bin.trailing_comments,
571            }))
572        }
573        Expression::Not(un) => {
574            let inner = replace_aliases_recursive(un.this, aliases);
575            Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
576        }
577        Expression::Paren(paren) => {
578            let inner = replace_aliases_recursive(paren.this, aliases);
579            Expression::Paren(Box::new(crate::expressions::Paren {
580                this: inner,
581                trailing_comments: paren.trailing_comments,
582            }))
583        }
584        other => other,
585    }
586}
587
588/// Create a TRUE literal expression
589pub fn make_true() -> Expression {
590    Expression::Boolean(BooleanLiteral { value: true })
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596    use crate::generator::Generator;
597    use crate::parser::Parser;
598
599    fn gen(expr: &Expression) -> String {
600        Generator::new().generate(expr).unwrap()
601    }
602
603    fn parse(sql: &str) -> Expression {
604        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
605    }
606
607    #[test]
608    fn test_pushdown_simple() {
609        let expr = parse("SELECT a FROM t WHERE a = 1");
610        let result = pushdown_predicates(expr, None);
611        let sql = gen(&result);
612        assert!(sql.contains("WHERE"));
613    }
614
615    #[test]
616    fn test_pushdown_preserves_structure() {
617        let expr = parse("SELECT y.a FROM (SELECT x.a FROM x) AS y WHERE y.a = 1");
618        let result = pushdown_predicates(expr, None);
619        let sql = gen(&result);
620        assert!(sql.contains("SELECT"));
621    }
622
623    #[test]
624    fn test_get_column_table_names() {
625        let expr = parse("SELECT 1 WHERE t.a = 1 AND s.b = 2");
626        if let Expression::Select(select) = &expr {
627            if let Some(where_clause) = &select.where_clause {
628                let tables = get_column_table_names(&where_clause.this);
629                assert!(tables.contains(&"t".to_string()));
630                assert!(tables.contains(&"s".to_string()));
631            }
632        }
633    }
634
635    #[test]
636    fn test_flatten_and() {
637        let expr = parse("SELECT 1 WHERE a = 1 AND b = 2 AND c = 3");
638        if let Expression::Select(select) = &expr {
639            if let Some(where_clause) = &select.where_clause {
640                let predicates = flatten_and(&where_clause.this);
641                assert_eq!(predicates.len(), 3);
642            }
643        }
644    }
645
646    #[test]
647    fn test_flatten_or() {
648        let expr = parse("SELECT 1 WHERE a = 1 OR b = 2 OR c = 3");
649        if let Expression::Select(select) = &expr {
650            if let Some(where_clause) = &select.where_clause {
651                let predicates = flatten_or(&where_clause.this);
652                assert_eq!(predicates.len(), 3);
653            }
654        }
655    }
656
657    #[test]
658    fn test_replace_aliases() {
659        let source = parse("SELECT x.a AS col_a FROM x");
660        let predicate = parse("SELECT 1 WHERE col_a = 1");
661
662        if let Expression::Select(select) = &predicate {
663            if let Some(where_clause) = &select.where_clause {
664                let replaced = replace_aliases(&source, where_clause.this.clone());
665                // The alias should be replaced
666                let sql = gen(&replaced);
667                assert!(sql.contains("="));
668            }
669        }
670    }
671
672    #[test]
673    fn test_pushdown_with_join() {
674        let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id WHERE t.a = 1");
675        let result = pushdown_predicates(expr, None);
676        let sql = gen(&result);
677        assert!(sql.contains("JOIN"));
678    }
679
680    #[test]
681    fn test_pushdown_complex_and() {
682        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2 AND c < 3");
683        let result = pushdown_predicates(expr, None);
684        let sql = gen(&result);
685        assert!(sql.contains("AND"));
686    }
687
688    #[test]
689    fn test_pushdown_complex_or() {
690        let expr = parse("SELECT 1 WHERE a = 1 OR b = 2");
691        let result = pushdown_predicates(expr, None);
692        let sql = gen(&result);
693        assert!(sql.contains("OR"));
694    }
695
696    #[test]
697    fn test_normalized_dnf_simple() {
698        // a = 1 is in both CNF and DNF form
699        let expr = parse("SELECT 1 WHERE a = 1");
700        if let Expression::Select(select) = &expr {
701            if let Some(where_clause) = &select.where_clause {
702                // Check DNF: pass true for dnf flag
703                assert!(normalized(&where_clause.this, true));
704            }
705        }
706    }
707
708    #[test]
709    fn test_make_true() {
710        let t = make_true();
711        let sql = gen(&t);
712        assert_eq!(sql, "TRUE");
713    }
714}