use crate::{
array::PrimitiveArray,
compute::{
arithmetics::{ArrayCheckedMul, ArrayMul, ArraySaturatingMul},
arity::{binary, binary_checked, unary},
utils::{check_same_len, combine_validities},
},
datatypes::DataType,
error::{Error, Result},
scalar::{PrimitiveScalar, Scalar},
};
use super::{adjusted_precision_scale, get_parameters, max_value, number_digits};
pub fn mul(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
let scale = 10i128.pow(scale as u32);
let max = max_value(precision);
let op = move |a: i128, b: i128| {
let res: i128 = a.checked_mul(b).expect("Mayor overflow for multiplication");
let res = res / scale;
assert!(
res.abs() <= max,
"Overflow in multiplication presented for precision {precision}"
);
res
};
binary(lhs, rhs, lhs.data_type().clone(), op)
}
pub fn mul_scalar(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveScalar<i128>) -> PrimitiveArray<i128> {
let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
let rhs = if let Some(rhs) = *rhs.value() {
rhs
} else {
return PrimitiveArray::<i128>::new_null(lhs.data_type().clone(), lhs.len());
};
let scale = 10i128.pow(scale as u32);
let max = max_value(precision);
let op = move |a: i128| {
let res: i128 = a
.checked_mul(rhs)
.expect("Mayor overflow for multiplication");
let res = res / scale;
assert!(
res.abs() <= max,
"Overflow in multiplication presented for precision {precision}"
);
res
};
unary(lhs, op, lhs.data_type().clone())
}
pub fn saturating_mul(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> PrimitiveArray<i128> {
let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
let scale = 10i128.pow(scale as u32);
let max = max_value(precision);
let op = move |a: i128, b: i128| match a.checked_mul(b) {
Some(res) => {
let res = res / scale;
match res {
res if res.abs() > max => {
if res > 0 {
max
} else {
-max
}
}
_ => res,
}
}
None => max,
};
binary(lhs, rhs, lhs.data_type().clone(), op)
}
pub fn checked_mul(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
let scale = 10i128.pow(scale as u32);
let max = max_value(precision);
let op = move |a: i128, b: i128| match a.checked_mul(b) {
Some(res) => {
let res = res / scale;
match res {
res if res.abs() > max => None,
_ => Some(res),
}
}
None => None,
};
binary_checked(lhs, rhs, lhs.data_type().clone(), op)
}
impl ArrayMul<PrimitiveArray<i128>> for PrimitiveArray<i128> {
fn mul(&self, rhs: &PrimitiveArray<i128>) -> Self {
mul(self, rhs)
}
}
impl ArrayCheckedMul<PrimitiveArray<i128>> for PrimitiveArray<i128> {
fn checked_mul(&self, rhs: &PrimitiveArray<i128>) -> Self {
checked_mul(self, rhs)
}
}
impl ArraySaturatingMul<PrimitiveArray<i128>> for PrimitiveArray<i128> {
fn saturating_mul(&self, rhs: &PrimitiveArray<i128>) -> Self {
saturating_mul(self, rhs)
}
}
pub fn adaptive_mul(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> Result<PrimitiveArray<i128>> {
check_same_len(lhs, rhs)?;
let (lhs_p, lhs_s, rhs_p, rhs_s) =
if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
(lhs.data_type(), rhs.data_type())
{
(*lhs_p, *lhs_s, *rhs_p, *rhs_s)
} else {
return Err(Error::InvalidArgumentError(
"Incorrect data type for the array".to_string(),
));
};
let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s);
let shift = 10i128.pow(diff as u32);
let shift_1 = 10i128.pow(res_s as u32);
let mut max = max_value(res_p);
let values = lhs
.values()
.iter()
.zip(rhs.values().iter())
.map(|(l, r)| {
let res = if lhs_s > rhs_s {
l.checked_mul(r * shift)
} else {
(l * shift).checked_mul(*r)
}
.expect("Mayor overflow for multiplication");
let res = res / shift_1;
if res.abs() > max {
res_p = number_digits(res);
max = max_value(res_p);
}
res
})
.collect::<Vec<_>>();
let validity = combine_validities(lhs.validity(), rhs.validity());
Ok(PrimitiveArray::<i128>::new(
DataType::Decimal(res_p, res_s),
values.into(),
validity,
))
}