use crate::signature::TypeSignature;
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::{Result, internal_err, plan_err};
pub static INTEGERS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
];
pub static NUMERICS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float16,
DataType::Float32,
DataType::Float64,
];
pub fn check_arg_count(
func_name: &str,
input_fields: &[FieldRef],
signature: &TypeSignature,
) -> Result<()> {
match signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_fields.len() != *agg_count {
return plan_err!(
"The function {func_name} expects {:?} arguments, but {:?} were provided",
agg_count,
input_fields.len()
);
}
}
TypeSignature::Exact(types) => {
if types.len() != input_fields.len() {
return plan_err!(
"The function {func_name} expects {:?} arguments, but {:?} were provided",
types.len(),
input_fields.len()
);
}
}
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
.any(|v| check_arg_count(func_name, input_fields, v).is_ok());
if !ok {
return plan_err!(
"The function {func_name} does not accept {:?} function arguments.",
input_fields.len()
);
}
}
TypeSignature::VariadicAny => {
if input_fields.is_empty() {
return plan_err!(
"The function {func_name} expects at least one argument"
);
}
}
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_) => {
}
_ => {
return internal_err!(
"Aggregate functions do not support this {signature:?}"
);
}
}
Ok(())
}