use crate::{
array::PrimitiveArray,
compute::{
arithmetics::{ArrayCheckedSub, ArraySaturatingSub, ArraySub},
arity::{binary, binary_checked},
utils::{check_same_len, combine_validities},
},
datatypes::DataType,
error::{Error, Result},
};
use super::{adjusted_precision_scale, get_parameters, max_value, number_digits};
pub fn sub(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
let max = max_value(precision);
let op = move |a, b| {
let res: i128 = a - b;
assert!(
res.abs() <= max,
"Overflow in subtract presented for precision {precision}"
);
res
};
binary(lhs, rhs, lhs.data_type().clone(), op)
}
pub fn saturating_sub(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> PrimitiveArray<i128> {
let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
let max = max_value(precision);
let op = move |a, b| {
let res: i128 = a - b;
match res {
res if res.abs() > max => {
if res > 0 {
max
} else {
-max
}
}
_ => res,
}
};
binary(lhs, rhs, lhs.data_type().clone(), op)
}
impl ArraySub<PrimitiveArray<i128>> for PrimitiveArray<i128> {
fn sub(&self, rhs: &PrimitiveArray<i128>) -> Self {
sub(self, rhs)
}
}
impl ArrayCheckedSub<PrimitiveArray<i128>> for PrimitiveArray<i128> {
fn checked_sub(&self, rhs: &PrimitiveArray<i128>) -> Self {
checked_sub(self, rhs)
}
}
impl ArraySaturatingSub<PrimitiveArray<i128>> for PrimitiveArray<i128> {
fn saturating_sub(&self, rhs: &PrimitiveArray<i128>) -> Self {
saturating_sub(self, rhs)
}
}
pub fn checked_sub(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
let max = max_value(precision);
let op = move |a, b| {
let res: i128 = a - b;
match res {
res if res.abs() > max => None,
_ => Some(res),
}
};
binary_checked(lhs, rhs, lhs.data_type().clone(), op)
}
pub fn adaptive_sub(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> Result<PrimitiveArray<i128>> {
check_same_len(lhs, rhs)?;
let (lhs_p, lhs_s, rhs_p, rhs_s) =
if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
(lhs.data_type(), rhs.data_type())
{
(*lhs_p, *lhs_s, *rhs_p, *rhs_s)
} else {
return Err(Error::InvalidArgumentError(
"Incorrect data type for the array".to_string(),
));
};
let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s);
let shift = 10i128.pow(diff as u32);
let mut max = max_value(res_p);
let values = lhs
.values()
.iter()
.zip(rhs.values().iter())
.map(|(l, r)| {
let res: i128 = if lhs_s > rhs_s {
l - r * shift
} else {
l * shift - r
};
if res.abs() > max {
res_p = number_digits(res);
max = max_value(res_p);
}
res
})
.collect::<Vec<_>>();
let validity = combine_validities(lhs.validity(), rhs.validity());
Ok(PrimitiveArray::<i128>::new(
DataType::Decimal(res_p, res_s),
values.into(),
validity,
))
}