use crate::nullif::SUPPORTED_NULLIF_TYPES;
use crate::type_coercion::functions::data_types;
use crate::ColumnarValue;
use crate::{
array_expressions, conditional_expressions, struct_expressions, Accumulator,
BuiltinScalarFunction, Signature, TypeSignature,
};
use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion_common::{DataFusionError, Result};
use std::sync::Arc;
pub type ScalarFunctionImplementation =
Arc<dyn Fn(&[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>;
pub type ReturnTypeFunction =
Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;
pub type AccumulatorFunctionImplementation =
Arc<dyn Fn(&DataType) -> Result<Box<dyn Accumulator>> + Send + Sync>;
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
macro_rules! make_utf8_to_return_type {
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
Ok(match arg_type {
DataType::LargeUtf8 => $largeUtf8Type,
DataType::Utf8 => $utf8Type,
DataType::Null => DataType::Null,
_ => {
return Err(DataFusionError::Internal(format!(
"The {:?} function can only accept strings.",
name
)));
}
})
}
};
}
make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result<DataType> {
Ok(match arg_type {
DataType::LargeUtf8
| DataType::Utf8
| DataType::Binary
| DataType::LargeBinary => DataType::Binary,
DataType::Null => DataType::Null,
_ => {
return Err(DataFusionError::Internal(format!(
"The {name:?} function can only accept strings or binary arrays."
)));
}
})
}
pub fn return_type(
fun: &BuiltinScalarFunction,
input_expr_types: &[DataType],
) -> Result<DataType> {
if input_expr_types.is_empty() && !fun.supports_zero_argument() {
return Err(DataFusionError::Internal(format!(
"Builtin scalar function {fun} does not support empty arguments"
)));
}
data_types(input_expr_types, &signature(fun))?;
match fun {
BuiltinScalarFunction::MakeArray => Ok(DataType::FixedSizeList(
Box::new(Field::new("item", input_expr_types[0].clone(), true)),
input_expr_types.len() as i32,
)),
BuiltinScalarFunction::Ascii => Ok(DataType::Int32),
BuiltinScalarFunction::BitLength => {
utf8_to_int_type(&input_expr_types[0], "bit_length")
}
BuiltinScalarFunction::Btrim => utf8_to_str_type(&input_expr_types[0], "btrim"),
BuiltinScalarFunction::CharacterLength => {
utf8_to_int_type(&input_expr_types[0], "character_length")
}
BuiltinScalarFunction::Chr => Ok(DataType::Utf8),
BuiltinScalarFunction::Coalesce => {
let coerced_types = data_types(input_expr_types, &signature(fun));
coerced_types.map(|types| types[0].clone())
}
BuiltinScalarFunction::Concat => Ok(DataType::Utf8),
BuiltinScalarFunction::ConcatWithSeparator => Ok(DataType::Utf8),
BuiltinScalarFunction::DatePart => Ok(DataType::Float64),
BuiltinScalarFunction::DateTrunc => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
BuiltinScalarFunction::DateBin => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
BuiltinScalarFunction::InitCap => {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"),
BuiltinScalarFunction::Lower => utf8_to_str_type(&input_expr_types[0], "lower"),
BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"),
BuiltinScalarFunction::Ltrim => utf8_to_str_type(&input_expr_types[0], "ltrim"),
BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"),
BuiltinScalarFunction::NullIf => {
let coerced_types = data_types(input_expr_types, &signature(fun));
coerced_types.map(|typs| typs[0].clone())
}
BuiltinScalarFunction::OctetLength => {
utf8_to_int_type(&input_expr_types[0], "octet_length")
}
BuiltinScalarFunction::Random => Ok(DataType::Float64),
BuiltinScalarFunction::Uuid => Ok(DataType::Utf8),
BuiltinScalarFunction::RegexpReplace => {
utf8_to_str_type(&input_expr_types[0], "regex_replace")
}
BuiltinScalarFunction::Repeat => utf8_to_str_type(&input_expr_types[0], "repeat"),
BuiltinScalarFunction::Replace => {
utf8_to_str_type(&input_expr_types[0], "replace")
}
BuiltinScalarFunction::Reverse => {
utf8_to_str_type(&input_expr_types[0], "reverse")
}
BuiltinScalarFunction::Right => utf8_to_str_type(&input_expr_types[0], "right"),
BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"),
BuiltinScalarFunction::Rtrim => utf8_to_str_type(&input_expr_types[0], "rtrimp"),
BuiltinScalarFunction::SHA224 => {
utf8_or_binary_to_binary_type(&input_expr_types[0], "sha224")
}
BuiltinScalarFunction::SHA256 => {
utf8_or_binary_to_binary_type(&input_expr_types[0], "sha256")
}
BuiltinScalarFunction::SHA384 => {
utf8_or_binary_to_binary_type(&input_expr_types[0], "sha384")
}
BuiltinScalarFunction::SHA512 => {
utf8_or_binary_to_binary_type(&input_expr_types[0], "sha512")
}
BuiltinScalarFunction::Digest => {
utf8_or_binary_to_binary_type(&input_expr_types[0], "digest")
}
BuiltinScalarFunction::SplitPart => {
utf8_to_str_type(&input_expr_types[0], "split_part")
}
BuiltinScalarFunction::StartsWith => Ok(DataType::Boolean),
BuiltinScalarFunction::Strpos => utf8_to_int_type(&input_expr_types[0], "strpos"),
BuiltinScalarFunction::Substr => utf8_to_str_type(&input_expr_types[0], "substr"),
BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
DataType::Utf8
}
_ => {
return Err(DataFusionError::Internal(
"The to_hex function can only accept integers.".to_string(),
));
}
}),
BuiltinScalarFunction::ToTimestamp => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
BuiltinScalarFunction::ToTimestampMillis => {
Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
}
BuiltinScalarFunction::ToTimestampMicros => {
Ok(DataType::Timestamp(TimeUnit::Microsecond, None))
}
BuiltinScalarFunction::ToTimestampSeconds => {
Ok(DataType::Timestamp(TimeUnit::Second, None))
}
BuiltinScalarFunction::FromUnixtime => {
Ok(DataType::Timestamp(TimeUnit::Second, None))
}
BuiltinScalarFunction::Now => Ok(DataType::Timestamp(
TimeUnit::Nanosecond,
Some("+00:00".to_owned()),
)),
BuiltinScalarFunction::CurrentDate => Ok(DataType::Date32),
BuiltinScalarFunction::CurrentTime => Ok(DataType::Time64(TimeUnit::Nanosecond)),
BuiltinScalarFunction::Translate => {
utf8_to_str_type(&input_expr_types[0], "translate")
}
BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"),
BuiltinScalarFunction::Upper => utf8_to_str_type(&input_expr_types[0], "upper"),
BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] {
DataType::LargeUtf8 => {
DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true)))
}
DataType::Utf8 => {
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
}
DataType::Null => DataType::Null,
_ => {
return Err(DataFusionError::Internal(
"The regexp_extract function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::Power => match &input_expr_types[0] {
DataType::Int64 => Ok(DataType::Int64),
_ => Ok(DataType::Float64),
},
BuiltinScalarFunction::Struct => Ok(DataType::Struct(vec![])),
BuiltinScalarFunction::Atan2 => match &input_expr_types[0] {
DataType::Float32 => Ok(DataType::Float32),
_ => Ok(DataType::Float64),
},
BuiltinScalarFunction::ArrowTypeof => Ok(DataType::Utf8),
BuiltinScalarFunction::Abs
| BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Cos
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Floor
| BuiltinScalarFunction::Log
| BuiltinScalarFunction::Ln
| BuiltinScalarFunction::Log10
| BuiltinScalarFunction::Log2
| BuiltinScalarFunction::Round
| BuiltinScalarFunction::Signum
| BuiltinScalarFunction::Sin
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Trunc => match input_expr_types[0] {
DataType::Float32 => Ok(DataType::Float32),
_ => Ok(DataType::Float64),
},
}
}
pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
match fun {
BuiltinScalarFunction::MakeArray => Signature::variadic(
array_expressions::SUPPORTED_ARRAY_TYPES.to_vec(),
fun.volatility(),
),
BuiltinScalarFunction::Struct => Signature::variadic(
struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(),
fun.volatility(),
),
BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => {
Signature::variadic(vec![DataType::Utf8], fun.volatility())
}
BuiltinScalarFunction::Coalesce => Signature::variadic(
conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(),
fun.volatility(),
),
BuiltinScalarFunction::SHA224
| BuiltinScalarFunction::SHA256
| BuiltinScalarFunction::SHA384
| BuiltinScalarFunction::SHA512
| BuiltinScalarFunction::MD5 => Signature::uniform(
1,
vec![
DataType::Utf8,
DataType::LargeUtf8,
DataType::Binary,
DataType::LargeBinary,
],
fun.volatility(),
),
BuiltinScalarFunction::Ascii
| BuiltinScalarFunction::BitLength
| BuiltinScalarFunction::CharacterLength
| BuiltinScalarFunction::InitCap
| BuiltinScalarFunction::Lower
| BuiltinScalarFunction::OctetLength
| BuiltinScalarFunction::Reverse
| BuiltinScalarFunction::Upper => Signature::uniform(
1,
vec![DataType::Utf8, DataType::LargeUtf8],
fun.volatility(),
),
BuiltinScalarFunction::Btrim
| BuiltinScalarFunction::Ltrim
| BuiltinScalarFunction::Rtrim
| BuiltinScalarFunction::Trim => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
],
fun.volatility(),
),
BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => {
Signature::uniform(1, vec![DataType::Int64], fun.volatility())
}
BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Int64,
DataType::Utf8,
]),
TypeSignature::Exact(vec![
DataType::LargeUtf8,
DataType::Int64,
DataType::Utf8,
]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Int64,
DataType::LargeUtf8,
]),
TypeSignature::Exact(vec![
DataType::LargeUtf8,
DataType::Int64,
DataType::LargeUtf8,
]),
],
fun.volatility(),
),
BuiltinScalarFunction::Left
| BuiltinScalarFunction::Repeat
| BuiltinScalarFunction::Right => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]),
],
fun.volatility(),
),
BuiltinScalarFunction::ToTimestamp => Signature::uniform(
1,
vec![
DataType::Int64,
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Second, None),
DataType::Utf8,
],
fun.volatility(),
),
BuiltinScalarFunction::ToTimestampMillis => Signature::uniform(
1,
vec![
DataType::Int64,
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Second, None),
DataType::Utf8,
],
fun.volatility(),
),
BuiltinScalarFunction::ToTimestampMicros => Signature::uniform(
1,
vec![
DataType::Int64,
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Second, None),
DataType::Utf8,
],
fun.volatility(),
),
BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform(
1,
vec![
DataType::Int64,
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Second, None),
DataType::Utf8,
],
fun.volatility(),
),
BuiltinScalarFunction::FromUnixtime => {
Signature::uniform(1, vec![DataType::Int64], fun.volatility())
}
BuiltinScalarFunction::Digest => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Binary, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeBinary, DataType::Utf8]),
],
fun.volatility(),
),
BuiltinScalarFunction::DateTrunc => Signature::exact(
vec![
DataType::Utf8,
DataType::Timestamp(TimeUnit::Nanosecond, None),
],
fun.volatility(),
),
BuiltinScalarFunction::DateBin => Signature::exact(
vec![
DataType::Interval(IntervalUnit::DayTime),
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, None),
],
fun.volatility(),
),
BuiltinScalarFunction::DatePart => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Date32]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::Date64]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Timestamp(TimeUnit::Second, None),
]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Timestamp(TimeUnit::Microsecond, None),
]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Timestamp(TimeUnit::Millisecond, None),
]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Timestamp(TimeUnit::Nanosecond, None),
]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".to_owned())),
]),
],
fun.volatility(),
),
BuiltinScalarFunction::SplitPart => Signature::one_of(
vec![
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Utf8,
DataType::Int64,
]),
TypeSignature::Exact(vec![
DataType::LargeUtf8,
DataType::Utf8,
DataType::Int64,
]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::LargeUtf8,
DataType::Int64,
]),
TypeSignature::Exact(vec![
DataType::LargeUtf8,
DataType::LargeUtf8,
DataType::Int64,
]),
],
fun.volatility(),
),
BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => {
Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
],
fun.volatility(),
)
}
BuiltinScalarFunction::Substr => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Int64,
DataType::Int64,
]),
TypeSignature::Exact(vec![
DataType::LargeUtf8,
DataType::Int64,
DataType::Int64,
]),
],
fun.volatility(),
),
BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => {
Signature::one_of(
vec![TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
])],
fun.volatility(),
)
}
BuiltinScalarFunction::RegexpReplace => Signature::one_of(
vec![
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
]),
],
fun.volatility(),
),
BuiltinScalarFunction::NullIf => {
Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), fun.volatility())
}
BuiltinScalarFunction::RegexpMatch => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
]),
TypeSignature::Exact(vec![
DataType::LargeUtf8,
DataType::Utf8,
DataType::Utf8,
]),
],
fun.volatility(),
),
BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()),
BuiltinScalarFunction::Uuid => Signature::exact(vec![], fun.volatility()),
BuiltinScalarFunction::Power => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),
TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]),
],
fun.volatility(),
),
BuiltinScalarFunction::Round => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]),
TypeSignature::Exact(vec![DataType::Float32, DataType::Int64]),
TypeSignature::Exact(vec![DataType::Float64]),
TypeSignature::Exact(vec![DataType::Float32]),
],
fun.volatility(),
),
BuiltinScalarFunction::Atan2 => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]),
TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]),
],
fun.volatility(),
),
BuiltinScalarFunction::ArrowTypeof => Signature::any(1, fun.volatility()),
_ => Signature::uniform(
1,
vec![DataType::Float64, DataType::Float32],
fun.volatility(),
),
}
}