1use std::collections::{HashMap, HashSet};
10
11use crate::dialects::DialectType;
12use crate::expressions::{Alias, AggregateFunction, Expression, Identifier, Literal};
13use crate::scope::{build_scope, traverse_scope, Scope};
14
15const SELECT_ALL: &str = "__SELECT_ALL__";
17
18pub fn pushdown_projections(
37 expression: Expression,
38 _dialect: Option<DialectType>,
39 remove_unused_selections: bool,
40) -> Expression {
41 let _root = build_scope(&expression);
42
43 let mut referenced_columns: HashMap<u64, HashSet<String>> = HashMap::new();
45 let source_column_alias_count: HashMap<u64, usize> = HashMap::new();
46
47 let scopes = traverse_scope(&expression);
49
50 for scope in scopes.iter().rev() {
51 let scope_id = scope as *const Scope as u64;
52 let parent_selections = referenced_columns.get(&scope_id)
53 .cloned()
54 .unwrap_or_else(|| {
55 let mut set = HashSet::new();
56 set.insert(SELECT_ALL.to_string());
57 set
58 });
59
60 let alias_count = source_column_alias_count.get(&scope_id).copied().unwrap_or(0);
61
62 let has_distinct = if let Expression::Select(ref select) = scope.expression {
64 select.distinct || select.distinct_on.is_some()
65 } else {
66 false
67 };
68
69 let parent_selections = if has_distinct {
70 let mut set = HashSet::new();
71 set.insert(SELECT_ALL.to_string());
72 set
73 } else {
74 parent_selections
75 };
76
77 process_set_operations(&scope, &parent_selections, &mut referenced_columns);
79
80 if let Expression::Select(ref select) = scope.expression {
82 if remove_unused_selections {
83 let _selections_to_keep = get_selections_to_keep(
86 select,
87 &parent_selections,
88 alias_count,
89 );
90 }
91
92 let is_star = select.expressions.iter().any(|e| matches!(e, Expression::Star(_)));
94 if is_star {
95 continue;
96 }
97
98 let mut selects: HashMap<String, HashSet<String>> = HashMap::new();
100 for col_expr in &select.expressions {
101 collect_column_refs(col_expr, &mut selects);
102 }
103
104 for source_name in scope.sources.keys() {
106 let columns = selects.get(source_name).cloned().unwrap_or_default();
107
108 for child_scope in collect_child_scopes(&scope) {
110 let child_id = child_scope as *const Scope as u64;
111 referenced_columns
112 .entry(child_id)
113 .or_insert_with(HashSet::new)
114 .extend(columns.clone());
115 }
116 }
117 }
118 }
119
120 expression
123}
124
125fn process_set_operations(
127 scope: &Scope,
128 parent_selections: &HashSet<String>,
129 referenced_columns: &mut HashMap<u64, HashSet<String>>,
130) {
131 match &scope.expression {
132 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_) => {
133 for child_scope in &scope.union_scopes {
135 let child_id = child_scope as *const Scope as u64;
136 referenced_columns
137 .entry(child_id)
138 .or_insert_with(HashSet::new)
139 .extend(parent_selections.clone());
140 }
141 }
142 _ => {}
143 }
144}
145
146fn get_selections_to_keep(
148 select: &crate::expressions::Select,
149 parent_selections: &HashSet<String>,
150 mut alias_count: usize,
151) -> Vec<usize> {
152 let mut keep_indices = Vec::new();
153 let select_all = parent_selections.contains(SELECT_ALL);
154
155 let order_refs: HashSet<String> = select.order_by
157 .as_ref()
158 .map(|o| get_order_by_column_refs(&o.expressions))
159 .unwrap_or_default();
160
161 for (i, selection) in select.expressions.iter().enumerate() {
162 let name = get_alias_or_name(selection);
163
164 if select_all
165 || parent_selections.contains(&name)
166 || order_refs.contains(&name)
167 || alias_count > 0
168 {
169 keep_indices.push(i);
170 if alias_count > 0 {
171 alias_count -= 1;
172 }
173 }
174 }
175
176 if keep_indices.is_empty() {
178 keep_indices.push(0);
180 }
181
182 keep_indices
183}
184
185fn get_order_by_column_refs(ordered_exprs: &[crate::expressions::Ordered]) -> HashSet<String> {
187 let mut refs = HashSet::new();
188 for ordered in ordered_exprs {
189 collect_unqualified_column_names(&ordered.this, &mut refs);
190 }
191 refs
192}
193
194fn collect_unqualified_column_names(expr: &Expression, names: &mut HashSet<String>) {
196 match expr {
197 Expression::Column(col) => {
198 if col.table.is_none() {
199 names.insert(col.name.name.clone());
200 }
201 }
202 Expression::And(bin) | Expression::Or(bin) => {
203 collect_unqualified_column_names(&bin.left, names);
204 collect_unqualified_column_names(&bin.right, names);
205 }
206 Expression::Function(func) => {
207 for arg in &func.args {
208 collect_unqualified_column_names(arg, names);
209 }
210 }
211 Expression::AggregateFunction(agg) => {
212 for arg in &agg.args {
213 collect_unqualified_column_names(arg, names);
214 }
215 }
216 Expression::Paren(p) => {
217 collect_unqualified_column_names(&p.this, names);
218 }
219 _ => {}
220 }
221}
222
223fn get_alias_or_name(expr: &Expression) -> String {
225 match expr {
226 Expression::Alias(alias) => alias.alias.name.clone(),
227 Expression::Column(col) => col.name.name.clone(),
228 _ => String::new(),
229 }
230}
231
232fn collect_column_refs(expr: &Expression, selects: &mut HashMap<String, HashSet<String>>) {
234 match expr {
235 Expression::Column(col) => {
236 if let Some(ref table) = col.table {
237 selects
238 .entry(table.name.clone())
239 .or_insert_with(HashSet::new)
240 .insert(col.name.name.clone());
241 }
242 }
243 Expression::Alias(alias) => {
244 collect_column_refs(&alias.this, selects);
245 }
246 Expression::Function(func) => {
247 for arg in &func.args {
248 collect_column_refs(arg, selects);
249 }
250 }
251 Expression::AggregateFunction(agg) => {
252 for arg in &agg.args {
253 collect_column_refs(arg, selects);
254 }
255 }
256 Expression::And(bin) | Expression::Or(bin) => {
257 collect_column_refs(&bin.left, selects);
258 collect_column_refs(&bin.right, selects);
259 }
260 Expression::Eq(bin) | Expression::Neq(bin) | Expression::Lt(bin) |
261 Expression::Lte(bin) | Expression::Gt(bin) | Expression::Gte(bin) |
262 Expression::Add(bin) | Expression::Sub(bin) | Expression::Mul(bin) |
263 Expression::Div(bin) => {
264 collect_column_refs(&bin.left, selects);
265 collect_column_refs(&bin.right, selects);
266 }
267 Expression::Paren(p) => {
268 collect_column_refs(&p.this, selects);
269 }
270 Expression::Case(case) => {
271 if let Some(ref operand) = case.operand {
272 collect_column_refs(operand, selects);
273 }
274 for (when, then) in &case.whens {
275 collect_column_refs(when, selects);
276 collect_column_refs(then, selects);
277 }
278 if let Some(ref else_) = case.else_ {
279 collect_column_refs(else_, selects);
280 }
281 }
282 _ => {}
283 }
284}
285
286fn collect_child_scopes(scope: &Scope) -> Vec<&Scope> {
288 let mut children = Vec::new();
289 children.extend(&scope.subquery_scopes);
290 children.extend(&scope.derived_table_scopes);
291 children.extend(&scope.cte_scopes);
292 children.extend(&scope.union_scopes);
293 children
294}
295
296pub fn default_selection(is_agg: bool) -> Expression {
298 if is_agg {
299 Expression::Alias(Box::new(Alias {
301 this: Expression::AggregateFunction(Box::new(AggregateFunction {
302 name: "MAX".to_string(),
303 args: vec![Expression::Literal(Literal::Number("1".to_string()))],
304 distinct: false,
305 filter: None,
306 order_by: Vec::new(),
307 limit: None,
308 ignore_nulls: None,
309 })),
310 alias: Identifier {
311 name: "_".to_string(),
312 quoted: false,
313 trailing_comments: vec![],
314 },
315 column_aliases: vec![],
316 pre_alias_comments: vec![],
317 trailing_comments: vec![],
318 }))
319 } else {
320 Expression::Alias(Box::new(Alias {
322 this: Expression::Literal(Literal::Number("1".to_string())),
323 alias: Identifier {
324 name: "_".to_string(),
325 quoted: false,
326 trailing_comments: vec![],
327 },
328 column_aliases: vec![],
329 pre_alias_comments: vec![],
330 trailing_comments: vec![],
331 }))
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::generator::Generator;
339 use crate::parser::Parser;
340
341 fn gen(expr: &Expression) -> String {
342 Generator::new().generate(expr).unwrap()
343 }
344
345 fn parse(sql: &str) -> Expression {
346 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
347 }
348
349 #[test]
350 fn test_pushdown_simple() {
351 let expr = parse("SELECT a FROM t");
352 let result = pushdown_projections(expr, None, true);
353 let sql = gen(&result);
354 assert!(sql.contains("SELECT"));
355 }
356
357 #[test]
358 fn test_pushdown_preserves_structure() {
359 let expr = parse("SELECT y.a FROM (SELECT x.a, x.b FROM x) AS y");
360 let result = pushdown_projections(expr, None, true);
361 let sql = gen(&result);
362 assert!(sql.contains("SELECT"));
363 }
364
365 #[test]
366 fn test_get_alias_or_name_alias() {
367 let expr = parse("SELECT a AS col_a FROM t");
368 if let Expression::Select(select) = &expr {
369 if let Some(first) = select.expressions.first() {
370 let name = get_alias_or_name(first);
371 assert_eq!(name, "col_a");
372 }
373 }
374 }
375
376 #[test]
377 fn test_get_alias_or_name_column() {
378 let expr = parse("SELECT a FROM t");
379 if let Expression::Select(select) = &expr {
380 if let Some(first) = select.expressions.first() {
381 let name = get_alias_or_name(first);
382 assert_eq!(name, "a");
383 }
384 }
385 }
386
387 #[test]
388 fn test_collect_column_refs() {
389 let expr = parse("SELECT t.a, t.b, s.c FROM t, s");
390 if let Expression::Select(select) = &expr {
391 let mut refs: HashMap<String, HashSet<String>> = HashMap::new();
392 for sel in &select.expressions {
393 collect_column_refs(sel, &mut refs);
394 }
395 assert!(refs.contains_key("t"));
396 assert!(refs.contains_key("s"));
397 assert!(refs.get("t").unwrap().contains("a"));
398 assert!(refs.get("t").unwrap().contains("b"));
399 assert!(refs.get("s").unwrap().contains("c"));
400 }
401 }
402
403 #[test]
404 fn test_default_selection_non_agg() {
405 let sel = default_selection(false);
406 let sql = gen(&sel);
407 assert!(sql.contains("1"));
408 assert!(sql.contains("AS"));
409 }
410
411 #[test]
412 fn test_default_selection_agg() {
413 let sel = default_selection(true);
414 let sql = gen(&sel);
415 assert!(sql.contains("MAX"));
416 assert!(sql.contains("AS"));
417 }
418
419 #[test]
420 fn test_pushdown_with_distinct() {
421 let expr = parse("SELECT DISTINCT a FROM t");
422 let result = pushdown_projections(expr, None, true);
423 let sql = gen(&result);
424 assert!(sql.contains("DISTINCT"));
425 }
426
427 #[test]
428 fn test_pushdown_with_star() {
429 let expr = parse("SELECT * FROM t");
430 let result = pushdown_projections(expr, None, true);
431 let sql = gen(&result);
432 assert!(sql.contains("*"));
433 }
434
435 #[test]
436 fn test_pushdown_subquery() {
437 let expr = parse("SELECT y.a FROM (SELECT a, b FROM x) AS y");
438 let result = pushdown_projections(expr, None, true);
439 let sql = gen(&result);
440 assert!(sql.contains("SELECT"));
441 }
442
443 #[test]
444 fn test_pushdown_union() {
445 let expr = parse("SELECT a FROM t UNION SELECT a FROM s");
446 let result = pushdown_projections(expr, None, true);
447 let sql = gen(&result);
448 assert!(sql.contains("UNION"));
449 }
450}