Skip to main content

datafusion_optimizer/
eliminate_group_by_constant.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateGroupByConstant`] removes constant and functionally redundant
19//! expressions from `GROUP BY` clause
20use crate::optimizer::ApplyOrder;
21use crate::{OptimizerConfig, OptimizerRule};
22
23use std::collections::HashSet;
24
25use datafusion_common::Result;
26use datafusion_common::tree_node::Transformed;
27use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility};
28
29/// Optimizer rule that removes constant expressions from `GROUP BY` clause
30/// and places additional projection on top of aggregation, to preserve
31/// original schema
32#[derive(Default, Debug)]
33pub struct EliminateGroupByConstant {}
34
35impl EliminateGroupByConstant {
36    pub fn new() -> Self {
37        Self {}
38    }
39}
40
41impl OptimizerRule for EliminateGroupByConstant {
42    fn supports_rewrite(&self) -> bool {
43        true
44    }
45
46    fn rewrite(
47        &self,
48        plan: LogicalPlan,
49        _config: &dyn OptimizerConfig,
50    ) -> Result<Transformed<LogicalPlan>> {
51        match plan {
52            LogicalPlan::Aggregate(aggregate) => {
53                // Collect bare column references in GROUP BY
54                let group_by_columns: HashSet<&datafusion_common::Column> = aggregate
55                    .group_expr
56                    .iter()
57                    .filter_map(|expr| match expr {
58                        Expr::Column(c) => Some(c),
59                        _ => None,
60                    })
61                    .collect();
62
63                let (redundant, required): (Vec<_>, Vec<_>) = aggregate
64                    .group_expr
65                    .iter()
66                    .partition(|expr| is_redundant_group_expr(expr, &group_by_columns));
67
68                if redundant.is_empty()
69                    || (required.is_empty() && aggregate.aggr_expr.is_empty())
70                {
71                    return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
72                }
73
74                let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
75                    aggregate.input,
76                    required.into_iter().cloned().collect(),
77                    aggregate.aggr_expr.clone(),
78                )?);
79
80                let projection_expr =
81                    aggregate.group_expr.into_iter().chain(aggregate.aggr_expr);
82
83                let projection = LogicalPlanBuilder::from(simplified_aggregate)
84                    .project(projection_expr)?
85                    .build()?;
86
87                Ok(Transformed::yes(projection))
88            }
89            _ => Ok(Transformed::no(plan)),
90        }
91    }
92
93    fn name(&self) -> &str {
94        "eliminate_group_by_constant"
95    }
96
97    fn apply_order(&self) -> Option<ApplyOrder> {
98        Some(ApplyOrder::BottomUp)
99    }
100}
101
102/// Checks if a GROUP BY expression is redundant (can be removed without
103/// changing grouping semantics). An expression is redundant if it is a
104/// deterministic function of constants and columns already present as bare
105/// column references in the GROUP BY.
106fn is_redundant_group_expr(
107    expr: &Expr,
108    group_by_columns: &HashSet<&datafusion_common::Column>,
109) -> bool {
110    // Bare column references are never redundant - they define the grouping
111    if matches!(expr, Expr::Column(_)) {
112        return false;
113    }
114    is_deterministic_of(expr, group_by_columns)
115}
116
117/// Returns true if `expr` is a deterministic expression whose only column
118/// references are contained in `known_columns`.
119fn is_deterministic_of(
120    expr: &Expr,
121    known_columns: &HashSet<&datafusion_common::Column>,
122) -> bool {
123    match expr {
124        Expr::Alias(e) => is_deterministic_of(&e.expr, known_columns),
125        Expr::Column(c) => known_columns.contains(c),
126        Expr::Literal(_, _) => true,
127        Expr::BinaryExpr(e) => {
128            is_deterministic_of(&e.left, known_columns)
129                && is_deterministic_of(&e.right, known_columns)
130        }
131        Expr::ScalarFunction(e) => {
132            matches!(
133                e.func.signature().volatility,
134                Volatility::Immutable | Volatility::Stable
135            ) && e
136                .args
137                .iter()
138                .all(|arg| is_deterministic_of(arg, known_columns))
139        }
140        Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns),
141        Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns),
142        Expr::Negative(e) => is_deterministic_of(e, known_columns),
143        _ => false,
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::OptimizerContext;
151    use crate::assert_optimized_plan_eq_snapshot;
152    use crate::test::*;
153
154    use arrow::datatypes::DataType;
155    use datafusion_expr::expr::ScalarFunction;
156    use datafusion_expr::{
157        ColumnarValue, LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
158        Signature, TypeSignature, col, lit,
159    };
160
161    use datafusion_functions_aggregate::expr_fn::count;
162
163    use std::sync::Arc;
164
165    macro_rules! assert_optimized_plan_equal {
166        (
167            $plan:expr,
168            @ $expected:literal $(,)?
169        ) => {{
170            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
171            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateGroupByConstant::new())];
172            assert_optimized_plan_eq_snapshot!(
173                optimizer_ctx,
174                rules,
175                $plan,
176                @ $expected,
177            )
178        }};
179    }
180
181    #[derive(Debug, PartialEq, Eq, Hash)]
182    struct ScalarUDFMock {
183        signature: Signature,
184    }
185
186    impl ScalarUDFMock {
187        fn new_with_volatility(volatility: Volatility) -> Self {
188            Self {
189                signature: Signature::new(TypeSignature::Any(1), volatility),
190            }
191        }
192    }
193
194    impl ScalarUDFImpl for ScalarUDFMock {
195        fn name(&self) -> &str {
196            "scalar_fn_mock"
197        }
198        fn signature(&self) -> &Signature {
199            &self.signature
200        }
201        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
202            Ok(DataType::Int32)
203        }
204        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
205            unimplemented!()
206        }
207    }
208
209    #[test]
210    fn test_eliminate_gby_literal() -> Result<()> {
211        let scan = test_table_scan()?;
212        let plan = LogicalPlanBuilder::from(scan)
213            .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])?
214            .build()?;
215
216        assert_optimized_plan_equal!(plan, @r"
217        Projection: test.a, UInt32(1), count(test.c)
218          Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
219            TableScan: test
220        ")
221    }
222
223    #[test]
224    fn test_eliminate_constant() -> Result<()> {
225        let scan = test_table_scan()?;
226        let plan = LogicalPlanBuilder::from(scan)
227            .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])?
228            .build()?;
229
230        assert_optimized_plan_equal!(plan, @r#"
231        Projection: Utf8("test"), UInt32(123), count(test.c)
232          Aggregate: groupBy=[[]], aggr=[[count(test.c)]]
233            TableScan: test
234        "#)
235    }
236
237    #[test]
238    fn test_no_op_no_constants() -> Result<()> {
239        let scan = test_table_scan()?;
240        let plan = LogicalPlanBuilder::from(scan)
241            .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])?
242            .build()?;
243
244        assert_optimized_plan_equal!(plan, @r"
245        Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]
246          TableScan: test
247        ")
248    }
249
250    #[test]
251    fn test_no_op_only_constant() -> Result<()> {
252        let scan = test_table_scan()?;
253        let plan = LogicalPlanBuilder::from(scan)
254            .aggregate(vec![lit(123u32)], Vec::<Expr>::new())?
255            .build()?;
256
257        assert_optimized_plan_equal!(plan, @r"
258        Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]
259          TableScan: test
260        ")
261    }
262
263    #[test]
264    fn test_eliminate_constant_with_alias() -> Result<()> {
265        let scan = test_table_scan()?;
266        let plan = LogicalPlanBuilder::from(scan)
267            .aggregate(
268                vec![lit(123u32).alias("const"), col("a")],
269                vec![count(col("c"))],
270            )?
271            .build()?;
272
273        assert_optimized_plan_equal!(plan, @r"
274        Projection: UInt32(123) AS const, test.a, count(test.c)
275          Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
276            TableScan: test
277        ")
278    }
279
280    #[test]
281    fn test_eliminate_scalar_fn_with_constant_arg() -> Result<()> {
282        let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
283            Volatility::Immutable,
284        ));
285        let udf_expr =
286            Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)]));
287        let scan = test_table_scan()?;
288        let plan = LogicalPlanBuilder::from(scan)
289            .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
290            .build()?;
291
292        assert_optimized_plan_equal!(plan, @r"
293        Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)
294          Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
295            TableScan: test
296        ")
297    }
298
299    #[test]
300    fn test_eliminate_deterministic_expr_of_group_by_column() -> Result<()> {
301        let scan = test_table_scan()?;
302        // GROUP BY a, a - 1, a - 2, a - 3  ->  GROUP BY a
303        let plan = LogicalPlanBuilder::from(scan)
304            .aggregate(
305                vec![
306                    col("a"),
307                    col("a") - lit(1u32),
308                    col("a") - lit(2u32),
309                    col("a") - lit(3u32),
310                ],
311                vec![count(col("c"))],
312            )?
313            .build()?;
314
315        assert_optimized_plan_equal!(plan, @r"
316        Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a - UInt32(3), count(test.c)
317          Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
318            TableScan: test
319        ")
320    }
321
322    #[test]
323    fn test_no_eliminate_independent_columns() -> Result<()> {
324        // GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by column)
325        let scan = test_table_scan()?;
326        let plan = LogicalPlanBuilder::from(scan)
327            .aggregate(vec![col("a"), col("b") - lit(1u32)], vec![count(col("c"))])?
328            .build()?;
329
330        assert_optimized_plan_equal!(plan, @r"
331        Aggregate: groupBy=[[test.a, test.b - UInt32(1)]], aggr=[[count(test.c)]]
332          TableScan: test
333        ")
334    }
335
336    #[test]
337    fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> {
338        let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
339            Volatility::Volatile,
340        ));
341        let udf_expr =
342            Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)]));
343        let scan = test_table_scan()?;
344        let plan = LogicalPlanBuilder::from(scan)
345            .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
346            .build()?;
347
348        assert_optimized_plan_equal!(plan, @r"
349        Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]
350          TableScan: test
351        ")
352    }
353}