datafusion_optimizer/
replace_distinct_aggregate.rs1use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp};
21use crate::{OptimizerConfig, OptimizerRule};
22use std::sync::Arc;
23
24use datafusion_common::tree_node::Transformed;
25use datafusion_common::{Column, Result};
26use datafusion_expr::expr_rewriter::normalize_cols;
27use datafusion_expr::utils::expand_wildcard;
28use datafusion_expr::{col, lit, ExprFunctionExt, Limit, LogicalPlanBuilder};
29use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
30
31#[derive(Default, Debug)]
69pub struct ReplaceDistinctWithAggregate {}
70
71impl ReplaceDistinctWithAggregate {
72 #[allow(missing_docs)]
73 pub fn new() -> Self {
74 Self {}
75 }
76}
77
78impl OptimizerRule for ReplaceDistinctWithAggregate {
79 fn supports_rewrite(&self) -> bool {
80 true
81 }
82
83 fn rewrite(
84 &self,
85 plan: LogicalPlan,
86 config: &dyn OptimizerConfig,
87 ) -> Result<Transformed<LogicalPlan>> {
88 match plan {
89 LogicalPlan::Distinct(Distinct::All(input)) => {
90 let group_expr = expand_wildcard(input.schema(), &input, None)?;
91
92 if group_expr.is_empty() {
93 return Ok(Transformed::yes(LogicalPlan::Limit(Limit {
96 skip: None,
97 fetch: Some(Box::new(lit(1i64))),
98 input,
99 })));
100 }
101
102 let field_count = input.schema().fields().len();
103 for dep in input.schema().functional_dependencies().iter() {
104 if dep.source_indices.len() >= field_count
107 && dep.source_indices[..field_count]
108 .iter()
109 .enumerate()
110 .all(|(idx, f_idx)| idx == *f_idx)
111 {
112 return Ok(Transformed::yes(input.as_ref().clone()));
113 }
114 }
115
116 let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new(
118 input,
119 group_expr,
120 vec![],
121 )?);
122 Ok(Transformed::yes(aggr_plan))
123 }
124 LogicalPlan::Distinct(Distinct::On(DistinctOn {
125 select_expr,
126 on_expr,
127 sort_expr,
128 input,
129 schema,
130 })) => {
131 let expr_cnt = on_expr.len();
132
133 let first_value_udaf: Arc<datafusion_expr::AggregateUDF> =
135 config.function_registry().unwrap().udaf("first_value")?;
136 let aggr_expr = select_expr.into_iter().map(|e| {
137 if let Some(order_by) = &sort_expr {
138 first_value_udaf
139 .call(vec![e])
140 .order_by(order_by.clone())
141 .build()
142 .unwrap()
144 } else {
145 first_value_udaf.call(vec![e])
146 }
147 });
148
149 let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
150 let group_expr = normalize_cols(on_expr, input.as_ref())?;
151
152 let plan = LogicalPlan::Aggregate(Aggregate::try_new(
154 input, group_expr, aggr_expr,
155 )?);
156 let lpb = LogicalPlanBuilder::from(plan);
159
160 let plan = if let Some(mut sort_expr) = sort_expr {
161 sort_expr.truncate(expr_cnt);
167
168 lpb.sort(sort_expr)?.build()?
169 } else {
170 lpb.build()?
171 };
172
173 let project_exprs = plan
177 .schema()
178 .iter()
179 .skip(expr_cnt)
180 .zip(schema.iter())
181 .map(|((new_qualifier, new_field), (old_qualifier, old_field))| {
182 col(Column::from((new_qualifier, new_field)))
183 .alias_qualified(old_qualifier.cloned(), old_field.name())
184 })
185 .collect::<Vec<Expr>>();
186
187 let plan = LogicalPlanBuilder::from(plan)
188 .project(project_exprs)?
189 .build()?;
190
191 Ok(Transformed::yes(plan))
192 }
193 _ => Ok(Transformed::no(plan)),
194 }
195 }
196
197 fn name(&self) -> &str {
198 "replace_distinct_aggregate"
199 }
200
201 fn apply_order(&self) -> Option<ApplyOrder> {
202 Some(BottomUp)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use crate::assert_optimized_plan_eq_snapshot;
209 use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
210 use crate::test::*;
211 use arrow::datatypes::{Fields, Schema};
212 use std::sync::Arc;
213
214 use crate::OptimizerContext;
215 use datafusion_common::Result;
216 use datafusion_expr::{
217 col, logical_plan::builder::LogicalPlanBuilder, table_scan, Expr,
218 };
219 use datafusion_functions_aggregate::sum::sum;
220
221 macro_rules! assert_optimized_plan_equal {
222 (
223 $plan:expr,
224 @ $expected:literal $(,)?
225 ) => {{
226 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
227 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(ReplaceDistinctWithAggregate::new())];
228 assert_optimized_plan_eq_snapshot!(
229 optimizer_ctx,
230 rules,
231 $plan,
232 @ $expected,
233 )
234 }};
235 }
236
237 #[test]
238 fn eliminate_redundant_distinct_simple() -> Result<()> {
239 let table_scan = test_table_scan().unwrap();
240 let plan = LogicalPlanBuilder::from(table_scan)
241 .aggregate(vec![col("c")], Vec::<Expr>::new())?
242 .project(vec![col("c")])?
243 .distinct()?
244 .build()?;
245
246 assert_optimized_plan_equal!(plan, @r"
247 Projection: test.c
248 Aggregate: groupBy=[[test.c]], aggr=[[]]
249 TableScan: test
250 ")
251 }
252
253 #[test]
254 fn eliminate_redundant_distinct_pair() -> Result<()> {
255 let table_scan = test_table_scan().unwrap();
256 let plan = LogicalPlanBuilder::from(table_scan)
257 .aggregate(vec![col("a"), col("b")], Vec::<Expr>::new())?
258 .project(vec![col("a"), col("b")])?
259 .distinct()?
260 .build()?;
261
262 assert_optimized_plan_equal!(plan, @r"
263 Projection: test.a, test.b
264 Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
265 TableScan: test
266 ")
267 }
268
269 #[test]
270 fn do_not_eliminate_distinct() -> Result<()> {
271 let table_scan = test_table_scan().unwrap();
272 let plan = LogicalPlanBuilder::from(table_scan)
273 .project(vec![col("a"), col("b")])?
274 .distinct()?
275 .build()?;
276
277 assert_optimized_plan_equal!(plan, @r"
278 Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
279 Projection: test.a, test.b
280 TableScan: test
281 ")
282 }
283
284 #[test]
285 fn do_not_eliminate_distinct_with_aggr() -> Result<()> {
286 let table_scan = test_table_scan().unwrap();
287 let plan = LogicalPlanBuilder::from(table_scan)
288 .aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])?
289 .project(vec![col("a"), col("b")])?
290 .distinct()?
291 .build()?;
292
293 assert_optimized_plan_equal!(plan, @r"
294 Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
295 Projection: test.a, test.b
296 Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]
297 TableScan: test
298 ")
299 }
300
301 #[test]
302 fn use_limit_1_when_no_columns() -> Result<()> {
303 let plan = table_scan(Some("test"), &Schema::new(Fields::empty()), None)?
304 .distinct()?
305 .build()?;
306
307 assert_optimized_plan_equal!(plan, @r"
308 Limit: skip=0, fetch=1
309 TableScan: test
310 ")
311 }
312}