datafusion_functions_aggregate/planner.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL planning extensions like [`AggregateFunctionPlanner`]
19
20use datafusion_common::Result;
21use datafusion_expr::{
22 expr::{AggregateFunction, AggregateFunctionParams},
23 expr_rewriter::NamePreserver,
24 planner::{ExprPlanner, PlannerResult, RawAggregateExpr},
25 utils::COUNT_STAR_EXPANSION,
26 Expr,
27};
28
29#[derive(Debug)]
30pub struct AggregateFunctionPlanner;
31
32impl ExprPlanner for AggregateFunctionPlanner {
33 fn plan_aggregate(
34 &self,
35 raw_expr: RawAggregateExpr,
36 ) -> Result<PlannerResult<RawAggregateExpr>> {
37 let RawAggregateExpr {
38 func,
39 args,
40 distinct,
41 filter,
42 order_by,
43 null_treatment,
44 } = raw_expr;
45
46 let origin_expr = Expr::AggregateFunction(AggregateFunction {
47 func,
48 params: AggregateFunctionParams {
49 args,
50 distinct,
51 filter,
52 order_by,
53 null_treatment,
54 },
55 });
56
57 let saved_name = NamePreserver::new_for_projection().save(&origin_expr);
58
59 let Expr::AggregateFunction(AggregateFunction {
60 func,
61 params:
62 AggregateFunctionParams {
63 args,
64 distinct,
65 filter,
66 order_by,
67 null_treatment,
68 },
69 }) = origin_expr
70 else {
71 unreachable!("")
72 };
73 let raw_expr = RawAggregateExpr {
74 func,
75 args,
76 distinct,
77 filter,
78 order_by,
79 null_treatment,
80 };
81
82 // handle count() and count(*) case
83 // convert to count(1) as "count()"
84 // or count(1) as "count(*)"
85 // TODO: remove the next line after `Expr::Wildcard` is removed
86 #[expect(deprecated)]
87 if raw_expr.func.name() == "count"
88 && (raw_expr.args.len() == 1
89 && matches!(raw_expr.args[0], Expr::Wildcard { .. })
90 || raw_expr.args.is_empty())
91 {
92 let RawAggregateExpr {
93 func,
94 args: _,
95 distinct,
96 filter,
97 order_by,
98 null_treatment,
99 } = raw_expr;
100
101 let new_expr = Expr::AggregateFunction(AggregateFunction::new_udf(
102 func,
103 vec![Expr::Literal(COUNT_STAR_EXPANSION)],
104 distinct,
105 filter,
106 order_by,
107 null_treatment,
108 ));
109
110 let new_expr = saved_name.restore(new_expr);
111 return Ok(PlannerResult::Planned(new_expr));
112 }
113
114 Ok(PlannerResult::Original(raw_expr))
115 }
116}