use arrow::array::*;
use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err};
use datafusion_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_functions::{
downcast_named_arg, make_abs_function, make_try_abs_function,
make_wrapping_abs_function,
};
use std::any::Any;
use std::sync::Arc;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkAbs {
signature: Signature,
}
impl Default for SparkAbs {
fn default() -> Self {
Self::new()
}
}
impl SparkAbs {
pub fn new() -> Self {
Self {
signature: Signature::numeric(1, Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkAbs {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"abs"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!(
"SparkAbs: return_type() is not used; return_field_from_args() is implemented"
)
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let input_field = &args.arg_fields[0];
let out_dt = input_field.data_type().clone();
let out_nullable = input_field.is_nullable();
Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
spark_abs(&args.args, args.config_options.execution.enable_ansi_mode)
}
}
macro_rules! scalar_compute_op {
($ENABLE_ANSI_MODE:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{
let result = if $ENABLE_ANSI_MODE {
$INPUT.checked_abs().ok_or_else(|| {
ArrowError::ComputeError(format!(
"{} overflow on abs({:?})",
stringify!($SCALAR_TYPE),
$INPUT
))
})?
} else {
$INPUT.wrapping_abs()
};
Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some(
result,
))))
}};
($ENABLE_ANSI_MODE:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{
let result = if $ENABLE_ANSI_MODE {
$INPUT.checked_abs().ok_or_else(|| {
ArrowError::ComputeError(format!(
"{} overflow on abs({:?})",
stringify!($SCALAR_TYPE),
$INPUT
))
})?
} else {
$INPUT.wrapping_abs()
};
Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(
Some(result),
$PRECISION,
$SCALE,
)))
}};
}
pub fn spark_abs(
args: &[ColumnarValue],
enable_ansi_mode: bool,
) -> Result<ColumnarValue, DataFusionError> {
if args.len() != 1 {
return internal_err!("abs takes exactly 1 argument, but got: {}", args.len());
}
match &args[0] {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Null
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => Ok(args[0].clone()),
DataType::Int8 => {
let abs_fun = if enable_ansi_mode {
make_try_abs_function!(Int8Array)
} else {
make_wrapping_abs_function!(Int8Array)
};
abs_fun(array).map(ColumnarValue::Array)
}
DataType::Int16 => {
let abs_fun = if enable_ansi_mode {
make_try_abs_function!(Int16Array)
} else {
make_wrapping_abs_function!(Int16Array)
};
abs_fun(array).map(ColumnarValue::Array)
}
DataType::Int32 => {
let abs_fun = if enable_ansi_mode {
make_try_abs_function!(Int32Array)
} else {
make_wrapping_abs_function!(Int32Array)
};
abs_fun(array).map(ColumnarValue::Array)
}
DataType::Int64 => {
let abs_fun = if enable_ansi_mode {
make_try_abs_function!(Int64Array)
} else {
make_wrapping_abs_function!(Int64Array)
};
abs_fun(array).map(ColumnarValue::Array)
}
DataType::Float32 => {
let abs_fun = make_abs_function!(Float32Array);
abs_fun(array).map(ColumnarValue::Array)
}
DataType::Float64 => {
let abs_fun = make_abs_function!(Float64Array);
abs_fun(array).map(ColumnarValue::Array)
}
DataType::Decimal128(_, _) => {
let abs_fun = if enable_ansi_mode {
make_try_abs_function!(Decimal128Array)
} else {
make_wrapping_abs_function!(Decimal128Array)
};
abs_fun(array).map(ColumnarValue::Array)
}
DataType::Decimal256(_, _) => {
let abs_fun = if enable_ansi_mode {
make_try_abs_function!(Decimal256Array)
} else {
make_wrapping_abs_function!(Decimal256Array)
};
abs_fun(array).map(ColumnarValue::Array)
}
dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
},
ColumnarValue::Scalar(sv) => match sv {
ScalarValue::Null
| ScalarValue::UInt8(_)
| ScalarValue::UInt16(_)
| ScalarValue::UInt32(_)
| ScalarValue::UInt64(_) => Ok(args[0].clone()),
sv if sv.is_null() => Ok(args[0].clone()),
ScalarValue::Int8(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int8),
ScalarValue::Int16(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int16),
ScalarValue::Int32(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int32),
ScalarValue::Int64(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int64),
ScalarValue::Float32(Some(v)) => {
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs()))))
}
ScalarValue::Float64(Some(v)) => {
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs()))))
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal128)
}
ScalarValue::Decimal256(Some(v), precision, scale) => {
scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal256)
}
dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::i256;
macro_rules! eval_array_legacy_mode {
($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
let input = $INPUT;
let args = ColumnarValue::Array(Arc::new(input));
let expected = $OUTPUT;
match spark_abs(&[args], false) {
Ok(ColumnarValue::Array(result)) => {
let actual = datafusion_common::cast::$FUNC(&result).unwrap();
assert_eq!(actual, &expected);
}
_ => unreachable!(),
}
}};
}
#[test]
fn test_abs_array_legacy_mode() {
eval_array_legacy_mode!(
Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]),
Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]),
as_int8_array
);
eval_array_legacy_mode!(
Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]),
Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]),
as_int16_array
);
eval_array_legacy_mode!(
Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]),
Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]),
as_int32_array
);
eval_array_legacy_mode!(
Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]),
Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]),
as_int64_array
);
eval_array_legacy_mode!(
Float32Array::from(vec![
Some(-1f32),
Some(f32::MIN),
Some(f32::MAX),
None,
Some(f32::NAN),
Some(f32::INFINITY),
Some(f32::NEG_INFINITY),
Some(0.0),
Some(-0.0),
]),
Float32Array::from(vec![
Some(1f32),
Some(f32::MAX),
Some(f32::MAX),
None,
Some(f32::NAN),
Some(f32::INFINITY),
Some(f32::INFINITY),
Some(0.0),
Some(0.0),
]),
as_float32_array
);
eval_array_legacy_mode!(
Float64Array::from(vec![
Some(-1f64),
Some(f64::MIN),
Some(f64::MAX),
None,
Some(f64::NAN),
Some(f64::INFINITY),
Some(f64::NEG_INFINITY),
Some(0.0),
Some(-0.0),
]),
Float64Array::from(vec![
Some(1f64),
Some(f64::MAX),
Some(f64::MAX),
None,
Some(f64::NAN),
Some(f64::INFINITY),
Some(f64::INFINITY),
Some(0.0),
Some(0.0),
]),
as_float64_array
);
eval_array_legacy_mode!(
Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MIN + 1), None])
.with_precision_and_scale(38, 37)
.unwrap(),
Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MAX), None])
.with_precision_and_scale(38, 37)
.unwrap(),
as_decimal128_array
);
eval_array_legacy_mode!(
Decimal256Array::from(vec![
Some(i256::MIN),
Some(i256::MINUS_ONE),
Some(i256::MIN + i256::from(1)),
None
])
.with_precision_and_scale(5, 2)
.unwrap(),
Decimal256Array::from(vec![
Some(i256::MIN),
Some(i256::ONE),
Some(i256::MAX),
None
])
.with_precision_and_scale(5, 2)
.unwrap(),
as_decimal256_array
);
}
macro_rules! eval_array_ansi_mode {
($INPUT:expr) => {{
let input = $INPUT;
let args = ColumnarValue::Array(Arc::new(input));
match spark_abs(&[args], true) {
Err(e) => {
assert!(
e.to_string().contains("overflow on abs"),
"Error message did not match. Actual message: {e}"
);
}
_ => unreachable!(),
}
}};
($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
let input = $INPUT;
let args = ColumnarValue::Array(Arc::new(input));
let expected = $OUTPUT;
match spark_abs(&[args], true) {
Ok(ColumnarValue::Array(result)) => {
let actual = datafusion_common::cast::$FUNC(&result).unwrap();
assert_eq!(actual, &expected);
}
_ => unreachable!(),
}
}};
}
#[test]
fn test_abs_array_ansi_mode() {
eval_array_ansi_mode!(
UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
as_uint64_array
);
eval_array_ansi_mode!(Int8Array::from(vec![
Some(-1),
Some(i8::MIN),
Some(i8::MAX),
None
]));
eval_array_ansi_mode!(Int16Array::from(vec![
Some(-1),
Some(i16::MIN),
Some(i16::MAX),
None
]));
eval_array_ansi_mode!(Int32Array::from(vec![
Some(-1),
Some(i32::MIN),
Some(i32::MAX),
None
]));
eval_array_ansi_mode!(Int64Array::from(vec![
Some(-1),
Some(i64::MIN),
Some(i64::MAX),
None
]));
eval_array_ansi_mode!(
Float32Array::from(vec![
Some(-1f32),
Some(f32::MIN),
Some(f32::MAX),
None,
Some(f32::NAN),
Some(f32::INFINITY),
Some(f32::NEG_INFINITY),
Some(0.0),
Some(-0.0),
]),
Float32Array::from(vec![
Some(1f32),
Some(f32::MAX),
Some(f32::MAX),
None,
Some(f32::NAN),
Some(f32::INFINITY),
Some(f32::INFINITY),
Some(0.0),
Some(0.0),
]),
as_float32_array
);
eval_array_ansi_mode!(
Float64Array::from(vec![
Some(-1f64),
Some(f64::MIN),
Some(f64::MAX),
None,
Some(f64::NAN),
Some(f64::INFINITY),
Some(f64::NEG_INFINITY),
Some(0.0),
Some(-0.0),
]),
Float64Array::from(vec![
Some(1f64),
Some(f64::MAX),
Some(f64::MAX),
None,
Some(f64::NAN),
Some(f64::INFINITY),
Some(f64::INFINITY),
Some(0.0),
Some(0.0),
]),
as_float64_array
);
eval_array_ansi_mode!(
Decimal128Array::from(vec![Some(-1), Some(-2), Some(i128::MIN + 1)])
.with_precision_and_scale(38, 37)
.unwrap(),
Decimal128Array::from(vec![Some(1), Some(2), Some(i128::MAX)])
.with_precision_and_scale(38, 37)
.unwrap(),
as_decimal128_array
);
eval_array_ansi_mode!(
Decimal256Array::from(vec![
Some(i256::MINUS_ONE),
Some(i256::from(-2)),
Some(i256::MIN + i256::from(1))
])
.with_precision_and_scale(18, 7)
.unwrap(),
Decimal256Array::from(vec![
Some(i256::ONE),
Some(i256::from(2)),
Some(i256::MAX)
])
.with_precision_and_scale(18, 7)
.unwrap(),
as_decimal256_array
);
eval_array_ansi_mode!(
Decimal128Array::from(vec![Some(i128::MIN), None])
.with_precision_and_scale(38, 37)
.unwrap()
);
eval_array_ansi_mode!(
Decimal256Array::from(vec![Some(i256::MIN), None])
.with_precision_and_scale(5, 2)
.unwrap()
);
}
#[test]
fn test_abs_nullability() {
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ReturnFieldArgs;
use std::sync::Arc;
let abs = SparkAbs::new();
let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
let out_non_null = abs
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&non_nullable_i32)],
scalar_arguments: &[None],
})
.unwrap();
assert!(!out_non_null.is_nullable());
assert_eq!(out_non_null.data_type(), &DataType::Int32);
let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
let out_nullable = abs
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&nullable_i32)],
scalar_arguments: &[None],
})
.unwrap();
assert!(out_nullable.is_nullable());
assert_eq!(out_nullable.data_type(), &DataType::Int32);
let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false));
let out_f64 = abs
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&non_nullable_f64)],
scalar_arguments: &[None],
})
.unwrap();
assert!(!out_f64.is_nullable());
assert_eq!(out_f64.data_type(), &DataType::Float64);
let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true));
let out_f64_null = abs
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&nullable_f64)],
scalar_arguments: &[None],
})
.unwrap();
assert!(out_f64_null.is_nullable());
assert_eq!(out_f64_null.data_type(), &DataType::Float64);
}
}