use std::collections::{HashMap, HashSet};
use crate::dialects::DialectType;
use crate::expressions::{AggregateFunction, Alias, Expression, Identifier, Literal};
use crate::scope::{build_scope, traverse_scope, Scope};
const SELECT_ALL: &str = "__SELECT_ALL__";
pub fn pushdown_projections(
expression: Expression,
_dialect: Option<DialectType>,
remove_unused_selections: bool,
) -> Expression {
let _root = build_scope(&expression);
let mut referenced_columns: HashMap<u64, HashSet<String>> = HashMap::new();
let source_column_alias_count: HashMap<u64, usize> = HashMap::new();
let scopes = traverse_scope(&expression);
for scope in scopes.iter().rev() {
let scope_id = scope as *const Scope as u64;
let parent_selections = referenced_columns
.get(&scope_id)
.cloned()
.unwrap_or_else(|| {
let mut set = HashSet::new();
set.insert(SELECT_ALL.to_string());
set
});
let alias_count = source_column_alias_count
.get(&scope_id)
.copied()
.unwrap_or(0);
let has_distinct = if let Expression::Select(ref select) = scope.expression {
select.distinct || select.distinct_on.is_some()
} else {
false
};
let parent_selections = if has_distinct {
let mut set = HashSet::new();
set.insert(SELECT_ALL.to_string());
set
} else {
parent_selections
};
process_set_operations(&scope, &parent_selections, &mut referenced_columns);
if let Expression::Select(ref select) = scope.expression {
if remove_unused_selections {
let _selections_to_keep =
get_selections_to_keep(select, &parent_selections, alias_count);
}
let is_star = select
.expressions
.iter()
.any(|e| matches!(e, Expression::Star(_)));
if is_star {
continue;
}
let mut selects: HashMap<String, HashSet<String>> = HashMap::new();
for col_expr in &select.expressions {
collect_column_refs(col_expr, &mut selects);
}
for source_name in scope.sources.keys() {
let columns = selects.get(source_name).cloned().unwrap_or_default();
for child_scope in collect_child_scopes(&scope) {
let child_id = child_scope as *const Scope as u64;
referenced_columns
.entry(child_id)
.or_insert_with(HashSet::new)
.extend(columns.clone());
}
}
}
}
expression
}
fn process_set_operations(
scope: &Scope,
parent_selections: &HashSet<String>,
referenced_columns: &mut HashMap<u64, HashSet<String>>,
) {
match &scope.expression {
Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_) => {
for child_scope in &scope.union_scopes {
let child_id = child_scope as *const Scope as u64;
referenced_columns
.entry(child_id)
.or_insert_with(HashSet::new)
.extend(parent_selections.clone());
}
}
_ => {}
}
}
fn get_selections_to_keep(
select: &crate::expressions::Select,
parent_selections: &HashSet<String>,
mut alias_count: usize,
) -> Vec<usize> {
let mut keep_indices = Vec::new();
let select_all = parent_selections.contains(SELECT_ALL);
let order_refs: HashSet<String> = select
.order_by
.as_ref()
.map(|o| get_order_by_column_refs(&o.expressions))
.unwrap_or_default();
for (i, selection) in select.expressions.iter().enumerate() {
let name = get_alias_or_name(selection);
if select_all
|| parent_selections.contains(&name)
|| order_refs.contains(&name)
|| alias_count > 0
{
keep_indices.push(i);
if alias_count > 0 {
alias_count -= 1;
}
}
}
if keep_indices.is_empty() {
keep_indices.push(0);
}
keep_indices
}
fn get_order_by_column_refs(ordered_exprs: &[crate::expressions::Ordered]) -> HashSet<String> {
let mut refs = HashSet::new();
for ordered in ordered_exprs {
collect_unqualified_column_names(&ordered.this, &mut refs);
}
refs
}
fn collect_unqualified_column_names(expr: &Expression, names: &mut HashSet<String>) {
match expr {
Expression::Column(col) => {
if col.table.is_none() {
names.insert(col.name.name.clone());
}
}
Expression::And(bin) | Expression::Or(bin) => {
collect_unqualified_column_names(&bin.left, names);
collect_unqualified_column_names(&bin.right, names);
}
Expression::Function(func) => {
for arg in &func.args {
collect_unqualified_column_names(arg, names);
}
}
Expression::AggregateFunction(agg) => {
for arg in &agg.args {
collect_unqualified_column_names(arg, names);
}
}
Expression::Paren(p) => {
collect_unqualified_column_names(&p.this, names);
}
_ => {}
}
}
fn get_alias_or_name(expr: &Expression) -> String {
match expr {
Expression::Alias(alias) => alias.alias.name.clone(),
Expression::Column(col) => col.name.name.clone(),
_ => String::new(),
}
}
fn collect_column_refs(expr: &Expression, selects: &mut HashMap<String, HashSet<String>>) {
match expr {
Expression::Column(col) => {
if let Some(ref table) = col.table {
selects
.entry(table.name.clone())
.or_insert_with(HashSet::new)
.insert(col.name.name.clone());
}
}
Expression::Alias(alias) => {
collect_column_refs(&alias.this, selects);
}
Expression::Function(func) => {
for arg in &func.args {
collect_column_refs(arg, selects);
}
}
Expression::AggregateFunction(agg) => {
for arg in &agg.args {
collect_column_refs(arg, selects);
}
}
Expression::And(bin) | Expression::Or(bin) => {
collect_column_refs(&bin.left, selects);
collect_column_refs(&bin.right, selects);
}
Expression::Eq(bin)
| Expression::Neq(bin)
| Expression::Lt(bin)
| Expression::Lte(bin)
| Expression::Gt(bin)
| Expression::Gte(bin)
| Expression::Add(bin)
| Expression::Sub(bin)
| Expression::Mul(bin)
| Expression::Div(bin) => {
collect_column_refs(&bin.left, selects);
collect_column_refs(&bin.right, selects);
}
Expression::Paren(p) => {
collect_column_refs(&p.this, selects);
}
Expression::Case(case) => {
if let Some(ref operand) = case.operand {
collect_column_refs(operand, selects);
}
for (when, then) in &case.whens {
collect_column_refs(when, selects);
collect_column_refs(then, selects);
}
if let Some(ref else_) = case.else_ {
collect_column_refs(else_, selects);
}
}
_ => {}
}
}
fn collect_child_scopes(scope: &Scope) -> Vec<&Scope> {
let mut children = Vec::new();
children.extend(&scope.subquery_scopes);
children.extend(&scope.derived_table_scopes);
children.extend(&scope.cte_scopes);
children.extend(&scope.union_scopes);
children
}
pub fn default_selection(is_agg: bool) -> Expression {
if is_agg {
Expression::Alias(Box::new(Alias {
this: Expression::AggregateFunction(Box::new(AggregateFunction {
name: "MAX".to_string(),
args: vec![Expression::Literal(Box::new(Literal::Number(
"1".to_string(),
)))],
distinct: false,
filter: None,
order_by: Vec::new(),
limit: None,
ignore_nulls: None,
inferred_type: None,
})),
alias: Identifier {
name: "_".to_string(),
quoted: false,
trailing_comments: vec![],
span: None,
},
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
} else {
Expression::Alias(Box::new(Alias {
this: Expression::Literal(Box::new(Literal::Number("1".to_string()))),
alias: Identifier {
name: "_".to_string(),
quoted: false,
trailing_comments: vec![],
span: None,
},
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::generator::Generator;
use crate::parser::Parser;
fn gen(expr: &Expression) -> String {
Generator::new().generate(expr).unwrap()
}
fn parse(sql: &str) -> Expression {
Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
}
#[test]
fn test_pushdown_simple() {
let expr = parse("SELECT a FROM t");
let result = pushdown_projections(expr, None, true);
let sql = gen(&result);
assert!(sql.contains("SELECT"));
}
#[test]
fn test_pushdown_preserves_structure() {
let expr = parse("SELECT y.a FROM (SELECT x.a, x.b FROM x) AS y");
let result = pushdown_projections(expr, None, true);
let sql = gen(&result);
assert!(sql.contains("SELECT"));
}
#[test]
fn test_get_alias_or_name_alias() {
let expr = parse("SELECT a AS col_a FROM t");
if let Expression::Select(select) = &expr {
if let Some(first) = select.expressions.first() {
let name = get_alias_or_name(first);
assert_eq!(name, "col_a");
}
}
}
#[test]
fn test_get_alias_or_name_column() {
let expr = parse("SELECT a FROM t");
if let Expression::Select(select) = &expr {
if let Some(first) = select.expressions.first() {
let name = get_alias_or_name(first);
assert_eq!(name, "a");
}
}
}
#[test]
fn test_collect_column_refs() {
let expr = parse("SELECT t.a, t.b, s.c FROM t, s");
if let Expression::Select(select) = &expr {
let mut refs: HashMap<String, HashSet<String>> = HashMap::new();
for sel in &select.expressions {
collect_column_refs(sel, &mut refs);
}
assert!(refs.contains_key("t"));
assert!(refs.contains_key("s"));
assert!(refs.get("t").unwrap().contains("a"));
assert!(refs.get("t").unwrap().contains("b"));
assert!(refs.get("s").unwrap().contains("c"));
}
}
#[test]
fn test_default_selection_non_agg() {
let sel = default_selection(false);
let sql = gen(&sel);
assert!(sql.contains("1"));
assert!(sql.contains("AS"));
}
#[test]
fn test_default_selection_agg() {
let sel = default_selection(true);
let sql = gen(&sel);
assert!(sql.contains("MAX"));
assert!(sql.contains("AS"));
}
#[test]
fn test_pushdown_with_distinct() {
let expr = parse("SELECT DISTINCT a FROM t");
let result = pushdown_projections(expr, None, true);
let sql = gen(&result);
assert!(sql.contains("DISTINCT"));
}
#[test]
fn test_pushdown_with_star() {
let expr = parse("SELECT * FROM t");
let result = pushdown_projections(expr, None, true);
let sql = gen(&result);
assert!(sql.contains("*"));
}
#[test]
fn test_pushdown_subquery() {
let expr = parse("SELECT y.a FROM (SELECT a, b FROM x) AS y");
let result = pushdown_projections(expr, None, true);
let sql = gen(&result);
assert!(sql.contains("SELECT"));
}
#[test]
fn test_pushdown_union() {
let expr = parse("SELECT a FROM t UNION SELECT a FROM s");
let result = pushdown_projections(expr, None, true);
let sql = gen(&result);
assert!(sql.contains("UNION"));
}
}