Skip to main content

datafusion_functions_aggregate/
grouping.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use std::any::Any;
21
22use arrow::datatypes::Field;
23use arrow::datatypes::{DataType, FieldRef};
24use datafusion_common::{Result, not_impl_err};
25use datafusion_expr::function::AccumulatorArgs;
26use datafusion_expr::function::StateFieldsArgs;
27use datafusion_expr::utils::format_state_name;
28use datafusion_expr::{
29    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
30};
31use datafusion_macros::user_doc;
32
33make_udaf_expr_and_func!(
34    Grouping,
35    grouping,
36    expression,
37    "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.",
38    grouping_udaf
39);
40
41#[user_doc(
42    doc_section(label = "General Functions"),
43    description = "Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set.",
44    syntax_example = "grouping(expression)",
45    sql_example = r#"```sql
46> SELECT column_name, GROUPING(column_name) AS group_column
47  FROM table_name
48  GROUP BY GROUPING SETS ((column_name), ());
49+-------------+-------------+
50| column_name | group_column |
51+-------------+-------------+
52| value1      | 0           |
53| value2      | 0           |
54| NULL        | 1           |
55+-------------+-------------+
56```"#,
57    argument(
58        name = "expression",
59        description = "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function."
60    )
61)]
62#[derive(PartialEq, Eq, Hash, Debug)]
63pub struct Grouping {
64    signature: Signature,
65}
66
67impl Default for Grouping {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl Grouping {
74    /// Create a new GROUPING aggregate function.
75    pub fn new() -> Self {
76        Self {
77            signature: Signature::variadic_any(Volatility::Immutable),
78        }
79    }
80}
81
82impl AggregateUDFImpl for Grouping {
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    fn name(&self) -> &str {
88        "grouping"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
96        Ok(DataType::Int32)
97    }
98
99    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
100        Ok(vec![
101            Field::new(
102                format_state_name(args.name, "grouping"),
103                DataType::Int32,
104                true,
105            )
106            .into(),
107        ])
108    }
109
110    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
111        not_impl_err!(
112            "physical plan is not yet implemented for GROUPING aggregate function"
113        )
114    }
115
116    fn documentation(&self) -> Option<&Documentation> {
117        self.doc()
118    }
119}