use std::sync::Arc;
use std::{fmt, str::FromStr};
use crate::utils;
use crate::{type_coercion::aggregates::*, Signature, Volatility};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};
use strum_macros::EnumIter;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
Min,
Max,
ArrayAgg,
}
impl AggregateFunction {
pub fn name(&self) -> &str {
use AggregateFunction::*;
match self {
Min => "MIN",
Max => "MAX",
ArrayAgg => "ARRAY_AGG",
}
}
}
impl fmt::Display for AggregateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
}
impl FromStr for AggregateFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match name {
"max" => AggregateFunction::Max,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
_ => {
return plan_err!("There is no built-in function named {name}");
}
})
}
}
impl AggregateFunction {
pub fn return_type(
&self,
input_expr_types: &[DataType],
input_expr_nullable: &[bool],
) -> Result<DataType> {
let coerced_data_types = coerce_types(self, input_expr_types, &self.signature())
.map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
&format!("{self}"),
self.signature(),
input_expr_types,
)
)
})?;
match self {
AggregateFunction::Max | AggregateFunction::Min => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
input_expr_nullable[0],
)))),
}
}
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(false),
}
}
}
impl AggregateFunction {
pub fn signature(&self) -> Signature {
match self {
AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
.chain(NUMERICS.iter())
.chain(TIMESTAMPS.iter())
.chain(DATES.iter())
.chain(TIMES.iter())
.chain(BINARYS.iter())
.cloned()
.collect::<Vec<_>>();
Signature::uniform(1, valid, Volatility::Immutable)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use strum::IntoEnumIterator;
#[test]
fn test_display_and_from_str() {
for func_original in AggregateFunction::iter() {
let func_name = func_original.to_string();
let func_from_str =
AggregateFunction::from_str(func_name.to_lowercase().as_str()).unwrap();
assert_eq!(func_from_str, func_original);
}
}
}