datafusion_functions_aggregate/
planner.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL planning extensions like [`AggregateFunctionPlanner`]
19
20use datafusion_common::Result;
21use datafusion_expr::{
22    expr::{AggregateFunction, AggregateFunctionParams},
23    expr_rewriter::NamePreserver,
24    planner::{ExprPlanner, PlannerResult, RawAggregateExpr},
25    utils::COUNT_STAR_EXPANSION,
26    Expr,
27};
28
29#[derive(Debug)]
30pub struct AggregateFunctionPlanner;
31
32impl ExprPlanner for AggregateFunctionPlanner {
33    fn plan_aggregate(
34        &self,
35        raw_expr: RawAggregateExpr,
36    ) -> Result<PlannerResult<RawAggregateExpr>> {
37        let RawAggregateExpr {
38            func,
39            args,
40            distinct,
41            filter,
42            order_by,
43            null_treatment,
44        } = raw_expr;
45
46        let origin_expr = Expr::AggregateFunction(AggregateFunction {
47            func,
48            params: AggregateFunctionParams {
49                args,
50                distinct,
51                filter,
52                order_by,
53                null_treatment,
54            },
55        });
56
57        let saved_name = NamePreserver::new_for_projection().save(&origin_expr);
58
59        let Expr::AggregateFunction(AggregateFunction {
60            func,
61            params:
62                AggregateFunctionParams {
63                    args,
64                    distinct,
65                    filter,
66                    order_by,
67                    null_treatment,
68                },
69        }) = origin_expr
70        else {
71            unreachable!("")
72        };
73        let raw_expr = RawAggregateExpr {
74            func,
75            args,
76            distinct,
77            filter,
78            order_by,
79            null_treatment,
80        };
81
82        // handle count() and count(*) case
83        // convert to count(1) as "count()"
84        // or         count(1) as "count(*)"
85        // TODO: remove the next line after `Expr::Wildcard` is removed
86        #[expect(deprecated)]
87        if raw_expr.func.name() == "count"
88            && (raw_expr.args.len() == 1
89                && matches!(raw_expr.args[0], Expr::Wildcard { .. })
90                || raw_expr.args.is_empty())
91        {
92            let RawAggregateExpr {
93                func,
94                args: _,
95                distinct,
96                filter,
97                order_by,
98                null_treatment,
99            } = raw_expr;
100
101            let new_expr = Expr::AggregateFunction(AggregateFunction::new_udf(
102                func,
103                vec![Expr::Literal(COUNT_STAR_EXPANSION)],
104                distinct,
105                filter,
106                order_by,
107                null_treatment,
108            ));
109
110            let new_expr = saved_name.restore(new_expr);
111            return Ok(PlannerResult::Planned(new_expr));
112        }
113
114        Ok(PlannerResult::Original(raw_expr))
115    }
116}