use crate::arity::*;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::i256;
use arrow_buffer::ArrowNativeType;
use arrow_schema::*;
use std::cmp::min;
use std::sync::Arc;
fn get_fixed_point_info(
left: (u8, i8),
right: (u8, i8),
required_scale: i8,
) -> Result<(u8, i8, i256), ArrowError> {
let product_scale = left.1 + right.1;
let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
if required_scale > product_scale {
return Err(ArrowError::ComputeError(format!(
"Required scale {} is greater than product scale {}",
required_scale, product_scale
)));
}
let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
Ok((precision, product_scale, divisor))
}
pub fn multiply_fixed_point_dyn(
left: &dyn Array,
right: &dyn Array,
required_scale: i8,
) -> Result<ArrayRef, ArrowError> {
match (left.data_type(), right.data_type()) {
(DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef)
}
(_, _) => Err(ArrowError::CastError(format!(
"Unsupported data type {}, {}",
left.data_type(),
right.data_type()
))),
}
}
pub fn multiply_fixed_point_checked(
left: &PrimitiveArray<Decimal128Type>,
right: &PrimitiveArray<Decimal128Type>,
required_scale: i8,
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
let (precision, product_scale, divisor) = get_fixed_point_info(
(left.precision(), left.scale()),
(right.precision(), right.scale()),
required_scale,
)?;
if required_scale == product_scale {
return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))?
.with_precision_and_scale(precision, required_scale);
}
try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
let a = i256::from_i128(a);
let b = i256::from_i128(b);
let mut mul = a.wrapping_mul(b);
mul = divide_and_round::<Decimal256Type>(mul, divisor);
mul.to_i128().ok_or_else(|| {
ArrowError::ArithmeticOverflow(format!("Overflow happened on: {:?} * {:?}", a, b))
})
})
.and_then(|a| a.with_precision_and_scale(precision, required_scale))
}
pub fn multiply_fixed_point(
left: &PrimitiveArray<Decimal128Type>,
right: &PrimitiveArray<Decimal128Type>,
required_scale: i8,
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
let (precision, product_scale, divisor) = get_fixed_point_info(
(left.precision(), left.scale()),
(right.precision(), right.scale()),
required_scale,
)?;
if required_scale == product_scale {
return binary(left, right, |a, b| a.mul_wrapping(b))?
.with_precision_and_scale(precision, required_scale);
}
binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
let a = i256::from_i128(a);
let b = i256::from_i128(b);
let mut mul = a.wrapping_mul(b);
mul = divide_and_round::<Decimal256Type>(mul, divisor);
mul.as_i128()
})
.and_then(|a| a.with_precision_and_scale(precision, required_scale))
}
fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
where
I: DecimalType,
I::Native: ArrowNativeTypeOp,
{
let d = input.div_wrapping(div);
let r = input.mod_wrapping(div);
let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
let half_neg = half.neg_wrapping();
match input >= I::Native::ZERO {
true if r >= half => d.add_wrapping(I::Native::ONE),
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
_ => d,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::numeric::mul;
#[test]
fn test_decimal_multiply_allow_precision_loss() {
let a = Decimal128Array::from(vec![123456789000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();
let b = Decimal128Array::from(vec![10000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();
let err = mul(&a, &b).unwrap_err();
assert!(err
.to_string()
.contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000"));
let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
.with_precision_and_scale(38, 28)
.unwrap();
assert_eq!(&expected, &result);
assert_eq!(
result.value_as_string(0),
"1234567890.0000000000000000000000000000"
);
let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555])
.with_precision_and_scale(38, 18)
.unwrap();
let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
.with_precision_and_scale(38, 18)
.unwrap();
let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
let expected = Decimal128Array::from(vec![
15555555556,
13854595272345679012071330528765432099,
15555555556,
])
.with_precision_and_scale(38, 28)
.unwrap();
assert_eq!(&expected, &result);
assert_eq!(
result.value_as_string(1),
"1385459527.2345679012071330528765432099"
);
assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
let a = Decimal128Array::from(vec![1230])
.with_precision_and_scale(4, 2)
.unwrap();
let b = Decimal128Array::from(vec![1000])
.with_precision_and_scale(4, 2)
.unwrap();
let result = multiply_fixed_point_checked(&a, &b, 4).unwrap();
assert_eq!(result.precision(), 9);
assert_eq!(result.scale(), 4);
let expected = mul(&a, &b).unwrap();
assert_eq!(expected.as_ref(), &result);
let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err();
assert!(result
.to_string()
.contains("Required scale 5 is greater than product scale 4"));
}
#[test]
fn test_decimal_multiply_allow_precision_loss_overflow() {
let a = Decimal128Array::from(vec![99999999999123456789000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();
let b = Decimal128Array::from(vec![9999999999910000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();
let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err();
assert!(err.to_string().contains(
"Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000"
));
let result = multiply_fixed_point(&a, &b, 28).unwrap();
let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960])
.with_precision_and_scale(38, 28)
.unwrap();
assert_eq!(&expected, &result);
}
#[test]
fn test_decimal_multiply_fixed_point() {
let a = Decimal128Array::from(vec![123456789000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();
let b = Decimal128Array::from(vec![10000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();
let err = mul(&a, &b).unwrap_err();
assert_eq!(err.to_string(), "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000");
let result = multiply_fixed_point(&a, &b, 28).unwrap();
let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
.with_precision_and_scale(38, 28)
.unwrap();
assert_eq!(&expected, &result);
assert_eq!(
result.value_as_string(0),
"1234567890.0000000000000000000000000000"
);
}
}