datafusion_optimizer/
single_distinct_to_groupby.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SingleDistinctToGroupBy`] replaces `AGG(DISTINCT ..)` with `AGG(..) GROUP BY ..`
19
20use std::sync::Arc;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24
25use datafusion_common::{
26    internal_err, tree_node::Transformed, DataFusionError, HashSet, Result,
27};
28use datafusion_expr::builder::project;
29use datafusion_expr::expr::AggregateFunctionParams;
30use datafusion_expr::{
31    col,
32    expr::AggregateFunction,
33    logical_plan::{Aggregate, LogicalPlan},
34    Expr,
35};
36
37/// single distinct to group by optimizer rule
38///  ```text
39///    Before:
40///    SELECT a, count(DISTINCT b), sum(c)
41///    FROM t
42///    GROUP BY a
43///
44///    After:
45///    SELECT a, count(alias1), sum(alias2)
46///    FROM (
47///      SELECT a, b as alias1, sum(c) as alias2
48///      FROM t
49///      GROUP BY a, b
50///    )
51///    GROUP BY a
52///  ```
53#[derive(Default, Debug)]
54pub struct SingleDistinctToGroupBy {}
55
56const SINGLE_DISTINCT_ALIAS: &str = "alias1";
57
58impl SingleDistinctToGroupBy {
59    #[allow(missing_docs)]
60    pub fn new() -> Self {
61        Self {}
62    }
63}
64
65/// Check whether all aggregate exprs are distinct on a single field.
66fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result<bool> {
67    let mut fields_set = HashSet::new();
68    let mut aggregate_count = 0;
69    for expr in aggr_expr {
70        if let Expr::AggregateFunction(AggregateFunction {
71            func,
72            params:
73                AggregateFunctionParams {
74                    distinct,
75                    args,
76                    filter,
77                    order_by,
78                    null_treatment: _,
79                },
80        }) = expr
81        {
82            if filter.is_some() || order_by.is_some() {
83                return Ok(false);
84            }
85            aggregate_count += 1;
86            if *distinct {
87                for e in args {
88                    fields_set.insert(e);
89                }
90            } else if func.name() != "sum"
91                && func.name().to_lowercase() != "min"
92                && func.name().to_lowercase() != "max"
93            {
94                return Ok(false);
95            }
96        } else {
97            return Ok(false);
98        }
99    }
100    Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
101}
102
103/// Check if the first expr is [Expr::GroupingSet].
104fn contains_grouping_set(expr: &[Expr]) -> bool {
105    matches!(expr.first(), Some(Expr::GroupingSet(_)))
106}
107
108impl OptimizerRule for SingleDistinctToGroupBy {
109    fn name(&self) -> &str {
110        "single_distinct_aggregation_to_group_by"
111    }
112
113    fn apply_order(&self) -> Option<ApplyOrder> {
114        Some(ApplyOrder::TopDown)
115    }
116
117    fn supports_rewrite(&self) -> bool {
118        true
119    }
120
121    fn rewrite(
122        &self,
123        plan: LogicalPlan,
124        _config: &dyn OptimizerConfig,
125    ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
126        match plan {
127            LogicalPlan::Aggregate(Aggregate {
128                input,
129                aggr_expr,
130                schema,
131                group_expr,
132                ..
133            }) if is_single_distinct_agg(&aggr_expr)?
134                && !contains_grouping_set(&group_expr) =>
135            {
136                let group_size = group_expr.len();
137                // alias all original group_by exprs
138                let (mut inner_group_exprs, out_group_expr_with_alias): (
139                    Vec<Expr>,
140                    Vec<(Expr, _)>,
141                ) = group_expr
142                    .into_iter()
143                    .enumerate()
144                    .map(|(i, group_expr)| {
145                        if let Expr::Column(_) = group_expr {
146                            // For Column expressions we can use existing expression as is.
147                            (group_expr.clone(), (group_expr, None))
148                        } else {
149                            // For complex expression write is as alias, to be able to refer
150                            // if from parent operators successfully.
151                            // Consider plan below.
152                            //
153                            // Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
154                            // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
155                            // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
156                            //
157                            // First aggregate(from bottom) refers to `test.a` column.
158                            // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate.
159
160                            // If we were to write plan above as below without alias
161                            //
162                            // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
163                            // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
164                            // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
165                            //
166                            // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it.
167                            let alias_str = format!("group_alias_{i}");
168                            let (qualifier, field) = schema.qualified_field(i);
169                            (
170                                group_expr.alias(alias_str.clone()),
171                                (col(alias_str), Some((qualifier, field.name()))),
172                            )
173                        }
174                    })
175                    .unzip();
176
177                // replace the distinct arg with alias
178                let mut index = 1;
179                let mut group_fields_set = HashSet::new();
180                let mut inner_aggr_exprs = vec![];
181                let outer_aggr_exprs = aggr_expr
182                    .into_iter()
183                    .map(|aggr_expr| match aggr_expr {
184                        Expr::AggregateFunction(AggregateFunction {
185                            func,
186                            params: AggregateFunctionParams { mut args, distinct, .. }
187                        }) => {
188                            if distinct {
189                                if args.len() != 1 {
190                                    return internal_err!("DISTINCT aggregate should have exactly one argument");
191                                }
192                                let arg = args.swap_remove(0);
193
194                                if group_fields_set.insert(arg.schema_name().to_string()) {
195                                    inner_group_exprs
196                                        .push(arg.alias(SINGLE_DISTINCT_ALIAS));
197                                }
198                                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
199                                    func,
200                                    vec![col(SINGLE_DISTINCT_ALIAS)],
201                                    false, // intentional to remove distinct here
202                                    None,
203                                    None,
204                                    None,
205                                )))
206                                // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
207                            } else {
208                                index += 1;
209                                let alias_str = format!("alias{}", index);
210                                inner_aggr_exprs.push(
211                                    Expr::AggregateFunction(AggregateFunction::new_udf(
212                                        Arc::clone(&func),
213                                        args,
214                                        false,
215                                        None,
216                                        None,
217                                        None,
218                                    ))
219                                    .alias(&alias_str),
220                                );
221                                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
222                                    func,
223                                    vec![col(&alias_str)],
224                                    false,
225                                    None,
226                                    None,
227                                    None,
228                                )))
229                            }
230                        }
231                        _ => Ok(aggr_expr),
232                    })
233                    .collect::<Result<Vec<_>>>()?;
234
235                // construct the inner AggrPlan
236                let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
237                    input,
238                    inner_group_exprs,
239                    inner_aggr_exprs,
240                )?);
241
242                let outer_group_exprs = out_group_expr_with_alias
243                    .iter()
244                    .map(|(expr, _)| expr.clone())
245                    .collect();
246
247                // so the aggregates are displayed in the same way even after the rewrite
248                // this optimizer has two kinds of alias:
249                // - group_by aggr
250                // - aggr expr
251                let alias_expr: Vec<_> = out_group_expr_with_alias
252                    .into_iter()
253                    .map(|(group_expr, original_name)| match original_name {
254                        Some((qualifier, name)) => {
255                            group_expr.alias_qualified(qualifier.cloned(), name)
256                        }
257                        None => group_expr,
258                    })
259                    .chain(outer_aggr_exprs.iter().cloned().enumerate().map(
260                        |(idx, expr)| {
261                            let idx = idx + group_size;
262                            let (qualifier, field) = schema.qualified_field(idx);
263                            expr.alias_qualified(qualifier.cloned(), field.name())
264                        },
265                    ))
266                    .collect();
267
268                let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
269                    Arc::new(inner_agg),
270                    outer_group_exprs,
271                    outer_aggr_exprs,
272                )?);
273                Ok(Transformed::yes(project(outer_aggr, alias_expr)?))
274            }
275            _ => Ok(Transformed::no(plan)),
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::test::*;
284    use datafusion_expr::expr::GroupingSet;
285    use datafusion_expr::ExprFunctionExt;
286    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
287    use datafusion_functions_aggregate::count::count_udaf;
288    use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum};
289    use datafusion_functions_aggregate::min_max::max_udaf;
290    use datafusion_functions_aggregate::sum::sum_udaf;
291
292    fn max_distinct(expr: Expr) -> Expr {
293        Expr::AggregateFunction(AggregateFunction::new_udf(
294            max_udaf(),
295            vec![expr],
296            true,
297            None,
298            None,
299            None,
300        ))
301    }
302
303    fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
304        assert_optimized_plan_eq_display_indent(
305            Arc::new(SingleDistinctToGroupBy::new()),
306            plan,
307            expected,
308        );
309        Ok(())
310    }
311
312    #[test]
313    fn not_exist_distinct() -> Result<()> {
314        let table_scan = test_table_scan()?;
315
316        let plan = LogicalPlanBuilder::from(table_scan)
317            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
318            .build()?;
319
320        // Do nothing
321        let expected =
322            "Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N]\
323                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
324
325        assert_optimized_plan_equal(plan, expected)
326    }
327
328    #[test]
329    fn single_distinct() -> Result<()> {
330        let table_scan = test_table_scan()?;
331
332        let plan = LogicalPlanBuilder::from(table_scan)
333            .aggregate(Vec::<Expr>::new(), vec![count_distinct(col("b"))])?
334            .build()?;
335
336        // Should work
337        let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\
338                            \n  Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\
339                            \n    Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
340                            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
341
342        assert_optimized_plan_equal(plan, expected)
343    }
344
345    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
346    #[test]
347    fn single_distinct_and_grouping_set() -> Result<()> {
348        let table_scan = test_table_scan()?;
349
350        let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
351            vec![col("a")],
352            vec![col("b")],
353        ]));
354
355        let plan = LogicalPlanBuilder::from(table_scan)
356            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
357            .build()?;
358
359        // Should not be optimized
360        let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
361                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
362
363        assert_optimized_plan_equal(plan, expected)
364    }
365
366    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
367    #[test]
368    fn single_distinct_and_cube() -> Result<()> {
369        let table_scan = test_table_scan()?;
370
371        let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));
372
373        let plan = LogicalPlanBuilder::from(table_scan)
374            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
375            .build()?;
376
377        // Should not be optimized
378        let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
379                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
380
381        assert_optimized_plan_equal(plan, expected)
382    }
383
384    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
385    #[test]
386    fn single_distinct_and_rollup() -> Result<()> {
387        let table_scan = test_table_scan()?;
388
389        let grouping_set =
390            Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));
391
392        let plan = LogicalPlanBuilder::from(table_scan)
393            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
394            .build()?;
395
396        // Should not be optimized
397        let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
398                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
399
400        assert_optimized_plan_equal(plan, expected)
401    }
402
403    #[test]
404    fn single_distinct_expr() -> Result<()> {
405        let table_scan = test_table_scan()?;
406
407        let plan = LogicalPlanBuilder::from(table_scan)
408            .aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
409            .build()?;
410
411        let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\
412                            \n  Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\
413                            \n    Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\
414                            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
415
416        assert_optimized_plan_equal(plan, expected)
417    }
418
419    #[test]
420    fn single_distinct_and_groupby() -> Result<()> {
421        let table_scan = test_table_scan()?;
422
423        let plan = LogicalPlanBuilder::from(table_scan)
424            .aggregate(vec![col("a")], vec![count_distinct(col("b"))])?
425            .build()?;
426
427        // Should work
428        let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\
429                            \n  Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\
430                            \n    Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
431                            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
432
433        assert_optimized_plan_equal(plan, expected)
434    }
435
436    #[test]
437    fn two_distinct_and_groupby() -> Result<()> {
438        let table_scan = test_table_scan()?;
439
440        let plan = LogicalPlanBuilder::from(table_scan)
441            .aggregate(
442                vec![col("a")],
443                vec![count_distinct(col("b")), count_distinct(col("c"))],
444            )?
445            .build()?;
446
447        // Do nothing
448        let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\
449                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
450
451        assert_optimized_plan_equal(plan, expected)
452    }
453
454    #[test]
455    fn one_field_two_distinct_and_groupby() -> Result<()> {
456        let table_scan = test_table_scan()?;
457
458        let plan = LogicalPlanBuilder::from(table_scan)
459            .aggregate(
460                vec![col("a")],
461                vec![count_distinct(col("b")), max_distinct(col("b"))],
462            )?
463            .build()?;
464        // Should work
465        let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\
466                            \n  Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N]\
467                            \n    Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
468                            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
469
470        assert_optimized_plan_equal(plan, expected)
471    }
472
473    #[test]
474    fn distinct_and_common() -> Result<()> {
475        let table_scan = test_table_scan()?;
476
477        let plan = LogicalPlanBuilder::from(table_scan)
478            .aggregate(
479                vec![col("a")],
480                vec![count_distinct(col("b")), count(col("c"))],
481            )?
482            .build()?;
483
484        // Do nothing
485        let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\
486                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
487
488        assert_optimized_plan_equal(plan, expected)
489    }
490
491    #[test]
492    fn group_by_with_expr() -> Result<()> {
493        let table_scan = test_table_scan().unwrap();
494
495        let plan = LogicalPlanBuilder::from(table_scan)
496            .aggregate(vec![col("a") + lit(1)], vec![count_distinct(col("c"))])?
497            .build()?;
498
499        // Should work
500        let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\
501                            \n  Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]\
502                            \n    Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\
503                            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
504
505        assert_optimized_plan_equal(plan, expected)
506    }
507
508    #[test]
509    fn two_distinct_and_one_common() -> Result<()> {
510        let table_scan = test_table_scan()?;
511
512        let plan = LogicalPlanBuilder::from(table_scan)
513            .aggregate(
514                vec![col("a")],
515                vec![
516                    sum(col("c")),
517                    count_distinct(col("b")),
518                    max_distinct(col("b")),
519                ],
520            )?
521            .build()?;
522        // Should work
523        let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\
524                            \n  Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N]\
525                            \n    Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
526                            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
527
528        assert_optimized_plan_equal(plan, expected)
529    }
530
531    #[test]
532    fn one_distinct_and_two_common() -> Result<()> {
533        let table_scan = test_table_scan()?;
534
535        let plan = LogicalPlanBuilder::from(table_scan)
536            .aggregate(
537                vec![col("a")],
538                vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
539            )?
540            .build()?;
541        // Should work
542        let expected = "Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\
543                            \n  Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]\
544                            \n    Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
545                            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
546
547        assert_optimized_plan_equal(plan, expected)
548    }
549
550    #[test]
551    fn one_distinct_and_one_common() -> Result<()> {
552        let table_scan = test_table_scan()?;
553
554        let plan = LogicalPlanBuilder::from(table_scan)
555            .aggregate(
556                vec![col("c")],
557                vec![min(col("a")), count_distinct(col("b"))],
558            )?
559            .build()?;
560        // Should work
561        let expected = "Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\
562                            \n  Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]\
563                            \n    Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\
564                            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
565
566        assert_optimized_plan_equal(plan, expected)
567    }
568
569    #[test]
570    fn common_with_filter() -> Result<()> {
571        let table_scan = test_table_scan()?;
572
573        // sum(a) FILTER (WHERE a > 5)
574        let expr = Expr::AggregateFunction(AggregateFunction::new_udf(
575            sum_udaf(),
576            vec![col("a")],
577            false,
578            Some(Box::new(col("a").gt(lit(5)))),
579            None,
580            None,
581        ));
582        let plan = LogicalPlanBuilder::from(table_scan)
583            .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
584            .build()?;
585        // Do nothing
586        let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\
587                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
588
589        assert_optimized_plan_equal(plan, expected)
590    }
591
592    #[test]
593    fn distinct_with_filter() -> Result<()> {
594        let table_scan = test_table_scan()?;
595
596        // count(DISTINCT a) FILTER (WHERE a > 5)
597        let expr = count_udaf()
598            .call(vec![col("a")])
599            .distinct()
600            .filter(col("a").gt(lit(5)))
601            .build()?;
602        let plan = LogicalPlanBuilder::from(table_scan)
603            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
604            .build()?;
605        // Do nothing
606        let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\
607                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
608
609        assert_optimized_plan_equal(plan, expected)
610    }
611
612    #[test]
613    fn common_with_order_by() -> Result<()> {
614        let table_scan = test_table_scan()?;
615
616        // SUM(a ORDER BY a)
617        let expr = Expr::AggregateFunction(AggregateFunction::new_udf(
618            sum_udaf(),
619            vec![col("a")],
620            false,
621            None,
622            Some(vec![col("a").sort(true, false)]),
623            None,
624        ));
625        let plan = LogicalPlanBuilder::from(table_scan)
626            .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
627            .build()?;
628        // Do nothing
629        let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]\
630                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
631
632        assert_optimized_plan_equal(plan, expected)
633    }
634
635    #[test]
636    fn distinct_with_order_by() -> Result<()> {
637        let table_scan = test_table_scan()?;
638
639        // count(DISTINCT a ORDER BY a)
640        let expr = count_udaf()
641            .call(vec![col("a")])
642            .distinct()
643            .order_by(vec![col("a").sort(true, false)])
644            .build()?;
645        let plan = LogicalPlanBuilder::from(table_scan)
646            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
647            .build()?;
648        // Do nothing
649        let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\
650                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
651
652        assert_optimized_plan_equal(plan, expected)
653    }
654
655    #[test]
656    fn aggregate_with_filter_and_order_by() -> Result<()> {
657        let table_scan = test_table_scan()?;
658
659        // count(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
660        let expr = count_udaf()
661            .call(vec![col("a")])
662            .distinct()
663            .filter(col("a").gt(lit(5)))
664            .order_by(vec![col("a").sort(true, false)])
665            .build()?;
666        let plan = LogicalPlanBuilder::from(table_scan)
667            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
668            .build()?;
669        // Do nothing
670        let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\
671                            \n  TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
672
673        assert_optimized_plan_equal(plan, expected)
674    }
675}