use fpdec_core::{
i128_div_rounded, i128_mul_div_ten_pow_rounded, ten_pow, MAX_N_FRAC_DIGITS,
};
use crate::{Decimal, DecimalError};
pub trait MulRounded<Rhs = Self> {
type Output;
fn mul_rounded(self, rhs: Rhs, n_frac_digits: u8) -> Self::Output;
}
pub(crate) fn checked_mul_rounded(
x: Decimal,
y: Decimal,
n_frac_digits: u8,
) -> Option<Decimal> {
let max_n_frac_digits = x.n_frac_digits + y.n_frac_digits;
if n_frac_digits >= max_n_frac_digits {
Some(Decimal {
coeff: x.coeff.checked_mul(y.coeff)?,
n_frac_digits: max_n_frac_digits,
})
} else {
let shift = max_n_frac_digits - n_frac_digits;
if let Some(coeff) = x.coeff.checked_mul(y.coeff) {
Some(Decimal {
coeff: i128_div_rounded(coeff, ten_pow(shift), None),
n_frac_digits,
})
} else {
let coeff =
i128_mul_div_ten_pow_rounded(x.coeff, y.coeff, shift, None)?;
Some(Decimal {
coeff,
n_frac_digits,
})
}
}
}
impl MulRounded<Self> for Decimal {
type Output = Self;
#[inline]
fn mul_rounded(self, rhs: Self, n_frac_digits: u8) -> Self::Output {
if n_frac_digits > MAX_N_FRAC_DIGITS {
panic!("{}", DecimalError::MaxNFracDigitsExceeded);
}
if self.eq_zero() || rhs.eq_zero() {
return Self::ZERO;
}
if let Some(res) = checked_mul_rounded(self, rhs, n_frac_digits) {
res
} else {
panic!("{}", DecimalError::InternalOverflow);
}
}
}
forward_ref_binop_rounded!(impl MulRounded, mul_rounded);
#[cfg(test)]
mod mul_rounded_decimal_tests {
use super::*;
#[test]
fn test_mul_rounded_less_n_frac_digits() {
let x = Decimal::new_raw(12345, 2);
let z = x.mul_rounded(x, 2);
assert_eq!(z.coefficient(), 1523990);
assert_eq!(z.n_frac_digits(), 2);
let y = Decimal::new_raw(5781, 4);
let z = x.mul_rounded(y, 1);
assert_eq!(z.coefficient(), 714);
assert_eq!(z.n_frac_digits(), 1);
let z = y.mul_rounded(x, 1);
assert_eq!(z.coefficient(), 714);
assert_eq!(z.n_frac_digits(), 1);
}
#[test]
fn test_mul_rounded_no_adj_needed() {
let x = Decimal::new_raw(12345, 2);
let z = x.mul_rounded(x, 4);
assert_eq!(z.coefficient(), 152399025);
assert_eq!(z.n_frac_digits(), 4);
let y = Decimal::new_raw(5781, 4);
let z = x.mul_rounded(y, 10);
assert_eq!(z.coefficient(), 71366445);
assert_eq!(z.n_frac_digits(), 6);
let z = y.mul_rounded(x, 7);
assert_eq!(z.coefficient(), 71366445);
assert_eq!(z.n_frac_digits(), 6);
}
#[test]
fn test_mul_rounded_ref() {
let x = Decimal::new_raw(12345, 3);
let y = Decimal::new_raw(12345, 1);
let z = x.mul_rounded(y, 2);
let a = MulRounded::mul_rounded(&x, y, 2);
assert_eq!(a.coefficient(), z.coefficient());
let a = MulRounded::mul_rounded(x, &y, 2);
assert_eq!(a.coefficient(), z.coefficient());
let a = MulRounded::mul_rounded(&x, &y, 2);
assert_eq!(a.coefficient(), z.coefficient());
}
}