Skip to main content

polyglot_sql/optimizer/
pushdown_projections.rs

1//! Projection Pushdown Module
2//!
3//! This module provides functionality for removing unused column projections
4//! from SQL queries. When a subquery selects columns that are never used by
5//! the outer query, those columns can be eliminated to reduce data processing.
6//!
7//! Ported from sqlglot's optimizer/pushdown_projections.py
8
9use std::collections::{HashMap, HashSet};
10
11use crate::dialects::DialectType;
12use crate::expressions::{Alias, AggregateFunction, Expression, Identifier, Literal};
13use crate::scope::{build_scope, traverse_scope, Scope};
14
15/// Sentinel value indicating all columns are selected
16const SELECT_ALL: &str = "__SELECT_ALL__";
17
18/// Rewrite SQL AST to remove unused column projections.
19///
20/// # Example
21///
22/// ```sql
23/// -- Before:
24/// SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y
25/// -- After:
26/// SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y
27/// ```
28///
29/// # Arguments
30/// * `expression` - The expression to optimize
31/// * `dialect` - Optional dialect for dialect-specific behavior
32/// * `remove_unused_selections` - Whether to actually remove unused selections
33///
34/// # Returns
35/// The optimized expression with unused projections removed
36pub fn pushdown_projections(
37    expression: Expression,
38    _dialect: Option<DialectType>,
39    remove_unused_selections: bool,
40) -> Expression {
41    let _root = build_scope(&expression);
42
43    // Map of scope to columns being selected by outer queries
44    let mut referenced_columns: HashMap<u64, HashSet<String>> = HashMap::new();
45    let source_column_alias_count: HashMap<u64, usize> = HashMap::new();
46
47    // Collect all scopes and process in reverse order (bottom-up)
48    let scopes = traverse_scope(&expression);
49
50    for scope in scopes.iter().rev() {
51        let scope_id = scope as *const Scope as u64;
52        let parent_selections = referenced_columns.get(&scope_id)
53            .cloned()
54            .unwrap_or_else(|| {
55                let mut set = HashSet::new();
56                set.insert(SELECT_ALL.to_string());
57                set
58            });
59
60        let alias_count = source_column_alias_count.get(&scope_id).copied().unwrap_or(0);
61
62        // Check for DISTINCT - can't optimize if present
63        let has_distinct = if let Expression::Select(ref select) = scope.expression {
64            select.distinct || select.distinct_on.is_some()
65        } else {
66            false
67        };
68
69        let parent_selections = if has_distinct {
70            let mut set = HashSet::new();
71            set.insert(SELECT_ALL.to_string());
72            set
73        } else {
74            parent_selections
75        };
76
77        // Handle set operations (UNION, INTERSECT, EXCEPT)
78        process_set_operations(&scope, &parent_selections, &mut referenced_columns);
79
80        // Handle SELECT statements
81        if let Expression::Select(ref select) = scope.expression {
82            if remove_unused_selections {
83                // Note: actual removal would require mutable access to expression
84                // For now, we just track what would be removed
85                let _selections_to_keep = get_selections_to_keep(
86                    select,
87                    &parent_selections,
88                    alias_count,
89                );
90            }
91
92            // Check if SELECT *
93            let is_star = select.expressions.iter().any(|e| matches!(e, Expression::Star(_)));
94            if is_star {
95                continue;
96            }
97
98            // Group columns by source name
99            let mut selects: HashMap<String, HashSet<String>> = HashMap::new();
100            for col_expr in &select.expressions {
101                collect_column_refs(col_expr, &mut selects);
102            }
103
104            // Push selected columns down to child scopes
105            for source_name in scope.sources.keys() {
106                let columns = selects.get(source_name).cloned().unwrap_or_default();
107
108                // Find the child scope for this source
109                for child_scope in collect_child_scopes(&scope) {
110                    let child_id = child_scope as *const Scope as u64;
111                    referenced_columns
112                        .entry(child_id)
113                        .or_insert_with(HashSet::new)
114                        .extend(columns.clone());
115                }
116            }
117        }
118    }
119
120    // In a full implementation, we would modify the expression tree
121    // For now, return unchanged
122    expression
123}
124
125/// Process set operations (UNION, INTERSECT, EXCEPT)
126fn process_set_operations(
127    scope: &Scope,
128    parent_selections: &HashSet<String>,
129    referenced_columns: &mut HashMap<u64, HashSet<String>>,
130) {
131    match &scope.expression {
132        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_) => {
133            // Propagate parent selections to both sides of set operation
134            for child_scope in &scope.union_scopes {
135                let child_id = child_scope as *const Scope as u64;
136                referenced_columns
137                    .entry(child_id)
138                    .or_insert_with(HashSet::new)
139                    .extend(parent_selections.clone());
140            }
141        }
142        _ => {}
143    }
144}
145
146/// Get the list of selections that should be kept
147fn get_selections_to_keep(
148    select: &crate::expressions::Select,
149    parent_selections: &HashSet<String>,
150    mut alias_count: usize,
151) -> Vec<usize> {
152    let mut keep_indices = Vec::new();
153    let select_all = parent_selections.contains(SELECT_ALL);
154
155    // Get ORDER BY column references (unqualified columns)
156    let order_refs: HashSet<String> = select.order_by
157        .as_ref()
158        .map(|o| get_order_by_column_refs(&o.expressions))
159        .unwrap_or_default();
160
161    for (i, selection) in select.expressions.iter().enumerate() {
162        let name = get_alias_or_name(selection);
163
164        if select_all
165            || parent_selections.contains(&name)
166            || order_refs.contains(&name)
167            || alias_count > 0
168        {
169            keep_indices.push(i);
170            if alias_count > 0 {
171                alias_count -= 1;
172            }
173        }
174    }
175
176    // If no selections remain, we need at least one
177    if keep_indices.is_empty() {
178        // Would add a default selection like "1 AS _"
179        keep_indices.push(0);
180    }
181
182    keep_indices
183}
184
185/// Get column references from ORDER BY expressions
186fn get_order_by_column_refs(ordered_exprs: &[crate::expressions::Ordered]) -> HashSet<String> {
187    let mut refs = HashSet::new();
188    for ordered in ordered_exprs {
189        collect_unqualified_column_names(&ordered.this, &mut refs);
190    }
191    refs
192}
193
194/// Collect unqualified column names from an expression
195fn collect_unqualified_column_names(expr: &Expression, names: &mut HashSet<String>) {
196    match expr {
197        Expression::Column(col) => {
198            if col.table.is_none() {
199                names.insert(col.name.name.clone());
200            }
201        }
202        Expression::And(bin) | Expression::Or(bin) => {
203            collect_unqualified_column_names(&bin.left, names);
204            collect_unqualified_column_names(&bin.right, names);
205        }
206        Expression::Function(func) => {
207            for arg in &func.args {
208                collect_unqualified_column_names(arg, names);
209            }
210        }
211        Expression::AggregateFunction(agg) => {
212            for arg in &agg.args {
213                collect_unqualified_column_names(arg, names);
214            }
215        }
216        Expression::Paren(p) => {
217            collect_unqualified_column_names(&p.this, names);
218        }
219        _ => {}
220    }
221}
222
223/// Get the alias or name from a selection expression
224fn get_alias_or_name(expr: &Expression) -> String {
225    match expr {
226        Expression::Alias(alias) => alias.alias.name.clone(),
227        Expression::Column(col) => col.name.name.clone(),
228        _ => String::new(),
229    }
230}
231
232/// Collect column references grouped by table name
233fn collect_column_refs(expr: &Expression, selects: &mut HashMap<String, HashSet<String>>) {
234    match expr {
235        Expression::Column(col) => {
236            if let Some(ref table) = col.table {
237                selects
238                    .entry(table.name.clone())
239                    .or_insert_with(HashSet::new)
240                    .insert(col.name.name.clone());
241            }
242        }
243        Expression::Alias(alias) => {
244            collect_column_refs(&alias.this, selects);
245        }
246        Expression::Function(func) => {
247            for arg in &func.args {
248                collect_column_refs(arg, selects);
249            }
250        }
251        Expression::AggregateFunction(agg) => {
252            for arg in &agg.args {
253                collect_column_refs(arg, selects);
254            }
255        }
256        Expression::And(bin) | Expression::Or(bin) => {
257            collect_column_refs(&bin.left, selects);
258            collect_column_refs(&bin.right, selects);
259        }
260        Expression::Eq(bin) | Expression::Neq(bin) | Expression::Lt(bin) |
261        Expression::Lte(bin) | Expression::Gt(bin) | Expression::Gte(bin) |
262        Expression::Add(bin) | Expression::Sub(bin) | Expression::Mul(bin) |
263        Expression::Div(bin) => {
264            collect_column_refs(&bin.left, selects);
265            collect_column_refs(&bin.right, selects);
266        }
267        Expression::Paren(p) => {
268            collect_column_refs(&p.this, selects);
269        }
270        Expression::Case(case) => {
271            if let Some(ref operand) = case.operand {
272                collect_column_refs(operand, selects);
273            }
274            for (when, then) in &case.whens {
275                collect_column_refs(when, selects);
276                collect_column_refs(then, selects);
277            }
278            if let Some(ref else_) = case.else_ {
279                collect_column_refs(else_, selects);
280            }
281        }
282        _ => {}
283    }
284}
285
286/// Collect all child scopes
287fn collect_child_scopes(scope: &Scope) -> Vec<&Scope> {
288    let mut children = Vec::new();
289    children.extend(&scope.subquery_scopes);
290    children.extend(&scope.derived_table_scopes);
291    children.extend(&scope.cte_scopes);
292    children.extend(&scope.union_scopes);
293    children
294}
295
296/// Create a default selection when all others are removed
297pub fn default_selection(is_agg: bool) -> Expression {
298    if is_agg {
299        // MAX(1) AS _
300        Expression::Alias(Box::new(Alias {
301            this: Expression::AggregateFunction(Box::new(AggregateFunction {
302                name: "MAX".to_string(),
303                args: vec![Expression::Literal(Literal::Number("1".to_string()))],
304                distinct: false,
305                filter: None,
306                order_by: Vec::new(),
307                limit: None,
308                ignore_nulls: None,
309            })),
310            alias: Identifier {
311                name: "_".to_string(),
312                quoted: false,
313                trailing_comments: vec![],
314            },
315            column_aliases: vec![],
316            pre_alias_comments: vec![],
317            trailing_comments: vec![],
318        }))
319    } else {
320        // 1 AS _
321        Expression::Alias(Box::new(Alias {
322            this: Expression::Literal(Literal::Number("1".to_string())),
323            alias: Identifier {
324                name: "_".to_string(),
325                quoted: false,
326                trailing_comments: vec![],
327            },
328            column_aliases: vec![],
329            pre_alias_comments: vec![],
330            trailing_comments: vec![],
331        }))
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::generator::Generator;
339    use crate::parser::Parser;
340
341    fn gen(expr: &Expression) -> String {
342        Generator::new().generate(expr).unwrap()
343    }
344
345    fn parse(sql: &str) -> Expression {
346        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
347    }
348
349    #[test]
350    fn test_pushdown_simple() {
351        let expr = parse("SELECT a FROM t");
352        let result = pushdown_projections(expr, None, true);
353        let sql = gen(&result);
354        assert!(sql.contains("SELECT"));
355    }
356
357    #[test]
358    fn test_pushdown_preserves_structure() {
359        let expr = parse("SELECT y.a FROM (SELECT x.a, x.b FROM x) AS y");
360        let result = pushdown_projections(expr, None, true);
361        let sql = gen(&result);
362        assert!(sql.contains("SELECT"));
363    }
364
365    #[test]
366    fn test_get_alias_or_name_alias() {
367        let expr = parse("SELECT a AS col_a FROM t");
368        if let Expression::Select(select) = &expr {
369            if let Some(first) = select.expressions.first() {
370                let name = get_alias_or_name(first);
371                assert_eq!(name, "col_a");
372            }
373        }
374    }
375
376    #[test]
377    fn test_get_alias_or_name_column() {
378        let expr = parse("SELECT a FROM t");
379        if let Expression::Select(select) = &expr {
380            if let Some(first) = select.expressions.first() {
381                let name = get_alias_or_name(first);
382                assert_eq!(name, "a");
383            }
384        }
385    }
386
387    #[test]
388    fn test_collect_column_refs() {
389        let expr = parse("SELECT t.a, t.b, s.c FROM t, s");
390        if let Expression::Select(select) = &expr {
391            let mut refs: HashMap<String, HashSet<String>> = HashMap::new();
392            for sel in &select.expressions {
393                collect_column_refs(sel, &mut refs);
394            }
395            assert!(refs.contains_key("t"));
396            assert!(refs.contains_key("s"));
397            assert!(refs.get("t").unwrap().contains("a"));
398            assert!(refs.get("t").unwrap().contains("b"));
399            assert!(refs.get("s").unwrap().contains("c"));
400        }
401    }
402
403    #[test]
404    fn test_default_selection_non_agg() {
405        let sel = default_selection(false);
406        let sql = gen(&sel);
407        assert!(sql.contains("1"));
408        assert!(sql.contains("AS"));
409    }
410
411    #[test]
412    fn test_default_selection_agg() {
413        let sel = default_selection(true);
414        let sql = gen(&sel);
415        assert!(sql.contains("MAX"));
416        assert!(sql.contains("AS"));
417    }
418
419    #[test]
420    fn test_pushdown_with_distinct() {
421        let expr = parse("SELECT DISTINCT a FROM t");
422        let result = pushdown_projections(expr, None, true);
423        let sql = gen(&result);
424        assert!(sql.contains("DISTINCT"));
425    }
426
427    #[test]
428    fn test_pushdown_with_star() {
429        let expr = parse("SELECT * FROM t");
430        let result = pushdown_projections(expr, None, true);
431        let sql = gen(&result);
432        assert!(sql.contains("*"));
433    }
434
435    #[test]
436    fn test_pushdown_subquery() {
437        let expr = parse("SELECT y.a FROM (SELECT a, b FROM x) AS y");
438        let result = pushdown_projections(expr, None, true);
439        let sql = gen(&result);
440        assert!(sql.contains("SELECT"));
441    }
442
443    #[test]
444    fn test_pushdown_union() {
445        let expr = parse("SELECT a FROM t UNION SELECT a FROM s");
446        let result = pushdown_projections(expr, None, true);
447        let sql = gen(&result);
448        assert!(sql.contains("UNION"));
449    }
450}