use datafusion_common::config::ConfigOptions;
use datafusion_common::scalar::ScalarValue;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_physical_plan::aggregates::AggregateExec;
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
use datafusion_physical_plan::{expressions, ExecutionPlan};
use std::sync::Arc;
use crate::PhysicalOptimizerRule;
#[derive(Default, Debug)]
pub struct AggregateStatistics {}
impl AggregateStatistics {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl PhysicalOptimizerRule for AggregateStatistics {
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(partial_agg_exec) = take_optimizable(&*plan) {
let partial_agg_exec = partial_agg_exec
.as_any()
.downcast_ref::<AggregateExec>()
.expect("take_optimizable() ensures that this is a AggregateExec");
let stats = partial_agg_exec.input().statistics()?;
let mut projections = vec![];
for expr in partial_agg_exec.aggr_expr() {
let field = expr.field();
let args = expr.expressions();
let statistics_args = StatisticsArgs {
statistics: &stats,
return_type: field.data_type(),
is_distinct: expr.is_distinct(),
exprs: args.as_slice(),
};
if let Some((optimizable_statistic, name)) =
take_optimizable_value_from_statistics(&statistics_args, expr)
{
projections
.push((expressions::lit(optimizable_statistic), name.to_owned()));
} else {
break;
}
}
if projections.len() == partial_agg_exec.aggr_expr().len() {
Ok(Arc::new(ProjectionExec::try_new(
projections,
Arc::new(PlaceholderRowExec::new(plan.schema())),
)?))
} else {
plan.map_children(|child| {
self.optimize(child, _config).map(Transformed::yes)
})
.data()
}
} else {
plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes))
.data()
}
}
fn name(&self) -> &str {
"aggregate_statistics"
}
fn schema_check(&self) -> bool {
false
}
}
fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> {
if let Some(final_agg_exec) = node.as_any().downcast_ref::<AggregateExec>() {
if !final_agg_exec.mode().is_first_stage()
&& final_agg_exec.group_expr().is_empty()
{
let mut child = Arc::clone(final_agg_exec.input());
loop {
if let Some(partial_agg_exec) =
child.as_any().downcast_ref::<AggregateExec>()
{
if partial_agg_exec.mode().is_first_stage()
&& partial_agg_exec.group_expr().is_empty()
&& partial_agg_exec.filter_expr().iter().all(|e| e.is_none())
{
return Some(child);
}
}
if let [childrens_child] = child.children().as_slice() {
child = Arc::clone(childrens_child);
} else {
break;
}
}
}
}
None
}
fn take_optimizable_value_from_statistics(
statistics_args: &StatisticsArgs,
agg_expr: &AggregateFunctionExpr,
) -> Option<(ScalarValue, String)> {
let value = agg_expr.fun().value_from_stats(statistics_args);
value.map(|val| (val, agg_expr.name().to_string()))
}