use std::ops::{Add, Mul, Shl, Shr};
use crate::math::bn::Downcast;
use super::bn::{LowHigh, Shift, U128, U256};
pub trait FullMath<RHS = Self> {
type Output;
type FullOutput;
fn mul_div_floor(self, num: RHS, denom: RHS) -> Self::Output;
fn mul_div_round(self, num: RHS, denom: RHS) -> Self::Output;
fn mul_div_ceil(self, num: RHS, denom: RHS) -> Self::Output;
fn mul_shift_right(self, num: RHS, shift: u32) -> Self::Output;
fn mul_shift_left(self, num: RHS, shift: u32) -> Self::Output;
fn full_mul(self, num: RHS) -> Self::FullOutput;
}
impl FullMath for u128 {
type Output = u128;
type FullOutput = U256;
fn mul_div_floor(self, num: Self, denom: Self) -> Self::Output {
let r = self.full_mul(num) / denom;
r.as_u128()
}
fn mul_div_round(self, num: Self, denom: Self) -> Self::Output {
let r = (self.full_mul(num) + denom >> 1) / denom;
r.as_u128()
}
fn mul_div_ceil(self, num: Self, denom: Self) -> Self::Output {
let r = self.full_mul(num) + (denom - 1) / denom;
r.as_u128()
}
fn mul_shift_right(self, num: Self, shift: u32) -> Self::Output {
self.full_mul(num).shift_right(shift).as_u128()
}
fn mul_shift_left(self, num: Self, shift: u32) -> Self::Output {
self.full_mul(num).shift_left(shift).as_u128()
}
fn full_mul(self, num: Self) -> Self::FullOutput {
let mut c0 = self.lo_u128() * num.lo_u128();
let a1 = self.lo_u128() * num.hi_u128();
let b0 = self.hi_u128() * num.lo_u128();
let mut c1 = c0.hi_u128() + a1.lo_u128() + b0.lo_u128();
c0 = u128::from_hi_lo(c1.lo(), c0.lo());
c1 = self.hi_u128() * num.hi_u128() + c1.hi_u128() + a1.hi_u128() + b0.hi_u128();
U256([c0.lo(), c0.hi(), c1.lo(), c1.hi()])
}
}
impl FullMath for u64 {
type Output = u64;
type FullOutput = u128;
fn mul_div_floor(self, num: Self, denom: Self) -> Self::Output {
U128::from(self)
.mul(U128::from(num))
.checked_div(U128::from(denom))
.unwrap()
.as_u64()
}
fn mul_div_round(self, num: Self, denom: Self) -> Self::Output {
U128::from(self)
.mul(U128::from(num))
.add(U128::from(denom >> 1))
.checked_div(U128::from(denom))
.unwrap()
.as_u64()
}
fn mul_div_ceil(self, num: Self, denom: Self) -> Self::Output {
U128::from(self)
.mul(U128::from(num))
.add(U128::from(denom - 1))
.checked_div(U128::from(denom))
.unwrap()
.as_u64()
}
fn mul_shift_right(self, num: Self, shift: u32) -> Self::Output {
U128::from(self).mul(U128::from(num)).shr(shift).as_u64()
}
fn mul_shift_left(self, num: Self, shift: u32) -> Self::Output {
U128::from(self).mul(U128::from(num)).shl(shift).as_u64()
}
fn full_mul(self, num: Self) -> Self::FullOutput {
U128::from(self).mul(num).as_u128()
}
}
pub trait DivRoundUpIf<RHS = Self> {
type Output;
fn checked_div_round_up_if(self, divisor: RHS, round_up: bool) -> Option<Self::Output>;
}
impl DivRoundUpIf for u128 {
type Output = u128;
fn checked_div_round_up_if(self, divisor: Self, round_up: bool) -> Option<Self::Output> {
if divisor == 0 {
return None;
}
let (quotient, remainer) = (self / divisor, self % divisor);
if round_up && remainer != 0 {
Some(quotient + 1)
} else {
Some(quotient)
}
}
}
impl DivRoundUpIf for U256 {
type Output = U256;
fn checked_div_round_up_if(self, divisor: Self, round_up: bool) -> Option<Self::Output> {
if divisor.is_zero() {
return None;
}
let (quotient, remain) = self.div_mod(divisor);
if round_up && !remain.is_zero() {
Some(quotient.add(1))
} else {
Some(quotient)
}
}
}
#[cfg(test)]
mod test_full_mul {
use proptest::prelude::*;
use crate::math::full_math::FullMath;
use super::*;
#[test]
fn test_zero() {
let (n, v) = (0u128, 0u128);
let r1 = U256::from(n) * U256::from(v);
let r2 = n.full_mul(v);
assert_eq!(r1, r2)
}
#[test]
fn test_u64_max() {
let (n, v) = (u64::MAX as u128, u64::MAX as u128);
let r1 = U256::from(n) * U256::from(v);
let r2 = n.full_mul(v);
assert_eq!(r1, r2)
}
#[test]
fn test_u128_max() {
let (n, v) = (u128::MAX - 1, u128::MAX - 1);
let r1 = U256::from(n) * U256::from(v);
let r2 = n.full_mul(v);
assert_eq!(r1, r2)
}
proptest! {
#[test]
fn fuzz_test(
n in u128::MIN..u128::MAX,
v in u128::MIN..u128::MAX
) {
let r1 = U128::from(n).as_u256() * U128::from(v).as_u256();
let r2 = n.full_mul(v);
assert_eq!(r1, r2);
}
}
}