datafusion_physical_optimizer/
aggregate_statistics.rs1use datafusion_common::config::ConfigOptions;
20use datafusion_common::scalar::ScalarValue;
21use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
22use datafusion_common::Result;
23use datafusion_physical_plan::aggregates::AggregateExec;
24use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
25use datafusion_physical_plan::projection::ProjectionExec;
26use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
27use datafusion_physical_plan::{expressions, ExecutionPlan};
28use std::sync::Arc;
29
30use crate::PhysicalOptimizerRule;
31
32#[derive(Default, Debug)]
34pub struct AggregateStatistics {}
35
36impl AggregateStatistics {
37 #[allow(missing_docs)]
38 pub fn new() -> Self {
39 Self {}
40 }
41}
42
43impl PhysicalOptimizerRule for AggregateStatistics {
44 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
45 fn optimize(
46 &self,
47 plan: Arc<dyn ExecutionPlan>,
48 _config: &ConfigOptions,
49 ) -> Result<Arc<dyn ExecutionPlan>> {
50 if let Some(partial_agg_exec) = take_optimizable(&*plan) {
51 let partial_agg_exec = partial_agg_exec
52 .as_any()
53 .downcast_ref::<AggregateExec>()
54 .expect("take_optimizable() ensures that this is a AggregateExec");
55 let stats = partial_agg_exec.input().statistics()?;
56 let mut projections = vec![];
57 for expr in partial_agg_exec.aggr_expr() {
58 let field = expr.field();
59 let args = expr.expressions();
60 let statistics_args = StatisticsArgs {
61 statistics: &stats,
62 return_type: field.data_type(),
63 is_distinct: expr.is_distinct(),
64 exprs: args.as_slice(),
65 };
66 if let Some((optimizable_statistic, name)) =
67 take_optimizable_value_from_statistics(&statistics_args, expr)
68 {
69 projections
70 .push((expressions::lit(optimizable_statistic), name.to_owned()));
71 } else {
72 break;
74 }
75 }
76
77 if projections.len() == partial_agg_exec.aggr_expr().len() {
79 Ok(Arc::new(ProjectionExec::try_new(
81 projections,
82 Arc::new(PlaceholderRowExec::new(plan.schema())),
83 )?))
84 } else {
85 plan.map_children(|child| {
86 self.optimize(child, _config).map(Transformed::yes)
87 })
88 .data()
89 }
90 } else {
91 plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes))
92 .data()
93 }
94 }
95
96 fn name(&self) -> &str {
97 "aggregate_statistics"
98 }
99
100 fn schema_check(&self) -> bool {
102 false
103 }
104}
105
106fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> {
114 if let Some(final_agg_exec) = node.as_any().downcast_ref::<AggregateExec>() {
115 if !final_agg_exec.mode().is_first_stage()
116 && final_agg_exec.group_expr().is_empty()
117 {
118 let mut child = Arc::clone(final_agg_exec.input());
119 loop {
120 if let Some(partial_agg_exec) =
121 child.as_any().downcast_ref::<AggregateExec>()
122 {
123 if partial_agg_exec.mode().is_first_stage()
124 && partial_agg_exec.group_expr().is_empty()
125 && partial_agg_exec.filter_expr().iter().all(|e| e.is_none())
126 {
127 return Some(child);
128 }
129 }
130 if let [childrens_child] = child.children().as_slice() {
131 child = Arc::clone(childrens_child);
132 } else {
133 break;
134 }
135 }
136 }
137 }
138 None
139}
140
141fn take_optimizable_value_from_statistics(
143 statistics_args: &StatisticsArgs,
144 agg_expr: &AggregateFunctionExpr,
145) -> Option<(ScalarValue, String)> {
146 let value = agg_expr.fun().value_from_stats(statistics_args);
147 value.map(|val| (val, agg_expr.name().to_string()))
148}
149
150