datafusion_optimizer/
replace_distinct_aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...`
19
20use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp};
21use crate::{OptimizerConfig, OptimizerRule};
22use std::sync::Arc;
23
24use datafusion_common::tree_node::Transformed;
25use datafusion_common::{Column, Result};
26use datafusion_expr::expr_rewriter::normalize_cols;
27use datafusion_expr::utils::expand_wildcard;
28use datafusion_expr::{col, lit, ExprFunctionExt, Limit, LogicalPlanBuilder};
29use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
30
31/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
32///
33/// ```text
34/// SELECT DISTINCT a, b FROM tab
35/// ```
36///
37/// Into
38/// ```text
39/// SELECT a, b FROM tab GROUP BY a, b
40/// ```
41///
42/// On the other hand, for a `DISTINCT ON` query the replacement is
43/// a bit more involved and effectively converts
44/// ```text
45/// SELECT DISTINCT ON (a) b FROM tab ORDER BY a DESC, c
46/// ```
47///
48/// into
49/// ```text
50/// SELECT b FROM (
51///     SELECT a, FIRST_VALUE(b ORDER BY a DESC, c) AS b
52///     FROM tab
53///     GROUP BY a
54/// )
55/// ORDER BY a DESC
56/// ```
57///
58/// In case there are no columns, the [[Distinct]] is replaced by a [[Limit]]
59///
60/// ```text
61/// SELECT DISTINCT * FROM empty_table
62/// ```
63///
64/// Into
65/// ```text
66/// SELECT * FROM empty_table LIMIT 1
67/// ```
68#[derive(Default, Debug)]
69pub struct ReplaceDistinctWithAggregate {}
70
71impl ReplaceDistinctWithAggregate {
72    #[allow(missing_docs)]
73    pub fn new() -> Self {
74        Self {}
75    }
76}
77
78impl OptimizerRule for ReplaceDistinctWithAggregate {
79    fn supports_rewrite(&self) -> bool {
80        true
81    }
82
83    fn rewrite(
84        &self,
85        plan: LogicalPlan,
86        config: &dyn OptimizerConfig,
87    ) -> Result<Transformed<LogicalPlan>> {
88        match plan {
89            LogicalPlan::Distinct(Distinct::All(input)) => {
90                let group_expr = expand_wildcard(input.schema(), &input, None)?;
91
92                if group_expr.is_empty() {
93                    // Special case: there are no columns to group by, so we can't replace it by a group by
94                    // however, we can replace it by LIMIT 1 because there is either no output or a single empty row
95                    return Ok(Transformed::yes(LogicalPlan::Limit(Limit {
96                        skip: None,
97                        fetch: Some(Box::new(lit(1i64))),
98                        input,
99                    })));
100                }
101
102                let field_count = input.schema().fields().len();
103                for dep in input.schema().functional_dependencies().iter() {
104                    // If distinct is exactly the same with a previous GROUP BY, we can
105                    // simply remove it:
106                    if dep.source_indices.len() >= field_count
107                        && dep.source_indices[..field_count]
108                            .iter()
109                            .enumerate()
110                            .all(|(idx, f_idx)| idx == *f_idx)
111                    {
112                        return Ok(Transformed::yes(input.as_ref().clone()));
113                    }
114                }
115
116                // Replace with aggregation:
117                let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new(
118                    input,
119                    group_expr,
120                    vec![],
121                )?);
122                Ok(Transformed::yes(aggr_plan))
123            }
124            LogicalPlan::Distinct(Distinct::On(DistinctOn {
125                select_expr,
126                on_expr,
127                sort_expr,
128                input,
129                schema,
130            })) => {
131                let expr_cnt = on_expr.len();
132
133                // Construct the aggregation expression to be used to fetch the selected expressions.
134                let first_value_udaf: Arc<datafusion_expr::AggregateUDF> =
135                    config.function_registry().unwrap().udaf("first_value")?;
136                let aggr_expr = select_expr.into_iter().map(|e| {
137                    if let Some(order_by) = &sort_expr {
138                        first_value_udaf
139                            .call(vec![e])
140                            .order_by(order_by.clone())
141                            .build()
142                            // guaranteed to be `Expr::AggregateFunction`
143                            .unwrap()
144                    } else {
145                        first_value_udaf.call(vec![e])
146                    }
147                });
148
149                let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
150                let group_expr = normalize_cols(on_expr, input.as_ref())?;
151
152                // Build the aggregation plan
153                let plan = LogicalPlan::Aggregate(Aggregate::try_new(
154                    input, group_expr, aggr_expr,
155                )?);
156                // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate
157                // when https://github.com/apache/datafusion/issues/10485 is available
158                let lpb = LogicalPlanBuilder::from(plan);
159
160                let plan = if let Some(mut sort_expr) = sort_expr {
161                    // While sort expressions were used in the `FIRST_VALUE` aggregation itself above,
162                    // this on it's own isn't enough to guarantee the proper output order of the grouping
163                    // (`ON`) expression, so we need to sort those as well.
164
165                    // truncate the sort_expr to the length of on_expr
166                    sort_expr.truncate(expr_cnt);
167
168                    lpb.sort(sort_expr)?.build()?
169                } else {
170                    lpb.build()?
171                };
172
173                // Whereas the aggregation plan by default outputs both the grouping and the aggregation
174                // expressions, for `DISTINCT ON` we only need to emit the original selection expressions.
175
176                let project_exprs = plan
177                    .schema()
178                    .iter()
179                    .skip(expr_cnt)
180                    .zip(schema.iter())
181                    .map(|((new_qualifier, new_field), (old_qualifier, old_field))| {
182                        col(Column::from((new_qualifier, new_field)))
183                            .alias_qualified(old_qualifier.cloned(), old_field.name())
184                    })
185                    .collect::<Vec<Expr>>();
186
187                let plan = LogicalPlanBuilder::from(plan)
188                    .project(project_exprs)?
189                    .build()?;
190
191                Ok(Transformed::yes(plan))
192            }
193            _ => Ok(Transformed::no(plan)),
194        }
195    }
196
197    fn name(&self) -> &str {
198        "replace_distinct_aggregate"
199    }
200
201    fn apply_order(&self) -> Option<ApplyOrder> {
202        Some(BottomUp)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use crate::assert_optimized_plan_eq_snapshot;
209    use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
210    use crate::test::*;
211    use arrow::datatypes::{Fields, Schema};
212    use std::sync::Arc;
213
214    use crate::OptimizerContext;
215    use datafusion_common::Result;
216    use datafusion_expr::{
217        col, logical_plan::builder::LogicalPlanBuilder, table_scan, Expr,
218    };
219    use datafusion_functions_aggregate::sum::sum;
220
221    macro_rules! assert_optimized_plan_equal {
222        (
223            $plan:expr,
224            @ $expected:literal $(,)?
225        ) => {{
226            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
227            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(ReplaceDistinctWithAggregate::new())];
228            assert_optimized_plan_eq_snapshot!(
229                optimizer_ctx,
230                rules,
231                $plan,
232                @ $expected,
233            )
234        }};
235    }
236
237    #[test]
238    fn eliminate_redundant_distinct_simple() -> Result<()> {
239        let table_scan = test_table_scan().unwrap();
240        let plan = LogicalPlanBuilder::from(table_scan)
241            .aggregate(vec![col("c")], Vec::<Expr>::new())?
242            .project(vec![col("c")])?
243            .distinct()?
244            .build()?;
245
246        assert_optimized_plan_equal!(plan, @r"
247        Projection: test.c
248          Aggregate: groupBy=[[test.c]], aggr=[[]]
249            TableScan: test
250        ")
251    }
252
253    #[test]
254    fn eliminate_redundant_distinct_pair() -> Result<()> {
255        let table_scan = test_table_scan().unwrap();
256        let plan = LogicalPlanBuilder::from(table_scan)
257            .aggregate(vec![col("a"), col("b")], Vec::<Expr>::new())?
258            .project(vec![col("a"), col("b")])?
259            .distinct()?
260            .build()?;
261
262        assert_optimized_plan_equal!(plan, @r"
263        Projection: test.a, test.b
264          Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
265            TableScan: test
266        ")
267    }
268
269    #[test]
270    fn do_not_eliminate_distinct() -> Result<()> {
271        let table_scan = test_table_scan().unwrap();
272        let plan = LogicalPlanBuilder::from(table_scan)
273            .project(vec![col("a"), col("b")])?
274            .distinct()?
275            .build()?;
276
277        assert_optimized_plan_equal!(plan, @r"
278        Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
279          Projection: test.a, test.b
280            TableScan: test
281        ")
282    }
283
284    #[test]
285    fn do_not_eliminate_distinct_with_aggr() -> Result<()> {
286        let table_scan = test_table_scan().unwrap();
287        let plan = LogicalPlanBuilder::from(table_scan)
288            .aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])?
289            .project(vec![col("a"), col("b")])?
290            .distinct()?
291            .build()?;
292
293        assert_optimized_plan_equal!(plan, @r"
294        Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
295          Projection: test.a, test.b
296            Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]
297              TableScan: test
298        ")
299    }
300
301    #[test]
302    fn use_limit_1_when_no_columns() -> Result<()> {
303        let plan = table_scan(Some("test"), &Schema::new(Fields::empty()), None)?
304            .distinct()?
305            .build()?;
306
307        assert_optimized_plan_equal!(plan, @r"
308        Limit: skip=0, fetch=1
309          TableScan: test
310        ")
311    }
312}