use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp};
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Column, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr_rewriter::normalize_cols;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{col, LogicalPlanBuilder};
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
#[derive(Default)]
pub struct ReplaceDistinctWithAggregate {}
impl ReplaceDistinctWithAggregate {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl OptimizerRule for ReplaceDistinctWithAggregate {
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Distinct(Distinct::All(input)) => {
let group_expr = expand_wildcard(input.schema(), &input, None)?;
let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new(
input,
group_expr,
vec![],
)?);
Ok(Transformed::yes(aggr_plan))
}
LogicalPlan::Distinct(Distinct::On(DistinctOn {
select_expr,
on_expr,
sort_expr,
input,
schema,
})) => {
let expr_cnt = on_expr.len();
let first_value_udaf =
config.function_registry().unwrap().udaf("first_value")?;
let aggr_expr = select_expr.into_iter().map(|e| {
Expr::AggregateFunction(AggregateFunction::new_udf(
first_value_udaf.clone(),
vec![e],
false,
None,
sort_expr.clone(),
None,
))
});
let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
let group_expr = normalize_cols(on_expr, input.as_ref())?;
let plan = LogicalPlan::Aggregate(Aggregate::try_new(
input, group_expr, aggr_expr,
)?);
let lpb = LogicalPlanBuilder::from(plan);
let plan = if let Some(mut sort_expr) = sort_expr {
sort_expr.truncate(expr_cnt);
lpb.sort(sort_expr)?.build()?
} else {
lpb.build()?
};
let project_exprs = plan
.schema()
.iter()
.skip(expr_cnt)
.zip(schema.iter())
.map(|((new_qualifier, new_field), (old_qualifier, old_field))| {
col(Column::from((new_qualifier, new_field)))
.alias_qualified(old_qualifier.cloned(), old_field.name())
})
.collect::<Vec<Expr>>();
let plan = LogicalPlanBuilder::from(plan)
.project(project_exprs)?
.build()?;
Ok(Transformed::yes(plan))
}
_ => Ok(Transformed::no(plan)),
}
}
fn try_optimize(
&self,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
internal_err!("Should have called ReplaceDistinctWithAggregate::rewrite")
}
fn name(&self) -> &str {
"replace_distinct_aggregate"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(BottomUp)
}
}