use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan};
pub(crate) fn find_agg_node_within_select(
plan: &LogicalPlan,
already_projected: bool,
) -> Option<&Aggregate> {
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
if let LogicalPlan::Aggregate(agg) = input {
Some(agg)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
None
} else {
find_agg_node_within_select(input, true)
}
} else {
find_agg_node_within_select(input, already_projected)
}
}
pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Ok(n) = agg.schema.index_of_column(&c) {
let unprojected_expr = agg
.group_expr
.iter()
.chain(agg.aggr_expr.iter())
.nth(n)
.unwrap();
Ok(Transformed::yes(unprojected_expr.clone()))
} else {
internal_err!(
"Tried to unproject agg expr not found in provided Aggregate!"
)
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}