datafusion_functions_extra/
mode.rs1use datafusion::{arrow, common as df_common, error, logical_expr};
19use std::{any, fmt};
20
21use crate::common;
22
23make_udaf_expr_and_func!(
24 ModeFunction,
25 mode,
26 x,
27 "Calculates the most frequent value.",
28 mode_udaf
29);
30
31pub struct ModeFunction {
36 signature: logical_expr::Signature,
37}
38
39impl fmt::Debug for ModeFunction {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 f.debug_struct("ModeFunction")
42 .field("signature", &self.signature)
43 .finish()
44 }
45}
46
47impl Default for ModeFunction {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl ModeFunction {
54 pub fn new() -> Self {
55 Self {
56 signature: logical_expr::Signature::variadic_any(logical_expr::Volatility::Immutable),
57 }
58 }
59}
60
61impl logical_expr::AggregateUDFImpl for ModeFunction {
62 fn as_any(&self) -> &dyn any::Any {
63 self
64 }
65
66 fn name(&self) -> &str {
67 "mode"
68 }
69
70 fn signature(&self) -> &logical_expr::Signature {
71 &self.signature
72 }
73
74 fn return_type(
75 &self,
76 arg_types: &[arrow::datatypes::DataType],
77 ) -> error::Result<arrow::datatypes::DataType> {
78 Ok(arg_types[0].clone())
79 }
80
81 fn state_fields(
82 &self,
83 args: logical_expr::function::StateFieldsArgs,
84 ) -> error::Result<Vec<arrow::datatypes::FieldRef>> {
85 let value_type = args.input_fields[0].data_type().clone();
86
87 Ok(vec![
88 arrow::datatypes::Field::new("values", value_type, true).into(),
89 arrow::datatypes::Field::new("frequencies", arrow::datatypes::DataType::UInt64, true)
90 .into(),
91 ])
92 }
93
94 fn accumulator(
95 &self,
96 acc_args: logical_expr::function::AccumulatorArgs,
97 ) -> error::Result<Box<dyn logical_expr::Accumulator>> {
98 let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
99
100 let accumulator: Box<dyn logical_expr::Accumulator> = match data_type {
101 arrow::datatypes::DataType::Int8 => Box::new(common::mode::PrimitiveModeAccumulator::<
102 arrow::datatypes::Int8Type,
103 >::new(data_type)),
104 arrow::datatypes::DataType::Int16 => {
105 Box::new(common::mode::PrimitiveModeAccumulator::<
106 arrow::datatypes::Int16Type,
107 >::new(data_type))
108 }
109 arrow::datatypes::DataType::Int32 => {
110 Box::new(common::mode::PrimitiveModeAccumulator::<
111 arrow::datatypes::Int32Type,
112 >::new(data_type))
113 }
114 arrow::datatypes::DataType::Int64 => {
115 Box::new(common::mode::PrimitiveModeAccumulator::<
116 arrow::datatypes::Int64Type,
117 >::new(data_type))
118 }
119 arrow::datatypes::DataType::UInt8 => {
120 Box::new(common::mode::PrimitiveModeAccumulator::<
121 arrow::datatypes::UInt8Type,
122 >::new(data_type))
123 }
124 arrow::datatypes::DataType::UInt16 => {
125 Box::new(common::mode::PrimitiveModeAccumulator::<
126 arrow::datatypes::UInt16Type,
127 >::new(data_type))
128 }
129 arrow::datatypes::DataType::UInt32 => {
130 Box::new(common::mode::PrimitiveModeAccumulator::<
131 arrow::datatypes::UInt32Type,
132 >::new(data_type))
133 }
134 arrow::datatypes::DataType::UInt64 => {
135 Box::new(common::mode::PrimitiveModeAccumulator::<
136 arrow::datatypes::UInt64Type,
137 >::new(data_type))
138 }
139
140 arrow::datatypes::DataType::Date32 => {
141 Box::new(common::mode::PrimitiveModeAccumulator::<
142 arrow::datatypes::Date32Type,
143 >::new(data_type))
144 }
145 arrow::datatypes::DataType::Date64 => {
146 Box::new(common::mode::PrimitiveModeAccumulator::<
147 arrow::datatypes::Date64Type,
148 >::new(data_type))
149 }
150 arrow::datatypes::DataType::Time32(arrow::datatypes::TimeUnit::Millisecond) => {
151 Box::new(common::mode::PrimitiveModeAccumulator::<
152 arrow::datatypes::Time32MillisecondType,
153 >::new(data_type))
154 }
155 arrow::datatypes::DataType::Time32(arrow::datatypes::TimeUnit::Second) => {
156 Box::new(common::mode::PrimitiveModeAccumulator::<
157 arrow::datatypes::Time32SecondType,
158 >::new(data_type))
159 }
160 arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond) => {
161 Box::new(common::mode::PrimitiveModeAccumulator::<
162 arrow::datatypes::Time64MicrosecondType,
163 >::new(data_type))
164 }
165 arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond) => {
166 Box::new(common::mode::PrimitiveModeAccumulator::<
167 arrow::datatypes::Time64NanosecondType,
168 >::new(data_type))
169 }
170 arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => {
171 Box::new(common::mode::PrimitiveModeAccumulator::<
172 arrow::datatypes::TimestampMicrosecondType,
173 >::new(data_type))
174 }
175 arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, _) => {
176 Box::new(common::mode::PrimitiveModeAccumulator::<
177 arrow::datatypes::TimestampMillisecondType,
178 >::new(data_type))
179 }
180 arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, _) => {
181 Box::new(common::mode::PrimitiveModeAccumulator::<
182 arrow::datatypes::TimestampNanosecondType,
183 >::new(data_type))
184 }
185 arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Second, _) => {
186 Box::new(common::mode::PrimitiveModeAccumulator::<
187 arrow::datatypes::TimestampSecondType,
188 >::new(data_type))
189 }
190
191 arrow::datatypes::DataType::Float16 => Box::new(common::mode::FloatModeAccumulator::<
192 arrow::datatypes::Float16Type,
193 >::new(data_type)),
194 arrow::datatypes::DataType::Float32 => Box::new(common::mode::FloatModeAccumulator::<
195 arrow::datatypes::Float32Type,
196 >::new(data_type)),
197 arrow::datatypes::DataType::Float64 => Box::new(common::mode::FloatModeAccumulator::<
198 arrow::datatypes::Float64Type,
199 >::new(data_type)),
200
201 arrow::datatypes::DataType::Utf8
202 | arrow::datatypes::DataType::Utf8View
203 | arrow::datatypes::DataType::LargeUtf8 => {
204 Box::new(common::mode::BytesModeAccumulator::new(data_type))
205 }
206 _ => {
207 return df_common::not_impl_err!(
208 "Unsupported data type: {:?} for mode function",
209 data_type
210 );
211 }
212 };
213
214 Ok(accumulator)
215 }
216}