use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};
pub(crate) enum AggVariant<'a> {
Aggregate(&'a Aggregate),
Window(Vec<&'a Window>),
}
pub(crate) fn find_agg_node_within_select<'a>(
plan: &'a LogicalPlan,
mut prev_windows: Option<AggVariant<'a>>,
already_projected: bool,
) -> Option<AggVariant<'a>> {
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
match input {
LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)),
LogicalPlan::Window(window) => {
prev_windows = match &mut prev_windows {
Some(AggVariant::Window(windows)) => {
windows.push(window);
prev_windows
}
_ => Some(AggVariant::Window(vec![window])),
};
find_agg_node_within_select(input, prev_windows, already_projected)
}
LogicalPlan::Projection(_) => {
if already_projected {
prev_windows
} else {
find_agg_node_within_select(input, prev_windows, true)
}
}
LogicalPlan::TableScan(_) => prev_windows,
_ => find_agg_node_within_select(input, prev_windows, already_projected),
}
}
pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Ok(n) = agg.schema.index_of_column(&c) {
let unprojected_expr = agg
.group_expr
.iter()
.chain(agg.aggr_expr.iter())
.nth(n)
.unwrap();
Ok(Transformed::yes(unprojected_expr.clone()))
} else {
internal_err!(
"Tried to unproject agg expr not found in provided Aggregate!"
)
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}
pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unproj) = windows
.iter()
.flat_map(|w| w.window_expr.iter())
.find(|window_expr| window_expr.schema_name().to_string() == c.name)
{
Ok(Transformed::yes(unproj.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}