datafusion_optimizer/
single_distinct_to_groupby.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SingleDistinctToGroupBy`] replaces `AGG(DISTINCT ..)` with `AGG(..) GROUP BY ..`
19
20use std::sync::Arc;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24
25use datafusion_common::{
26    DataFusionError, HashSet, Result, assert_eq_or_internal_err, tree_node::Transformed,
27};
28use datafusion_expr::builder::project;
29use datafusion_expr::expr::AggregateFunctionParams;
30use datafusion_expr::{
31    Expr, col,
32    expr::AggregateFunction,
33    logical_plan::{Aggregate, LogicalPlan},
34};
35
36/// single distinct to group by optimizer rule
37///  ```text
38///    Before:
39///    SELECT a, count(DISTINCT b), sum(c)
40///    FROM t
41///    GROUP BY a
42///
43///    After:
44///    SELECT a, count(alias1), sum(alias2)
45///    FROM (
46///      SELECT a, b as alias1, sum(c) as alias2
47///      FROM t
48///      GROUP BY a, b
49///    )
50///    GROUP BY a
51///  ```
52#[derive(Default, Debug)]
53pub struct SingleDistinctToGroupBy {}
54
55const SINGLE_DISTINCT_ALIAS: &str = "alias1";
56
57impl SingleDistinctToGroupBy {
58    #[expect(missing_docs)]
59    pub fn new() -> Self {
60        Self {}
61    }
62}
63
64/// Check whether all aggregate exprs are distinct on a single field.
65fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result<bool> {
66    let mut fields_set = HashSet::new();
67    let mut aggregate_count = 0;
68    for expr in aggr_expr {
69        if let Expr::AggregateFunction(AggregateFunction {
70            func,
71            params:
72                AggregateFunctionParams {
73                    distinct,
74                    args,
75                    filter,
76                    order_by,
77                    null_treatment: _,
78                },
79        }) = expr
80        {
81            if filter.is_some() || !order_by.is_empty() {
82                return Ok(false);
83            }
84            aggregate_count += 1;
85            if *distinct {
86                for e in args {
87                    fields_set.insert(e);
88                }
89            } else if func.name() != "sum"
90                && func.name().to_lowercase() != "min"
91                && func.name().to_lowercase() != "max"
92            {
93                return Ok(false);
94            }
95        } else {
96            return Ok(false);
97        }
98    }
99    Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
100}
101
102/// Check if the first expr is [Expr::GroupingSet].
103fn contains_grouping_set(expr: &[Expr]) -> bool {
104    matches!(expr.first(), Some(Expr::GroupingSet(_)))
105}
106
107impl OptimizerRule for SingleDistinctToGroupBy {
108    fn name(&self) -> &str {
109        "single_distinct_aggregation_to_group_by"
110    }
111
112    fn apply_order(&self) -> Option<ApplyOrder> {
113        Some(ApplyOrder::TopDown)
114    }
115
116    fn supports_rewrite(&self) -> bool {
117        true
118    }
119
120    fn rewrite(
121        &self,
122        plan: LogicalPlan,
123        _config: &dyn OptimizerConfig,
124    ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
125        match plan {
126            LogicalPlan::Aggregate(Aggregate {
127                input,
128                aggr_expr,
129                schema,
130                group_expr,
131                ..
132            }) if is_single_distinct_agg(&aggr_expr)?
133                && !contains_grouping_set(&group_expr) =>
134            {
135                let group_size = group_expr.len();
136                // alias all original group_by exprs
137                let (mut inner_group_exprs, out_group_expr_with_alias): (
138                    Vec<Expr>,
139                    Vec<(Expr, _)>,
140                ) = group_expr
141                    .into_iter()
142                    .enumerate()
143                    .map(|(i, group_expr)| {
144                        if let Expr::Column(_) = group_expr {
145                            // For Column expressions we can use existing expression as is.
146                            (group_expr.clone(), (group_expr, None))
147                        } else {
148                            // For complex expression write is as alias, to be able to refer
149                            // if from parent operators successfully.
150                            // Consider plan below.
151                            //
152                            // Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
153                            // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
154                            // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
155                            //
156                            // First aggregate(from bottom) refers to `test.a` column.
157                            // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate.
158
159                            // If we were to write plan above as below without alias
160                            //
161                            // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
162                            // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
163                            // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
164                            //
165                            // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it.
166                            let alias_str = format!("group_alias_{i}");
167                            let (qualifier, field) = schema.qualified_field(i);
168                            (
169                                group_expr.alias(alias_str.clone()),
170                                (col(alias_str), Some((qualifier, field.name()))),
171                            )
172                        }
173                    })
174                    .unzip();
175
176                // replace the distinct arg with alias
177                let mut index = 1;
178                let mut group_fields_set = HashSet::new();
179                let mut inner_aggr_exprs = vec![];
180                let outer_aggr_exprs = aggr_expr
181                    .into_iter()
182                    .map(|aggr_expr| match aggr_expr {
183                        Expr::AggregateFunction(AggregateFunction {
184                            func,
185                            params:
186                                AggregateFunctionParams {
187                                    mut args, distinct, ..
188                                },
189                        }) => {
190                            if distinct {
191                                assert_eq_or_internal_err!(
192                                    args.len(),
193                                    1,
194                                    "DISTINCT aggregate should have exactly one argument"
195                                );
196                                let arg = args.swap_remove(0);
197
198                                if group_fields_set.insert(arg.schema_name().to_string())
199                                {
200                                    inner_group_exprs
201                                        .push(arg.alias(SINGLE_DISTINCT_ALIAS));
202                                }
203                                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
204                                    func,
205                                    vec![col(SINGLE_DISTINCT_ALIAS)],
206                                    false, // intentional to remove distinct here
207                                    None,
208                                    vec![],
209                                    None,
210                                )))
211                                // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
212                            } else {
213                                index += 1;
214                                let alias_str = format!("alias{index}");
215                                inner_aggr_exprs.push(
216                                    Expr::AggregateFunction(AggregateFunction::new_udf(
217                                        Arc::clone(&func),
218                                        args,
219                                        false,
220                                        None,
221                                        vec![],
222                                        None,
223                                    ))
224                                    .alias(&alias_str),
225                                );
226                                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
227                                    func,
228                                    vec![col(&alias_str)],
229                                    false,
230                                    None,
231                                    vec![],
232                                    None,
233                                )))
234                            }
235                        }
236                        _ => Ok(aggr_expr),
237                    })
238                    .collect::<Result<Vec<_>>>()?;
239
240                // construct the inner AggrPlan
241                let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
242                    input,
243                    inner_group_exprs,
244                    inner_aggr_exprs,
245                )?);
246
247                let outer_group_exprs = out_group_expr_with_alias
248                    .iter()
249                    .map(|(expr, _)| expr.clone())
250                    .collect();
251
252                // so the aggregates are displayed in the same way even after the rewrite
253                // this optimizer has two kinds of alias:
254                // - group_by aggr
255                // - aggr expr
256                let alias_expr: Vec<_> = out_group_expr_with_alias
257                    .into_iter()
258                    .map(|(group_expr, original_name)| match original_name {
259                        Some((qualifier, name)) => {
260                            group_expr.alias_qualified(qualifier.cloned(), name)
261                        }
262                        None => group_expr,
263                    })
264                    .chain(outer_aggr_exprs.iter().cloned().enumerate().map(
265                        |(idx, expr)| {
266                            let idx = idx + group_size;
267                            let (qualifier, field) = schema.qualified_field(idx);
268                            expr.alias_qualified(qualifier.cloned(), field.name())
269                        },
270                    ))
271                    .collect();
272
273                let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
274                    Arc::new(inner_agg),
275                    outer_group_exprs,
276                    outer_aggr_exprs,
277                )?);
278                Ok(Transformed::yes(project(outer_aggr, alias_expr)?))
279            }
280            _ => Ok(Transformed::no(plan)),
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use crate::assert_optimized_plan_eq_display_indent_snapshot;
289    use crate::test::*;
290    use datafusion_expr::ExprFunctionExt;
291    use datafusion_expr::expr::GroupingSet;
292    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
293    use datafusion_functions_aggregate::count::count_udaf;
294    use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum};
295    use datafusion_functions_aggregate::min_max::max_udaf;
296    use datafusion_functions_aggregate::sum::sum_udaf;
297
298    fn max_distinct(expr: Expr) -> Expr {
299        Expr::AggregateFunction(AggregateFunction::new_udf(
300            max_udaf(),
301            vec![expr],
302            true,
303            None,
304            vec![],
305            None,
306        ))
307    }
308
309    macro_rules! assert_optimized_plan_equal {
310        (
311            $plan:expr,
312            @ $expected:literal $(,)?
313        ) => {{
314            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(SingleDistinctToGroupBy::new());
315            assert_optimized_plan_eq_display_indent_snapshot!(
316                rule,
317                $plan,
318                @ $expected,
319            )
320        }};
321    }
322
323    #[test]
324    fn not_exist_distinct() -> Result<()> {
325        let table_scan = test_table_scan()?;
326
327        let plan = LogicalPlanBuilder::from(table_scan)
328            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
329            .build()?;
330
331        // Do nothing
332        assert_optimized_plan_equal!(
333            plan,
334            @r"
335        Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N]
336          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
337        "
338        )
339    }
340
341    #[test]
342    fn single_distinct() -> Result<()> {
343        let table_scan = test_table_scan()?;
344
345        let plan = LogicalPlanBuilder::from(table_scan)
346            .aggregate(Vec::<Expr>::new(), vec![count_distinct(col("b"))])?
347            .build()?;
348
349        // Should work
350        assert_optimized_plan_equal!(
351            plan,
352            @r"
353        Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]
354          Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]
355            Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]
356              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
357        "
358        )
359    }
360
361    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
362    #[test]
363    fn single_distinct_and_grouping_set() -> Result<()> {
364        let table_scan = test_table_scan()?;
365
366        let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
367            vec![col("a")],
368            vec![col("b")],
369        ]));
370
371        let plan = LogicalPlanBuilder::from(table_scan)
372            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
373            .build()?;
374
375        // Should not be optimized
376        assert_optimized_plan_equal!(
377            plan,
378            @r"
379        Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
380          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
381        "
382        )
383    }
384
385    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
386    #[test]
387    fn single_distinct_and_cube() -> Result<()> {
388        let table_scan = test_table_scan()?;
389
390        let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));
391
392        let plan = LogicalPlanBuilder::from(table_scan)
393            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
394            .build()?;
395
396        // Should not be optimized
397        assert_optimized_plan_equal!(
398            plan,
399            @r"
400        Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
401          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
402        "
403        )
404    }
405
406    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
407    #[test]
408    fn single_distinct_and_rollup() -> Result<()> {
409        let table_scan = test_table_scan()?;
410
411        let grouping_set =
412            Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));
413
414        let plan = LogicalPlanBuilder::from(table_scan)
415            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
416            .build()?;
417
418        // Should not be optimized
419        assert_optimized_plan_equal!(
420            plan,
421            @r"
422        Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
423          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
424        "
425        )
426    }
427
428    #[test]
429    fn single_distinct_expr() -> Result<()> {
430        let table_scan = test_table_scan()?;
431
432        let plan = LogicalPlanBuilder::from(table_scan)
433            .aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
434            .build()?;
435
436        assert_optimized_plan_equal!(
437            plan,
438            @r"
439        Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]
440          Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]
441            Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]
442              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
443        "
444        )
445    }
446
447    #[test]
448    fn single_distinct_and_groupby() -> Result<()> {
449        let table_scan = test_table_scan()?;
450
451        let plan = LogicalPlanBuilder::from(table_scan)
452            .aggregate(vec![col("a")], vec![count_distinct(col("b"))])?
453            .build()?;
454
455        // Should work
456        assert_optimized_plan_equal!(
457            plan,
458            @r"
459        Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]
460          Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]
461            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]
462              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
463        "
464        )
465    }
466
467    #[test]
468    fn two_distinct_and_groupby() -> Result<()> {
469        let table_scan = test_table_scan()?;
470
471        let plan = LogicalPlanBuilder::from(table_scan)
472            .aggregate(
473                vec![col("a")],
474                vec![count_distinct(col("b")), count_distinct(col("c"))],
475            )?
476            .build()?;
477
478        // Do nothing
479        assert_optimized_plan_equal!(
480            plan,
481            @r"
482        Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]
483          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
484        "
485        )
486    }
487
488    #[test]
489    fn one_field_two_distinct_and_groupby() -> Result<()> {
490        let table_scan = test_table_scan()?;
491
492        let plan = LogicalPlanBuilder::from(table_scan)
493            .aggregate(
494                vec![col("a")],
495                vec![count_distinct(col("b")), max_distinct(col("b"))],
496            )?
497            .build()?;
498
499        // Should work
500        assert_optimized_plan_equal!(
501            plan,
502            @r"
503        Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]
504          Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N]
505            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]
506              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
507        "
508        )
509    }
510
511    #[test]
512    fn distinct_and_common() -> Result<()> {
513        let table_scan = test_table_scan()?;
514
515        let plan = LogicalPlanBuilder::from(table_scan)
516            .aggregate(
517                vec![col("a")],
518                vec![count_distinct(col("b")), count(col("c"))],
519            )?
520            .build()?;
521
522        // Do nothing
523        assert_optimized_plan_equal!(
524            plan,
525            @r"
526        Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]
527          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
528        "
529        )
530    }
531
532    #[test]
533    fn group_by_with_expr() -> Result<()> {
534        let table_scan = test_table_scan().unwrap();
535
536        let plan = LogicalPlanBuilder::from(table_scan)
537            .aggregate(vec![col("a") + lit(1)], vec![count_distinct(col("c"))])?
538            .build()?;
539
540        // Should work
541        assert_optimized_plan_equal!(
542            plan,
543            @r"
544        Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]
545          Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]
546            Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]
547              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
548        "
549        )
550    }
551
552    #[test]
553    fn two_distinct_and_one_common() -> Result<()> {
554        let table_scan = test_table_scan()?;
555
556        let plan = LogicalPlanBuilder::from(table_scan)
557            .aggregate(
558                vec![col("a")],
559                vec![
560                    sum(col("c")),
561                    count_distinct(col("b")),
562                    max_distinct(col("b")),
563                ],
564            )?
565            .build()?;
566
567        // Should work
568        assert_optimized_plan_equal!(
569            plan,
570            @r"
571        Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]
572          Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N]
573            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]
574              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
575        "
576        )
577    }
578
579    #[test]
580    fn one_distinct_and_two_common() -> Result<()> {
581        let table_scan = test_table_scan()?;
582
583        let plan = LogicalPlanBuilder::from(table_scan)
584            .aggregate(
585                vec![col("a")],
586                vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
587            )?
588            .build()?;
589
590        // Should work
591        assert_optimized_plan_equal!(
592            plan,
593            @r"
594        Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]
595          Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]
596            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]
597              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
598        "
599        )
600    }
601
602    #[test]
603    fn one_distinct_and_one_common() -> Result<()> {
604        let table_scan = test_table_scan()?;
605
606        let plan = LogicalPlanBuilder::from(table_scan)
607            .aggregate(
608                vec![col("c")],
609                vec![min(col("a")), count_distinct(col("b"))],
610            )?
611            .build()?;
612
613        // Should work
614        assert_optimized_plan_equal!(
615            plan,
616            @r"
617        Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]
618          Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]
619            Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]
620              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
621        "
622        )
623    }
624
625    #[test]
626    fn common_with_filter() -> Result<()> {
627        let table_scan = test_table_scan()?;
628
629        // sum(a) FILTER (WHERE a > 5)
630        let expr = Expr::AggregateFunction(AggregateFunction::new_udf(
631            sum_udaf(),
632            vec![col("a")],
633            false,
634            Some(Box::new(col("a").gt(lit(5)))),
635            vec![],
636            None,
637        ));
638        let plan = LogicalPlanBuilder::from(table_scan)
639            .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
640            .build()?;
641
642        // Do nothing
643        assert_optimized_plan_equal!(
644            plan,
645            @r"
646        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]
647          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
648        "
649        )
650    }
651
652    #[test]
653    fn distinct_with_filter() -> Result<()> {
654        let table_scan = test_table_scan()?;
655
656        // count(DISTINCT a) FILTER (WHERE a > 5)
657        let expr = count_udaf()
658            .call(vec![col("a")])
659            .distinct()
660            .filter(col("a").gt(lit(5)))
661            .build()?;
662        let plan = LogicalPlanBuilder::from(table_scan)
663            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
664            .build()?;
665
666        // Do nothing
667        assert_optimized_plan_equal!(
668            plan,
669            @r"
670        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]
671          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
672        "
673        )
674    }
675
676    #[test]
677    fn common_with_order_by() -> Result<()> {
678        let table_scan = test_table_scan()?;
679
680        // SUM(a ORDER BY a)
681        let expr = Expr::AggregateFunction(AggregateFunction::new_udf(
682            sum_udaf(),
683            vec![col("a")],
684            false,
685            None,
686            vec![col("a").sort(true, false)],
687            None,
688        ));
689        let plan = LogicalPlanBuilder::from(table_scan)
690            .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
691            .build()?;
692
693        // Do nothing
694        assert_optimized_plan_equal!(
695            plan,
696            @r"
697        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]
698          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
699        "
700        )
701    }
702
703    #[test]
704    fn distinct_with_order_by() -> Result<()> {
705        let table_scan = test_table_scan()?;
706
707        // count(DISTINCT a ORDER BY a)
708        let expr = count_udaf()
709            .call(vec![col("a")])
710            .distinct()
711            .order_by(vec![col("a").sort(true, false)])
712            .build()?;
713        let plan = LogicalPlanBuilder::from(table_scan)
714            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
715            .build()?;
716
717        // Do nothing
718        assert_optimized_plan_equal!(
719            plan,
720            @r"
721        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]
722          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
723        "
724        )
725    }
726
727    #[test]
728    fn aggregate_with_filter_and_order_by() -> Result<()> {
729        let table_scan = test_table_scan()?;
730
731        // count(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
732        let expr = count_udaf()
733            .call(vec![col("a")])
734            .distinct()
735            .filter(col("a").gt(lit(5)))
736            .order_by(vec![col("a").sort(true, false)])
737            .build()?;
738        let plan = LogicalPlanBuilder::from(table_scan)
739            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
740            .build()?;
741
742        // Do nothing
743        assert_optimized_plan_equal!(
744            plan,
745            @r"
746        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]
747          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
748        "
749        )
750    }
751}