datafusion_optimizer/
replace_distinct_aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...`
19
20use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp};
21use crate::{OptimizerConfig, OptimizerRule};
22use std::sync::Arc;
23
24use datafusion_common::tree_node::Transformed;
25use datafusion_common::{Column, Result};
26use datafusion_expr::expr_rewriter::normalize_cols;
27use datafusion_expr::utils::expand_wildcard;
28use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder};
29use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
30
31/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
32///
33/// ```text
34/// SELECT DISTINCT a, b FROM tab
35/// ```
36///
37/// Into
38/// ```text
39/// SELECT a, b FROM tab GROUP BY a, b
40/// ```
41///
42/// On the other hand, for a `DISTINCT ON` query the replacement is
43/// a bit more involved and effectively converts
44/// ```text
45/// SELECT DISTINCT ON (a) b FROM tab ORDER BY a DESC, c
46/// ```
47///
48/// into
49/// ```text
50/// SELECT b FROM (
51///     SELECT a, FIRST_VALUE(b ORDER BY a DESC, c) AS b
52///     FROM tab
53///     GROUP BY a
54/// )
55/// ORDER BY a DESC
56/// ```
57#[derive(Default, Debug)]
58pub struct ReplaceDistinctWithAggregate {}
59
60impl ReplaceDistinctWithAggregate {
61    #[allow(missing_docs)]
62    pub fn new() -> Self {
63        Self {}
64    }
65}
66
67impl OptimizerRule for ReplaceDistinctWithAggregate {
68    fn supports_rewrite(&self) -> bool {
69        true
70    }
71
72    fn rewrite(
73        &self,
74        plan: LogicalPlan,
75        config: &dyn OptimizerConfig,
76    ) -> Result<Transformed<LogicalPlan>> {
77        match plan {
78            LogicalPlan::Distinct(Distinct::All(input)) => {
79                let group_expr = expand_wildcard(input.schema(), &input, None)?;
80
81                let field_count = input.schema().fields().len();
82                for dep in input.schema().functional_dependencies().iter() {
83                    // If distinct is exactly the same with a previous GROUP BY, we can
84                    // simply remove it:
85                    if dep.source_indices.len() >= field_count
86                        && dep.source_indices[..field_count]
87                            .iter()
88                            .enumerate()
89                            .all(|(idx, f_idx)| idx == *f_idx)
90                    {
91                        return Ok(Transformed::yes(input.as_ref().clone()));
92                    }
93                }
94
95                // Replace with aggregation:
96                let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new(
97                    input,
98                    group_expr,
99                    vec![],
100                )?);
101                Ok(Transformed::yes(aggr_plan))
102            }
103            LogicalPlan::Distinct(Distinct::On(DistinctOn {
104                select_expr,
105                on_expr,
106                sort_expr,
107                input,
108                schema,
109            })) => {
110                let expr_cnt = on_expr.len();
111
112                // Construct the aggregation expression to be used to fetch the selected expressions.
113                let first_value_udaf: Arc<datafusion_expr::AggregateUDF> =
114                    config.function_registry().unwrap().udaf("first_value")?;
115                let aggr_expr = select_expr.into_iter().map(|e| {
116                    if let Some(order_by) = &sort_expr {
117                        first_value_udaf
118                            .call(vec![e])
119                            .order_by(order_by.clone())
120                            .build()
121                            // guaranteed to be `Expr::AggregateFunction`
122                            .unwrap()
123                    } else {
124                        first_value_udaf.call(vec![e])
125                    }
126                });
127
128                let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
129                let group_expr = normalize_cols(on_expr, input.as_ref())?;
130
131                // Build the aggregation plan
132                let plan = LogicalPlan::Aggregate(Aggregate::try_new(
133                    input, group_expr, aggr_expr,
134                )?);
135                // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate
136                // when https://github.com/apache/datafusion/issues/10485 is available
137                let lpb = LogicalPlanBuilder::from(plan);
138
139                let plan = if let Some(mut sort_expr) = sort_expr {
140                    // While sort expressions were used in the `FIRST_VALUE` aggregation itself above,
141                    // this on it's own isn't enough to guarantee the proper output order of the grouping
142                    // (`ON`) expression, so we need to sort those as well.
143
144                    // truncate the sort_expr to the length of on_expr
145                    sort_expr.truncate(expr_cnt);
146
147                    lpb.sort(sort_expr)?.build()?
148                } else {
149                    lpb.build()?
150                };
151
152                // Whereas the aggregation plan by default outputs both the grouping and the aggregation
153                // expressions, for `DISTINCT ON` we only need to emit the original selection expressions.
154
155                let project_exprs = plan
156                    .schema()
157                    .iter()
158                    .skip(expr_cnt)
159                    .zip(schema.iter())
160                    .map(|((new_qualifier, new_field), (old_qualifier, old_field))| {
161                        col(Column::from((new_qualifier, new_field)))
162                            .alias_qualified(old_qualifier.cloned(), old_field.name())
163                    })
164                    .collect::<Vec<Expr>>();
165
166                let plan = LogicalPlanBuilder::from(plan)
167                    .project(project_exprs)?
168                    .build()?;
169
170                Ok(Transformed::yes(plan))
171            }
172            _ => Ok(Transformed::no(plan)),
173        }
174    }
175
176    fn name(&self) -> &str {
177        "replace_distinct_aggregate"
178    }
179
180    fn apply_order(&self) -> Option<ApplyOrder> {
181        Some(BottomUp)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use std::sync::Arc;
188
189    use crate::assert_optimized_plan_eq_snapshot;
190    use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
191    use crate::test::*;
192
193    use crate::OptimizerContext;
194    use datafusion_common::Result;
195    use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, Expr};
196    use datafusion_functions_aggregate::sum::sum;
197
198    macro_rules! assert_optimized_plan_equal {
199        (
200            $plan:expr,
201            @ $expected:literal $(,)?
202        ) => {{
203            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
204            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(ReplaceDistinctWithAggregate::new())];
205            assert_optimized_plan_eq_snapshot!(
206                optimizer_ctx,
207                rules,
208                $plan,
209                @ $expected,
210            )
211        }};
212    }
213
214    #[test]
215    fn eliminate_redundant_distinct_simple() -> Result<()> {
216        let table_scan = test_table_scan().unwrap();
217        let plan = LogicalPlanBuilder::from(table_scan)
218            .aggregate(vec![col("c")], Vec::<Expr>::new())?
219            .project(vec![col("c")])?
220            .distinct()?
221            .build()?;
222
223        assert_optimized_plan_equal!(plan, @r"
224        Projection: test.c
225          Aggregate: groupBy=[[test.c]], aggr=[[]]
226            TableScan: test
227        ")
228    }
229
230    #[test]
231    fn eliminate_redundant_distinct_pair() -> Result<()> {
232        let table_scan = test_table_scan().unwrap();
233        let plan = LogicalPlanBuilder::from(table_scan)
234            .aggregate(vec![col("a"), col("b")], Vec::<Expr>::new())?
235            .project(vec![col("a"), col("b")])?
236            .distinct()?
237            .build()?;
238
239        assert_optimized_plan_equal!(plan, @r"
240        Projection: test.a, test.b
241          Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
242            TableScan: test
243        ")
244    }
245
246    #[test]
247    fn do_not_eliminate_distinct() -> Result<()> {
248        let table_scan = test_table_scan().unwrap();
249        let plan = LogicalPlanBuilder::from(table_scan)
250            .project(vec![col("a"), col("b")])?
251            .distinct()?
252            .build()?;
253
254        assert_optimized_plan_equal!(plan, @r"
255        Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
256          Projection: test.a, test.b
257            TableScan: test
258        ")
259    }
260
261    #[test]
262    fn do_not_eliminate_distinct_with_aggr() -> Result<()> {
263        let table_scan = test_table_scan().unwrap();
264        let plan = LogicalPlanBuilder::from(table_scan)
265            .aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])?
266            .project(vec![col("a"), col("b")])?
267            .distinct()?
268            .build()?;
269
270        assert_optimized_plan_equal!(plan, @r"
271        Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
272          Projection: test.a, test.b
273            Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]
274              TableScan: test
275        ")
276    }
277}