datafusion_functions_extra/
mode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::{arrow, common as df_common, error, logical_expr};
19use std::{any, fmt};
20
21use crate::common;
22
23make_udaf_expr_and_func!(
24    ModeFunction,
25    mode,
26    x,
27    "Calculates the most frequent value.",
28    mode_udaf
29);
30
31/// The `ModeFunction` calculates the mode (most frequent value) from a set of values.
32///
33/// - Null values are ignored during the calculation.
34/// - If multiple values have the same frequency, the MAX value with the highest frequency is returned.
35pub struct ModeFunction {
36    signature: logical_expr::Signature,
37}
38
39impl fmt::Debug for ModeFunction {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        f.debug_struct("ModeFunction")
42            .field("signature", &self.signature)
43            .finish()
44    }
45}
46
47impl Default for ModeFunction {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl ModeFunction {
54    pub fn new() -> Self {
55        Self {
56            signature: logical_expr::Signature::variadic_any(logical_expr::Volatility::Immutable),
57        }
58    }
59}
60
61impl logical_expr::AggregateUDFImpl for ModeFunction {
62    fn as_any(&self) -> &dyn any::Any {
63        self
64    }
65
66    fn name(&self) -> &str {
67        "mode"
68    }
69
70    fn signature(&self) -> &logical_expr::Signature {
71        &self.signature
72    }
73
74    fn return_type(
75        &self,
76        arg_types: &[arrow::datatypes::DataType],
77    ) -> error::Result<arrow::datatypes::DataType> {
78        Ok(arg_types[0].clone())
79    }
80
81    fn state_fields(
82        &self,
83        args: logical_expr::function::StateFieldsArgs,
84    ) -> error::Result<Vec<arrow::datatypes::FieldRef>> {
85        let value_type = args.input_fields[0].data_type().clone();
86
87        Ok(vec![
88            arrow::datatypes::Field::new("values", value_type, true).into(),
89            arrow::datatypes::Field::new("frequencies", arrow::datatypes::DataType::UInt64, true)
90                .into(),
91        ])
92    }
93
94    fn accumulator(
95        &self,
96        acc_args: logical_expr::function::AccumulatorArgs,
97    ) -> error::Result<Box<dyn logical_expr::Accumulator>> {
98        let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
99
100        let accumulator: Box<dyn logical_expr::Accumulator> = match data_type {
101            arrow::datatypes::DataType::Int8 => Box::new(common::mode::PrimitiveModeAccumulator::<
102                arrow::datatypes::Int8Type,
103            >::new(data_type)),
104            arrow::datatypes::DataType::Int16 => {
105                Box::new(common::mode::PrimitiveModeAccumulator::<
106                    arrow::datatypes::Int16Type,
107                >::new(data_type))
108            }
109            arrow::datatypes::DataType::Int32 => {
110                Box::new(common::mode::PrimitiveModeAccumulator::<
111                    arrow::datatypes::Int32Type,
112                >::new(data_type))
113            }
114            arrow::datatypes::DataType::Int64 => {
115                Box::new(common::mode::PrimitiveModeAccumulator::<
116                    arrow::datatypes::Int64Type,
117                >::new(data_type))
118            }
119            arrow::datatypes::DataType::UInt8 => {
120                Box::new(common::mode::PrimitiveModeAccumulator::<
121                    arrow::datatypes::UInt8Type,
122                >::new(data_type))
123            }
124            arrow::datatypes::DataType::UInt16 => {
125                Box::new(common::mode::PrimitiveModeAccumulator::<
126                    arrow::datatypes::UInt16Type,
127                >::new(data_type))
128            }
129            arrow::datatypes::DataType::UInt32 => {
130                Box::new(common::mode::PrimitiveModeAccumulator::<
131                    arrow::datatypes::UInt32Type,
132                >::new(data_type))
133            }
134            arrow::datatypes::DataType::UInt64 => {
135                Box::new(common::mode::PrimitiveModeAccumulator::<
136                    arrow::datatypes::UInt64Type,
137                >::new(data_type))
138            }
139
140            arrow::datatypes::DataType::Date32 => {
141                Box::new(common::mode::PrimitiveModeAccumulator::<
142                    arrow::datatypes::Date32Type,
143                >::new(data_type))
144            }
145            arrow::datatypes::DataType::Date64 => {
146                Box::new(common::mode::PrimitiveModeAccumulator::<
147                    arrow::datatypes::Date64Type,
148                >::new(data_type))
149            }
150            arrow::datatypes::DataType::Time32(arrow::datatypes::TimeUnit::Millisecond) => {
151                Box::new(common::mode::PrimitiveModeAccumulator::<
152                    arrow::datatypes::Time32MillisecondType,
153                >::new(data_type))
154            }
155            arrow::datatypes::DataType::Time32(arrow::datatypes::TimeUnit::Second) => {
156                Box::new(common::mode::PrimitiveModeAccumulator::<
157                    arrow::datatypes::Time32SecondType,
158                >::new(data_type))
159            }
160            arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond) => {
161                Box::new(common::mode::PrimitiveModeAccumulator::<
162                    arrow::datatypes::Time64MicrosecondType,
163                >::new(data_type))
164            }
165            arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond) => {
166                Box::new(common::mode::PrimitiveModeAccumulator::<
167                    arrow::datatypes::Time64NanosecondType,
168                >::new(data_type))
169            }
170            arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => {
171                Box::new(common::mode::PrimitiveModeAccumulator::<
172                    arrow::datatypes::TimestampMicrosecondType,
173                >::new(data_type))
174            }
175            arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, _) => {
176                Box::new(common::mode::PrimitiveModeAccumulator::<
177                    arrow::datatypes::TimestampMillisecondType,
178                >::new(data_type))
179            }
180            arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, _) => {
181                Box::new(common::mode::PrimitiveModeAccumulator::<
182                    arrow::datatypes::TimestampNanosecondType,
183                >::new(data_type))
184            }
185            arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Second, _) => {
186                Box::new(common::mode::PrimitiveModeAccumulator::<
187                    arrow::datatypes::TimestampSecondType,
188                >::new(data_type))
189            }
190
191            arrow::datatypes::DataType::Float16 => Box::new(common::mode::FloatModeAccumulator::<
192                arrow::datatypes::Float16Type,
193            >::new(data_type)),
194            arrow::datatypes::DataType::Float32 => Box::new(common::mode::FloatModeAccumulator::<
195                arrow::datatypes::Float32Type,
196            >::new(data_type)),
197            arrow::datatypes::DataType::Float64 => Box::new(common::mode::FloatModeAccumulator::<
198                arrow::datatypes::Float64Type,
199            >::new(data_type)),
200
201            arrow::datatypes::DataType::Utf8
202            | arrow::datatypes::DataType::Utf8View
203            | arrow::datatypes::DataType::LargeUtf8 => {
204                Box::new(common::mode::BytesModeAccumulator::new(data_type))
205            }
206            _ => {
207                return df_common::not_impl_err!(
208                    "Unsupported data type: {:?} for mode function",
209                    data_type
210                );
211            }
212        };
213
214        Ok(accumulator)
215    }
216}