datafusion_functions_extra/
mode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{
19    Date32Type, Date64Type, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
20    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType,
21    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
22    UInt8Type,
23};
24use datafusion::arrow;
25
26use datafusion::error::Result;
27
28use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
29use datafusion::common::not_impl_err;
30use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
32
33use std::any::Any;
34use std::fmt::Debug;
35
36use crate::common::mode::{BytesModeAccumulator, FloatModeAccumulator, PrimitiveModeAccumulator};
37
38make_udaf_expr_and_func!(ModeFunction, mode, x, "Calculates the most frequent value.", mode_udaf);
39
40/// The `ModeFunction` calculates the mode (most frequent value) from a set of values.
41///
42/// - Null values are ignored during the calculation.
43/// - If multiple values have the same frequency, the MAX value with the highest frequency is returned.
44pub struct ModeFunction {
45    signature: Signature,
46}
47
48impl Debug for ModeFunction {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.debug_struct("ModeFunction")
51            .field("signature", &self.signature)
52            .finish()
53    }
54}
55
56impl Default for ModeFunction {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl ModeFunction {
63    pub fn new() -> Self {
64        Self {
65            signature: Signature::variadic_any(Volatility::Immutable),
66        }
67    }
68}
69
70impl AggregateUDFImpl for ModeFunction {
71    fn as_any(&self) -> &dyn Any {
72        self
73    }
74
75    fn name(&self) -> &str {
76        "mode"
77    }
78
79    fn signature(&self) -> &Signature {
80        &self.signature
81    }
82
83    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
84        Ok(arg_types[0].clone())
85    }
86
87    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
88        let value_type = args.input_types[0].clone();
89
90        Ok(vec![
91            Field::new("values", value_type, true),
92            Field::new("frequencies", DataType::UInt64, true),
93        ])
94    }
95
96    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
97        let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
98
99        let accumulator: Box<dyn Accumulator> = match data_type {
100            DataType::Int8 => Box::new(PrimitiveModeAccumulator::<Int8Type>::new(data_type)),
101            DataType::Int16 => Box::new(PrimitiveModeAccumulator::<Int16Type>::new(data_type)),
102            DataType::Int32 => Box::new(PrimitiveModeAccumulator::<Int32Type>::new(data_type)),
103            DataType::Int64 => Box::new(PrimitiveModeAccumulator::<Int64Type>::new(data_type)),
104            DataType::UInt8 => Box::new(PrimitiveModeAccumulator::<UInt8Type>::new(data_type)),
105            DataType::UInt16 => Box::new(PrimitiveModeAccumulator::<UInt16Type>::new(data_type)),
106            DataType::UInt32 => Box::new(PrimitiveModeAccumulator::<UInt32Type>::new(data_type)),
107            DataType::UInt64 => Box::new(PrimitiveModeAccumulator::<UInt64Type>::new(data_type)),
108
109            DataType::Date32 => Box::new(PrimitiveModeAccumulator::<Date32Type>::new(data_type)),
110            DataType::Date64 => Box::new(PrimitiveModeAccumulator::<Date64Type>::new(data_type)),
111            DataType::Time32(TimeUnit::Millisecond) => {
112                Box::new(PrimitiveModeAccumulator::<Time32MillisecondType>::new(data_type))
113            }
114            DataType::Time32(TimeUnit::Second) => {
115                Box::new(PrimitiveModeAccumulator::<Time32SecondType>::new(data_type))
116            }
117            DataType::Time64(TimeUnit::Microsecond) => {
118                Box::new(PrimitiveModeAccumulator::<Time64MicrosecondType>::new(data_type))
119            }
120            DataType::Time64(TimeUnit::Nanosecond) => {
121                Box::new(PrimitiveModeAccumulator::<Time64NanosecondType>::new(data_type))
122            }
123            DataType::Timestamp(TimeUnit::Microsecond, _) => {
124                Box::new(PrimitiveModeAccumulator::<TimestampMicrosecondType>::new(data_type))
125            }
126            DataType::Timestamp(TimeUnit::Millisecond, _) => {
127                Box::new(PrimitiveModeAccumulator::<TimestampMillisecondType>::new(data_type))
128            }
129            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
130                Box::new(PrimitiveModeAccumulator::<TimestampNanosecondType>::new(data_type))
131            }
132            DataType::Timestamp(TimeUnit::Second, _) => {
133                Box::new(PrimitiveModeAccumulator::<TimestampSecondType>::new(data_type))
134            }
135
136            DataType::Float16 => Box::new(FloatModeAccumulator::<Float16Type>::new(data_type)),
137            DataType::Float32 => Box::new(FloatModeAccumulator::<Float32Type>::new(data_type)),
138            DataType::Float64 => Box::new(FloatModeAccumulator::<Float64Type>::new(data_type)),
139
140            DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => Box::new(BytesModeAccumulator::new(data_type)),
141            _ => {
142                return not_impl_err!("Unsupported data type: {:?} for mode function", data_type);
143            }
144        };
145
146        Ok(accumulator)
147    }
148}