1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21
22use datafusion_common::tree_node::Transformed;
23use datafusion_common::Result;
24use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility};
25
26#[derive(Default, Debug)]
30pub struct EliminateGroupByConstant {}
31
32impl EliminateGroupByConstant {
33 pub fn new() -> Self {
34 Self {}
35 }
36}
37
38impl OptimizerRule for EliminateGroupByConstant {
39 fn supports_rewrite(&self) -> bool {
40 true
41 }
42
43 fn rewrite(
44 &self,
45 plan: LogicalPlan,
46 _config: &dyn OptimizerConfig,
47 ) -> Result<Transformed<LogicalPlan>> {
48 match plan {
49 LogicalPlan::Aggregate(aggregate) => {
50 let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate
51 .group_expr
52 .iter()
53 .partition(|expr| is_constant_expression(expr));
54
55 if const_group_expr.is_empty()
59 || (!const_group_expr.is_empty()
60 && nonconst_group_expr.is_empty()
61 && aggregate.aggr_expr.is_empty())
62 {
63 return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
64 }
65
66 let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
67 aggregate.input,
68 nonconst_group_expr.into_iter().cloned().collect(),
69 aggregate.aggr_expr.clone(),
70 )?);
71
72 let projection_expr =
73 aggregate.group_expr.into_iter().chain(aggregate.aggr_expr);
74
75 let projection = LogicalPlanBuilder::from(simplified_aggregate)
76 .project(projection_expr)?
77 .build()?;
78
79 Ok(Transformed::yes(projection))
80 }
81 _ => Ok(Transformed::no(plan)),
82 }
83 }
84
85 fn name(&self) -> &str {
86 "eliminate_group_by_constant"
87 }
88
89 fn apply_order(&self) -> Option<ApplyOrder> {
90 Some(ApplyOrder::BottomUp)
91 }
92}
93
94fn is_constant_expression(expr: &Expr) -> bool {
99 match expr {
100 Expr::Alias(e) => is_constant_expression(&e.expr),
101 Expr::BinaryExpr(e) => {
102 is_constant_expression(&e.left) && is_constant_expression(&e.right)
103 }
104 Expr::Literal(_) => true,
105 Expr::ScalarFunction(e) => {
106 matches!(
107 e.func.signature().volatility,
108 Volatility::Immutable | Volatility::Stable
109 ) && e.args.iter().all(is_constant_expression)
110 }
111 _ => false,
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::test::*;
119
120 use arrow::datatypes::DataType;
121 use datafusion_common::Result;
122 use datafusion_expr::expr::ScalarFunction;
123 use datafusion_expr::{
124 col, lit, ColumnarValue, LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF,
125 ScalarUDFImpl, Signature, TypeSignature,
126 };
127
128 use datafusion_functions_aggregate::expr_fn::count;
129
130 use std::sync::Arc;
131
132 #[derive(Debug)]
133 struct ScalarUDFMock {
134 signature: Signature,
135 }
136
137 impl ScalarUDFMock {
138 fn new_with_volatility(volatility: Volatility) -> Self {
139 Self {
140 signature: Signature::new(TypeSignature::Any(1), volatility),
141 }
142 }
143 }
144
145 impl ScalarUDFImpl for ScalarUDFMock {
146 fn as_any(&self) -> &dyn std::any::Any {
147 self
148 }
149 fn name(&self) -> &str {
150 "scalar_fn_mock"
151 }
152 fn signature(&self) -> &Signature {
153 &self.signature
154 }
155 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
156 Ok(DataType::Int32)
157 }
158 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
159 unimplemented!()
160 }
161 }
162
163 #[test]
164 fn test_eliminate_gby_literal() -> Result<()> {
165 let scan = test_table_scan()?;
166 let plan = LogicalPlanBuilder::from(scan)
167 .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])?
168 .build()?;
169
170 let expected = "\
171 Projection: test.a, UInt32(1), count(test.c)\
172 \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
173 \n TableScan: test\
174 ";
175
176 assert_optimized_plan_eq(
177 Arc::new(EliminateGroupByConstant::new()),
178 plan,
179 expected,
180 )
181 }
182
183 #[test]
184 fn test_eliminate_constant() -> Result<()> {
185 let scan = test_table_scan()?;
186 let plan = LogicalPlanBuilder::from(scan)
187 .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])?
188 .build()?;
189
190 let expected = "\
191 Projection: Utf8(\"test\"), UInt32(123), count(test.c)\
192 \n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\
193 \n TableScan: test\
194 ";
195
196 assert_optimized_plan_eq(
197 Arc::new(EliminateGroupByConstant::new()),
198 plan,
199 expected,
200 )
201 }
202
203 #[test]
204 fn test_no_op_no_constants() -> Result<()> {
205 let scan = test_table_scan()?;
206 let plan = LogicalPlanBuilder::from(scan)
207 .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])?
208 .build()?;
209
210 let expected = "\
211 Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\
212 \n TableScan: test\
213 ";
214
215 assert_optimized_plan_eq(
216 Arc::new(EliminateGroupByConstant::new()),
217 plan,
218 expected,
219 )
220 }
221
222 #[test]
223 fn test_no_op_only_constant() -> Result<()> {
224 let scan = test_table_scan()?;
225 let plan = LogicalPlanBuilder::from(scan)
226 .aggregate(vec![lit(123u32)], Vec::<Expr>::new())?
227 .build()?;
228
229 let expected = "\
230 Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\
231 \n TableScan: test\
232 ";
233
234 assert_optimized_plan_eq(
235 Arc::new(EliminateGroupByConstant::new()),
236 plan,
237 expected,
238 )
239 }
240
241 #[test]
242 fn test_eliminate_constant_with_alias() -> Result<()> {
243 let scan = test_table_scan()?;
244 let plan = LogicalPlanBuilder::from(scan)
245 .aggregate(
246 vec![lit(123u32).alias("const"), col("a")],
247 vec![count(col("c"))],
248 )?
249 .build()?;
250
251 let expected = "\
252 Projection: UInt32(123) AS const, test.a, count(test.c)\
253 \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
254 \n TableScan: test\
255 ";
256
257 assert_optimized_plan_eq(
258 Arc::new(EliminateGroupByConstant::new()),
259 plan,
260 expected,
261 )
262 }
263
264 #[test]
265 fn test_eliminate_scalar_fn_with_constant_arg() -> Result<()> {
266 let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
267 Volatility::Immutable,
268 ));
269 let udf_expr =
270 Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)]));
271 let scan = test_table_scan()?;
272 let plan = LogicalPlanBuilder::from(scan)
273 .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
274 .build()?;
275
276 let expected = "\
277 Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\
278 \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
279 \n TableScan: test\
280 ";
281
282 assert_optimized_plan_eq(
283 Arc::new(EliminateGroupByConstant::new()),
284 plan,
285 expected,
286 )
287 }
288
289 #[test]
290 fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> {
291 let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
292 Volatility::Volatile,
293 ));
294 let udf_expr =
295 Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)]));
296 let scan = test_table_scan()?;
297 let plan = LogicalPlanBuilder::from(scan)
298 .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
299 .build()?;
300
301 let expected = "\
302 Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\
303 \n TableScan: test\
304 ";
305
306 assert_optimized_plan_eq(
307 Arc::new(EliminateGroupByConstant::new()),
308 plan,
309 expected,
310 )
311 }
312}