use datafusion_common::Result;
use datafusion_common::{
config::ConfigOptions,
tree_node::{Transformed, TransformedResult, TreeNode},
};
use datafusion_physical_expr::expressions::{FirstValue, LastValue};
use datafusion_physical_expr::{
equivalence::ProjectionMapping, reverse_order_bys, AggregateExpr,
EquivalenceProperties, PhysicalSortRequirement,
};
use datafusion_physical_plan::aggregates::concat_slices;
use datafusion_physical_plan::{
aggregates::{AggregateExec, AggregateMode},
ExecutionPlan, ExecutionPlanProperties, InputOrderMode,
};
use std::sync::Arc;
use datafusion_physical_plan::windows::get_ordered_partition_by_indices;
use super::PhysicalOptimizerRule;
#[derive(Default)]
pub struct OptimizeAggregateOrder {}
impl OptimizeAggregateOrder {
pub fn new() -> Self {
Self::default()
}
}
impl PhysicalOptimizerRule for OptimizeAggregateOrder {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_up(get_common_requirement_of_aggregate_input)
.data()
}
fn name(&self) -> &str {
"OptimizeAggregateOrder"
}
fn schema_check(&self) -> bool {
true
}
}
fn get_common_requirement_of_aggregate_input(
plan: Arc<dyn ExecutionPlan>,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
if let Some(aggr_exec) = plan.as_any().downcast_ref::<AggregateExec>() {
let input = aggr_exec.input();
let mut aggr_expr = try_get_updated_aggr_expr_from_child(aggr_exec);
let group_by = aggr_exec.group_expr();
let mode = aggr_exec.mode();
let input_eq_properties = input.equivalence_properties();
let groupby_exprs = group_by.input_exprs();
let indices = get_ordered_partition_by_indices(&groupby_exprs, input);
let requirement = indices
.iter()
.map(|&idx| PhysicalSortRequirement {
expr: groupby_exprs[idx].clone(),
options: None,
})
.collect::<Vec<_>>();
try_convert_first_last_if_better(
&requirement,
&mut aggr_expr,
input_eq_properties,
)?;
let required_input_ordering = (!requirement.is_empty()).then_some(requirement);
let input_order_mode =
if indices.len() == groupby_exprs.len() && !indices.is_empty() {
InputOrderMode::Sorted
} else if !indices.is_empty() {
InputOrderMode::PartiallySorted(indices)
} else {
InputOrderMode::Linear
};
let projection_mapping =
ProjectionMapping::try_new(group_by.expr(), &input.schema())?;
let cache = AggregateExec::compute_properties(
input,
plan.schema().clone(),
&projection_mapping,
mode,
&input_order_mode,
);
let aggr_exec = aggr_exec.new_with_aggr_expr_and_ordering_info(
required_input_ordering,
aggr_expr,
cache,
input_order_mode,
);
Ok(Transformed::yes(
Arc::new(aggr_exec) as Arc<dyn ExecutionPlan>
))
} else {
Ok(Transformed::no(plan))
}
}
fn try_get_updated_aggr_expr_from_child(
aggr_exec: &AggregateExec,
) -> Vec<Arc<dyn AggregateExpr>> {
let input = aggr_exec.input();
if aggr_exec.mode() == &AggregateMode::Final
|| aggr_exec.mode() == &AggregateMode::FinalPartitioned
{
if let Some(c_aggr_exec) = input.as_any().downcast_ref::<AggregateExec>() {
if c_aggr_exec.mode() == &AggregateMode::Partial {
return c_aggr_exec.aggr_expr().to_vec();
}
}
}
aggr_exec.aggr_expr().to_vec()
}
fn try_convert_first_last_if_better(
prefix_requirement: &[PhysicalSortRequirement],
aggr_exprs: &mut [Arc<dyn AggregateExpr>],
eq_properties: &EquivalenceProperties,
) -> Result<()> {
for aggr_expr in aggr_exprs.iter_mut() {
let aggr_req = aggr_expr.order_bys().unwrap_or(&[]);
let reverse_aggr_req = reverse_order_bys(aggr_req);
let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req);
let reverse_aggr_req =
PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req);
if let Some(first_value) = aggr_expr.as_any().downcast_ref::<FirstValue>() {
let mut first_value = first_value.clone();
if eq_properties.ordering_satisfy_requirement(&concat_slices(
prefix_requirement,
&aggr_req,
)) {
first_value = first_value.with_requirement_satisfied(true);
*aggr_expr = Arc::new(first_value) as _;
} else if eq_properties.ordering_satisfy_requirement(&concat_slices(
prefix_requirement,
&reverse_aggr_req,
)) {
let mut last_value = first_value.convert_to_last();
last_value = last_value.with_requirement_satisfied(true);
*aggr_expr = Arc::new(last_value) as _;
} else {
first_value = first_value.with_requirement_satisfied(false);
*aggr_expr = Arc::new(first_value) as _;
}
continue;
}
if let Some(last_value) = aggr_expr.as_any().downcast_ref::<LastValue>() {
let mut last_value = last_value.clone();
if eq_properties.ordering_satisfy_requirement(&concat_slices(
prefix_requirement,
&aggr_req,
)) {
last_value = last_value.with_requirement_satisfied(true);
*aggr_expr = Arc::new(last_value) as _;
} else if eq_properties.ordering_satisfy_requirement(&concat_slices(
prefix_requirement,
&reverse_aggr_req,
)) {
let mut first_value = last_value.convert_to_first();
first_value = first_value.with_requirement_satisfied(true);
*aggr_expr = Arc::new(first_value) as _;
} else {
last_value = last_value.with_requirement_satisfied(false);
*aggr_expr = Arc::new(last_value) as _;
}
continue;
}
}
Ok(())
}