Skip to main content

datafusion_optimizer/
single_distinct_to_groupby.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SingleDistinctToGroupBy`] replaces `AGG(DISTINCT ..)` with `AGG(..) GROUP BY ..`
19
20use std::sync::Arc;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24
25use datafusion_common::{
26    DataFusionError, HashSet, Result, assert_eq_or_internal_err, tree_node::Transformed,
27};
28use datafusion_expr::builder::project;
29use datafusion_expr::expr::AggregateFunctionParams;
30use datafusion_expr::{
31    Expr, col,
32    expr::AggregateFunction,
33    logical_plan::{Aggregate, LogicalPlan},
34};
35
36/// single distinct to group by optimizer rule
37///  ```text
38///    Before:
39///    SELECT a, count(DISTINCT b), sum(c)
40///    FROM t
41///    GROUP BY a
42///
43///    After:
44///    SELECT a, count(alias1), sum(alias2)
45///    FROM (
46///      SELECT a, b as alias1, sum(c) as alias2
47///      FROM t
48///      GROUP BY a, b
49///    )
50///    GROUP BY a
51///  ```
52#[derive(Default, Debug)]
53pub struct SingleDistinctToGroupBy {}
54
55const SINGLE_DISTINCT_ALIAS: &str = "alias1";
56
57impl SingleDistinctToGroupBy {
58    #[expect(missing_docs)]
59    pub fn new() -> Self {
60        Self {}
61    }
62}
63
64/// Check whether all aggregate exprs are distinct on a single field.
65fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result<bool> {
66    let mut fields_set = HashSet::new();
67    let mut aggregate_count = 0;
68    for expr in aggr_expr {
69        if let Expr::AggregateFunction(AggregateFunction {
70            func,
71            params:
72                AggregateFunctionParams {
73                    distinct,
74                    args,
75                    filter,
76                    order_by,
77                    null_treatment: _,
78                },
79        }) = expr
80        {
81            if filter.is_some() || !order_by.is_empty() {
82                return Ok(false);
83            }
84            aggregate_count += 1;
85            if *distinct {
86                for e in args {
87                    fields_set.insert(e);
88                }
89            } else if func.name() != "sum"
90                && func.name().to_lowercase() != "min"
91                && func.name().to_lowercase() != "max"
92            {
93                return Ok(false);
94            }
95        } else {
96            return Ok(false);
97        }
98    }
99    Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
100}
101
102/// Check if the first expr is [Expr::GroupingSet].
103fn contains_grouping_set(expr: &[Expr]) -> bool {
104    matches!(expr.first(), Some(Expr::GroupingSet(_)))
105}
106
107impl OptimizerRule for SingleDistinctToGroupBy {
108    fn name(&self) -> &str {
109        "single_distinct_aggregation_to_group_by"
110    }
111
112    fn apply_order(&self) -> Option<ApplyOrder> {
113        Some(ApplyOrder::TopDown)
114    }
115
116    fn supports_rewrite(&self) -> bool {
117        true
118    }
119
120    fn rewrite(
121        &self,
122        plan: LogicalPlan,
123        _config: &dyn OptimizerConfig,
124    ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
125        match plan {
126            LogicalPlan::Aggregate(Aggregate {
127                input,
128                aggr_expr,
129                schema,
130                group_expr,
131                ..
132            }) if is_single_distinct_agg(&aggr_expr)?
133                && !contains_grouping_set(&group_expr) =>
134            {
135                let group_size = group_expr.len();
136                // alias all original group_by exprs
137                let (mut inner_group_exprs, out_group_expr_with_alias): (
138                    Vec<Expr>,
139                    Vec<(Expr, _)>,
140                ) = group_expr
141                    .into_iter()
142                    .enumerate()
143                    .map(|(i, group_expr)| {
144                        if let Expr::Column(_) = group_expr {
145                            // For Column expressions we can use existing expression as is.
146                            (group_expr.clone(), (group_expr, None))
147                        } else {
148                            // For complex expression write is as alias, to be able to refer
149                            // if from parent operators successfully.
150                            // Consider plan below.
151                            //
152                            // Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
153                            // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
154                            // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
155                            //
156                            // First aggregate(from bottom) refers to `test.a` column.
157                            // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate.
158
159                            // If we were to write plan above as below without alias
160                            //
161                            // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
162                            // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
163                            // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
164                            //
165                            // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it.
166                            let alias_str = format!("group_alias_{i}");
167                            let (qualifier, field) = schema.qualified_field(i);
168                            (
169                                group_expr.alias(alias_str.clone()),
170                                (col(alias_str), Some((qualifier, field.name()))),
171                            )
172                        }
173                    })
174                    .unzip();
175
176                // replace the distinct arg with alias
177                let mut index = 1;
178                let mut group_fields_set = HashSet::new();
179                let mut inner_aggr_exprs = vec![];
180                let outer_aggr_exprs = aggr_expr
181                    .into_iter()
182                    .map(|aggr_expr| match aggr_expr {
183                        Expr::AggregateFunction(AggregateFunction {
184                            func,
185                            params:
186                                AggregateFunctionParams {
187                                    mut args,
188                                    distinct,
189                                    filter,
190                                    order_by,
191                                    null_treatment,
192                                },
193                        }) => {
194                            if distinct {
195                                assert_eq_or_internal_err!(
196                                    args.len(),
197                                    1,
198                                    "DISTINCT aggregate should have exactly one argument"
199                                );
200                                let arg = args.swap_remove(0);
201
202                                if group_fields_set.insert(arg.schema_name().to_string())
203                                {
204                                    inner_group_exprs
205                                        .push(arg.alias(SINGLE_DISTINCT_ALIAS));
206                                }
207                                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
208                                    func,
209                                    vec![col(SINGLE_DISTINCT_ALIAS)],
210                                    false, // intentional to remove distinct here
211                                    filter,
212                                    order_by,
213                                    null_treatment,
214                                )))
215                                // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
216                            } else {
217                                index += 1;
218                                let alias_str = format!("alias{index}");
219                                inner_aggr_exprs.push(
220                                    Expr::AggregateFunction(AggregateFunction::new_udf(
221                                        Arc::clone(&func),
222                                        args,
223                                        false,
224                                        filter,
225                                        order_by,
226                                        null_treatment,
227                                    ))
228                                    .alias(&alias_str),
229                                );
230                                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
231                                    func,
232                                    vec![col(&alias_str)],
233                                    false,
234                                    None,
235                                    vec![],
236                                    None,
237                                )))
238                            }
239                        }
240                        _ => Ok(aggr_expr),
241                    })
242                    .collect::<Result<Vec<_>>>()?;
243
244                // construct the inner AggrPlan
245                let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
246                    input,
247                    inner_group_exprs,
248                    inner_aggr_exprs,
249                )?);
250
251                let outer_group_exprs = out_group_expr_with_alias
252                    .iter()
253                    .map(|(expr, _)| expr.clone())
254                    .collect();
255
256                // so the aggregates are displayed in the same way even after the rewrite
257                // this optimizer has two kinds of alias:
258                // - group_by aggr
259                // - aggr expr
260                let alias_expr: Vec<_> = out_group_expr_with_alias
261                    .into_iter()
262                    .map(|(group_expr, original_name)| match original_name {
263                        Some((qualifier, name)) => {
264                            group_expr.alias_qualified(qualifier.cloned(), name)
265                        }
266                        None => group_expr,
267                    })
268                    .chain(outer_aggr_exprs.iter().cloned().enumerate().map(
269                        |(idx, expr)| {
270                            let idx = idx + group_size;
271                            let (qualifier, field) = schema.qualified_field(idx);
272                            expr.alias_qualified(qualifier.cloned(), field.name())
273                        },
274                    ))
275                    .collect();
276
277                let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
278                    Arc::new(inner_agg),
279                    outer_group_exprs,
280                    outer_aggr_exprs,
281                )?);
282                Ok(Transformed::yes(project(outer_aggr, alias_expr)?))
283            }
284            _ => Ok(Transformed::no(plan)),
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::assert_optimized_plan_eq_display_indent_snapshot;
293    use crate::test::*;
294    use datafusion_expr::ExprFunctionExt;
295    use datafusion_expr::expr::GroupingSet;
296    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
297    use datafusion_functions_aggregate::count::count_udaf;
298    use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum};
299    use datafusion_functions_aggregate::min_max::max_udaf;
300    use datafusion_functions_aggregate::sum::sum_udaf;
301
302    fn max_distinct(expr: Expr) -> Expr {
303        Expr::AggregateFunction(AggregateFunction::new_udf(
304            max_udaf(),
305            vec![expr],
306            true,
307            None,
308            vec![],
309            None,
310        ))
311    }
312
313    macro_rules! assert_optimized_plan_equal {
314        (
315            $plan:expr,
316            @ $expected:literal $(,)?
317        ) => {{
318            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(SingleDistinctToGroupBy::new());
319            assert_optimized_plan_eq_display_indent_snapshot!(
320                rule,
321                $plan,
322                @ $expected,
323            )
324        }};
325    }
326
327    #[test]
328    fn not_exist_distinct() -> Result<()> {
329        let table_scan = test_table_scan()?;
330
331        let plan = LogicalPlanBuilder::from(table_scan)
332            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
333            .build()?;
334
335        // Do nothing
336        assert_optimized_plan_equal!(
337            plan,
338            @r"
339        Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N]
340          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
341        "
342        )
343    }
344
345    #[test]
346    fn single_distinct() -> Result<()> {
347        let table_scan = test_table_scan()?;
348
349        let plan = LogicalPlanBuilder::from(table_scan)
350            .aggregate(Vec::<Expr>::new(), vec![count_distinct(col("b"))])?
351            .build()?;
352
353        // Should work
354        assert_optimized_plan_equal!(
355            plan,
356            @r"
357        Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]
358          Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]
359            Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]
360              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
361        "
362        )
363    }
364
365    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
366    #[test]
367    fn single_distinct_and_grouping_set() -> Result<()> {
368        let table_scan = test_table_scan()?;
369
370        let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
371            vec![col("a")],
372            vec![col("b")],
373        ]));
374
375        let plan = LogicalPlanBuilder::from(table_scan)
376            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
377            .build()?;
378
379        // Should not be optimized
380        assert_optimized_plan_equal!(
381            plan,
382            @r"
383        Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
384          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
385        "
386        )
387    }
388
389    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
390    #[test]
391    fn single_distinct_and_cube() -> Result<()> {
392        let table_scan = test_table_scan()?;
393
394        let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));
395
396        let plan = LogicalPlanBuilder::from(table_scan)
397            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
398            .build()?;
399
400        // Should not be optimized
401        assert_optimized_plan_equal!(
402            plan,
403            @r"
404        Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
405          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
406        "
407        )
408    }
409
410    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
411    #[test]
412    fn single_distinct_and_rollup() -> Result<()> {
413        let table_scan = test_table_scan()?;
414
415        let grouping_set =
416            Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));
417
418        let plan = LogicalPlanBuilder::from(table_scan)
419            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
420            .build()?;
421
422        // Should not be optimized
423        assert_optimized_plan_equal!(
424            plan,
425            @r"
426        Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
427          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
428        "
429        )
430    }
431
432    #[test]
433    fn single_distinct_expr() -> Result<()> {
434        let table_scan = test_table_scan()?;
435
436        let plan = LogicalPlanBuilder::from(table_scan)
437            .aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
438            .build()?;
439
440        assert_optimized_plan_equal!(
441            plan,
442            @r"
443        Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]
444          Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]
445            Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]
446              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
447        "
448        )
449    }
450
451    #[test]
452    fn single_distinct_and_groupby() -> Result<()> {
453        let table_scan = test_table_scan()?;
454
455        let plan = LogicalPlanBuilder::from(table_scan)
456            .aggregate(vec![col("a")], vec![count_distinct(col("b"))])?
457            .build()?;
458
459        // Should work
460        assert_optimized_plan_equal!(
461            plan,
462            @r"
463        Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]
464          Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]
465            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]
466              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
467        "
468        )
469    }
470
471    #[test]
472    fn two_distinct_and_groupby() -> Result<()> {
473        let table_scan = test_table_scan()?;
474
475        let plan = LogicalPlanBuilder::from(table_scan)
476            .aggregate(
477                vec![col("a")],
478                vec![count_distinct(col("b")), count_distinct(col("c"))],
479            )?
480            .build()?;
481
482        // Do nothing
483        assert_optimized_plan_equal!(
484            plan,
485            @r"
486        Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]
487          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
488        "
489        )
490    }
491
492    #[test]
493    fn one_field_two_distinct_and_groupby() -> Result<()> {
494        let table_scan = test_table_scan()?;
495
496        let plan = LogicalPlanBuilder::from(table_scan)
497            .aggregate(
498                vec![col("a")],
499                vec![count_distinct(col("b")), max_distinct(col("b"))],
500            )?
501            .build()?;
502
503        // Should work
504        assert_optimized_plan_equal!(
505            plan,
506            @r"
507        Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]
508          Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N]
509            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]
510              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
511        "
512        )
513    }
514
515    #[test]
516    fn distinct_and_common() -> Result<()> {
517        let table_scan = test_table_scan()?;
518
519        let plan = LogicalPlanBuilder::from(table_scan)
520            .aggregate(
521                vec![col("a")],
522                vec![count_distinct(col("b")), count(col("c"))],
523            )?
524            .build()?;
525
526        // Do nothing
527        assert_optimized_plan_equal!(
528            plan,
529            @r"
530        Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]
531          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
532        "
533        )
534    }
535
536    #[test]
537    fn group_by_with_expr() -> Result<()> {
538        let table_scan = test_table_scan().unwrap();
539
540        let plan = LogicalPlanBuilder::from(table_scan)
541            .aggregate(vec![col("a") + lit(1)], vec![count_distinct(col("c"))])?
542            .build()?;
543
544        // Should work
545        assert_optimized_plan_equal!(
546            plan,
547            @r"
548        Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]
549          Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]
550            Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]
551              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
552        "
553        )
554    }
555
556    #[test]
557    fn two_distinct_and_one_common() -> Result<()> {
558        let table_scan = test_table_scan()?;
559
560        let plan = LogicalPlanBuilder::from(table_scan)
561            .aggregate(
562                vec![col("a")],
563                vec![
564                    sum(col("c")),
565                    count_distinct(col("b")),
566                    max_distinct(col("b")),
567                ],
568            )?
569            .build()?;
570
571        // Should work
572        assert_optimized_plan_equal!(
573            plan,
574            @r"
575        Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]
576          Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N]
577            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]
578              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
579        "
580        )
581    }
582
583    #[test]
584    fn one_distinct_and_two_common() -> Result<()> {
585        let table_scan = test_table_scan()?;
586
587        let plan = LogicalPlanBuilder::from(table_scan)
588            .aggregate(
589                vec![col("a")],
590                vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
591            )?
592            .build()?;
593
594        // Should work
595        assert_optimized_plan_equal!(
596            plan,
597            @r"
598        Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]
599          Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]
600            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]
601              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
602        "
603        )
604    }
605
606    #[test]
607    fn one_distinct_and_one_common() -> Result<()> {
608        let table_scan = test_table_scan()?;
609
610        let plan = LogicalPlanBuilder::from(table_scan)
611            .aggregate(
612                vec![col("c")],
613                vec![min(col("a")), count_distinct(col("b"))],
614            )?
615            .build()?;
616
617        // Should work
618        assert_optimized_plan_equal!(
619            plan,
620            @r"
621        Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]
622          Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]
623            Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]
624              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
625        "
626        )
627    }
628
629    #[test]
630    fn common_with_filter() -> Result<()> {
631        let table_scan = test_table_scan()?;
632
633        // sum(a) FILTER (WHERE a > 5)
634        let expr = Expr::AggregateFunction(AggregateFunction::new_udf(
635            sum_udaf(),
636            vec![col("a")],
637            false,
638            Some(Box::new(col("a").gt(lit(5)))),
639            vec![],
640            None,
641        ));
642        let plan = LogicalPlanBuilder::from(table_scan)
643            .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
644            .build()?;
645
646        // Do nothing
647        assert_optimized_plan_equal!(
648            plan,
649            @r"
650        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]
651          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
652        "
653        )
654    }
655
656    #[test]
657    fn distinct_with_filter() -> Result<()> {
658        let table_scan = test_table_scan()?;
659
660        // count(DISTINCT a) FILTER (WHERE a > 5)
661        let expr = count_udaf()
662            .call(vec![col("a")])
663            .distinct()
664            .filter(col("a").gt(lit(5)))
665            .build()?;
666        let plan = LogicalPlanBuilder::from(table_scan)
667            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
668            .build()?;
669
670        // Do nothing
671        assert_optimized_plan_equal!(
672            plan,
673            @r"
674        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]
675          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
676        "
677        )
678    }
679
680    #[test]
681    fn common_with_order_by() -> Result<()> {
682        let table_scan = test_table_scan()?;
683
684        // SUM(a ORDER BY a)
685        let expr = Expr::AggregateFunction(AggregateFunction::new_udf(
686            sum_udaf(),
687            vec![col("a")],
688            false,
689            None,
690            vec![col("a").sort(true, false)],
691            None,
692        ));
693        let plan = LogicalPlanBuilder::from(table_scan)
694            .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
695            .build()?;
696
697        // Do nothing
698        assert_optimized_plan_equal!(
699            plan,
700            @r"
701        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]
702          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
703        "
704        )
705    }
706
707    #[test]
708    fn distinct_with_order_by() -> Result<()> {
709        let table_scan = test_table_scan()?;
710
711        // count(DISTINCT a ORDER BY a)
712        let expr = count_udaf()
713            .call(vec![col("a")])
714            .distinct()
715            .order_by(vec![col("a").sort(true, false)])
716            .build()?;
717        let plan = LogicalPlanBuilder::from(table_scan)
718            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
719            .build()?;
720
721        // Do nothing
722        assert_optimized_plan_equal!(
723            plan,
724            @r"
725        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]
726          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
727        "
728        )
729    }
730
731    #[test]
732    fn aggregate_with_filter_and_order_by() -> Result<()> {
733        let table_scan = test_table_scan()?;
734
735        // count(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
736        let expr = count_udaf()
737            .call(vec![col("a")])
738            .distinct()
739            .filter(col("a").gt(lit(5)))
740            .order_by(vec![col("a").sort(true, false)])
741            .build()?;
742        let plan = LogicalPlanBuilder::from(table_scan)
743            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
744            .build()?;
745
746        // Do nothing
747        assert_optimized_plan_equal!(
748            plan,
749            @r"
750        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]
751          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
752        "
753        )
754    }
755}