use super::{
functions::Signature,
type_coercion::{coerce, data_types},
Accumulator, AggregateExpr, PhysicalExpr,
};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Schema, TimeUnit};
use expressions::{avg_return_type, sum_return_type};
use std::{fmt, str::FromStr, sync::Arc};
pub type AccumulatorFunctionImplementation =
Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AggregateFunction {
Count,
Sum,
Min,
Max,
Avg,
}
impl fmt::Display for AggregateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", format!("{:?}", self).to_uppercase())
}
}
impl FromStr for AggregateFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match name {
"min" => AggregateFunction::Min,
"max" => AggregateFunction::Max,
"count" => AggregateFunction::Count,
"avg" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
name
)))
}
})
}
}
pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<DataType> {
data_types(arg_types, &signature(fun))?;
match fun {
AggregateFunction::Count => Ok(DataType::UInt64),
AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()),
AggregateFunction::Sum => sum_return_type(&arg_types[0]),
AggregateFunction::Avg => avg_return_type(&arg_types[0]),
}
}
pub fn create_aggregate_expr(
fun: &AggregateFunction,
distinct: bool,
args: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
name: String,
) -> Result<Arc<dyn AggregateExpr>> {
let arg = coerce(args, input_schema, &signature(fun))?[0].clone();
let arg_types = args
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
let return_type = return_type(&fun, &arg_types)?;
Ok(match (fun, distinct) {
(AggregateFunction::Count, false) => {
Arc::new(expressions::Count::new(arg, name, return_type))
}
(AggregateFunction::Count, true) => {
Arc::new(distinct_expressions::DistinctCount::new(
arg_types,
args.to_vec(),
name,
return_type,
))
}
(AggregateFunction::Sum, false) => {
Arc::new(expressions::Sum::new(arg, name, return_type))
}
(AggregateFunction::Sum, true) => {
return Err(DataFusionError::NotImplemented(
"SUM(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::Min, _) => {
Arc::new(expressions::Min::new(arg, name, return_type))
}
(AggregateFunction::Max, _) => {
Arc::new(expressions::Max::new(arg, name, return_type))
}
(AggregateFunction::Avg, false) => {
Arc::new(expressions::Avg::new(arg, name, return_type))
}
(AggregateFunction::Avg, true) => {
return Err(DataFusionError::NotImplemented(
"AVG(DISTINCT) aggregations are not available".to_string(),
));
}
})
}
static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];
static NUMERICS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
];
static TIMESTAMPS: &[DataType] = &[
DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, None),
];
fn signature(fun: &AggregateFunction) -> Signature {
match fun {
AggregateFunction::Count => Signature::Any(1),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
.chain(NUMERICS.iter())
.chain(TIMESTAMPS.iter())
.cloned()
.collect::<Vec<_>>();
Signature::Uniform(1, valid)
}
AggregateFunction::Avg | AggregateFunction::Sum => {
Signature::Uniform(1, NUMERICS.to_vec())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
#[test]
fn test_min_max() -> Result<()> {
let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?;
assert_eq!(DataType::Int32, observed);
Ok(())
}
#[test]
fn test_sum_no_utf8() {
let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]);
assert!(observed.is_err());
}
#[test]
fn test_sum_upcasts() -> Result<()> {
let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt32])?;
assert_eq!(DataType::UInt64, observed);
Ok(())
}
#[test]
fn test_count_return_type() -> Result<()> {
let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?;
assert_eq!(DataType::UInt64, observed);
let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?;
assert_eq!(DataType::UInt64, observed);
Ok(())
}
#[test]
fn test_avg_return_type() -> Result<()> {
let observed = return_type(&AggregateFunction::Avg, &[DataType::Float32])?;
assert_eq!(DataType::Float64, observed);
let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_avg_no_utf8() {
let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]);
assert!(observed.is_err());
}
}