use std::sync::Arc;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{Result, plan_datafusion_err};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement};
use datafusion_physical_plan::aggregates::{
AggregateExec, AggregateInputMode, concat_slices,
};
use datafusion_physical_plan::windows::get_ordered_partition_by_indices;
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use crate::PhysicalOptimizerRule;
#[derive(Default, Debug)]
pub struct OptimizeAggregateOrder {}
impl OptimizeAggregateOrder {
#[expect(missing_docs)]
pub fn new() -> Self {
Self::default()
}
}
impl PhysicalOptimizerRule for OptimizeAggregateOrder {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_up(|plan| {
if let Some(aggr_exec) = plan.as_any().downcast_ref::<AggregateExec>() {
if aggr_exec.mode().input_mode() == AggregateInputMode::Partial {
return Ok(Transformed::no(plan));
}
let input = aggr_exec.input();
let mut aggr_exprs = aggr_exec.aggr_expr().to_vec();
let groupby_exprs = aggr_exec.group_expr().input_exprs();
let indices = get_ordered_partition_by_indices(&groupby_exprs, input)?;
let requirement = indices
.iter()
.map(|&idx| {
PhysicalSortRequirement::new(
Arc::clone(&groupby_exprs[idx]),
None,
)
})
.collect::<Vec<_>>();
aggr_exprs = try_convert_aggregate_if_better(
aggr_exprs,
&requirement,
input.equivalence_properties(),
)?;
let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_exprs);
Ok(Transformed::yes(Arc::new(aggr_exec) as _))
} else {
Ok(Transformed::no(plan))
}
})
.data()
}
fn name(&self) -> &str {
"OptimizeAggregateOrder"
}
fn schema_check(&self) -> bool {
true
}
}
fn try_convert_aggregate_if_better(
aggr_exprs: Vec<Arc<AggregateFunctionExpr>>,
prefix_requirement: &[PhysicalSortRequirement],
eq_properties: &EquivalenceProperties,
) -> Result<Vec<Arc<AggregateFunctionExpr>>> {
aggr_exprs
.into_iter()
.map(|aggr_expr| {
let order_bys = aggr_expr.order_bys();
if !aggr_expr.order_sensitivity().is_beneficial() {
Ok(aggr_expr)
} else if !order_bys.is_empty() {
if eq_properties.ordering_satisfy_requirement(concat_slices(
prefix_requirement,
&order_bys
.iter()
.map(|e| e.clone().into())
.collect::<Vec<_>>(),
))? {
aggr_expr.with_beneficial_ordering(true)?.map(Arc::new)
} else if eq_properties.ordering_satisfy_requirement(concat_slices(
prefix_requirement,
&order_bys
.iter()
.map(|e| e.reverse().into())
.collect::<Vec<_>>(),
))? {
aggr_expr
.reverse_expr()
.map(Arc::new)
.unwrap_or(aggr_expr)
.with_beneficial_ordering(true)?
.map(Arc::new)
} else {
aggr_expr.with_beneficial_ordering(false)?.map(Arc::new)
}
.ok_or_else(|| {
plan_datafusion_err!(
"Expects an aggregate expression that can benefit from input ordering"
)
})
} else {
Ok(aggr_expr)
}
})
.collect()
}