wp-evm-amm-math 0.1.0

Native Rust CLMM/AMM math (Uniswap V3 compatible, zero SDK deps)
Documentation
//! Full-precision math: `mul_div` and `mul_div_rounding_up`.
//!
//! Equivalent to Uniswap V3 `FullMath.sol`.
//! Uses 512-bit intermediate to avoid overflow on U256 x U256.

use alloy_primitives::U256;

use crate::AmmMathError;

/// (a * b) / denominator, rounded down.
///
/// Uses widening multiplication (U256 -> two-U256 hi/lo pair)
/// to avoid overflow.
pub fn mul_div(a: U256, b: U256, denominator: U256) -> crate::Result<U256> {
    if denominator.is_zero() {
        return Err(AmmMathError::DivisionByZero);
    }

    // Widening multiply: a * b = (hi, lo) where the full
    // 512-bit product is  hi * 2^256 + lo.
    let (lo, hi) = widening_mul(a, b);

    // Fast path: no overflow into hi limb.
    if hi.is_zero() {
        return Ok(lo / denominator);
    }

    // Full 512 / 256 division.
    div_512(hi, lo, denominator)
}

/// (a * b) / denominator, rounded up (ceiling).
pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> crate::Result<U256> {
    let result = mul_div(a, b, denominator)?;

    // Check if there was a remainder: a*b mod denominator != 0
    let (lo, hi) = widening_mul(a, b);
    let (_, rem_hi, rem_lo) = div_mod_512(hi, lo, denominator);

    if !rem_hi.is_zero() || !rem_lo.is_zero() {
        if result == U256::MAX {
            return Err(AmmMathError::MulDivOverflow);
        }
        Ok(result + U256::from(1))
    } else {
        Ok(result)
    }
}

// ---------- internal helpers ----------

/// 256x256 -> 512-bit widening multiply.
/// Returns (lo, hi) such that a * b = hi * 2^256 + lo.
fn widening_mul(a: U256, b: U256) -> (U256, U256) {
    // Split each U256 into two 128-bit halves.
    let mask128 = U256::from(u128::MAX);
    let a_lo = a & mask128;
    let a_hi = a >> 128;
    let b_lo = b & mask128;
    let b_hi = b >> 128;

    // Partial products (each fits in 256 bits).
    let p0: U256 = a_lo * b_lo; // max 2^256-1
    let p1: U256 = a_lo * b_hi;
    let p2: U256 = a_hi * b_lo;
    let p3: U256 = a_hi * b_hi;

    // Accumulate with carry tracking.
    // mid = p1 + p2  (may overflow U256 by at most 1 bit)
    let (mid, mid_carry) = p1.overflowing_add(p2);

    // lo = p0 + (mid_low_128 << 128)
    let mid_lo = (mid & mask128) << 128;
    let (lo, lo_carry) = p0.overflowing_add(mid_lo);

    // hi = p3 + (mid >> 128) + mid_carry_bit<<128 + lo_carry_bit
    let mut hi = p3 + (mid >> 128);
    if mid_carry {
        // The carry from p1+p2 is worth 2^256, shifted right 128
        // means +2^128 to hi.
        hi += U256::from(1u64) << 128;
    }
    if lo_carry {
        hi += U256::from(1u64);
    }

    (lo, hi)
}

/// 512 / 256 division.  Returns quotient (must fit in U256).
fn div_512(hi: U256, lo: U256, d: U256) -> crate::Result<U256> {
    let (q, ..) = div_mod_512(hi, lo, d);
    // If hi >= d before dividing, the quotient might not fit.
    // We validate post-hoc.
    // Actually we need to check: the quotient fit was already
    // guaranteed by the algorithm below producing at most 256 bits
    // when hi < d.  If hi >= d the Solidity version would revert;
    // mirror that.
    if hi >= d && !d.is_zero() {
        return Err(AmmMathError::MulDivOverflow);
    }
    Ok(q)
}

/// Full 512-bit division with remainder.
/// Returns (quotient, remainder_hi, remainder_lo).
/// Uses the same algorithm as Solidity FullMath: reduce hi first,
/// then do schoolbook long-division in 256-bit digits.
fn div_mod_512(hi: U256, lo: U256, d: U256) -> (U256, U256, U256) {
    if d.is_zero() {
        // Caller should have checked; return zeros.
        return (U256::ZERO, U256::ZERO, U256::ZERO);
    }

    // If hi is zero, simple 256-bit division.
    if hi.is_zero() {
        let q = lo / d;
        let r = lo % d;
        return (q, U256::ZERO, r);
    }

    // Solidity FullMath approach (Remco Bloemen's technique):
    // We need to compute  (hi * 2^256 + lo) / d.
    //
    // Step 1: subtract out 256-bit multiples so hi < d.
    // If hi >= d the result won't fit in U256 -- return
    // a wrapped value (caller checks).
    if hi >= d {
        // Result overflows U256.  Return a sentinel;
        // the outer function checks hi < d.
        // We'll still compute a "wrapped" quotient for the
        // rounding-up remainder check.
        // Use iterative subtraction only for remainder calc.
        let hi_rem = hi % d;
        return div_mod_512_inner(hi_rem, lo, d);
    }

    div_mod_512_inner(hi, lo, d)
}

/// Inner division where hi < d is guaranteed.
fn div_mod_512_inner(hi: U256, lo: U256, d: U256) -> (U256, U256, U256) {
    // We compute (hi * 2^256 + lo) / d where hi < d.
    // Use binary long-division bit by bit over 256 iterations
    // for the hi part contribution, then add lo / d.
    //
    // More efficient approach: use the identity
    //   (hi * 2^256 + lo) / d
    //   = hi * (2^256 / d) + (hi * (2^256 % d) + lo) / d
    //
    // But 2^256 doesn't fit in U256.  Instead:
    //   2^256 = (2^256 - d) + d
    //   2^256 / d = (2^256 - d) / d + 1  (when 2^256 % d == 0: edge)
    //   2^256 % d = (2^256 - d) % d
    //
    // Let inv = (0 - d) / d  (wrapping), rem = (0 - d) % d.
    //   where (0 - d) is U256 wrapping subtraction = 2^256 - d.

    let neg_d = U256::ZERO.wrapping_sub(d); // 2^256 - d (mod 2^256)
    let q_hi_part = neg_d / d + U256::from(1u64); // (2^256 - d)/d + 1
    let r_hi_part = neg_d % d; // 2^256 mod d

    // hi contribution to quotient/remainder:
    //   hi * 2^256 / d = hi * q_hi_part + (hi * r_hi_part) / d
    //   But hi * q_hi_part might overflow U256.  We need the lower
    //   256 bits only (since we know the final result fits in 256).

    // hi * q_hi_part (lower 256 bits is fine since result fits)
    let q_from_hi = hi.wrapping_mul(q_hi_part);

    // hi * r_hi_part may overflow, do widening:
    let (carry_lo, carry_hi) = widening_mul(hi, r_hi_part);

    // Now add lo to the carry:  total_lo = carry_lo + lo
    let (total_lo, overflow) = carry_lo.overflowing_add(lo);
    let total_hi = if overflow { carry_hi + U256::from(1u64) } else { carry_hi };

    // Now divide (total_hi, total_lo) by d.
    // Since hi < d and r_hi_part < d, we have hi * r_hi_part < d^2.
    // Adding lo (< 2^256) gives < d^2 + 2^256.  When d >= 2,
    // total_hi < d so we can recurse.  But to avoid infinite
    // recursion we handle total_hi == 0 base case.
    let (q_extra, _, r_lo) = if total_hi.is_zero() {
        let q = total_lo / d;
        let r = total_lo % d;
        (q, U256::ZERO, r)
    } else {
        // total_hi < d guaranteed, recurse once more.
        div_mod_512_inner(total_hi, total_lo, d)
    };

    let q = q_from_hi.wrapping_add(q_extra);
    (q, U256::ZERO, r_lo)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_simple_mul_div() {
        let r = mul_div(U256::from(10u64), U256::from(20u64), U256::from(5u64)).unwrap();
        assert_eq!(r, U256::from(40u64));
    }

    #[test]
    fn test_mul_div_large() {
        // (2^128) * (2^128) / (2^128) = 2^128
        let big = U256::from(1u64) << 128;
        let r = mul_div(big, big, big).unwrap();
        assert_eq!(r, big);
    }

    #[test]
    fn test_mul_div_max() {
        // U256::MAX * 1 / 1 = U256::MAX
        let r = mul_div(U256::MAX, U256::from(1u64), U256::from(1u64)).unwrap();
        assert_eq!(r, U256::MAX);
    }

    #[test]
    fn test_mul_div_rounding_up() {
        // 10 * 3 / 7 = 4.28.. -> 5
        let r = mul_div_rounding_up(U256::from(10u64), U256::from(3u64), U256::from(7u64)).unwrap();
        assert_eq!(r, U256::from(5u64));
    }

    #[test]
    fn test_mul_div_rounding_up_exact() {
        // 10 * 2 / 4 = 5.0 -> 5 (no rounding)
        let r = mul_div_rounding_up(U256::from(10u64), U256::from(2u64), U256::from(4u64)).unwrap();
        assert_eq!(r, U256::from(5u64));
    }

    #[test]
    fn test_div_by_zero() {
        let r = mul_div(U256::from(1u64), U256::from(1u64), U256::ZERO);
        assert!(r.is_err());
    }

    #[test]
    fn test_mul_div_q96_style() {
        // Simulate: amount * sqrt_price / Q96
        let q96 = U256::from(1u64) << 96;
        let amount = U256::from(1_000_000u64); // 1M
        let sqrt_price = q96; // price = 1.0
        let r = mul_div(amount, sqrt_price, q96).unwrap();
        assert_eq!(r, amount);
    }

    #[test]
    fn test_widening_mul_max() {
        // U256::MAX * U256::MAX = (2^256-1)^2
        // hi should be U256::MAX - 1, lo should be 1
        let (lo, hi) = widening_mul(U256::MAX, U256::MAX);
        assert_eq!(lo, U256::from(1u64));
        assert_eq!(hi, U256::MAX - U256::from(1u64));
    }

    #[test]
    fn test_mul_div_100_3_7() {
        // 100 * 3 / 7 = 42 (floor)
        let r = mul_div(U256::from(100u64), U256::from(3u64), U256::from(7u64)).unwrap();
        assert_eq!(r, U256::from(42u64));
    }

    #[test]
    fn test_mul_div_rounding_up_100_3_7() {
        // 100 * 3 / 7 = 42.857.. -> 43 (ceiling)
        let r =
            mul_div_rounding_up(U256::from(100u64), U256::from(3u64), U256::from(7u64)).unwrap();
        assert_eq!(r, U256::from(43u64));
    }

    #[test]
    fn test_mul_div_max_max_max() {
        // MAX * MAX / MAX = MAX
        let r = mul_div(U256::MAX, U256::MAX, U256::MAX).unwrap();
        assert_eq!(r, U256::MAX);
    }

    #[test]
    fn test_mul_div_rounding_up_div_by_zero() {
        let r = mul_div_rounding_up(U256::from(1u64), U256::from(1u64), U256::ZERO);
        assert!(r.is_err());
    }
}