use alloy_primitives::U256;
use crate::AmmMathError;
pub fn mul_div(a: U256, b: U256, denominator: U256) -> crate::Result<U256> {
if denominator.is_zero() {
return Err(AmmMathError::DivisionByZero);
}
let (lo, hi) = widening_mul(a, b);
if hi.is_zero() {
return Ok(lo / denominator);
}
div_512(hi, lo, denominator)
}
pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> crate::Result<U256> {
let result = mul_div(a, b, denominator)?;
let (lo, hi) = widening_mul(a, b);
let (_, rem_hi, rem_lo) = div_mod_512(hi, lo, denominator);
if !rem_hi.is_zero() || !rem_lo.is_zero() {
if result == U256::MAX {
return Err(AmmMathError::MulDivOverflow);
}
Ok(result + U256::from(1))
} else {
Ok(result)
}
}
fn widening_mul(a: U256, b: U256) -> (U256, U256) {
let mask128 = U256::from(u128::MAX);
let a_lo = a & mask128;
let a_hi = a >> 128;
let b_lo = b & mask128;
let b_hi = b >> 128;
let p0: U256 = a_lo * b_lo; let p1: U256 = a_lo * b_hi;
let p2: U256 = a_hi * b_lo;
let p3: U256 = a_hi * b_hi;
let (mid, mid_carry) = p1.overflowing_add(p2);
let mid_lo = (mid & mask128) << 128;
let (lo, lo_carry) = p0.overflowing_add(mid_lo);
let mut hi = p3 + (mid >> 128);
if mid_carry {
hi += U256::from(1u64) << 128;
}
if lo_carry {
hi += U256::from(1u64);
}
(lo, hi)
}
fn div_512(hi: U256, lo: U256, d: U256) -> crate::Result<U256> {
let (q, ..) = div_mod_512(hi, lo, d);
if hi >= d && !d.is_zero() {
return Err(AmmMathError::MulDivOverflow);
}
Ok(q)
}
fn div_mod_512(hi: U256, lo: U256, d: U256) -> (U256, U256, U256) {
if d.is_zero() {
return (U256::ZERO, U256::ZERO, U256::ZERO);
}
if hi.is_zero() {
let q = lo / d;
let r = lo % d;
return (q, U256::ZERO, r);
}
if hi >= d {
let hi_rem = hi % d;
return div_mod_512_inner(hi_rem, lo, d);
}
div_mod_512_inner(hi, lo, d)
}
fn div_mod_512_inner(hi: U256, lo: U256, d: U256) -> (U256, U256, U256) {
let neg_d = U256::ZERO.wrapping_sub(d); let q_hi_part = neg_d / d + U256::from(1u64); let r_hi_part = neg_d % d;
let q_from_hi = hi.wrapping_mul(q_hi_part);
let (carry_lo, carry_hi) = widening_mul(hi, r_hi_part);
let (total_lo, overflow) = carry_lo.overflowing_add(lo);
let total_hi = if overflow { carry_hi + U256::from(1u64) } else { carry_hi };
let (q_extra, _, r_lo) = if total_hi.is_zero() {
let q = total_lo / d;
let r = total_lo % d;
(q, U256::ZERO, r)
} else {
div_mod_512_inner(total_hi, total_lo, d)
};
let q = q_from_hi.wrapping_add(q_extra);
(q, U256::ZERO, r_lo)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_mul_div() {
let r = mul_div(U256::from(10u64), U256::from(20u64), U256::from(5u64)).unwrap();
assert_eq!(r, U256::from(40u64));
}
#[test]
fn test_mul_div_large() {
let big = U256::from(1u64) << 128;
let r = mul_div(big, big, big).unwrap();
assert_eq!(r, big);
}
#[test]
fn test_mul_div_max() {
let r = mul_div(U256::MAX, U256::from(1u64), U256::from(1u64)).unwrap();
assert_eq!(r, U256::MAX);
}
#[test]
fn test_mul_div_rounding_up() {
let r = mul_div_rounding_up(U256::from(10u64), U256::from(3u64), U256::from(7u64)).unwrap();
assert_eq!(r, U256::from(5u64));
}
#[test]
fn test_mul_div_rounding_up_exact() {
let r = mul_div_rounding_up(U256::from(10u64), U256::from(2u64), U256::from(4u64)).unwrap();
assert_eq!(r, U256::from(5u64));
}
#[test]
fn test_div_by_zero() {
let r = mul_div(U256::from(1u64), U256::from(1u64), U256::ZERO);
assert!(r.is_err());
}
#[test]
fn test_mul_div_q96_style() {
let q96 = U256::from(1u64) << 96;
let amount = U256::from(1_000_000u64); let sqrt_price = q96; let r = mul_div(amount, sqrt_price, q96).unwrap();
assert_eq!(r, amount);
}
#[test]
fn test_widening_mul_max() {
let (lo, hi) = widening_mul(U256::MAX, U256::MAX);
assert_eq!(lo, U256::from(1u64));
assert_eq!(hi, U256::MAX - U256::from(1u64));
}
#[test]
fn test_mul_div_100_3_7() {
let r = mul_div(U256::from(100u64), U256::from(3u64), U256::from(7u64)).unwrap();
assert_eq!(r, U256::from(42u64));
}
#[test]
fn test_mul_div_rounding_up_100_3_7() {
let r =
mul_div_rounding_up(U256::from(100u64), U256::from(3u64), U256::from(7u64)).unwrap();
assert_eq!(r, U256::from(43u64));
}
#[test]
fn test_mul_div_max_max_max() {
let r = mul_div(U256::MAX, U256::MAX, U256::MAX).unwrap();
assert_eq!(r, U256::MAX);
}
#[test]
fn test_mul_div_rounding_up_div_by_zero() {
let r = mul_div_rounding_up(U256::from(1u64), U256::from(1u64), U256::ZERO);
assert!(r.is_err());
}
}