1use std::collections::{HashMap, HashSet};
10
11use crate::dialects::DialectType;
12use crate::expressions::{AggregateFunction, Alias, Expression, Identifier, Literal};
13use crate::scope::{build_scope, traverse_scope, Scope};
14
15const SELECT_ALL: &str = "__SELECT_ALL__";
17
18pub fn pushdown_projections(
37 expression: Expression,
38 _dialect: Option<DialectType>,
39 remove_unused_selections: bool,
40) -> Expression {
41 let _root = build_scope(&expression);
42
43 let mut referenced_columns: HashMap<u64, HashSet<String>> = HashMap::new();
45 let source_column_alias_count: HashMap<u64, usize> = HashMap::new();
46
47 let scopes = traverse_scope(&expression);
49
50 for scope in scopes.iter().rev() {
51 let scope_id = scope as *const Scope as u64;
52 let parent_selections = referenced_columns
53 .get(&scope_id)
54 .cloned()
55 .unwrap_or_else(|| {
56 let mut set = HashSet::new();
57 set.insert(SELECT_ALL.to_string());
58 set
59 });
60
61 let alias_count = source_column_alias_count
62 .get(&scope_id)
63 .copied()
64 .unwrap_or(0);
65
66 let has_distinct = if let Expression::Select(ref select) = scope.expression {
68 select.distinct || select.distinct_on.is_some()
69 } else {
70 false
71 };
72
73 let parent_selections = if has_distinct {
74 let mut set = HashSet::new();
75 set.insert(SELECT_ALL.to_string());
76 set
77 } else {
78 parent_selections
79 };
80
81 process_set_operations(&scope, &parent_selections, &mut referenced_columns);
83
84 if let Expression::Select(ref select) = scope.expression {
86 if remove_unused_selections {
87 let _selections_to_keep =
90 get_selections_to_keep(select, &parent_selections, alias_count);
91 }
92
93 let is_star = select
95 .expressions
96 .iter()
97 .any(|e| matches!(e, Expression::Star(_)));
98 if is_star {
99 continue;
100 }
101
102 let mut selects: HashMap<String, HashSet<String>> = HashMap::new();
104 for col_expr in &select.expressions {
105 collect_column_refs(col_expr, &mut selects);
106 }
107
108 for source_name in scope.sources.keys() {
110 let columns = selects.get(source_name).cloned().unwrap_or_default();
111
112 for child_scope in collect_child_scopes(&scope) {
114 let child_id = child_scope as *const Scope as u64;
115 referenced_columns
116 .entry(child_id)
117 .or_insert_with(HashSet::new)
118 .extend(columns.clone());
119 }
120 }
121 }
122 }
123
124 expression
127}
128
129fn process_set_operations(
131 scope: &Scope,
132 parent_selections: &HashSet<String>,
133 referenced_columns: &mut HashMap<u64, HashSet<String>>,
134) {
135 match &scope.expression {
136 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_) => {
137 for child_scope in &scope.union_scopes {
139 let child_id = child_scope as *const Scope as u64;
140 referenced_columns
141 .entry(child_id)
142 .or_insert_with(HashSet::new)
143 .extend(parent_selections.clone());
144 }
145 }
146 _ => {}
147 }
148}
149
150fn get_selections_to_keep(
152 select: &crate::expressions::Select,
153 parent_selections: &HashSet<String>,
154 mut alias_count: usize,
155) -> Vec<usize> {
156 let mut keep_indices = Vec::new();
157 let select_all = parent_selections.contains(SELECT_ALL);
158
159 let order_refs: HashSet<String> = select
161 .order_by
162 .as_ref()
163 .map(|o| get_order_by_column_refs(&o.expressions))
164 .unwrap_or_default();
165
166 for (i, selection) in select.expressions.iter().enumerate() {
167 let name = get_alias_or_name(selection);
168
169 if select_all
170 || parent_selections.contains(&name)
171 || order_refs.contains(&name)
172 || alias_count > 0
173 {
174 keep_indices.push(i);
175 if alias_count > 0 {
176 alias_count -= 1;
177 }
178 }
179 }
180
181 if keep_indices.is_empty() {
183 keep_indices.push(0);
185 }
186
187 keep_indices
188}
189
190fn get_order_by_column_refs(ordered_exprs: &[crate::expressions::Ordered]) -> HashSet<String> {
192 let mut refs = HashSet::new();
193 for ordered in ordered_exprs {
194 collect_unqualified_column_names(&ordered.this, &mut refs);
195 }
196 refs
197}
198
199fn collect_unqualified_column_names(expr: &Expression, names: &mut HashSet<String>) {
201 match expr {
202 Expression::Column(col) => {
203 if col.table.is_none() {
204 names.insert(col.name.name.clone());
205 }
206 }
207 Expression::And(bin) | Expression::Or(bin) => {
208 collect_unqualified_column_names(&bin.left, names);
209 collect_unqualified_column_names(&bin.right, names);
210 }
211 Expression::Function(func) => {
212 for arg in &func.args {
213 collect_unqualified_column_names(arg, names);
214 }
215 }
216 Expression::AggregateFunction(agg) => {
217 for arg in &agg.args {
218 collect_unqualified_column_names(arg, names);
219 }
220 }
221 Expression::Paren(p) => {
222 collect_unqualified_column_names(&p.this, names);
223 }
224 _ => {}
225 }
226}
227
228fn get_alias_or_name(expr: &Expression) -> String {
230 match expr {
231 Expression::Alias(alias) => alias.alias.name.clone(),
232 Expression::Column(col) => col.name.name.clone(),
233 _ => String::new(),
234 }
235}
236
237fn collect_column_refs(expr: &Expression, selects: &mut HashMap<String, HashSet<String>>) {
239 match expr {
240 Expression::Column(col) => {
241 if let Some(ref table) = col.table {
242 selects
243 .entry(table.name.clone())
244 .or_insert_with(HashSet::new)
245 .insert(col.name.name.clone());
246 }
247 }
248 Expression::Alias(alias) => {
249 collect_column_refs(&alias.this, selects);
250 }
251 Expression::Function(func) => {
252 for arg in &func.args {
253 collect_column_refs(arg, selects);
254 }
255 }
256 Expression::AggregateFunction(agg) => {
257 for arg in &agg.args {
258 collect_column_refs(arg, selects);
259 }
260 }
261 Expression::And(bin) | Expression::Or(bin) => {
262 collect_column_refs(&bin.left, selects);
263 collect_column_refs(&bin.right, selects);
264 }
265 Expression::Eq(bin)
266 | Expression::Neq(bin)
267 | Expression::Lt(bin)
268 | Expression::Lte(bin)
269 | Expression::Gt(bin)
270 | Expression::Gte(bin)
271 | Expression::Add(bin)
272 | Expression::Sub(bin)
273 | Expression::Mul(bin)
274 | Expression::Div(bin) => {
275 collect_column_refs(&bin.left, selects);
276 collect_column_refs(&bin.right, selects);
277 }
278 Expression::Paren(p) => {
279 collect_column_refs(&p.this, selects);
280 }
281 Expression::Case(case) => {
282 if let Some(ref operand) = case.operand {
283 collect_column_refs(operand, selects);
284 }
285 for (when, then) in &case.whens {
286 collect_column_refs(when, selects);
287 collect_column_refs(then, selects);
288 }
289 if let Some(ref else_) = case.else_ {
290 collect_column_refs(else_, selects);
291 }
292 }
293 _ => {}
294 }
295}
296
297fn collect_child_scopes(scope: &Scope) -> Vec<&Scope> {
299 let mut children = Vec::new();
300 children.extend(&scope.subquery_scopes);
301 children.extend(&scope.derived_table_scopes);
302 children.extend(&scope.cte_scopes);
303 children.extend(&scope.union_scopes);
304 children
305}
306
307pub fn default_selection(is_agg: bool) -> Expression {
309 if is_agg {
310 Expression::Alias(Box::new(Alias {
312 this: Expression::AggregateFunction(Box::new(AggregateFunction {
313 name: "MAX".to_string(),
314 args: vec![Expression::Literal(Literal::Number("1".to_string()))],
315 distinct: false,
316 filter: None,
317 order_by: Vec::new(),
318 limit: None,
319 ignore_nulls: None,
320 inferred_type: None,
321 })),
322 alias: Identifier {
323 name: "_".to_string(),
324 quoted: false,
325 trailing_comments: vec![],
326 span: None,
327 },
328 column_aliases: vec![],
329 pre_alias_comments: vec![],
330 trailing_comments: vec![],
331 inferred_type: None,
332 }))
333 } else {
334 Expression::Alias(Box::new(Alias {
336 this: Expression::Literal(Literal::Number("1".to_string())),
337 alias: Identifier {
338 name: "_".to_string(),
339 quoted: false,
340 trailing_comments: vec![],
341 span: None,
342 },
343 column_aliases: vec![],
344 pre_alias_comments: vec![],
345 trailing_comments: vec![],
346 inferred_type: None,
347 }))
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::generator::Generator;
355 use crate::parser::Parser;
356
357 fn gen(expr: &Expression) -> String {
358 Generator::new().generate(expr).unwrap()
359 }
360
361 fn parse(sql: &str) -> Expression {
362 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
363 }
364
365 #[test]
366 fn test_pushdown_simple() {
367 let expr = parse("SELECT a FROM t");
368 let result = pushdown_projections(expr, None, true);
369 let sql = gen(&result);
370 assert!(sql.contains("SELECT"));
371 }
372
373 #[test]
374 fn test_pushdown_preserves_structure() {
375 let expr = parse("SELECT y.a FROM (SELECT x.a, x.b FROM x) AS y");
376 let result = pushdown_projections(expr, None, true);
377 let sql = gen(&result);
378 assert!(sql.contains("SELECT"));
379 }
380
381 #[test]
382 fn test_get_alias_or_name_alias() {
383 let expr = parse("SELECT a AS col_a FROM t");
384 if let Expression::Select(select) = &expr {
385 if let Some(first) = select.expressions.first() {
386 let name = get_alias_or_name(first);
387 assert_eq!(name, "col_a");
388 }
389 }
390 }
391
392 #[test]
393 fn test_get_alias_or_name_column() {
394 let expr = parse("SELECT a FROM t");
395 if let Expression::Select(select) = &expr {
396 if let Some(first) = select.expressions.first() {
397 let name = get_alias_or_name(first);
398 assert_eq!(name, "a");
399 }
400 }
401 }
402
403 #[test]
404 fn test_collect_column_refs() {
405 let expr = parse("SELECT t.a, t.b, s.c FROM t, s");
406 if let Expression::Select(select) = &expr {
407 let mut refs: HashMap<String, HashSet<String>> = HashMap::new();
408 for sel in &select.expressions {
409 collect_column_refs(sel, &mut refs);
410 }
411 assert!(refs.contains_key("t"));
412 assert!(refs.contains_key("s"));
413 assert!(refs.get("t").unwrap().contains("a"));
414 assert!(refs.get("t").unwrap().contains("b"));
415 assert!(refs.get("s").unwrap().contains("c"));
416 }
417 }
418
419 #[test]
420 fn test_default_selection_non_agg() {
421 let sel = default_selection(false);
422 let sql = gen(&sel);
423 assert!(sql.contains("1"));
424 assert!(sql.contains("AS"));
425 }
426
427 #[test]
428 fn test_default_selection_agg() {
429 let sel = default_selection(true);
430 let sql = gen(&sel);
431 assert!(sql.contains("MAX"));
432 assert!(sql.contains("AS"));
433 }
434
435 #[test]
436 fn test_pushdown_with_distinct() {
437 let expr = parse("SELECT DISTINCT a FROM t");
438 let result = pushdown_projections(expr, None, true);
439 let sql = gen(&result);
440 assert!(sql.contains("DISTINCT"));
441 }
442
443 #[test]
444 fn test_pushdown_with_star() {
445 let expr = parse("SELECT * FROM t");
446 let result = pushdown_projections(expr, None, true);
447 let sql = gen(&result);
448 assert!(sql.contains("*"));
449 }
450
451 #[test]
452 fn test_pushdown_subquery() {
453 let expr = parse("SELECT y.a FROM (SELECT a, b FROM x) AS y");
454 let result = pushdown_projections(expr, None, true);
455 let sql = gen(&result);
456 assert!(sql.contains("SELECT"));
457 }
458
459 #[test]
460 fn test_pushdown_union() {
461 let expr = parse("SELECT a FROM t UNION SELECT a FROM s");
462 let result = pushdown_projections(expr, None, true);
463 let sql = gen(&result);
464 assert!(sql.contains("UNION"));
465 }
466}