use crate::decimal::{Decimal, MAX_SCALE};
use crate::error::ArithmeticError;
pub(crate) const POW10: [i128; 29] = [
1, 10, 100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000, 10_000_000_000, 100_000_000_000, 1_000_000_000_000, 10_000_000_000_000, 100_000_000_000_000, 1_000_000_000_000_000, 10_000_000_000_000_000, 100_000_000_000_000_000, 1_000_000_000_000_000_000, 10_000_000_000_000_000_000, 100_000_000_000_000_000_000, 1_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000, 100_000_000_000_000_000_000_000, 1_000_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000_000, 100_000_000_000_000_000_000_000_000, 1_000_000_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000_000_000, ];
#[inline]
pub(crate) fn pow10(exp: u8) -> Result<i128, ArithmeticError> {
POW10.get(exp as usize).copied().ok_or(ArithmeticError::ScaleExceeded)
}
#[inline]
pub(crate) fn align_scales(
a: Decimal,
b: Decimal,
) -> Result<(i128, i128, u8), ArithmeticError> {
use core::cmp::Ordering;
match a.scale.cmp(&b.scale) {
Ordering::Equal => Ok((a.mantissa, b.mantissa, a.scale)),
Ordering::Less => {
let diff = b.scale - a.scale;
let factor = pow10(diff)?;
let scaled = a
.mantissa
.checked_mul(factor)
.ok_or(ArithmeticError::Overflow)?;
Ok((scaled, b.mantissa, b.scale))
}
Ordering::Greater => {
let diff = a.scale - b.scale;
let factor = pow10(diff)?;
let scaled = b
.mantissa
.checked_mul(factor)
.ok_or(ArithmeticError::Overflow)?;
Ok((a.mantissa, scaled, a.scale))
}
}
}
#[derive(Clone, Copy)]
pub(crate) enum Sign {
Positive,
Negative,
Zero,
}
#[inline]
pub(crate) fn sign3(a: i128, b: i128, c: i128) -> Sign {
if a == 0 || b == 0 {
return Sign::Zero;
}
let neg_a = a < 0;
let neg_b = b < 0;
let neg_c = c < 0;
let negative = (neg_a ^ neg_b) ^ neg_c;
if negative { Sign::Negative } else { Sign::Positive }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct U256 {
pub lo: u128,
pub hi: u128,
}
impl U256 {
pub const ZERO: Self = Self { lo: 0, hi: 0 };
pub fn mul(a: u128, b: u128) -> Self {
const MASK64: u128 = u64::MAX as u128;
let a_lo = a & MASK64;
let a_hi = a >> 64;
let b_lo = b & MASK64;
let b_hi = b >> 64;
let ll = a_lo * b_lo;
let lh = a_lo * b_hi;
let hl = a_hi * b_lo;
let hh = a_hi * b_hi;
let (mid, mid_carry) = lh.overflowing_add(hl);
let (lo, lo_carry) = ll.overflowing_add(mid << 64);
let hi = hh
.wrapping_add(mid >> 64)
.wrapping_add(if mid_carry { 1u128 << 64 } else { 0 })
.wrapping_add(lo_carry as u128);
U256 { lo, hi }
}
pub fn checked_div(self, d: u128) -> Option<(u128, u128)> {
if d == 0 {
return None;
}
if self.hi == 0 {
return Some((self.lo / d, self.lo % d));
}
if self.hi >= d {
return None;
}
if d <= u64::MAX as u128 {
const HALF: u128 = 1u128 << 64;
const MASK: u128 = HALF - 1;
let hi_hi = self.hi >> 64;
let hi_lo = self.hi & MASK;
let lo_hi = self.lo >> 64;
let lo_lo = self.lo & MASK;
let r_a = hi_hi % d;
let q_a = hi_hi / d;
let n_b = r_a * HALF + hi_lo;
let q_b = n_b / d;
let r_b = n_b % d;
let n_c = r_b * HALF + lo_hi;
let q_c = n_c / d;
let r_c = n_c % d;
let n_d = r_c * HALF + lo_lo;
let q_d = n_d / d;
let r_d = n_d % d;
if q_a != 0 || q_b != 0 {
return None; }
return Some((q_c * HALF + q_d, r_d));
}
let mut q: u128 = 0;
let mut r: u128 = 0;
for i in (0..256_u32).rev() {
let bit: u128 = if i >= 128 {
(self.hi >> (i - 128)) & 1
} else {
(self.lo >> i) & 1
};
let r_hi = r >> 127; let r_new = (r << 1) | bit;
if r_hi == 1 {
r = r_new.wrapping_sub(d);
if i < 128 {
q |= 1u128 << i;
}
} else if r_new >= d {
r = r_new - d;
if i < 128 {
q |= 1u128 << i;
}
} else {
r = r_new;
}
}
Some((q, r))
}
}
impl Decimal {
pub fn checked_add(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
let (a, b, scale) = align_scales(self, rhs)?;
let mantissa = a.checked_add(b).ok_or(ArithmeticError::Overflow)?;
Decimal::new(mantissa, scale)
}
pub fn checked_sub(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
let (a, b, scale) = align_scales(self, rhs)?;
let mantissa = a.checked_sub(b).ok_or(ArithmeticError::Overflow)?;
Decimal::new(mantissa, scale)
}
pub fn checked_mul(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
let new_scale = self
.scale
.checked_add(rhs.scale)
.filter(|&s| s <= MAX_SCALE)
.ok_or(ArithmeticError::ScaleExceeded)?;
let mantissa = self
.mantissa
.checked_mul(rhs.mantissa)
.ok_or(ArithmeticError::Overflow)?;
Decimal::new(mantissa, new_scale)
}
pub fn checked_div(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
if rhs.mantissa == 0 {
return Err(ArithmeticError::DivisionByZero);
}
let extra = MAX_SCALE.saturating_sub(self.scale);
let factor = pow10(extra)?;
let scaled_num = self
.mantissa
.checked_mul(factor)
.ok_or(ArithmeticError::Overflow)?;
let mantissa = scaled_num
.checked_div(rhs.mantissa)
.ok_or(ArithmeticError::Overflow)?;
let raw_scale = (self.scale as i32) + (extra as i32) - (rhs.scale as i32);
if raw_scale < 0 {
return Err(ArithmeticError::Underflow);
}
Decimal::new(mantissa, (raw_scale as u8).min(MAX_SCALE))
}
pub fn checked_neg(self) -> Result<Decimal, ArithmeticError> {
let mantissa = self.mantissa.checked_neg().ok_or(ArithmeticError::Overflow)?;
Decimal::new(mantissa, self.scale)
}
pub fn checked_abs(self) -> Result<Decimal, ArithmeticError> {
if self.mantissa >= 0 {
return Ok(self);
}
self.checked_neg()
}
pub fn checked_mul_div(
self,
numerator: Decimal,
denominator: Decimal,
) -> Result<Decimal, ArithmeticError> {
if denominator.mantissa == 0 {
return Err(ArithmeticError::DivisionByZero);
}
let sign = sign3(self.mantissa, numerator.mantissa, denominator.mantissa);
let a = self.mantissa.unsigned_abs();
let b = numerator.mantissa.unsigned_abs();
let c = denominator.mantissa.unsigned_abs();
let product = U256::mul(a, b);
let (quotient_u128, _rem) =
product.checked_div(c).ok_or(ArithmeticError::Overflow)?;
let mantissa_abs =
i128::try_from(quotient_u128).map_err(|_| ArithmeticError::Overflow)?;
let signed_mantissa = match sign {
Sign::Zero => 0i128,
Sign::Positive => mantissa_abs,
Sign::Negative => {
mantissa_abs.checked_neg().ok_or(ArithmeticError::Overflow)?
}
};
let num_scale = self.scale as i32 + numerator.scale as i32;
let den_scale = denominator.scale as i32;
let result_scale = num_scale - den_scale;
if result_scale < 0 || result_scale > MAX_SCALE as i32 {
return Err(ArithmeticError::ScaleExceeded);
}
Decimal::new(signed_mantissa, result_scale as u8)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pow10_table_spot_checks() {
assert_eq!(pow10(0).unwrap(), 1);
assert_eq!(pow10(6).unwrap(), 1_000_000);
assert_eq!(pow10(18).unwrap(), 1_000_000_000_000_000_000);
assert_eq!(
pow10(28).unwrap(),
10_000_000_000_000_000_000_000_000_000i128
);
assert!(pow10(29).is_err());
}
#[test]
fn u256_mul_small() {
assert_eq!(U256::mul(3, 7), U256 { lo: 21, hi: 0 });
}
#[test]
fn u256_mul_max_times_max() {
let r = U256::mul(u128::MAX, u128::MAX);
assert_eq!(r.lo, 1);
assert_eq!(r.hi, u128::MAX - 1);
}
#[test]
fn u256_div_basic() {
assert_eq!(U256 { lo: 21, hi: 0 }.checked_div(7), Some((3, 0)));
}
#[test]
fn u256_div_by_zero() {
assert_eq!(U256::ZERO.checked_div(0), None);
}
#[test]
fn u256_div_overflow_check() {
assert_eq!(U256 { lo: 0, hi: 100 }.checked_div(50), None);
}
}