Skip to main content

polyglot_sql/optimizer/
pushdown_projections.rs

1//! Projection Pushdown Module
2//!
3//! This module provides functionality for removing unused column projections
4//! from SQL queries. When a subquery selects columns that are never used by
5//! the outer query, those columns can be eliminated to reduce data processing.
6//!
7//! Ported from sqlglot's optimizer/pushdown_projections.py
8
9use std::collections::{HashMap, HashSet};
10
11use crate::dialects::DialectType;
12use crate::expressions::{AggregateFunction, Alias, Expression, Identifier, Literal};
13use crate::scope::{build_scope, traverse_scope, Scope};
14
15/// Sentinel value indicating all columns are selected
16const SELECT_ALL: &str = "__SELECT_ALL__";
17
18/// Rewrite SQL AST to remove unused column projections.
19///
20/// # Example
21///
22/// ```sql
23/// -- Before:
24/// SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y
25/// -- After:
26/// SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y
27/// ```
28///
29/// # Arguments
30/// * `expression` - The expression to optimize
31/// * `dialect` - Optional dialect for dialect-specific behavior
32/// * `remove_unused_selections` - Whether to actually remove unused selections
33///
34/// # Returns
35/// The optimized expression with unused projections removed
36pub fn pushdown_projections(
37    expression: Expression,
38    _dialect: Option<DialectType>,
39    remove_unused_selections: bool,
40) -> Expression {
41    let _root = build_scope(&expression);
42
43    // Map of scope to columns being selected by outer queries
44    let mut referenced_columns: HashMap<u64, HashSet<String>> = HashMap::new();
45    let source_column_alias_count: HashMap<u64, usize> = HashMap::new();
46
47    // Collect all scopes and process in reverse order (bottom-up)
48    let scopes = traverse_scope(&expression);
49
50    for scope in scopes.iter().rev() {
51        let scope_id = scope as *const Scope as u64;
52        let parent_selections = referenced_columns
53            .get(&scope_id)
54            .cloned()
55            .unwrap_or_else(|| {
56                let mut set = HashSet::new();
57                set.insert(SELECT_ALL.to_string());
58                set
59            });
60
61        let alias_count = source_column_alias_count
62            .get(&scope_id)
63            .copied()
64            .unwrap_or(0);
65
66        // Check for DISTINCT - can't optimize if present
67        let has_distinct = if let Expression::Select(ref select) = scope.expression {
68            select.distinct || select.distinct_on.is_some()
69        } else {
70            false
71        };
72
73        let parent_selections = if has_distinct {
74            let mut set = HashSet::new();
75            set.insert(SELECT_ALL.to_string());
76            set
77        } else {
78            parent_selections
79        };
80
81        // Handle set operations (UNION, INTERSECT, EXCEPT)
82        process_set_operations(&scope, &parent_selections, &mut referenced_columns);
83
84        // Handle SELECT statements
85        if let Expression::Select(ref select) = scope.expression {
86            if remove_unused_selections {
87                // Note: actual removal would require mutable access to expression
88                // For now, we just track what would be removed
89                let _selections_to_keep =
90                    get_selections_to_keep(select, &parent_selections, alias_count);
91            }
92
93            // Check if SELECT *
94            let is_star = select
95                .expressions
96                .iter()
97                .any(|e| matches!(e, Expression::Star(_)));
98            if is_star {
99                continue;
100            }
101
102            // Group columns by source name
103            let mut selects: HashMap<String, HashSet<String>> = HashMap::new();
104            for col_expr in &select.expressions {
105                collect_column_refs(col_expr, &mut selects);
106            }
107
108            // Push selected columns down to child scopes
109            for source_name in scope.sources.keys() {
110                let columns = selects.get(source_name).cloned().unwrap_or_default();
111
112                // Find the child scope for this source
113                for child_scope in collect_child_scopes(&scope) {
114                    let child_id = child_scope as *const Scope as u64;
115                    referenced_columns
116                        .entry(child_id)
117                        .or_insert_with(HashSet::new)
118                        .extend(columns.clone());
119                }
120            }
121        }
122    }
123
124    // In a full implementation, we would modify the expression tree
125    // For now, return unchanged
126    expression
127}
128
129/// Process set operations (UNION, INTERSECT, EXCEPT)
130fn process_set_operations(
131    scope: &Scope,
132    parent_selections: &HashSet<String>,
133    referenced_columns: &mut HashMap<u64, HashSet<String>>,
134) {
135    match &scope.expression {
136        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_) => {
137            // Propagate parent selections to both sides of set operation
138            for child_scope in &scope.union_scopes {
139                let child_id = child_scope as *const Scope as u64;
140                referenced_columns
141                    .entry(child_id)
142                    .or_insert_with(HashSet::new)
143                    .extend(parent_selections.clone());
144            }
145        }
146        _ => {}
147    }
148}
149
150/// Get the list of selections that should be kept
151fn get_selections_to_keep(
152    select: &crate::expressions::Select,
153    parent_selections: &HashSet<String>,
154    mut alias_count: usize,
155) -> Vec<usize> {
156    let mut keep_indices = Vec::new();
157    let select_all = parent_selections.contains(SELECT_ALL);
158
159    // Get ORDER BY column references (unqualified columns)
160    let order_refs: HashSet<String> = select
161        .order_by
162        .as_ref()
163        .map(|o| get_order_by_column_refs(&o.expressions))
164        .unwrap_or_default();
165
166    for (i, selection) in select.expressions.iter().enumerate() {
167        let name = get_alias_or_name(selection);
168
169        if select_all
170            || parent_selections.contains(&name)
171            || order_refs.contains(&name)
172            || alias_count > 0
173        {
174            keep_indices.push(i);
175            if alias_count > 0 {
176                alias_count -= 1;
177            }
178        }
179    }
180
181    // If no selections remain, we need at least one
182    if keep_indices.is_empty() {
183        // Would add a default selection like "1 AS _"
184        keep_indices.push(0);
185    }
186
187    keep_indices
188}
189
190/// Get column references from ORDER BY expressions
191fn get_order_by_column_refs(ordered_exprs: &[crate::expressions::Ordered]) -> HashSet<String> {
192    let mut refs = HashSet::new();
193    for ordered in ordered_exprs {
194        collect_unqualified_column_names(&ordered.this, &mut refs);
195    }
196    refs
197}
198
199/// Collect unqualified column names from an expression
200fn collect_unqualified_column_names(expr: &Expression, names: &mut HashSet<String>) {
201    match expr {
202        Expression::Column(col) => {
203            if col.table.is_none() {
204                names.insert(col.name.name.clone());
205            }
206        }
207        Expression::And(bin) | Expression::Or(bin) => {
208            collect_unqualified_column_names(&bin.left, names);
209            collect_unqualified_column_names(&bin.right, names);
210        }
211        Expression::Function(func) => {
212            for arg in &func.args {
213                collect_unqualified_column_names(arg, names);
214            }
215        }
216        Expression::AggregateFunction(agg) => {
217            for arg in &agg.args {
218                collect_unqualified_column_names(arg, names);
219            }
220        }
221        Expression::Paren(p) => {
222            collect_unqualified_column_names(&p.this, names);
223        }
224        _ => {}
225    }
226}
227
228/// Get the alias or name from a selection expression
229fn get_alias_or_name(expr: &Expression) -> String {
230    match expr {
231        Expression::Alias(alias) => alias.alias.name.clone(),
232        Expression::Column(col) => col.name.name.clone(),
233        _ => String::new(),
234    }
235}
236
237/// Collect column references grouped by table name
238fn collect_column_refs(expr: &Expression, selects: &mut HashMap<String, HashSet<String>>) {
239    match expr {
240        Expression::Column(col) => {
241            if let Some(ref table) = col.table {
242                selects
243                    .entry(table.name.clone())
244                    .or_insert_with(HashSet::new)
245                    .insert(col.name.name.clone());
246            }
247        }
248        Expression::Alias(alias) => {
249            collect_column_refs(&alias.this, selects);
250        }
251        Expression::Function(func) => {
252            for arg in &func.args {
253                collect_column_refs(arg, selects);
254            }
255        }
256        Expression::AggregateFunction(agg) => {
257            for arg in &agg.args {
258                collect_column_refs(arg, selects);
259            }
260        }
261        Expression::And(bin) | Expression::Or(bin) => {
262            collect_column_refs(&bin.left, selects);
263            collect_column_refs(&bin.right, selects);
264        }
265        Expression::Eq(bin)
266        | Expression::Neq(bin)
267        | Expression::Lt(bin)
268        | Expression::Lte(bin)
269        | Expression::Gt(bin)
270        | Expression::Gte(bin)
271        | Expression::Add(bin)
272        | Expression::Sub(bin)
273        | Expression::Mul(bin)
274        | Expression::Div(bin) => {
275            collect_column_refs(&bin.left, selects);
276            collect_column_refs(&bin.right, selects);
277        }
278        Expression::Paren(p) => {
279            collect_column_refs(&p.this, selects);
280        }
281        Expression::Case(case) => {
282            if let Some(ref operand) = case.operand {
283                collect_column_refs(operand, selects);
284            }
285            for (when, then) in &case.whens {
286                collect_column_refs(when, selects);
287                collect_column_refs(then, selects);
288            }
289            if let Some(ref else_) = case.else_ {
290                collect_column_refs(else_, selects);
291            }
292        }
293        _ => {}
294    }
295}
296
297/// Collect all child scopes
298fn collect_child_scopes(scope: &Scope) -> Vec<&Scope> {
299    let mut children = Vec::new();
300    children.extend(&scope.subquery_scopes);
301    children.extend(&scope.derived_table_scopes);
302    children.extend(&scope.cte_scopes);
303    children.extend(&scope.union_scopes);
304    children
305}
306
307/// Create a default selection when all others are removed
308pub fn default_selection(is_agg: bool) -> Expression {
309    if is_agg {
310        // MAX(1) AS _
311        Expression::Alias(Box::new(Alias {
312            this: Expression::AggregateFunction(Box::new(AggregateFunction {
313                name: "MAX".to_string(),
314                args: vec![Expression::Literal(Literal::Number("1".to_string()))],
315                distinct: false,
316                filter: None,
317                order_by: Vec::new(),
318                limit: None,
319                ignore_nulls: None,
320                inferred_type: None,
321            })),
322            alias: Identifier {
323                name: "_".to_string(),
324                quoted: false,
325                trailing_comments: vec![],
326                span: None,
327            },
328            column_aliases: vec![],
329            pre_alias_comments: vec![],
330            trailing_comments: vec![],
331            inferred_type: None,
332        }))
333    } else {
334        // 1 AS _
335        Expression::Alias(Box::new(Alias {
336            this: Expression::Literal(Literal::Number("1".to_string())),
337            alias: Identifier {
338                name: "_".to_string(),
339                quoted: false,
340                trailing_comments: vec![],
341                span: None,
342            },
343            column_aliases: vec![],
344            pre_alias_comments: vec![],
345            trailing_comments: vec![],
346            inferred_type: None,
347        }))
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::generator::Generator;
355    use crate::parser::Parser;
356
357    fn gen(expr: &Expression) -> String {
358        Generator::new().generate(expr).unwrap()
359    }
360
361    fn parse(sql: &str) -> Expression {
362        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
363    }
364
365    #[test]
366    fn test_pushdown_simple() {
367        let expr = parse("SELECT a FROM t");
368        let result = pushdown_projections(expr, None, true);
369        let sql = gen(&result);
370        assert!(sql.contains("SELECT"));
371    }
372
373    #[test]
374    fn test_pushdown_preserves_structure() {
375        let expr = parse("SELECT y.a FROM (SELECT x.a, x.b FROM x) AS y");
376        let result = pushdown_projections(expr, None, true);
377        let sql = gen(&result);
378        assert!(sql.contains("SELECT"));
379    }
380
381    #[test]
382    fn test_get_alias_or_name_alias() {
383        let expr = parse("SELECT a AS col_a FROM t");
384        if let Expression::Select(select) = &expr {
385            if let Some(first) = select.expressions.first() {
386                let name = get_alias_or_name(first);
387                assert_eq!(name, "col_a");
388            }
389        }
390    }
391
392    #[test]
393    fn test_get_alias_or_name_column() {
394        let expr = parse("SELECT a FROM t");
395        if let Expression::Select(select) = &expr {
396            if let Some(first) = select.expressions.first() {
397                let name = get_alias_or_name(first);
398                assert_eq!(name, "a");
399            }
400        }
401    }
402
403    #[test]
404    fn test_collect_column_refs() {
405        let expr = parse("SELECT t.a, t.b, s.c FROM t, s");
406        if let Expression::Select(select) = &expr {
407            let mut refs: HashMap<String, HashSet<String>> = HashMap::new();
408            for sel in &select.expressions {
409                collect_column_refs(sel, &mut refs);
410            }
411            assert!(refs.contains_key("t"));
412            assert!(refs.contains_key("s"));
413            assert!(refs.get("t").unwrap().contains("a"));
414            assert!(refs.get("t").unwrap().contains("b"));
415            assert!(refs.get("s").unwrap().contains("c"));
416        }
417    }
418
419    #[test]
420    fn test_default_selection_non_agg() {
421        let sel = default_selection(false);
422        let sql = gen(&sel);
423        assert!(sql.contains("1"));
424        assert!(sql.contains("AS"));
425    }
426
427    #[test]
428    fn test_default_selection_agg() {
429        let sel = default_selection(true);
430        let sql = gen(&sel);
431        assert!(sql.contains("MAX"));
432        assert!(sql.contains("AS"));
433    }
434
435    #[test]
436    fn test_pushdown_with_distinct() {
437        let expr = parse("SELECT DISTINCT a FROM t");
438        let result = pushdown_projections(expr, None, true);
439        let sql = gen(&result);
440        assert!(sql.contains("DISTINCT"));
441    }
442
443    #[test]
444    fn test_pushdown_with_star() {
445        let expr = parse("SELECT * FROM t");
446        let result = pushdown_projections(expr, None, true);
447        let sql = gen(&result);
448        assert!(sql.contains("*"));
449    }
450
451    #[test]
452    fn test_pushdown_subquery() {
453        let expr = parse("SELECT y.a FROM (SELECT a, b FROM x) AS y");
454        let result = pushdown_projections(expr, None, true);
455        let sql = gen(&result);
456        assert!(sql.contains("SELECT"));
457    }
458
459    #[test]
460    fn test_pushdown_union() {
461        let expr = parse("SELECT a FROM t UNION SELECT a FROM s");
462        let result = pushdown_projections(expr, None, true);
463        let sql = gen(&result);
464        assert!(sql.contains("UNION"));
465    }
466}