use std::sync::Arc;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{plan_datafusion_err, Result};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::{
reverse_order_bys, EquivalenceProperties, PhysicalSortRequirement,
};
use datafusion_physical_expr::{LexOrdering, LexRequirement};
use datafusion_physical_plan::aggregates::concat_slices;
use datafusion_physical_plan::windows::get_ordered_partition_by_indices;
use datafusion_physical_plan::{
aggregates::AggregateExec, ExecutionPlan, ExecutionPlanProperties,
};
use crate::PhysicalOptimizerRule;
#[derive(Default, Debug)]
pub struct OptimizeAggregateOrder {}
impl OptimizeAggregateOrder {
#[allow(missing_docs)]
pub fn new() -> Self {
Self::default()
}
}
impl PhysicalOptimizerRule for OptimizeAggregateOrder {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_up(|plan| {
if let Some(aggr_exec) = plan.as_any().downcast_ref::<AggregateExec>() {
if !aggr_exec.mode().is_first_stage() {
return Ok(Transformed::no(plan));
}
let input = aggr_exec.input();
let mut aggr_expr = aggr_exec.aggr_expr().to_vec();
let groupby_exprs = aggr_exec.group_expr().input_exprs();
let indices = get_ordered_partition_by_indices(&groupby_exprs, input);
let requirement = indices
.iter()
.map(|&idx| {
PhysicalSortRequirement::new(
Arc::<dyn datafusion_physical_plan::PhysicalExpr>::clone(
&groupby_exprs[idx],
),
None,
)
})
.collect::<Vec<_>>();
aggr_expr = try_convert_aggregate_if_better(
aggr_expr,
&requirement,
input.equivalence_properties(),
)?;
let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_expr);
Ok(Transformed::yes(Arc::new(aggr_exec) as _))
} else {
Ok(Transformed::no(plan))
}
})
.data()
}
fn name(&self) -> &str {
"OptimizeAggregateOrder"
}
fn schema_check(&self) -> bool {
true
}
}
fn try_convert_aggregate_if_better(
aggr_exprs: Vec<Arc<AggregateFunctionExpr>>,
prefix_requirement: &[PhysicalSortRequirement],
eq_properties: &EquivalenceProperties,
) -> Result<Vec<Arc<AggregateFunctionExpr>>> {
aggr_exprs
.into_iter()
.map(|aggr_expr| {
let aggr_sort_exprs = aggr_expr.order_bys().unwrap_or(LexOrdering::empty());
let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs);
let aggr_sort_reqs = LexRequirement::from(aggr_sort_exprs.clone());
let reverse_aggr_req = LexRequirement::from(reverse_aggr_sort_exprs);
if aggr_expr.order_sensitivity().is_beneficial() && !aggr_sort_reqs.is_empty()
{
let reqs = LexRequirement {
inner: concat_slices(prefix_requirement, &aggr_sort_reqs),
};
let prefix_requirement = LexRequirement {
inner: prefix_requirement.to_vec(),
};
if eq_properties.ordering_satisfy_requirement(&reqs) {
aggr_expr.with_beneficial_ordering(true)?.map(Arc::new)
} else if eq_properties.ordering_satisfy_requirement(&LexRequirement {
inner: concat_slices(&prefix_requirement, &reverse_aggr_req),
}) {
aggr_expr
.reverse_expr()
.map(Arc::new)
.unwrap_or(aggr_expr)
.with_beneficial_ordering(true)?
.map(Arc::new)
} else {
aggr_expr.with_beneficial_ordering(false)?.map(Arc::new)
}
.ok_or_else(|| {
plan_datafusion_err!(
"Expects an aggregate expression that can benefit from input ordering"
)
})
} else {
Ok(aggr_expr)
}
})
.collect()
}