1use crate::optimizer::ApplyOrder;
21use crate::{OptimizerConfig, OptimizerRule};
22
23use std::collections::HashSet;
24
25use datafusion_common::Result;
26use datafusion_common::tree_node::Transformed;
27use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility};
28
29#[derive(Default, Debug)]
33pub struct EliminateGroupByConstant {}
34
35impl EliminateGroupByConstant {
36 pub fn new() -> Self {
37 Self {}
38 }
39}
40
41impl OptimizerRule for EliminateGroupByConstant {
42 fn supports_rewrite(&self) -> bool {
43 true
44 }
45
46 fn rewrite(
47 &self,
48 plan: LogicalPlan,
49 _config: &dyn OptimizerConfig,
50 ) -> Result<Transformed<LogicalPlan>> {
51 match plan {
52 LogicalPlan::Aggregate(aggregate) => {
53 let group_by_columns: HashSet<&datafusion_common::Column> = aggregate
55 .group_expr
56 .iter()
57 .filter_map(|expr| match expr {
58 Expr::Column(c) => Some(c),
59 _ => None,
60 })
61 .collect();
62
63 let (redundant, required): (Vec<_>, Vec<_>) = aggregate
64 .group_expr
65 .iter()
66 .partition(|expr| is_redundant_group_expr(expr, &group_by_columns));
67
68 if redundant.is_empty()
69 || (required.is_empty() && aggregate.aggr_expr.is_empty())
70 {
71 return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
72 }
73
74 let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
75 aggregate.input,
76 required.into_iter().cloned().collect(),
77 aggregate.aggr_expr.clone(),
78 )?);
79
80 let projection_expr =
81 aggregate.group_expr.into_iter().chain(aggregate.aggr_expr);
82
83 let projection = LogicalPlanBuilder::from(simplified_aggregate)
84 .project(projection_expr)?
85 .build()?;
86
87 Ok(Transformed::yes(projection))
88 }
89 _ => Ok(Transformed::no(plan)),
90 }
91 }
92
93 fn name(&self) -> &str {
94 "eliminate_group_by_constant"
95 }
96
97 fn apply_order(&self) -> Option<ApplyOrder> {
98 Some(ApplyOrder::BottomUp)
99 }
100}
101
102fn is_redundant_group_expr(
107 expr: &Expr,
108 group_by_columns: &HashSet<&datafusion_common::Column>,
109) -> bool {
110 if matches!(expr, Expr::Column(_)) {
112 return false;
113 }
114 is_deterministic_of(expr, group_by_columns)
115}
116
117fn is_deterministic_of(
120 expr: &Expr,
121 known_columns: &HashSet<&datafusion_common::Column>,
122) -> bool {
123 match expr {
124 Expr::Alias(e) => is_deterministic_of(&e.expr, known_columns),
125 Expr::Column(c) => known_columns.contains(c),
126 Expr::Literal(_, _) => true,
127 Expr::BinaryExpr(e) => {
128 is_deterministic_of(&e.left, known_columns)
129 && is_deterministic_of(&e.right, known_columns)
130 }
131 Expr::ScalarFunction(e) => {
132 matches!(
133 e.func.signature().volatility,
134 Volatility::Immutable | Volatility::Stable
135 ) && e
136 .args
137 .iter()
138 .all(|arg| is_deterministic_of(arg, known_columns))
139 }
140 Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns),
141 Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns),
142 Expr::Negative(e) => is_deterministic_of(e, known_columns),
143 _ => false,
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::OptimizerContext;
151 use crate::assert_optimized_plan_eq_snapshot;
152 use crate::test::*;
153
154 use arrow::datatypes::DataType;
155 use datafusion_expr::expr::ScalarFunction;
156 use datafusion_expr::{
157 ColumnarValue, LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
158 Signature, TypeSignature, col, lit,
159 };
160
161 use datafusion_functions_aggregate::expr_fn::count;
162
163 use std::sync::Arc;
164
165 macro_rules! assert_optimized_plan_equal {
166 (
167 $plan:expr,
168 @ $expected:literal $(,)?
169 ) => {{
170 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
171 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateGroupByConstant::new())];
172 assert_optimized_plan_eq_snapshot!(
173 optimizer_ctx,
174 rules,
175 $plan,
176 @ $expected,
177 )
178 }};
179 }
180
181 #[derive(Debug, PartialEq, Eq, Hash)]
182 struct ScalarUDFMock {
183 signature: Signature,
184 }
185
186 impl ScalarUDFMock {
187 fn new_with_volatility(volatility: Volatility) -> Self {
188 Self {
189 signature: Signature::new(TypeSignature::Any(1), volatility),
190 }
191 }
192 }
193
194 impl ScalarUDFImpl for ScalarUDFMock {
195 fn name(&self) -> &str {
196 "scalar_fn_mock"
197 }
198 fn signature(&self) -> &Signature {
199 &self.signature
200 }
201 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
202 Ok(DataType::Int32)
203 }
204 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
205 unimplemented!()
206 }
207 }
208
209 #[test]
210 fn test_eliminate_gby_literal() -> Result<()> {
211 let scan = test_table_scan()?;
212 let plan = LogicalPlanBuilder::from(scan)
213 .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])?
214 .build()?;
215
216 assert_optimized_plan_equal!(plan, @r"
217 Projection: test.a, UInt32(1), count(test.c)
218 Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
219 TableScan: test
220 ")
221 }
222
223 #[test]
224 fn test_eliminate_constant() -> Result<()> {
225 let scan = test_table_scan()?;
226 let plan = LogicalPlanBuilder::from(scan)
227 .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])?
228 .build()?;
229
230 assert_optimized_plan_equal!(plan, @r#"
231 Projection: Utf8("test"), UInt32(123), count(test.c)
232 Aggregate: groupBy=[[]], aggr=[[count(test.c)]]
233 TableScan: test
234 "#)
235 }
236
237 #[test]
238 fn test_no_op_no_constants() -> Result<()> {
239 let scan = test_table_scan()?;
240 let plan = LogicalPlanBuilder::from(scan)
241 .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])?
242 .build()?;
243
244 assert_optimized_plan_equal!(plan, @r"
245 Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]
246 TableScan: test
247 ")
248 }
249
250 #[test]
251 fn test_no_op_only_constant() -> Result<()> {
252 let scan = test_table_scan()?;
253 let plan = LogicalPlanBuilder::from(scan)
254 .aggregate(vec![lit(123u32)], Vec::<Expr>::new())?
255 .build()?;
256
257 assert_optimized_plan_equal!(plan, @r"
258 Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]
259 TableScan: test
260 ")
261 }
262
263 #[test]
264 fn test_eliminate_constant_with_alias() -> Result<()> {
265 let scan = test_table_scan()?;
266 let plan = LogicalPlanBuilder::from(scan)
267 .aggregate(
268 vec![lit(123u32).alias("const"), col("a")],
269 vec![count(col("c"))],
270 )?
271 .build()?;
272
273 assert_optimized_plan_equal!(plan, @r"
274 Projection: UInt32(123) AS const, test.a, count(test.c)
275 Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
276 TableScan: test
277 ")
278 }
279
280 #[test]
281 fn test_eliminate_scalar_fn_with_constant_arg() -> Result<()> {
282 let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
283 Volatility::Immutable,
284 ));
285 let udf_expr =
286 Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)]));
287 let scan = test_table_scan()?;
288 let plan = LogicalPlanBuilder::from(scan)
289 .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
290 .build()?;
291
292 assert_optimized_plan_equal!(plan, @r"
293 Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)
294 Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
295 TableScan: test
296 ")
297 }
298
299 #[test]
300 fn test_eliminate_deterministic_expr_of_group_by_column() -> Result<()> {
301 let scan = test_table_scan()?;
302 let plan = LogicalPlanBuilder::from(scan)
304 .aggregate(
305 vec![
306 col("a"),
307 col("a") - lit(1u32),
308 col("a") - lit(2u32),
309 col("a") - lit(3u32),
310 ],
311 vec![count(col("c"))],
312 )?
313 .build()?;
314
315 assert_optimized_plan_equal!(plan, @r"
316 Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a - UInt32(3), count(test.c)
317 Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
318 TableScan: test
319 ")
320 }
321
322 #[test]
323 fn test_no_eliminate_independent_columns() -> Result<()> {
324 let scan = test_table_scan()?;
326 let plan = LogicalPlanBuilder::from(scan)
327 .aggregate(vec![col("a"), col("b") - lit(1u32)], vec![count(col("c"))])?
328 .build()?;
329
330 assert_optimized_plan_equal!(plan, @r"
331 Aggregate: groupBy=[[test.a, test.b - UInt32(1)]], aggr=[[count(test.c)]]
332 TableScan: test
333 ")
334 }
335
336 #[test]
337 fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> {
338 let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
339 Volatility::Volatile,
340 ));
341 let udf_expr =
342 Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)]));
343 let scan = test_table_scan()?;
344 let plan = LogicalPlanBuilder::from(scan)
345 .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
346 .build()?;
347
348 assert_optimized_plan_equal!(plan, @r"
349 Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]
350 TableScan: test
351 ")
352 }
353}