datafusion_optimizer/
replace_distinct_aggregate.rs1use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp};
21use crate::{OptimizerConfig, OptimizerRule};
22use std::sync::Arc;
23
24use datafusion_common::tree_node::Transformed;
25use datafusion_common::{Column, Result};
26use datafusion_expr::expr_rewriter::normalize_cols;
27use datafusion_expr::utils::expand_wildcard;
28use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder};
29use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
30
31#[derive(Default, Debug)]
58pub struct ReplaceDistinctWithAggregate {}
59
60impl ReplaceDistinctWithAggregate {
61 #[allow(missing_docs)]
62 pub fn new() -> Self {
63 Self {}
64 }
65}
66
67impl OptimizerRule for ReplaceDistinctWithAggregate {
68 fn supports_rewrite(&self) -> bool {
69 true
70 }
71
72 fn rewrite(
73 &self,
74 plan: LogicalPlan,
75 config: &dyn OptimizerConfig,
76 ) -> Result<Transformed<LogicalPlan>> {
77 match plan {
78 LogicalPlan::Distinct(Distinct::All(input)) => {
79 let group_expr = expand_wildcard(input.schema(), &input, None)?;
80
81 let field_count = input.schema().fields().len();
82 for dep in input.schema().functional_dependencies().iter() {
83 if dep.source_indices.len() >= field_count
86 && dep.source_indices[..field_count]
87 .iter()
88 .enumerate()
89 .all(|(idx, f_idx)| idx == *f_idx)
90 {
91 return Ok(Transformed::yes(input.as_ref().clone()));
92 }
93 }
94
95 let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new(
97 input,
98 group_expr,
99 vec![],
100 )?);
101 Ok(Transformed::yes(aggr_plan))
102 }
103 LogicalPlan::Distinct(Distinct::On(DistinctOn {
104 select_expr,
105 on_expr,
106 sort_expr,
107 input,
108 schema,
109 })) => {
110 let expr_cnt = on_expr.len();
111
112 let first_value_udaf: Arc<datafusion_expr::AggregateUDF> =
114 config.function_registry().unwrap().udaf("first_value")?;
115 let aggr_expr = select_expr.into_iter().map(|e| {
116 if let Some(order_by) = &sort_expr {
117 first_value_udaf
118 .call(vec![e])
119 .order_by(order_by.clone())
120 .build()
121 .unwrap()
123 } else {
124 first_value_udaf.call(vec![e])
125 }
126 });
127
128 let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
129 let group_expr = normalize_cols(on_expr, input.as_ref())?;
130
131 let plan = LogicalPlan::Aggregate(Aggregate::try_new(
133 input, group_expr, aggr_expr,
134 )?);
135 let lpb = LogicalPlanBuilder::from(plan);
138
139 let plan = if let Some(mut sort_expr) = sort_expr {
140 sort_expr.truncate(expr_cnt);
146
147 lpb.sort(sort_expr)?.build()?
148 } else {
149 lpb.build()?
150 };
151
152 let project_exprs = plan
156 .schema()
157 .iter()
158 .skip(expr_cnt)
159 .zip(schema.iter())
160 .map(|((new_qualifier, new_field), (old_qualifier, old_field))| {
161 col(Column::from((new_qualifier, new_field)))
162 .alias_qualified(old_qualifier.cloned(), old_field.name())
163 })
164 .collect::<Vec<Expr>>();
165
166 let plan = LogicalPlanBuilder::from(plan)
167 .project(project_exprs)?
168 .build()?;
169
170 Ok(Transformed::yes(plan))
171 }
172 _ => Ok(Transformed::no(plan)),
173 }
174 }
175
176 fn name(&self) -> &str {
177 "replace_distinct_aggregate"
178 }
179
180 fn apply_order(&self) -> Option<ApplyOrder> {
181 Some(BottomUp)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use std::sync::Arc;
188
189 use crate::assert_optimized_plan_eq_snapshot;
190 use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
191 use crate::test::*;
192
193 use crate::OptimizerContext;
194 use datafusion_common::Result;
195 use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, Expr};
196 use datafusion_functions_aggregate::sum::sum;
197
198 macro_rules! assert_optimized_plan_equal {
199 (
200 $plan:expr,
201 @ $expected:literal $(,)?
202 ) => {{
203 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
204 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(ReplaceDistinctWithAggregate::new())];
205 assert_optimized_plan_eq_snapshot!(
206 optimizer_ctx,
207 rules,
208 $plan,
209 @ $expected,
210 )
211 }};
212 }
213
214 #[test]
215 fn eliminate_redundant_distinct_simple() -> Result<()> {
216 let table_scan = test_table_scan().unwrap();
217 let plan = LogicalPlanBuilder::from(table_scan)
218 .aggregate(vec![col("c")], Vec::<Expr>::new())?
219 .project(vec![col("c")])?
220 .distinct()?
221 .build()?;
222
223 assert_optimized_plan_equal!(plan, @r"
224 Projection: test.c
225 Aggregate: groupBy=[[test.c]], aggr=[[]]
226 TableScan: test
227 ")
228 }
229
230 #[test]
231 fn eliminate_redundant_distinct_pair() -> Result<()> {
232 let table_scan = test_table_scan().unwrap();
233 let plan = LogicalPlanBuilder::from(table_scan)
234 .aggregate(vec![col("a"), col("b")], Vec::<Expr>::new())?
235 .project(vec![col("a"), col("b")])?
236 .distinct()?
237 .build()?;
238
239 assert_optimized_plan_equal!(plan, @r"
240 Projection: test.a, test.b
241 Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
242 TableScan: test
243 ")
244 }
245
246 #[test]
247 fn do_not_eliminate_distinct() -> Result<()> {
248 let table_scan = test_table_scan().unwrap();
249 let plan = LogicalPlanBuilder::from(table_scan)
250 .project(vec![col("a"), col("b")])?
251 .distinct()?
252 .build()?;
253
254 assert_optimized_plan_equal!(plan, @r"
255 Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
256 Projection: test.a, test.b
257 TableScan: test
258 ")
259 }
260
261 #[test]
262 fn do_not_eliminate_distinct_with_aggr() -> Result<()> {
263 let table_scan = test_table_scan().unwrap();
264 let plan = LogicalPlanBuilder::from(table_scan)
265 .aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])?
266 .project(vec![col("a"), col("b")])?
267 .distinct()?
268 .build()?;
269
270 assert_optimized_plan_equal!(plan, @r"
271 Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
272 Projection: test.a, test.b
273 Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]
274 TableScan: test
275 ")
276 }
277}