use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::Constant;
use crate::arrays::ConstantArray;
use crate::arrow::Datum;
use crate::arrow::from_arrow_array_with_len;
use crate::scalar::NumericOperator;
pub(crate) fn execute_numeric(
lhs: &ArrayRef,
rhs: &ArrayRef,
op: NumericOperator,
) -> VortexResult<ArrayRef> {
if let Some(result) = constant_numeric(lhs, rhs, op)? {
return Ok(result);
}
arrow_numeric(lhs, rhs, op)
}
pub(crate) fn arrow_numeric(
lhs: &ArrayRef,
rhs: &ArrayRef,
operator: NumericOperator,
) -> VortexResult<ArrayRef> {
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
let len = lhs.len();
let left = Datum::try_new(lhs)?;
let right = Datum::try_new_with_target_datatype(rhs, left.data_type())?;
let array = match operator {
NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
};
from_arrow_array_with_len(array.as_ref(), len, nullable)
}
fn constant_numeric(
lhs: &ArrayRef,
rhs: &ArrayRef,
op: NumericOperator,
) -> VortexResult<Option<ArrayRef>> {
let (Some(lhs), Some(rhs)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>()) else {
return Ok(None);
};
let Some(result) = lhs
.scalar()
.as_primitive()
.checked_binary_numeric(&rhs.scalar().as_primitive(), op)
else {
return Ok(None);
};
Ok(Some(ConstantArray::new(result, lhs.len()).into_array()))
}
#[cfg(test)]
#[allow(deprecated)]
mod test {
use vortex_buffer::buffer;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::RecursiveCanonical;
use crate::VortexSessionExecute;
use crate::arrays::PrimitiveArray;
use crate::assert_arrays_eq;
use crate::builtins::ArrayBuiltins;
use crate::scalar::Scalar;
use crate::scalar_fn::fns::binary::numeric::ConstantArray;
use crate::scalar_fn::fns::operators::Operator;
fn sub_scalar(array: &ArrayRef, scalar: impl Into<Scalar>) -> VortexResult<ArrayRef> {
array
.binary(
ConstantArray::new(scalar, array.len()).into_array(),
Operator::Sub,
)
.and_then(|a| {
a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
})
.map(|a| a.0.into_array())
}
#[test]
fn test_scalar_subtract_unsigned() {
let values = buffer![1u16, 2, 3].into_array();
let result = sub_scalar(&values, 1u16).unwrap();
assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
}
#[test]
fn test_scalar_subtract_signed() {
let values = buffer![1i64, 2, 3].into_array();
let result = sub_scalar(&values, -1i64).unwrap();
assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
}
#[test]
fn test_scalar_subtract_nullable() {
let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
let result = sub_scalar(&values.into_array(), Some(1u16)).unwrap();
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
);
}
#[test]
fn test_scalar_subtract_float() {
let values = buffer![1.0f64, 2.0, 3.0].into_array();
let result = sub_scalar(&values, -1f64).unwrap();
assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
}
#[test]
fn test_scalar_subtract_float_underflow_is_ok() {
let values = buffer![f32::MIN, 2.0, 3.0].into_array();
let _results = sub_scalar(&values, 1.0f32).unwrap();
let _results = sub_scalar(&values, f32::MAX).unwrap();
}
}