datafusion_physical_optimizer/
aggregate_statistics.rs1use datafusion_common::Result;
20use datafusion_common::config::ConfigOptions;
21use datafusion_common::scalar::ScalarValue;
22use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
23use datafusion_physical_plan::aggregates::AggregateExec;
24use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
25use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr};
26use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
27use datafusion_physical_plan::{ExecutionPlan, expressions};
28use std::sync::Arc;
29
30use crate::PhysicalOptimizerRule;
31
32#[derive(Default, Debug)]
34pub struct AggregateStatistics {}
35
36impl AggregateStatistics {
37 #[expect(missing_docs)]
38 pub fn new() -> Self {
39 Self {}
40 }
41}
42
43impl PhysicalOptimizerRule for AggregateStatistics {
44 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
45 #[expect(clippy::allow_attributes)] #[allow(clippy::only_used_in_recursion)] fn optimize(
48 &self,
49 plan: Arc<dyn ExecutionPlan>,
50 config: &ConfigOptions,
51 ) -> Result<Arc<dyn ExecutionPlan>> {
52 if let Some(partial_agg_exec) = take_optimizable(&*plan) {
53 let partial_agg_exec = partial_agg_exec
54 .as_any()
55 .downcast_ref::<AggregateExec>()
56 .expect("take_optimizable() ensures that this is a AggregateExec");
57 let stats = partial_agg_exec.input().partition_statistics(None)?;
58 let mut projections = vec![];
59 for expr in partial_agg_exec.aggr_expr() {
60 let field = expr.field();
61 let args = expr.expressions();
62 let statistics_args = StatisticsArgs {
63 statistics: &stats,
64 return_type: field.data_type(),
65 is_distinct: expr.is_distinct(),
66 exprs: args.as_slice(),
67 };
68 if let Some((optimizable_statistic, name)) =
69 take_optimizable_value_from_statistics(&statistics_args, expr)
70 {
71 projections.push(ProjectionExpr {
72 expr: expressions::lit(optimizable_statistic),
73 alias: name.to_owned(),
74 });
75 } else {
76 break;
78 }
79 }
80
81 if projections.len() == partial_agg_exec.aggr_expr().len() {
83 Ok(Arc::new(ProjectionExec::try_new(
85 projections,
86 Arc::new(PlaceholderRowExec::new(plan.schema())),
87 )?))
88 } else {
89 plan.map_children(|child| {
90 self.optimize(child, config).map(Transformed::yes)
91 })
92 .data()
93 }
94 } else {
95 plan.map_children(|child| self.optimize(child, config).map(Transformed::yes))
96 .data()
97 }
98 }
99
100 fn name(&self) -> &str {
101 "aggregate_statistics"
102 }
103
104 fn schema_check(&self) -> bool {
106 false
107 }
108}
109
110fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> {
118 if let Some(final_agg_exec) = node.as_any().downcast_ref::<AggregateExec>()
119 && !final_agg_exec.mode().is_first_stage()
120 && final_agg_exec.group_expr().is_empty()
121 {
122 let mut child = Arc::clone(final_agg_exec.input());
123 loop {
124 if let Some(partial_agg_exec) = child.as_any().downcast_ref::<AggregateExec>()
125 && partial_agg_exec.mode().is_first_stage()
126 && partial_agg_exec.group_expr().is_empty()
127 && partial_agg_exec.filter_expr().iter().all(|e| e.is_none())
128 {
129 return Some(child);
130 }
131 if let [childrens_child] = child.children().as_slice() {
132 child = Arc::clone(childrens_child);
133 } else {
134 break;
135 }
136 }
137 }
138 None
139}
140
141fn take_optimizable_value_from_statistics(
143 statistics_args: &StatisticsArgs,
144 agg_expr: &AggregateFunctionExpr,
145) -> Option<(ScalarValue, String)> {
146 let value = agg_expr.fun().value_from_stats(statistics_args);
147 value.map(|val| (val, agg_expr.name().to_string()))
148}
149
150