use std::sync::Arc;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{plan_datafusion_err, Result};
use datafusion_physical_expr::{
reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalSortRequirement,
};
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_plan::aggregates::concat_slices;
use datafusion_physical_plan::windows::get_ordered_partition_by_indices;
use datafusion_physical_plan::{
aggregates::AggregateExec, ExecutionPlan, ExecutionPlanProperties,
};
#[derive(Default)]
pub struct OptimizeAggregateOrder {}
impl OptimizeAggregateOrder {
#[allow(missing_docs)]
pub fn new() -> Self {
Self::default()
}
}
impl PhysicalOptimizerRule for OptimizeAggregateOrder {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_up(|plan| {
if let Some(aggr_exec) = plan.as_any().downcast_ref::<AggregateExec>() {
if !aggr_exec.mode().is_first_stage() {
return Ok(Transformed::no(plan));
}
let input = aggr_exec.input();
let mut aggr_expr = aggr_exec.aggr_expr().to_vec();
let groupby_exprs = aggr_exec.group_expr().input_exprs();
let indices = get_ordered_partition_by_indices(&groupby_exprs, input);
let requirement = indices
.iter()
.map(|&idx| {
PhysicalSortRequirement::new(groupby_exprs[idx].clone(), None)
})
.collect::<Vec<_>>();
aggr_expr = try_convert_aggregate_if_better(
aggr_expr,
&requirement,
input.equivalence_properties(),
)?;
let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_expr);
Ok(Transformed::yes(Arc::new(aggr_exec) as _))
} else {
Ok(Transformed::no(plan))
}
})
.data()
}
fn name(&self) -> &str {
"OptimizeAggregateOrder"
}
fn schema_check(&self) -> bool {
true
}
}
fn try_convert_aggregate_if_better(
aggr_exprs: Vec<Arc<dyn AggregateExpr>>,
prefix_requirement: &[PhysicalSortRequirement],
eq_properties: &EquivalenceProperties,
) -> Result<Vec<Arc<dyn AggregateExpr>>> {
aggr_exprs
.into_iter()
.map(|aggr_expr| {
let aggr_sort_exprs = aggr_expr.order_bys().unwrap_or(&[]);
let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs);
let aggr_sort_reqs =
PhysicalSortRequirement::from_sort_exprs(aggr_sort_exprs);
let reverse_aggr_req =
PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_sort_exprs);
if aggr_expr.order_sensitivity().is_beneficial() && !aggr_sort_reqs.is_empty()
{
let reqs = concat_slices(prefix_requirement, &aggr_sort_reqs);
if eq_properties.ordering_satisfy_requirement(&reqs) {
aggr_expr.with_beneficial_ordering(true)?
} else if eq_properties.ordering_satisfy_requirement(&concat_slices(
prefix_requirement,
&reverse_aggr_req,
)) {
aggr_expr
.reverse_expr()
.unwrap_or(aggr_expr)
.with_beneficial_ordering(true)?
} else {
aggr_expr.with_beneficial_ordering(false)?
}
.ok_or_else(|| {
plan_datafusion_err!(
"Expects an aggregate expression that can benefit from input ordering"
)
})
} else {
Ok(aggr_expr)
}
})
.collect()
}