use std::{collections::HashSet, sync::Arc};
use arrow::datatypes::Schema;
use datafusion_common::tree_node::TreeNodeContainer;
use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
Column, HashMap, Result, TableReference,
};
use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX};
use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr};
use sqlparser::ast::Ident;
pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result<LogicalPlan> {
let plan = plan.clone();
let transformed_plan = plan.transform_up(|plan| match plan {
LogicalPlan::Union(mut union) => {
let schema = Arc::unwrap_or_clone(union.schema);
let schema = schema.strip_qualifiers();
union.schema = Arc::new(schema);
Ok(Transformed::yes(LogicalPlan::Union(union)))
}
LogicalPlan::Sort(sort) => {
if !matches!(&*sort.input, LogicalPlan::Union(_)) {
return Ok(Transformed::no(LogicalPlan::Sort(sort)));
}
Ok(Transformed::yes(LogicalPlan::Sort(Sort {
expr: rewrite_sort_expr_for_union(sort.expr)?,
input: sort.input,
fetch: sort.fetch,
})))
}
_ => Ok(Transformed::no(plan)),
});
transformed_plan.data()
}
fn rewrite_sort_expr_for_union(exprs: Vec<SortExpr>) -> Result<Vec<SortExpr>> {
let sort_exprs = exprs
.map_elements(&mut |expr: Expr| {
expr.transform_up(|expr| {
if let Expr::Column(mut col) = expr {
col.relation = None;
Ok(Transformed::yes(Expr::Column(col)))
} else {
Ok(Transformed::no(expr))
}
})
})
.data()?;
Ok(sort_exprs)
}
pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
p: &Projection,
) -> Option<LogicalPlan> {
let LogicalPlan::Sort(sort) = p.input.as_ref() else {
return None;
};
let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else {
return None;
};
let mut map = HashMap::new();
let inner_exprs = inner_p
.expr
.iter()
.enumerate()
.map(|(i, f)| match f {
Expr::Alias(alias) => {
let a = Expr::Column(alias.name.clone().into());
map.insert(a.clone(), f.clone());
a
}
Expr::Column(_) => {
map.insert(
Expr::Column(inner_p.schema.field(i).name().into()),
f.clone(),
);
f.clone()
}
_ => {
let a = Expr::Column(inner_p.schema.field(i).name().into());
map.insert(a.clone(), f.clone());
a
}
})
.collect::<Vec<_>>();
let mut collects = p.expr.clone();
for sort in &sort.expr {
collects.push(sort.expr.clone());
}
let outer_collects = collects.iter().map(Expr::to_string).collect::<HashSet<_>>();
let inner_collects = inner_exprs
.iter()
.map(Expr::to_string)
.collect::<HashSet<_>>();
if outer_collects == inner_collects {
let mut sort = sort.clone();
let mut inner_p = inner_p.clone();
let new_exprs = p
.expr
.iter()
.map(|e| map.get(e).unwrap_or(e).clone())
.collect::<Vec<_>>();
inner_p.expr.clone_from(&new_exprs);
sort.input = Arc::new(LogicalPlan::Projection(inner_p));
Some(LogicalPlan::Sort(sort))
} else {
None
}
}
pub(super) fn subquery_alias_inner_query_and_columns(
subquery_alias: &datafusion_expr::SubqueryAlias,
) -> (&LogicalPlan, Vec<Ident>) {
let plan: &LogicalPlan = subquery_alias.input.as_ref();
if let LogicalPlan::Subquery(subquery) = plan {
let (inner_projection, Some(column)) =
find_unnest_column_alias(subquery.subquery.as_ref())
else {
return (plan, vec![]);
};
return (inner_projection, vec![Ident::new(column)]);
}
let LogicalPlan::Projection(outer_projections) = plan else {
return (plan, vec![]);
};
let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else {
return (plan, vec![]);
};
let mut columns: Vec<Ident> = vec![];
for (i, inner_expr) in inner_projection.expr.iter().enumerate() {
let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else {
return (plan, vec![]);
};
let inner_expr_string = match inner_expr {
Expr::Column(_) => inner_expr.to_string(),
_ => inner_projection.schema.field(i).name().clone(),
};
if outer_alias.expr.to_string() != inner_expr_string {
return (plan, vec![]);
};
columns.push(outer_alias.name.as_str().into());
}
(outer_projections.input.as_ref(), columns)
}
pub(super) fn find_unnest_column_alias(
plan: &LogicalPlan,
) -> (&LogicalPlan, Option<String>) {
if let LogicalPlan::Projection(projection) = plan {
if projection.expr.len() != 1 {
return (plan, None);
}
if let Some(Expr::Alias(alias)) = projection.expr.first() {
if alias
.expr
.schema_name()
.to_string()
.starts_with(&format!("{UNNEST_COLUMN_PREFIX}("))
{
return (projection.input.as_ref(), Some(alias.name.clone()));
}
}
}
(plan, None)
}
pub(super) fn inject_column_aliases_into_subquery(
plan: LogicalPlan,
aliases: Vec<Ident>,
) -> Result<LogicalPlan> {
match &plan {
LogicalPlan::Projection(inner_p) => Ok(inject_column_aliases(inner_p, aliases)),
_ => {
plan.map_children(|child| {
if let LogicalPlan::Projection(p) = &child {
Ok(Transformed::yes(inject_column_aliases(p, aliases.clone())))
} else {
Ok(Transformed::no(child))
}
})
.map(|plan| plan.data)
}
}
}
pub(super) fn inject_column_aliases(
projection: &Projection,
aliases: impl IntoIterator<Item = Ident>,
) -> LogicalPlan {
let mut updated_projection = projection.clone();
let new_exprs = updated_projection
.expr
.into_iter()
.zip(aliases)
.map(|(expr, col_alias)| {
let relation = match &expr {
Expr::Column(col) => col.relation.clone(),
_ => None,
};
Expr::Alias(Alias {
expr: Box::new(expr.clone()),
relation,
name: col_alias.value,
metadata: None,
})
})
.collect::<Vec<_>>();
updated_projection.expr = new_exprs;
LogicalPlan::Projection(updated_projection)
}
fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> {
match logical_plan {
LogicalPlan::Projection(p) => Some(p),
LogicalPlan::Limit(p) => find_projection(p.input.as_ref()),
LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()),
LogicalPlan::Sort(p) => find_projection(p.input.as_ref()),
_ => None,
}
}
pub struct TableAliasRewriter<'a> {
pub table_schema: &'a Schema,
pub alias_name: TableReference,
}
impl TreeNodeRewriter for TableAliasRewriter<'_> {
type Node = Expr;
fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
match expr {
Expr::Column(column) => {
if let Ok(field) = self.table_schema.field_with_name(&column.name) {
let new_column =
Column::new(Some(self.alias_name.clone()), field.name().clone());
Ok(Transformed::yes(Expr::Column(new_column)))
} else {
Ok(Transformed::no(Expr::Column(column)))
}
}
_ => Ok(Transformed::no(expr)),
}
}
}