datafusion_optimizer/
replace_distinct_aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...`
19
20use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp};
21use crate::{OptimizerConfig, OptimizerRule};
22use std::sync::Arc;
23
24use datafusion_common::tree_node::Transformed;
25use datafusion_common::{Column, Result};
26use datafusion_expr::expr_rewriter::normalize_cols;
27use datafusion_expr::utils::expand_wildcard;
28use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder};
29use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
30
31/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
32///
33/// ```text
34/// SELECT DISTINCT a, b FROM tab
35/// ```
36///
37/// Into
38/// ```text
39/// SELECT a, b FROM tab GROUP BY a, b
40/// ```
41///
42/// On the other hand, for a `DISTINCT ON` query the replacement is
43/// a bit more involved and effectively converts
44/// ```text
45/// SELECT DISTINCT ON (a) b FROM tab ORDER BY a DESC, c
46/// ```
47///
48/// into
49/// ```text
50/// SELECT b FROM (
51///     SELECT a, FIRST_VALUE(b ORDER BY a DESC, c) AS b
52///     FROM tab
53///     GROUP BY a
54/// )
55/// ORDER BY a DESC
56/// ```
57#[derive(Default, Debug)]
58pub struct ReplaceDistinctWithAggregate {}
59
60impl ReplaceDistinctWithAggregate {
61    #[allow(missing_docs)]
62    pub fn new() -> Self {
63        Self {}
64    }
65}
66
67impl OptimizerRule for ReplaceDistinctWithAggregate {
68    fn supports_rewrite(&self) -> bool {
69        true
70    }
71
72    fn rewrite(
73        &self,
74        plan: LogicalPlan,
75        config: &dyn OptimizerConfig,
76    ) -> Result<Transformed<LogicalPlan>> {
77        match plan {
78            LogicalPlan::Distinct(Distinct::All(input)) => {
79                let group_expr = expand_wildcard(input.schema(), &input, None)?;
80
81                let field_count = input.schema().fields().len();
82                for dep in input.schema().functional_dependencies().iter() {
83                    // If distinct is exactly the same with a previous GROUP BY, we can
84                    // simply remove it:
85                    if dep.source_indices.len() >= field_count
86                        && dep.source_indices[..field_count]
87                            .iter()
88                            .enumerate()
89                            .all(|(idx, f_idx)| idx == *f_idx)
90                    {
91                        return Ok(Transformed::yes(input.as_ref().clone()));
92                    }
93                }
94
95                // Replace with aggregation:
96                let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new(
97                    input,
98                    group_expr,
99                    vec![],
100                )?);
101                Ok(Transformed::yes(aggr_plan))
102            }
103            LogicalPlan::Distinct(Distinct::On(DistinctOn {
104                select_expr,
105                on_expr,
106                sort_expr,
107                input,
108                schema,
109            })) => {
110                let expr_cnt = on_expr.len();
111
112                // Construct the aggregation expression to be used to fetch the selected expressions.
113                let first_value_udaf: Arc<datafusion_expr::AggregateUDF> =
114                    config.function_registry().unwrap().udaf("first_value")?;
115                let aggr_expr = select_expr.into_iter().map(|e| {
116                    if let Some(order_by) = &sort_expr {
117                        first_value_udaf
118                            .call(vec![e])
119                            .order_by(order_by.clone())
120                            .build()
121                            // guaranteed to be `Expr::AggregateFunction`
122                            .unwrap()
123                    } else {
124                        first_value_udaf.call(vec![e])
125                    }
126                });
127
128                let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
129                let group_expr = normalize_cols(on_expr, input.as_ref())?;
130
131                // Build the aggregation plan
132                let plan = LogicalPlan::Aggregate(Aggregate::try_new(
133                    input, group_expr, aggr_expr,
134                )?);
135                // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate
136                // when https://github.com/apache/datafusion/issues/10485 is available
137                let lpb = LogicalPlanBuilder::from(plan);
138
139                let plan = if let Some(mut sort_expr) = sort_expr {
140                    // While sort expressions were used in the `FIRST_VALUE` aggregation itself above,
141                    // this on it's own isn't enough to guarantee the proper output order of the grouping
142                    // (`ON`) expression, so we need to sort those as well.
143
144                    // truncate the sort_expr to the length of on_expr
145                    sort_expr.truncate(expr_cnt);
146
147                    lpb.sort(sort_expr)?.build()?
148                } else {
149                    lpb.build()?
150                };
151
152                // Whereas the aggregation plan by default outputs both the grouping and the aggregation
153                // expressions, for `DISTINCT ON` we only need to emit the original selection expressions.
154
155                let project_exprs = plan
156                    .schema()
157                    .iter()
158                    .skip(expr_cnt)
159                    .zip(schema.iter())
160                    .map(|((new_qualifier, new_field), (old_qualifier, old_field))| {
161                        col(Column::from((new_qualifier, new_field)))
162                            .alias_qualified(old_qualifier.cloned(), old_field.name())
163                    })
164                    .collect::<Vec<Expr>>();
165
166                let plan = LogicalPlanBuilder::from(plan)
167                    .project(project_exprs)?
168                    .build()?;
169
170                Ok(Transformed::yes(plan))
171            }
172            _ => Ok(Transformed::no(plan)),
173        }
174    }
175
176    fn name(&self) -> &str {
177        "replace_distinct_aggregate"
178    }
179
180    fn apply_order(&self) -> Option<ApplyOrder> {
181        Some(BottomUp)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use std::sync::Arc;
188
189    use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
190    use crate::test::*;
191
192    use datafusion_common::Result;
193    use datafusion_expr::{
194        col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
195    };
196    use datafusion_functions_aggregate::sum::sum;
197
198    fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
199        assert_optimized_plan_eq(
200            Arc::new(ReplaceDistinctWithAggregate::new()),
201            plan.clone(),
202            expected,
203        )
204    }
205
206    #[test]
207    fn eliminate_redundant_distinct_simple() -> Result<()> {
208        let table_scan = test_table_scan().unwrap();
209        let plan = LogicalPlanBuilder::from(table_scan)
210            .aggregate(vec![col("c")], Vec::<Expr>::new())?
211            .project(vec![col("c")])?
212            .distinct()?
213            .build()?;
214
215        let expected = "Projection: test.c\n  Aggregate: groupBy=[[test.c]], aggr=[[]]\n    TableScan: test";
216        assert_optimized_plan_equal(&plan, expected)
217    }
218
219    #[test]
220    fn eliminate_redundant_distinct_pair() -> Result<()> {
221        let table_scan = test_table_scan().unwrap();
222        let plan = LogicalPlanBuilder::from(table_scan)
223            .aggregate(vec![col("a"), col("b")], Vec::<Expr>::new())?
224            .project(vec![col("a"), col("b")])?
225            .distinct()?
226            .build()?;
227
228        let expected =
229            "Projection: test.a, test.b\n  Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n    TableScan: test";
230        assert_optimized_plan_equal(&plan, expected)
231    }
232
233    #[test]
234    fn do_not_eliminate_distinct() -> Result<()> {
235        let table_scan = test_table_scan().unwrap();
236        let plan = LogicalPlanBuilder::from(table_scan)
237            .project(vec![col("a"), col("b")])?
238            .distinct()?
239            .build()?;
240
241        let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n  Projection: test.a, test.b\n    TableScan: test";
242        assert_optimized_plan_equal(&plan, expected)
243    }
244
245    #[test]
246    fn do_not_eliminate_distinct_with_aggr() -> Result<()> {
247        let table_scan = test_table_scan().unwrap();
248        let plan = LogicalPlanBuilder::from(table_scan)
249            .aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])?
250            .project(vec![col("a"), col("b")])?
251            .distinct()?
252            .build()?;
253
254        let expected =
255            "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n  Projection: test.a, test.b\n    Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n      TableScan: test";
256        assert_optimized_plan_equal(&plan, expected)
257    }
258}