datafusion_optimizer/
eliminate_group_by_constant.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21
22use datafusion_common::tree_node::Transformed;
23use datafusion_common::Result;
24use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility};
25
26/// Optimizer rule that removes constant expressions from `GROUP BY` clause
27/// and places additional projection on top of aggregation, to preserve
28/// original schema
29#[derive(Default, Debug)]
30pub struct EliminateGroupByConstant {}
31
32impl EliminateGroupByConstant {
33    pub fn new() -> Self {
34        Self {}
35    }
36}
37
38impl OptimizerRule for EliminateGroupByConstant {
39    fn supports_rewrite(&self) -> bool {
40        true
41    }
42
43    fn rewrite(
44        &self,
45        plan: LogicalPlan,
46        _config: &dyn OptimizerConfig,
47    ) -> Result<Transformed<LogicalPlan>> {
48        match plan {
49            LogicalPlan::Aggregate(aggregate) => {
50                let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate
51                    .group_expr
52                    .iter()
53                    .partition(|expr| is_constant_expression(expr));
54
55                // If no constant expressions found (nothing to optimize) or
56                // constant expression is the only expression in aggregate,
57                // optimization is skipped
58                if const_group_expr.is_empty()
59                    || (!const_group_expr.is_empty()
60                        && nonconst_group_expr.is_empty()
61                        && aggregate.aggr_expr.is_empty())
62                {
63                    return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
64                }
65
66                let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
67                    aggregate.input,
68                    nonconst_group_expr.into_iter().cloned().collect(),
69                    aggregate.aggr_expr.clone(),
70                )?);
71
72                let projection_expr =
73                    aggregate.group_expr.into_iter().chain(aggregate.aggr_expr);
74
75                let projection = LogicalPlanBuilder::from(simplified_aggregate)
76                    .project(projection_expr)?
77                    .build()?;
78
79                Ok(Transformed::yes(projection))
80            }
81            _ => Ok(Transformed::no(plan)),
82        }
83    }
84
85    fn name(&self) -> &str {
86        "eliminate_group_by_constant"
87    }
88
89    fn apply_order(&self) -> Option<ApplyOrder> {
90        Some(ApplyOrder::BottomUp)
91    }
92}
93
94/// Checks if expression is constant, and can be eliminated from group by.
95///
96/// Intended to be used only within this rule, helper function, which heavily
97/// relies on `SimplifyExpressions` result.
98fn is_constant_expression(expr: &Expr) -> bool {
99    match expr {
100        Expr::Alias(e) => is_constant_expression(&e.expr),
101        Expr::BinaryExpr(e) => {
102            is_constant_expression(&e.left) && is_constant_expression(&e.right)
103        }
104        Expr::Literal(_) => true,
105        Expr::ScalarFunction(e) => {
106            matches!(
107                e.func.signature().volatility,
108                Volatility::Immutable | Volatility::Stable
109            ) && e.args.iter().all(is_constant_expression)
110        }
111        _ => false,
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::test::*;
119
120    use arrow::datatypes::DataType;
121    use datafusion_common::Result;
122    use datafusion_expr::expr::ScalarFunction;
123    use datafusion_expr::{
124        col, lit, ColumnarValue, LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF,
125        ScalarUDFImpl, Signature, TypeSignature,
126    };
127
128    use datafusion_functions_aggregate::expr_fn::count;
129
130    use std::sync::Arc;
131
132    #[derive(Debug)]
133    struct ScalarUDFMock {
134        signature: Signature,
135    }
136
137    impl ScalarUDFMock {
138        fn new_with_volatility(volatility: Volatility) -> Self {
139            Self {
140                signature: Signature::new(TypeSignature::Any(1), volatility),
141            }
142        }
143    }
144
145    impl ScalarUDFImpl for ScalarUDFMock {
146        fn as_any(&self) -> &dyn std::any::Any {
147            self
148        }
149        fn name(&self) -> &str {
150            "scalar_fn_mock"
151        }
152        fn signature(&self) -> &Signature {
153            &self.signature
154        }
155        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
156            Ok(DataType::Int32)
157        }
158        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
159            unimplemented!()
160        }
161    }
162
163    #[test]
164    fn test_eliminate_gby_literal() -> Result<()> {
165        let scan = test_table_scan()?;
166        let plan = LogicalPlanBuilder::from(scan)
167            .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])?
168            .build()?;
169
170        let expected = "\
171            Projection: test.a, UInt32(1), count(test.c)\
172            \n  Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
173            \n    TableScan: test\
174        ";
175
176        assert_optimized_plan_eq(
177            Arc::new(EliminateGroupByConstant::new()),
178            plan,
179            expected,
180        )
181    }
182
183    #[test]
184    fn test_eliminate_constant() -> Result<()> {
185        let scan = test_table_scan()?;
186        let plan = LogicalPlanBuilder::from(scan)
187            .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])?
188            .build()?;
189
190        let expected = "\
191            Projection: Utf8(\"test\"), UInt32(123), count(test.c)\
192            \n  Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\
193            \n    TableScan: test\
194        ";
195
196        assert_optimized_plan_eq(
197            Arc::new(EliminateGroupByConstant::new()),
198            plan,
199            expected,
200        )
201    }
202
203    #[test]
204    fn test_no_op_no_constants() -> Result<()> {
205        let scan = test_table_scan()?;
206        let plan = LogicalPlanBuilder::from(scan)
207            .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])?
208            .build()?;
209
210        let expected = "\
211            Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\
212            \n  TableScan: test\
213        ";
214
215        assert_optimized_plan_eq(
216            Arc::new(EliminateGroupByConstant::new()),
217            plan,
218            expected,
219        )
220    }
221
222    #[test]
223    fn test_no_op_only_constant() -> Result<()> {
224        let scan = test_table_scan()?;
225        let plan = LogicalPlanBuilder::from(scan)
226            .aggregate(vec![lit(123u32)], Vec::<Expr>::new())?
227            .build()?;
228
229        let expected = "\
230            Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\
231            \n  TableScan: test\
232        ";
233
234        assert_optimized_plan_eq(
235            Arc::new(EliminateGroupByConstant::new()),
236            plan,
237            expected,
238        )
239    }
240
241    #[test]
242    fn test_eliminate_constant_with_alias() -> Result<()> {
243        let scan = test_table_scan()?;
244        let plan = LogicalPlanBuilder::from(scan)
245            .aggregate(
246                vec![lit(123u32).alias("const"), col("a")],
247                vec![count(col("c"))],
248            )?
249            .build()?;
250
251        let expected = "\
252            Projection: UInt32(123) AS const, test.a, count(test.c)\
253            \n  Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
254            \n    TableScan: test\
255        ";
256
257        assert_optimized_plan_eq(
258            Arc::new(EliminateGroupByConstant::new()),
259            plan,
260            expected,
261        )
262    }
263
264    #[test]
265    fn test_eliminate_scalar_fn_with_constant_arg() -> Result<()> {
266        let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
267            Volatility::Immutable,
268        ));
269        let udf_expr =
270            Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)]));
271        let scan = test_table_scan()?;
272        let plan = LogicalPlanBuilder::from(scan)
273            .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
274            .build()?;
275
276        let expected = "\
277            Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\
278            \n  Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
279            \n    TableScan: test\
280        ";
281
282        assert_optimized_plan_eq(
283            Arc::new(EliminateGroupByConstant::new()),
284            plan,
285            expected,
286        )
287    }
288
289    #[test]
290    fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> {
291        let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
292            Volatility::Volatile,
293        ));
294        let udf_expr =
295            Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)]));
296        let scan = test_table_scan()?;
297        let plan = LogicalPlanBuilder::from(scan)
298            .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
299            .build()?;
300
301        let expected = "\
302            Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\
303            \n  TableScan: test\
304        ";
305
306        assert_optimized_plan_eq(
307            Arc::new(EliminateGroupByConstant::new()),
308            plan,
309            expected,
310        )
311    }
312}