use datafusion::{arrow, common as df_common, error, logical_expr};
use std::{any, fmt};
use crate::common;
make_udaf_expr_and_func!(
ModeFunction,
mode,
x,
"Calculates the most frequent value.",
mode_udaf
);
pub struct ModeFunction {
signature: logical_expr::Signature,
}
impl fmt::Debug for ModeFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ModeFunction")
.field("signature", &self.signature)
.finish()
}
}
impl Default for ModeFunction {
fn default() -> Self {
Self::new()
}
}
impl ModeFunction {
pub fn new() -> Self {
Self {
signature: logical_expr::Signature::variadic_any(logical_expr::Volatility::Immutable),
}
}
}
impl logical_expr::AggregateUDFImpl for ModeFunction {
fn as_any(&self) -> &dyn any::Any {
self
}
fn name(&self) -> &str {
"mode"
}
fn signature(&self) -> &logical_expr::Signature {
&self.signature
}
fn return_type(
&self,
arg_types: &[arrow::datatypes::DataType],
) -> error::Result<arrow::datatypes::DataType> {
Ok(arg_types[0].clone())
}
fn state_fields(
&self,
args: logical_expr::function::StateFieldsArgs,
) -> error::Result<Vec<arrow::datatypes::FieldRef>> {
let value_type = args.input_fields[0].data_type().clone();
Ok(vec![
arrow::datatypes::Field::new("values", value_type, true).into(),
arrow::datatypes::Field::new("frequencies", arrow::datatypes::DataType::UInt64, true)
.into(),
])
}
fn accumulator(
&self,
acc_args: logical_expr::function::AccumulatorArgs,
) -> error::Result<Box<dyn logical_expr::Accumulator>> {
let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
let accumulator: Box<dyn logical_expr::Accumulator> = match data_type {
arrow::datatypes::DataType::Int8 => Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Int8Type,
>::new(data_type)),
arrow::datatypes::DataType::Int16 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Int16Type,
>::new(data_type))
}
arrow::datatypes::DataType::Int32 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Int32Type,
>::new(data_type))
}
arrow::datatypes::DataType::Int64 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Int64Type,
>::new(data_type))
}
arrow::datatypes::DataType::UInt8 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::UInt8Type,
>::new(data_type))
}
arrow::datatypes::DataType::UInt16 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::UInt16Type,
>::new(data_type))
}
arrow::datatypes::DataType::UInt32 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::UInt32Type,
>::new(data_type))
}
arrow::datatypes::DataType::UInt64 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::UInt64Type,
>::new(data_type))
}
arrow::datatypes::DataType::Date32 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Date32Type,
>::new(data_type))
}
arrow::datatypes::DataType::Date64 => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Date64Type,
>::new(data_type))
}
arrow::datatypes::DataType::Time32(arrow::datatypes::TimeUnit::Millisecond) => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Time32MillisecondType,
>::new(data_type))
}
arrow::datatypes::DataType::Time32(arrow::datatypes::TimeUnit::Second) => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Time32SecondType,
>::new(data_type))
}
arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond) => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Time64MicrosecondType,
>::new(data_type))
}
arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond) => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::Time64NanosecondType,
>::new(data_type))
}
arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::TimestampMicrosecondType,
>::new(data_type))
}
arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, _) => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::TimestampMillisecondType,
>::new(data_type))
}
arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, _) => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::TimestampNanosecondType,
>::new(data_type))
}
arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Second, _) => {
Box::new(common::mode::PrimitiveModeAccumulator::<
arrow::datatypes::TimestampSecondType,
>::new(data_type))
}
arrow::datatypes::DataType::Float16 => Box::new(common::mode::FloatModeAccumulator::<
arrow::datatypes::Float16Type,
>::new(data_type)),
arrow::datatypes::DataType::Float32 => Box::new(common::mode::FloatModeAccumulator::<
arrow::datatypes::Float32Type,
>::new(data_type)),
arrow::datatypes::DataType::Float64 => Box::new(common::mode::FloatModeAccumulator::<
arrow::datatypes::Float64Type,
>::new(data_type)),
arrow::datatypes::DataType::Utf8
| arrow::datatypes::DataType::Utf8View
| arrow::datatypes::DataType::LargeUtf8 => {
Box::new(common::mode::BytesModeAccumulator::new(data_type))
}
_ => {
return df_common::not_impl_err!(
"Unsupported data type: {:?} for mode function",
data_type
);
}
};
Ok(accumulator)
}
}