use itertools::Itertools;
use num_traits::Num;
use vortex_error::VortexExpect;
use vortex_error::vortex_err;
use vortex_error::vortex_panic;
use crate::ArrayRef;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::RecursiveCanonical;
use crate::ToCanonical;
use crate::VortexSessionExecute;
use crate::arrays::ConstantArray;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::NativePType;
use crate::dtype::PType;
use crate::scalar::NumericOperator;
use crate::scalar::PrimitiveScalar;
use crate::scalar::Scalar;
fn to_vec_of_scalar(array: &ArrayRef) -> Vec<Scalar> {
(0..array.len())
.map(|index| {
array
.scalar_at(index)
.vortex_expect("scalar_at should succeed in conformance test")
})
.collect_vec()
}
fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
where
Scalar: From<T>,
{
test_standard_binary_numeric::<T>(array.clone());
test_binary_numeric_edge_cases(array);
}
fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
where
Scalar: From<T>,
{
let canonicalized_array = array.to_primitive();
let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
let one = T::from(1)
.ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
.vortex_expect("operation should succeed in conformance test");
let scalar_one = Scalar::from(one)
.cast(array.dtype())
.vortex_expect("operation should succeed in conformance test");
let operators: [NumericOperator; 4] = [
NumericOperator::Add,
NumericOperator::Sub,
NumericOperator::Mul,
NumericOperator::Div,
];
for operator in operators {
let op = operator;
let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
let result = array
.binary(rhs_const.clone(), op.into())
.vortex_expect("apply shouldn't fail")
.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
.map(|c| c.0.into_array());
let Ok(result) = result else {
continue;
};
let actual_values = to_vec_of_scalar(&result);
let expected_results: Vec<Option<Scalar>> = original_values
.iter()
.map(|x| {
x.as_primitive()
.checked_binary_numeric(&scalar_one.as_primitive(), op)
.map(<Scalar as From<PrimitiveScalar<'_>>>::from)
})
.collect();
for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
if let Some(expected_value) = expected {
assert_eq!(
actual,
expected_value,
"Binary numeric operation failed for encoding {} at index {}: \
({array:?})[{idx}] {operator:?} {scalar_one} \
expected {expected_value:?}, got {actual:?}",
array.encoding_id(),
idx,
);
}
}
let result = rhs_const.binary(array.clone(), op.into()).and_then(|a| {
a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
.map(|c| c.0.into_array())
});
let Ok(result) = result else {
continue;
};
let actual_values = to_vec_of_scalar(&result);
let expected_results: Vec<Option<Scalar>> = original_values
.iter()
.map(|x| {
scalar_one
.as_primitive()
.checked_binary_numeric(&x.as_primitive(), op)
.map(<Scalar as From<PrimitiveScalar<'_>>>::from)
})
.collect();
for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
if let Some(expected_value) = expected {
assert_eq!(
actual,
expected_value,
"Binary numeric operation failed for encoding {} at index {}: \
{scalar_one} {operator:?} ({array:?})[{idx}] \
expected {expected_value:?}, got {actual:?}",
array.encoding_id(),
idx,
);
}
}
}
}
pub fn test_binary_numeric_array(array: ArrayRef) {
match array.dtype() {
DType::Primitive(ptype, _) => match ptype {
PType::I8 => test_binary_numeric_conformance::<i8>(array),
PType::I16 => test_binary_numeric_conformance::<i16>(array),
PType::I32 => test_binary_numeric_conformance::<i32>(array),
PType::I64 => test_binary_numeric_conformance::<i64>(array),
PType::U8 => test_binary_numeric_conformance::<u8>(array),
PType::U16 => test_binary_numeric_conformance::<u16>(array),
PType::U32 => test_binary_numeric_conformance::<u32>(array),
PType::U64 => test_binary_numeric_conformance::<u64>(array),
PType::F16 => {
eprintln!("Skipping f16 binary numeric tests (not supported)");
}
PType::F32 => test_binary_numeric_conformance::<f32>(array),
PType::F64 => test_binary_numeric_conformance::<f64>(array),
},
dtype => vortex_panic!(
"Binary numeric tests are only supported for primitive numeric types, got {dtype}",
),
}
}
fn test_binary_numeric_edge_cases(array: ArrayRef) {
match array.dtype() {
DType::Primitive(ptype, _) => match ptype {
PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
PType::F16 => {
eprintln!("Skipping f16 edge case tests (not supported)");
}
PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
},
dtype => vortex_panic!(
"Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
),
}
}
fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
where
T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
Scalar: From<T>,
{
test_binary_numeric_with_scalar(array.clone(), T::zero());
test_binary_numeric_with_scalar(array.clone(), -T::one());
test_binary_numeric_with_scalar(array.clone(), T::max_value());
test_binary_numeric_with_scalar(array, T::min_value());
}
fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
where
T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
Scalar: From<T>,
{
test_binary_numeric_with_scalar(array.clone(), T::zero());
test_binary_numeric_with_scalar(array, T::max_value());
}
fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
where
T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
Scalar: From<T>,
{
test_binary_numeric_with_scalar(array.clone(), T::zero());
test_binary_numeric_with_scalar(array.clone(), -T::one());
test_binary_numeric_with_scalar(array.clone(), T::max_value());
test_binary_numeric_with_scalar(array.clone(), T::min_value());
test_binary_numeric_with_scalar(array.clone(), T::epsilon());
test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
test_binary_numeric_with_scalar(array.clone(), T::nan());
test_binary_numeric_with_scalar(array.clone(), T::infinity());
test_binary_numeric_with_scalar(array, T::neg_infinity());
}
fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
where
T: NativePType + Num + Copy + std::fmt::Debug,
Scalar: From<T>,
{
let canonicalized_array = array.to_primitive();
let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
let scalar = Scalar::from(scalar_value)
.cast(array.dtype())
.vortex_expect("operation should succeed in conformance test");
let operators = if scalar_value == T::zero() {
vec![
NumericOperator::Add,
NumericOperator::Sub,
NumericOperator::Mul,
]
} else {
vec![
NumericOperator::Add,
NumericOperator::Sub,
NumericOperator::Mul,
NumericOperator::Div,
]
};
for operator in operators {
let op = operator;
let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
let result = array
.binary(rhs_const, op.into())
.vortex_expect("apply failed")
.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
.map(|x| x.0.into_array());
if result.is_err() {
continue;
}
let result = result.vortex_expect("operation should succeed in conformance test");
let actual_values = to_vec_of_scalar(&result);
let expected_results: Vec<Option<Scalar>> = original_values
.iter()
.map(|x| {
x.as_primitive()
.checked_binary_numeric(&scalar.as_primitive(), op)
.map(<Scalar as From<PrimitiveScalar<'_>>>::from)
})
.collect();
for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
if let Some(expected_value) = expected {
assert_eq!(
actual,
expected_value,
"Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
({array:?})[{idx}] {operator:?} {scalar} \
expected {expected_value:?}, got {actual:?}",
array.encoding_id(),
idx,
scalar_value,
);
}
}
}
}