use arrow::array::{Scalar, new_null_array};
use arrow::compute::kernels::numeric::add;
use arrow::compute::kernels::{
cmp::{eq, lt},
numeric::rem,
zip::zip,
};
use arrow::datatypes::DataType;
use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use std::any::Any;
fn try_rem(
left: &arrow::array::ArrayRef,
right: &arrow::array::ArrayRef,
enable_ansi_mode: bool,
) -> Result<arrow::array::ArrayRef> {
match rem(left, right) {
Ok(result) => Ok(result),
Err(arrow::error::ArrowError::DivideByZero) if !enable_ansi_mode => {
let zero = ScalarValue::new_zero(right.data_type())?.to_array()?;
let zero = Scalar::new(zero);
let null = Scalar::new(new_null_array(right.data_type(), 1));
let is_zero = eq(right, &zero)?;
let safe_right = zip(&is_zero, &null, right)?;
Ok(rem(left, &safe_right)?)
}
Err(e) => Err(e.into()),
}
}
pub fn spark_mod(
args: &[ColumnarValue],
enable_ansi_mode: bool,
) -> Result<ColumnarValue> {
assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments");
let args = ColumnarValue::values_to_arrays(args)?;
let result = try_rem(&args[0], &args[1], enable_ansi_mode)?;
Ok(ColumnarValue::Array(result))
}
pub fn spark_pmod(
args: &[ColumnarValue],
enable_ansi_mode: bool,
) -> Result<ColumnarValue> {
assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments");
let args = ColumnarValue::values_to_arrays(args)?;
let left = &args[0];
let right = &args[1];
let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
let result = try_rem(left, right, enable_ansi_mode)?;
let neg = lt(&result, &zero)?;
let plus = zip(&neg, right, &zero)?;
let result = add(&plus, &result)?;
let result = try_rem(&result, right, enable_ansi_mode)?;
Ok(ColumnarValue::Array(result))
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkMod {
signature: Signature,
}
impl Default for SparkMod {
fn default() -> Self {
Self::new()
}
}
impl SparkMod {
pub fn new() -> Self {
Self {
signature: Signature::numeric(2, Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkMod {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"mod"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
assert_eq_or_internal_err!(
arg_types.len(),
2,
"mod expects exactly two arguments"
);
Ok(arg_types[0].clone())
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
spark_mod(&args.args, args.config_options.execution.enable_ansi_mode)
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkPmod {
signature: Signature,
}
impl Default for SparkPmod {
fn default() -> Self {
Self::new()
}
}
impl SparkPmod {
pub fn new() -> Self {
Self {
signature: Signature::numeric(2, Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkPmod {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"pmod"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
assert_eq_or_internal_err!(
arg_types.len(),
2,
"pmod expects exactly two arguments"
);
Ok(arg_types[0].clone())
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
spark_pmod(&args.args, args.config_options.execution.enable_ansi_mode)
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use super::*;
use arrow::array::*;
use datafusion_common::ScalarValue;
#[test]
fn test_mod_int32() {
let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_mod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int32 =
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); assert!(result_int32.is_null(3)); } else {
panic!("Expected array result");
}
}
#[test]
fn test_mod_int64() {
let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_mod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int64 =
result_array.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 0); assert_eq!(result_int64.value(2), 20); } else {
panic!("Expected array result");
}
}
#[test]
fn test_mod_float64() {
let left = Float64Array::from(vec![
Some(10.5),
Some(7.2),
Some(15.8),
Some(f64::NAN),
Some(f64::INFINITY),
Some(5.0),
Some(5.0),
Some(f64::NAN),
Some(f64::INFINITY),
]);
let right = Float64Array::from(vec![
Some(3.0),
Some(2.5),
Some(4.2),
Some(2.0),
Some(2.0),
Some(f64::NAN),
Some(f64::INFINITY),
Some(f64::INFINITY),
Some(f64::NAN),
]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_mod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_float64 = result_array
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); assert!(result_float64.value(3).is_nan());
assert!(result_float64.value(4).is_nan());
assert!(result_float64.value(5).is_nan());
assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
assert!(result_float64.value(7).is_nan());
assert!(result_float64.value(8).is_nan());
} else {
panic!("Expected array result");
}
}
#[test]
fn test_mod_float32() {
let left = Float32Array::from(vec![
Some(10.5),
Some(7.2),
Some(15.8),
Some(f32::NAN),
Some(f32::INFINITY),
Some(5.0),
Some(5.0),
Some(f32::NAN),
Some(f32::INFINITY),
]);
let right = Float32Array::from(vec![
Some(3.0),
Some(2.5),
Some(4.2),
Some(2.0),
Some(2.0),
Some(f32::NAN),
Some(f32::INFINITY),
Some(f32::INFINITY),
Some(f32::NAN),
]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_mod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_float32 = result_array
.as_any()
.downcast_ref::<Float32Array>()
.unwrap();
assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(3).is_nan());
assert!(result_float32.value(4).is_nan());
assert!(result_float32.value(5).is_nan());
assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
assert!(result_float32.value(7).is_nan());
assert!(result_float32.value(8).is_nan());
} else {
panic!("Expected array result");
}
}
#[test]
fn test_mod_scalar() {
let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
let left_value = ColumnarValue::Array(Arc::new(left));
let result = spark_mod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int32 =
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 0); } else {
panic!("Expected array result");
}
}
#[test]
fn test_mod_wrong_arg_count() {
let left = Int32Array::from(vec![Some(10)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let result = spark_mod(&[left_value], false);
assert!(result.is_err());
}
#[test]
fn test_mod_zero_division_legacy() {
let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_mod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int32 =
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
assert!(result_int32.is_null(0)); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); } else {
panic!("Expected array result");
}
}
#[test]
fn test_mod_zero_division_ansi() {
let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_mod(&[left_value, right_value], true);
assert!(result.is_err());
}
#[test]
fn test_pmod_int32() {
let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_pmod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int32 =
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 3); assert_eq!(result_int32.value(3), 1); assert!(result_int32.is_null(4)); } else {
panic!("Expected array result");
}
}
#[test]
fn test_pmod_int64() {
let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_pmod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int64 =
result_array.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 10); assert_eq!(result_int64.value(2), 20); assert_eq!(result_int64.value(3), 40); } else {
panic!("Expected array result");
}
}
#[test]
fn test_pmod_float64() {
let left = Float64Array::from(vec![
Some(10.5),
Some(-7.2),
Some(15.8),
Some(-15.8),
Some(f64::NAN),
Some(f64::INFINITY),
Some(5.0),
Some(-5.0),
]);
let right = Float64Array::from(vec![
Some(3.0),
Some(3.0),
Some(4.2),
Some(4.2),
Some(2.0),
Some(2.0),
Some(f64::INFINITY),
Some(f64::INFINITY),
]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_pmod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_float64 = result_array
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); assert!(result_float64.value(4).is_nan());
assert!(result_float64.value(5).is_nan());
assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
assert!(result_float64.value(7).is_nan());
} else {
panic!("Expected array result");
}
}
#[test]
fn test_pmod_float32() {
let left = Float32Array::from(vec![
Some(10.5),
Some(-7.2),
Some(15.8),
Some(-15.8),
Some(f32::NAN),
Some(f32::INFINITY),
Some(5.0),
Some(-5.0),
]);
let right = Float32Array::from(vec![
Some(3.0),
Some(3.0),
Some(4.2),
Some(4.2),
Some(2.0),
Some(2.0),
Some(f32::INFINITY),
Some(f32::INFINITY),
]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_pmod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_float32 = result_array
.as_any()
.downcast_ref::<Float32Array>()
.unwrap();
assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(4).is_nan());
assert!(result_float32.value(5).is_nan());
assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
assert!(result_float32.value(7).is_nan());
} else {
panic!("Expected array result");
}
}
#[test]
fn test_pmod_scalar() {
let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
let left_value = ColumnarValue::Array(Arc::new(left));
let result = spark_pmod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int32 =
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 0); assert_eq!(result_int32.value(3), 0); } else {
panic!("Expected array result");
}
}
#[test]
fn test_pmod_wrong_arg_count() {
let left = Int32Array::from(vec![Some(10)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let result = spark_pmod(&[left_value], false);
assert!(result.is_err());
}
#[test]
fn test_pmod_zero_division_legacy() {
let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_pmod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int32 =
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
assert!(result_int32.is_null(0)); assert!(result_int32.is_null(1)); assert_eq!(result_int32.value(2), 3); } else {
panic!("Expected array result");
}
}
#[test]
fn test_pmod_zero_division_ansi() {
let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_pmod(&[left_value, right_value], true);
assert!(result.is_err());
}
#[test]
fn test_pmod_negative_divisor() {
let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_pmod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int32 =
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), -1); assert_eq!(result_int32.value(2), 3); } else {
panic!("Expected array result");
}
}
#[test]
fn test_pmod_edge_cases() {
let left = Int32Array::from(vec![
Some(0), Some(-1), Some(1), Some(-5), Some(5), Some(-6), Some(6), ]);
let right = Int32Array::from(vec![
Some(5),
Some(5),
Some(5),
Some(5),
Some(5),
Some(5),
Some(5),
]);
let left_value = ColumnarValue::Array(Arc::new(left));
let right_value = ColumnarValue::Array(Arc::new(right));
let result = spark_pmod(&[left_value, right_value], false).unwrap();
if let ColumnarValue::Array(result_array) = result {
let result_int32 =
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_int32.value(0), 0); assert_eq!(result_int32.value(1), 4); assert_eq!(result_int32.value(2), 1); assert_eq!(result_int32.value(3), 0); assert_eq!(result_int32.value(4), 0); assert_eq!(result_int32.value(5), 4); assert_eq!(result_int32.value(6), 1); } else {
panic!("Expected array result");
}
}
}