use std::fmt::Display;
use std::hash::Hash;
use std::sync::Arc;
use crate::type_coercion::aggregates::NUMERICS;
use arrow::datatypes::{
DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DataType,
Decimal128Type, DecimalType, Field, IntervalUnit, TimeUnit,
};
use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType};
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, internal_err, plan_err};
use indexmap::IndexSet;
use itertools::Itertools;
pub const TIMEZONE_WILDCARD: &str = "+TZ";
pub const FIXED_SIZE_LIST_WILDCARD: i32 = i32::MIN;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum Volatility {
Immutable,
Stable,
Volatile,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Arity {
Fixed(usize),
Variable,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum TypeSignature {
Variadic(Vec<DataType>),
UserDefined,
VariadicAny,
Uniform(usize, Vec<DataType>),
Exact(Vec<DataType>),
Coercible(Vec<Coercion>),
Comparable(usize),
Any(usize),
OneOf(Vec<TypeSignature>),
ArraySignature(ArrayFunctionSignature),
Numeric(usize),
String(usize),
Nullary,
}
impl TypeSignature {
#[inline]
pub fn is_one_of(&self) -> bool {
matches!(self, TypeSignature::OneOf(_))
}
pub fn arity(&self) -> Arity {
match self {
TypeSignature::Exact(types) => Arity::Fixed(types.len()),
TypeSignature::Uniform(count, _) => Arity::Fixed(*count),
TypeSignature::Numeric(count) => Arity::Fixed(*count),
TypeSignature::String(count) => Arity::Fixed(*count),
TypeSignature::Comparable(count) => Arity::Fixed(*count),
TypeSignature::Any(count) => Arity::Fixed(*count),
TypeSignature::Coercible(types) => Arity::Fixed(types.len()),
TypeSignature::Nullary => Arity::Fixed(0),
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments,
..
}) => Arity::Fixed(arguments.len()),
TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray) => {
Arity::Fixed(1)
}
TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray) => {
Arity::Fixed(1)
}
TypeSignature::OneOf(variants) => {
let has_variable = variants.iter().any(|v| v.arity() == Arity::Variable);
if has_variable {
return Arity::Variable;
}
let max_arity = variants
.iter()
.filter_map(|v| match v.arity() {
Arity::Fixed(n) => Some(n),
Arity::Variable => None,
})
.max();
match max_arity {
Some(n) => Arity::Fixed(n),
None => Arity::Variable,
}
}
TypeSignature::Variadic(_)
| TypeSignature::VariadicAny
| TypeSignature::UserDefined => Arity::Variable,
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)]
pub enum TypeSignatureClass {
Any,
Timestamp,
Time,
Interval,
Duration,
Native(LogicalTypeRef),
Integer,
Float,
Decimal,
Numeric,
Binary,
}
impl Display for TypeSignatureClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TypeSignatureClass::{self:?}")
}
}
impl TypeSignatureClass {
fn get_example_types(&self) -> Vec<DataType> {
match self {
TypeSignatureClass::Any => vec![],
TypeSignatureClass::Native(l) => get_data_types(l.native()),
TypeSignatureClass::Timestamp => {
vec![
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(
TimeUnit::Nanosecond,
Some(TIMEZONE_WILDCARD.into()),
),
]
}
TypeSignatureClass::Time => {
vec![DataType::Time64(TimeUnit::Nanosecond)]
}
TypeSignatureClass::Interval => {
vec![DataType::Interval(IntervalUnit::DayTime)]
}
TypeSignatureClass::Duration => {
vec![DataType::Duration(TimeUnit::Nanosecond)]
}
TypeSignatureClass::Integer => {
vec![DataType::Int64]
}
TypeSignatureClass::Binary => {
vec![DataType::Binary]
}
TypeSignatureClass::Decimal => vec![Decimal128Type::DEFAULT_TYPE],
TypeSignatureClass::Float => vec![DataType::Float64],
TypeSignatureClass::Numeric => vec![
DataType::Float64,
DataType::Int64,
Decimal128Type::DEFAULT_TYPE,
],
}
}
pub fn matches_native_type(&self, logical_type: &NativeType) -> bool {
if logical_type == &NativeType::Null {
return true;
}
match self {
TypeSignatureClass::Any => true,
TypeSignatureClass::Native(t) if t.native() == logical_type => true,
TypeSignatureClass::Timestamp if logical_type.is_timestamp() => true,
TypeSignatureClass::Time if logical_type.is_time() => true,
TypeSignatureClass::Interval if logical_type.is_interval() => true,
TypeSignatureClass::Duration if logical_type.is_duration() => true,
TypeSignatureClass::Integer if logical_type.is_integer() => true,
TypeSignatureClass::Binary if logical_type.is_binary() => true,
TypeSignatureClass::Decimal if logical_type.is_decimal() => true,
TypeSignatureClass::Float if logical_type.is_float() => true,
TypeSignatureClass::Numeric if logical_type.is_numeric() => true,
_ => false,
}
}
pub fn default_casted_type(
&self,
native_type: &NativeType,
origin_type: &DataType,
) -> Result<DataType> {
match self {
TypeSignatureClass::Any => Ok(origin_type.to_owned()),
TypeSignatureClass::Native(logical_type) => {
logical_type.native().default_cast_for(origin_type)
}
TypeSignatureClass::Timestamp if native_type.is_timestamp() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Time if native_type.is_time() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Interval if native_type.is_interval() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Duration if native_type.is_duration() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Integer if native_type.is_integer() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Binary if native_type.is_binary() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Decimal if native_type.is_decimal() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Float if native_type.is_float() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Numeric if native_type.is_numeric() => {
Ok(origin_type.to_owned())
}
_ if native_type.is_null() => Ok(origin_type.to_owned()),
_ => internal_err!("May miss the matching logic in `matches_native_type`"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum ArrayFunctionSignature {
Array {
arguments: Vec<ArrayFunctionArgument>,
array_coercion: Option<ListCoercion>,
},
RecursiveArray,
MapArray,
}
impl Display for ArrayFunctionSignature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ArrayFunctionSignature::Array { arguments, .. } => {
for (idx, argument) in arguments.iter().enumerate() {
write!(f, "{argument}")?;
if idx != arguments.len() - 1 {
write!(f, ", ")?;
}
}
Ok(())
}
ArrayFunctionSignature::RecursiveArray => {
write!(f, "recursive_array")
}
ArrayFunctionSignature::MapArray => {
write!(f, "map_array")
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum ArrayFunctionArgument {
Element,
Index,
Array,
String,
}
impl Display for ArrayFunctionArgument {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ArrayFunctionArgument::Element => {
write!(f, "element")
}
ArrayFunctionArgument::Index => {
write!(f, "index")
}
ArrayFunctionArgument::Array => {
write!(f, "array")
}
ArrayFunctionArgument::String => {
write!(f, "string")
}
}
}
}
impl TypeSignature {
pub fn to_string_repr(&self) -> Vec<String> {
match self {
TypeSignature::Nullary => {
vec!["NullAry()".to_string()]
}
TypeSignature::Variadic(types) => {
vec![format!("{}, ..", Self::join_types(types, "/"))]
}
TypeSignature::Uniform(arg_count, valid_types) => {
vec![
std::iter::repeat_n(Self::join_types(valid_types, "/"), *arg_count)
.collect::<Vec<String>>()
.join(", "),
]
}
TypeSignature::String(num) => {
vec![format!("String({num})")]
}
TypeSignature::Numeric(num) => {
vec![format!("Numeric({num})")]
}
TypeSignature::Comparable(num) => {
vec![format!("Comparable({num})")]
}
TypeSignature::Coercible(coercions) => {
vec![Self::join_types(coercions, ", ")]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Any(arg_count) => {
vec![
std::iter::repeat_n("Any", *arg_count)
.collect::<Vec<&str>>()
.join(", "),
]
}
TypeSignature::UserDefined => {
vec!["UserDefined".to_string()]
}
TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()],
TypeSignature::OneOf(sigs) => {
sigs.iter().flat_map(|s| s.to_string_repr()).collect()
}
TypeSignature::ArraySignature(array_signature) => {
vec![array_signature.to_string()]
}
}
}
pub fn to_string_repr_with_names(
&self,
parameter_names: Option<&[String]>,
) -> Vec<String> {
match self {
TypeSignature::Exact(types) => {
if let Some(names) = parameter_names {
vec![
names
.iter()
.zip(types.iter())
.map(|(name, typ)| format!("{name}: {typ}"))
.collect::<Vec<_>>()
.join(", "),
]
} else {
vec![Self::join_types(types, ", ")]
}
}
TypeSignature::Any(count) => {
if let Some(names) = parameter_names {
vec![
names
.iter()
.take(*count)
.map(|name| format!("{name}: Any"))
.collect::<Vec<_>>()
.join(", "),
]
} else {
vec![
std::iter::repeat_n("Any", *count)
.collect::<Vec<&str>>()
.join(", "),
]
}
}
TypeSignature::Uniform(count, types) => {
if let Some(names) = parameter_names {
let type_str = Self::join_types(types, "/");
vec![
names
.iter()
.take(*count)
.map(|name| format!("{name}: {type_str}"))
.collect::<Vec<_>>()
.join(", "),
]
} else {
self.to_string_repr()
}
}
TypeSignature::Coercible(coercions) => {
if let Some(names) = parameter_names {
vec![
names
.iter()
.zip(coercions.iter())
.map(|(name, coercion)| format!("{name}: {coercion}"))
.collect::<Vec<_>>()
.join(", "),
]
} else {
vec![Self::join_types(coercions, ", ")]
}
}
TypeSignature::Comparable(count) => {
if let Some(names) = parameter_names {
vec![
names
.iter()
.take(*count)
.map(|name| format!("{name}: Comparable"))
.collect::<Vec<_>>()
.join(", "),
]
} else {
self.to_string_repr()
}
}
TypeSignature::Numeric(count) => {
if let Some(names) = parameter_names {
vec![
names
.iter()
.take(*count)
.map(|name| format!("{name}: Numeric"))
.collect::<Vec<_>>()
.join(", "),
]
} else {
self.to_string_repr()
}
}
TypeSignature::String(count) => {
if let Some(names) = parameter_names {
vec![
names
.iter()
.take(*count)
.map(|name| format!("{name}: String"))
.collect::<Vec<_>>()
.join(", "),
]
} else {
self.to_string_repr()
}
}
TypeSignature::Nullary => self.to_string_repr(),
TypeSignature::ArraySignature(array_sig) => {
if let Some(names) = parameter_names {
match array_sig {
ArrayFunctionSignature::Array { arguments, .. } => {
vec![
names
.iter()
.zip(arguments.iter())
.map(|(name, arg_type)| format!("{name}: {arg_type}"))
.collect::<Vec<_>>()
.join(", "),
]
}
ArrayFunctionSignature::RecursiveArray => {
vec![
names
.iter()
.take(1)
.map(|name| format!("{name}: recursive_array"))
.collect::<Vec<_>>()
.join(", "),
]
}
ArrayFunctionSignature::MapArray => {
vec![
names
.iter()
.take(1)
.map(|name| format!("{name}: map_array"))
.collect::<Vec<_>>()
.join(", "),
]
}
}
} else {
self.to_string_repr()
}
}
TypeSignature::OneOf(sigs) => sigs
.iter()
.flat_map(|s| s.to_string_repr_with_names(parameter_names))
.collect(),
TypeSignature::UserDefined => {
if let Some(names) = parameter_names {
vec![names.join(", ")]
} else {
self.to_string_repr()
}
}
TypeSignature::Variadic(_) | TypeSignature::VariadicAny => {
self.to_string_repr()
}
}
}
pub fn join_types<T: Display>(types: &[T], delimiter: &str) -> String {
types
.iter()
.map(|t| t.to_string())
.collect::<Vec<String>>()
.join(delimiter)
}
pub fn supports_zero_argument(&self) -> bool {
match &self {
TypeSignature::Exact(vec) => vec.is_empty(),
TypeSignature::Nullary => true,
TypeSignature::OneOf(types) => types
.iter()
.any(|type_sig| type_sig.supports_zero_argument()),
_ => false,
}
}
pub fn used_to_support_zero_arguments(&self) -> bool {
match &self {
TypeSignature::Any(num) => *num == 0,
_ => self.supports_zero_argument(),
}
}
#[deprecated(since = "46.0.0", note = "See get_example_types instead")]
pub fn get_possible_types(&self) -> Vec<Vec<DataType>> {
self.get_example_types()
}
pub fn get_example_types(&self) -> Vec<Vec<DataType>> {
match self {
TypeSignature::Exact(types) => vec![types.clone()],
TypeSignature::OneOf(types) => types
.iter()
.flat_map(|type_sig| type_sig.get_example_types())
.collect(),
TypeSignature::Uniform(arg_count, types) => types
.iter()
.cloned()
.map(|data_type| vec![data_type; *arg_count])
.collect(),
TypeSignature::Coercible(coercions) => coercions
.iter()
.map(|c| {
let mut all_types: IndexSet<DataType> =
c.desired_type().get_example_types().into_iter().collect();
if let Some(implicit_coercion) = c.implicit_coercion() {
let allowed_casts: Vec<DataType> = implicit_coercion
.allowed_source_types
.iter()
.flat_map(|t| t.get_example_types())
.collect();
all_types.extend(allowed_casts);
}
all_types.into_iter().collect::<Vec<_>>()
})
.multi_cartesian_product()
.collect(),
TypeSignature::Variadic(types) => types
.iter()
.cloned()
.map(|data_type| vec![data_type])
.collect(),
TypeSignature::Numeric(arg_count) => NUMERICS
.iter()
.cloned()
.map(|numeric_type| vec![numeric_type; *arg_count])
.collect(),
TypeSignature::String(arg_count) => get_data_types(&NativeType::String)
.into_iter()
.map(|dt| vec![dt; *arg_count])
.collect::<Vec<_>>(),
TypeSignature::Any(_)
| TypeSignature::Comparable(_)
| TypeSignature::Nullary
| TypeSignature::VariadicAny
| TypeSignature::ArraySignature(_)
| TypeSignature::UserDefined => vec![],
}
}
}
fn get_data_types(native_type: &NativeType) -> Vec<DataType> {
match native_type {
NativeType::Null => vec![DataType::Null],
NativeType::Boolean => vec![DataType::Boolean],
NativeType::Int8 => vec![DataType::Int8],
NativeType::Int16 => vec![DataType::Int16],
NativeType::Int32 => vec![DataType::Int32],
NativeType::Int64 => vec![DataType::Int64],
NativeType::UInt8 => vec![DataType::UInt8],
NativeType::UInt16 => vec![DataType::UInt16],
NativeType::UInt32 => vec![DataType::UInt32],
NativeType::UInt64 => vec![DataType::UInt64],
NativeType::Float16 => vec![DataType::Float16],
NativeType::Float32 => vec![DataType::Float32],
NativeType::Float64 => vec![DataType::Float64],
NativeType::Date => vec![DataType::Date32, DataType::Date64],
NativeType::Binary => vec![
DataType::Binary,
DataType::LargeBinary,
DataType::BinaryView,
],
NativeType::String => {
vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]
}
NativeType::Decimal(precision, scale) => {
let mut types = vec![DataType::Decimal256(*precision, *scale)];
if *precision <= DECIMAL32_MAX_PRECISION {
types.push(DataType::Decimal32(*precision, *scale));
}
if *precision <= DECIMAL64_MAX_PRECISION {
types.push(DataType::Decimal64(*precision, *scale));
}
if *precision <= DECIMAL128_MAX_PRECISION {
types.push(DataType::Decimal128(*precision, *scale));
}
types
}
NativeType::Timestamp(time_unit, timezone) => {
vec![DataType::Timestamp(*time_unit, timezone.to_owned())]
}
NativeType::Time(TimeUnit::Second) => vec![DataType::Time32(TimeUnit::Second)],
NativeType::Time(TimeUnit::Millisecond) => {
vec![DataType::Time32(TimeUnit::Millisecond)]
}
NativeType::Time(TimeUnit::Microsecond) => {
vec![DataType::Time64(TimeUnit::Microsecond)]
}
NativeType::Time(TimeUnit::Nanosecond) => {
vec![DataType::Time64(TimeUnit::Nanosecond)]
}
NativeType::Duration(time_unit) => vec![DataType::Duration(*time_unit)],
NativeType::Interval(interval_unit) => vec![DataType::Interval(*interval_unit)],
NativeType::FixedSizeBinary(size) => vec![DataType::FixedSizeBinary(*size)],
NativeType::FixedSizeList(logical_field, size) => {
get_data_types(logical_field.logical_type.native())
.iter()
.map(|child_dt| {
let field = Field::new(
logical_field.name.clone(),
child_dt.clone(),
logical_field.nullable,
);
DataType::FixedSizeList(Arc::new(field), *size)
})
.collect()
}
NativeType::List(_)
| NativeType::Struct(_)
| NativeType::Union(_)
| NativeType::Map(_) => {
vec![]
}
}
}
#[derive(Debug, Clone, Eq, PartialOrd)]
pub enum Coercion {
Exact {
desired_type: TypeSignatureClass,
},
Implicit {
desired_type: TypeSignatureClass,
implicit_coercion: ImplicitCoercion,
},
}
impl Coercion {
pub fn new_exact(desired_type: TypeSignatureClass) -> Self {
Self::Exact { desired_type }
}
pub fn new_implicit(
desired_type: TypeSignatureClass,
allowed_source_types: Vec<TypeSignatureClass>,
default_casted_type: NativeType,
) -> Self {
Self::Implicit {
desired_type,
implicit_coercion: ImplicitCoercion {
allowed_source_types,
default_casted_type,
},
}
}
pub fn allowed_source_types(&self) -> &[TypeSignatureClass] {
match self {
Coercion::Exact { .. } => &[],
Coercion::Implicit {
implicit_coercion, ..
} => implicit_coercion.allowed_source_types.as_slice(),
}
}
pub fn default_casted_type(&self) -> Option<&NativeType> {
match self {
Coercion::Exact { .. } => None,
Coercion::Implicit {
implicit_coercion, ..
} => Some(&implicit_coercion.default_casted_type),
}
}
pub fn desired_type(&self) -> &TypeSignatureClass {
match self {
Coercion::Exact { desired_type } => desired_type,
Coercion::Implicit { desired_type, .. } => desired_type,
}
}
pub fn implicit_coercion(&self) -> Option<&ImplicitCoercion> {
match self {
Coercion::Exact { .. } => None,
Coercion::Implicit {
implicit_coercion, ..
} => Some(implicit_coercion),
}
}
}
impl Display for Coercion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Coercion({}", self.desired_type())?;
if let Some(implicit_coercion) = self.implicit_coercion() {
write!(f, ", implicit_coercion={implicit_coercion}",)
} else {
write!(f, ")")
}
}
}
impl PartialEq for Coercion {
fn eq(&self, other: &Self) -> bool {
self.desired_type() == other.desired_type()
&& self.implicit_coercion() == other.implicit_coercion()
}
}
impl Hash for Coercion {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.desired_type().hash(state);
self.implicit_coercion().hash(state);
}
}
#[derive(Debug, Clone, Eq, PartialOrd)]
pub struct ImplicitCoercion {
allowed_source_types: Vec<TypeSignatureClass>,
default_casted_type: NativeType,
}
impl Display for ImplicitCoercion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ImplicitCoercion({:?}, default_type={:?})",
self.allowed_source_types, self.default_casted_type
)
}
}
impl PartialEq for ImplicitCoercion {
fn eq(&self, other: &Self) -> bool {
self.allowed_source_types == other.allowed_source_types
&& self.default_casted_type == other.default_casted_type
}
}
impl Hash for ImplicitCoercion {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.allowed_source_types.hash(state);
self.default_casted_type.hash(state);
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Signature {
pub type_signature: TypeSignature,
pub volatility: Volatility,
pub parameter_names: Option<Vec<String>>,
}
impl Signature {
pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self {
Signature {
type_signature,
volatility,
parameter_names: None,
}
}
pub fn variadic(common_types: Vec<DataType>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Variadic(common_types),
volatility,
parameter_names: None,
}
}
pub fn user_defined(volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::UserDefined,
volatility,
parameter_names: None,
}
}
pub fn numeric(arg_count: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Numeric(arg_count),
volatility,
parameter_names: None,
}
}
pub fn string(arg_count: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::String(arg_count),
volatility,
parameter_names: None,
}
}
pub fn variadic_any(volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::VariadicAny,
volatility,
parameter_names: None,
}
}
pub fn uniform(
arg_count: usize,
valid_types: Vec<DataType>,
volatility: Volatility,
) -> Self {
Self {
type_signature: TypeSignature::Uniform(arg_count, valid_types),
volatility,
parameter_names: None,
}
}
pub fn exact(exact_types: Vec<DataType>, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::Exact(exact_types),
volatility,
parameter_names: None,
}
}
pub fn coercible(target_types: Vec<Coercion>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
parameter_names: None,
}
}
pub fn comparable(arg_count: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Comparable(arg_count),
volatility,
parameter_names: None,
}
}
pub fn nullary(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::Nullary,
volatility,
parameter_names: None,
}
}
pub fn any(arg_count: usize, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::Any(arg_count),
volatility,
parameter_names: None,
}
}
pub fn one_of(type_signatures: Vec<TypeSignature>, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::OneOf(type_signatures),
volatility,
parameter_names: None,
}
}
pub fn array_and_element(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Element,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
},
),
volatility,
parameter_names: None,
}
}
pub fn element_and_array(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Element,
ArrayFunctionArgument::Array,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
},
),
volatility,
parameter_names: None,
}
}
pub fn arrays(
n: usize,
coercion: Option<ListCoercion>,
volatility: Volatility,
) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
arguments: vec![ArrayFunctionArgument::Array; n],
array_coercion: coercion,
},
),
volatility,
parameter_names: None,
}
}
pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::OneOf(vec![
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Element,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
}),
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Element,
ArrayFunctionArgument::Index,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
}),
]),
volatility,
parameter_names: None,
}
}
pub fn array_and_index(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Index,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
},
),
volatility,
parameter_names: None,
}
}
pub fn array(volatility: Volatility) -> Self {
Signature::arrays(1, Some(ListCoercion::FixedSizedListToList), volatility)
}
pub fn with_parameter_names(mut self, names: Vec<impl Into<String>>) -> Result<Self> {
let names = names.into_iter().map(Into::into).collect::<Vec<String>>();
self.validate_parameter_names(&names)?;
self.parameter_names = Some(names);
Ok(self)
}
fn validate_parameter_names(&self, names: &[String]) -> Result<()> {
match self.type_signature.arity() {
Arity::Fixed(expected) => {
if names.len() != expected {
return plan_err!(
"Parameter names count ({}) does not match signature arity ({})",
names.len(),
expected
);
}
}
Arity::Variable => {
if self.type_signature != TypeSignature::UserDefined {
return plan_err!(
"Cannot specify parameter names for variable arity signature: {:?}",
self.type_signature
);
}
}
}
let mut seen = std::collections::HashSet::new();
for name in names {
if !seen.insert(name) {
return plan_err!("Duplicate parameter name: '{}'", name);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use datafusion_common::types::{logical_int32, logical_int64, logical_string};
use super::*;
use crate::signature::{
ArrayFunctionArgument, ArrayFunctionSignature, Coercion, TypeSignatureClass,
};
#[test]
fn supports_zero_argument_tests() {
let positive_cases = vec![
TypeSignature::Exact(vec![]),
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int8]),
TypeSignature::Nullary,
TypeSignature::Uniform(1, vec![DataType::Int8]),
]),
TypeSignature::Nullary,
];
for case in positive_cases {
assert!(
case.supports_zero_argument(),
"Expected {case:?} to support zero arguments"
);
}
let negative_cases = vec![
TypeSignature::Exact(vec![DataType::Utf8]),
TypeSignature::Uniform(1, vec![DataType::Float64]),
TypeSignature::Any(1),
TypeSignature::VariadicAny,
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int8]),
TypeSignature::Uniform(1, vec![DataType::Int8]),
]),
];
for case in negative_cases {
assert!(
!case.supports_zero_argument(),
"Expected {case:?} not to support zero arguments"
);
}
}
#[test]
fn type_signature_partial_ord() {
assert!(TypeSignature::UserDefined < TypeSignature::VariadicAny);
assert!(TypeSignature::UserDefined < TypeSignature::Any(1));
assert!(
TypeSignature::Uniform(1, vec![DataType::Null])
< TypeSignature::Uniform(1, vec![DataType::Boolean])
);
assert!(
TypeSignature::Uniform(1, vec![DataType::Null])
< TypeSignature::Uniform(2, vec![DataType::Null])
);
assert!(
TypeSignature::Uniform(usize::MAX, vec![DataType::Null])
< TypeSignature::Exact(vec![DataType::Null])
);
}
#[test]
fn test_get_possible_types() {
let type_signature = TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]);
let possible_types = type_signature.get_example_types();
assert_eq!(possible_types, vec![vec![DataType::Int32, DataType::Int64]]);
let type_signature = TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]),
TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]),
]);
let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
vec![DataType::Int32, DataType::Int64],
vec![DataType::Float32, DataType::Float64]
]
);
let type_signature = TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]),
TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]),
TypeSignature::Exact(vec![DataType::Utf8]),
]);
let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
vec![DataType::Int32, DataType::Int64],
vec![DataType::Float32, DataType::Float64],
vec![DataType::Utf8]
]
);
let type_signature =
TypeSignature::Uniform(2, vec![DataType::Float32, DataType::Int64]);
let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
vec![DataType::Float32, DataType::Float32],
vec![DataType::Int64, DataType::Int64]
]
);
let type_signature = TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_int64())),
]);
let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
vec![DataType::Utf8, DataType::Int64],
vec![DataType::LargeUtf8, DataType::Int64],
vec![DataType::Utf8View, DataType::Int64]
]
);
let type_signature =
TypeSignature::Variadic(vec![DataType::Int32, DataType::Int64]);
let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![vec![DataType::Int32], vec![DataType::Int64]]
);
let type_signature = TypeSignature::Numeric(2);
let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
vec![DataType::Int8, DataType::Int8],
vec![DataType::Int16, DataType::Int16],
vec![DataType::Int32, DataType::Int32],
vec![DataType::Int64, DataType::Int64],
vec![DataType::UInt8, DataType::UInt8],
vec![DataType::UInt16, DataType::UInt16],
vec![DataType::UInt32, DataType::UInt32],
vec![DataType::UInt64, DataType::UInt64],
vec![DataType::Float16, DataType::Float16],
vec![DataType::Float32, DataType::Float32],
vec![DataType::Float64, DataType::Float64]
]
);
let type_signature = TypeSignature::String(2);
let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
vec![DataType::Utf8, DataType::Utf8],
vec![DataType::LargeUtf8, DataType::LargeUtf8],
vec![DataType::Utf8View, DataType::Utf8View]
]
);
}
#[test]
fn test_signature_with_parameter_names() {
let sig = Signature::exact(
vec![DataType::Int32, DataType::Utf8],
Volatility::Immutable,
)
.with_parameter_names(vec!["count".to_string(), "name".to_string()])
.unwrap();
assert_eq!(
sig.parameter_names,
Some(vec!["count".to_string(), "name".to_string()])
);
assert_eq!(
sig.type_signature,
TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8])
);
}
#[test]
fn test_signature_parameter_names_wrong_count() {
let result = Signature::exact(
vec![DataType::Int32, DataType::Utf8],
Volatility::Immutable,
)
.with_parameter_names(vec!["count".to_string()]);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("does not match signature arity")
);
}
#[test]
fn test_signature_parameter_names_duplicate() {
let result = Signature::exact(
vec![DataType::Int32, DataType::Int32],
Volatility::Immutable,
)
.with_parameter_names(vec!["count".to_string(), "count".to_string()]);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Duplicate parameter name")
);
}
#[test]
fn test_signature_parameter_names_variadic() {
let result = Signature::variadic(vec![DataType::Int32], Volatility::Immutable)
.with_parameter_names(vec!["arg".to_string()]);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("variable arity signature")
);
}
#[test]
fn test_signature_without_parameter_names() {
let sig = Signature::exact(
vec![DataType::Int32, DataType::Utf8],
Volatility::Immutable,
);
assert_eq!(sig.parameter_names, None);
}
#[test]
fn test_signature_uniform_with_parameter_names() {
let sig = Signature::uniform(3, vec![DataType::Float64], Volatility::Immutable)
.with_parameter_names(vec!["x".to_string(), "y".to_string(), "z".to_string()])
.unwrap();
assert_eq!(
sig.parameter_names,
Some(vec!["x".to_string(), "y".to_string(), "z".to_string()])
);
}
#[test]
fn test_signature_numeric_with_parameter_names() {
let sig = Signature::numeric(2, Volatility::Immutable)
.with_parameter_names(vec!["a".to_string(), "b".to_string()])
.unwrap();
assert_eq!(
sig.parameter_names,
Some(vec!["a".to_string(), "b".to_string()])
);
}
#[test]
fn test_signature_nullary_with_empty_names() {
let sig = Signature::nullary(Volatility::Immutable)
.with_parameter_names(Vec::<String>::new())
.unwrap();
assert_eq!(sig.parameter_names, Some(vec![]));
}
#[test]
fn test_to_string_repr_with_names_exact() {
let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]);
assert_eq!(sig.to_string_repr_with_names(None), vec!["Int32, Utf8"]);
let names = vec!["id".to_string(), "name".to_string()];
assert_eq!(
sig.to_string_repr_with_names(Some(&names)),
vec!["id: Int32, name: Utf8"]
);
}
#[test]
fn test_to_string_repr_with_names_any() {
let sig = TypeSignature::Any(3);
assert_eq!(sig.to_string_repr_with_names(None), vec!["Any, Any, Any"]);
let names = vec!["x".to_string(), "y".to_string(), "z".to_string()];
assert_eq!(
sig.to_string_repr_with_names(Some(&names)),
vec!["x: Any, y: Any, z: Any"]
);
}
#[test]
fn test_to_string_repr_with_names_one_of() {
let sig =
TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]);
assert_eq!(
sig.to_string_repr_with_names(None),
vec!["Any, Any", "Any, Any, Any"]
);
let names = vec![
"str".to_string(),
"start_pos".to_string(),
"length".to_string(),
];
assert_eq!(
sig.to_string_repr_with_names(Some(&names)),
vec![
"str: Any, start_pos: Any",
"str: Any, start_pos: Any, length: Any"
]
);
}
#[test]
fn test_to_string_repr_with_names_partial() {
let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]);
let names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
assert_eq!(
sig.to_string_repr_with_names(Some(&names)),
vec!["a: Int32, b: Utf8"]
);
}
#[test]
fn test_to_string_repr_with_names_uniform() {
let sig = TypeSignature::Uniform(2, vec![DataType::Float64]);
assert_eq!(
sig.to_string_repr_with_names(None),
vec!["Float64, Float64"]
);
let names = vec!["x".to_string(), "y".to_string()];
assert_eq!(
sig.to_string_repr_with_names(Some(&names)),
vec!["x: Float64, y: Float64"]
);
}
#[test]
fn test_to_string_repr_with_names_coercible() {
let sig = TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_int32())),
Coercion::new_exact(TypeSignatureClass::Native(logical_int32())),
]);
let names = vec!["a".to_string(), "b".to_string()];
let result = sig.to_string_repr_with_names(Some(&names));
assert_eq!(result.len(), 1);
assert!(result[0].starts_with("a: "));
assert!(result[0].contains(", b: "));
}
#[test]
fn test_to_string_repr_with_names_comparable_numeric_string() {
let comparable = TypeSignature::Comparable(3);
let numeric = TypeSignature::Numeric(2);
let string_sig = TypeSignature::String(2);
let names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
assert_eq!(
comparable.to_string_repr_with_names(Some(&names)),
vec!["a: Comparable, b: Comparable, c: Comparable"]
);
assert_eq!(
numeric.to_string_repr_with_names(Some(&names)),
vec!["a: Numeric, b: Numeric"]
);
assert_eq!(
string_sig.to_string_repr_with_names(Some(&names)),
vec!["a: String, b: String"]
);
}
#[test]
fn test_to_string_repr_with_names_variadic_fallback() {
let variadic = TypeSignature::Variadic(vec![DataType::Utf8, DataType::LargeUtf8]);
let names = vec!["x".to_string()];
assert_eq!(
variadic.to_string_repr_with_names(Some(&names)),
variadic.to_string_repr()
);
let variadic_any = TypeSignature::VariadicAny;
assert_eq!(
variadic_any.to_string_repr_with_names(Some(&names)),
variadic_any.to_string_repr()
);
let user_defined = TypeSignature::UserDefined;
assert_eq!(
user_defined.to_string_repr_with_names(Some(&names)),
vec!["x"]
);
assert_eq!(
user_defined.to_string_repr_with_names(None),
user_defined.to_string_repr()
);
}
#[test]
fn test_to_string_repr_with_names_nullary() {
let sig = TypeSignature::Nullary;
let names = vec!["x".to_string()];
assert_eq!(
sig.to_string_repr_with_names(Some(&names)),
vec!["NullAry()"]
);
assert_eq!(sig.to_string_repr_with_names(None), vec!["NullAry()"]);
}
#[test]
fn test_to_string_repr_with_names_array_signature() {
let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Index,
ArrayFunctionArgument::Element,
],
array_coercion: None,
});
assert_eq!(
sig.to_string_repr_with_names(None),
vec!["array, index, element"]
);
let names = vec!["arr".to_string(), "idx".to_string(), "val".to_string()];
assert_eq!(
sig.to_string_repr_with_names(Some(&names)),
vec!["arr: array, idx: index, val: element"]
);
let recursive =
TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray);
let names = vec!["array".to_string()];
assert_eq!(
recursive.to_string_repr_with_names(Some(&names)),
vec!["array: recursive_array"]
);
let map_array = TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray);
let names = vec!["map".to_string()];
assert_eq!(
map_array.to_string_repr_with_names(Some(&names)),
vec!["map: map_array"]
);
}
#[test]
fn test_type_signature_arity_exact() {
let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]);
assert_eq!(sig.arity(), Arity::Fixed(2));
let sig = TypeSignature::Exact(vec![]);
assert_eq!(sig.arity(), Arity::Fixed(0));
}
#[test]
fn test_type_signature_arity_uniform() {
let sig = TypeSignature::Uniform(3, vec![DataType::Float64]);
assert_eq!(sig.arity(), Arity::Fixed(3));
let sig = TypeSignature::Uniform(1, vec![DataType::Int32]);
assert_eq!(sig.arity(), Arity::Fixed(1));
}
#[test]
fn test_type_signature_arity_numeric() {
let sig = TypeSignature::Numeric(2);
assert_eq!(sig.arity(), Arity::Fixed(2));
}
#[test]
fn test_type_signature_arity_string() {
let sig = TypeSignature::String(3);
assert_eq!(sig.arity(), Arity::Fixed(3));
}
#[test]
fn test_type_signature_arity_comparable() {
let sig = TypeSignature::Comparable(2);
assert_eq!(sig.arity(), Arity::Fixed(2));
}
#[test]
fn test_type_signature_arity_any() {
let sig = TypeSignature::Any(4);
assert_eq!(sig.arity(), Arity::Fixed(4));
}
#[test]
fn test_type_signature_arity_coercible() {
let sig = TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_int32())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
]);
assert_eq!(sig.arity(), Arity::Fixed(2));
}
#[test]
fn test_type_signature_arity_nullary() {
let sig = TypeSignature::Nullary;
assert_eq!(sig.arity(), Arity::Fixed(0));
}
#[test]
fn test_type_signature_arity_array_signature() {
let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Index],
array_coercion: None,
});
assert_eq!(sig.arity(), Arity::Fixed(2));
let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Element,
ArrayFunctionArgument::Index,
],
array_coercion: None,
});
assert_eq!(sig.arity(), Arity::Fixed(3));
let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray);
assert_eq!(sig.arity(), Arity::Fixed(1));
let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray);
assert_eq!(sig.arity(), Arity::Fixed(1));
}
#[test]
fn test_type_signature_arity_one_of_fixed() {
let sig = TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int32]),
TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]),
TypeSignature::Exact(vec![
DataType::Int32,
DataType::Utf8,
DataType::Float64,
]),
]);
assert_eq!(sig.arity(), Arity::Fixed(3));
}
#[test]
fn test_type_signature_arity_one_of_variable() {
let sig = TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int32]),
TypeSignature::VariadicAny,
]);
assert_eq!(sig.arity(), Arity::Variable);
}
#[test]
fn test_type_signature_arity_variadic() {
let sig = TypeSignature::Variadic(vec![DataType::Int32]);
assert_eq!(sig.arity(), Arity::Variable);
let sig = TypeSignature::VariadicAny;
assert_eq!(sig.arity(), Arity::Variable);
}
#[test]
fn test_type_signature_arity_user_defined() {
let sig = TypeSignature::UserDefined;
assert_eq!(sig.arity(), Arity::Variable);
}
}