wp-evm-amm-math 0.1.6

Native Rust CLMM/AMM math (Uniswap V3 compatible, zero SDK deps)
Documentation
//! Full-precision math: `mul_div` and `mul_div_rounding_up`.
//!
//! Equivalent to Uniswap V3 `FullMath.sol`.
//! Uses 512-bit intermediate to avoid overflow on U256 x U256.

use alloy_primitives::{U256, U512};

use crate::AmmMathError;

/// Widen a `U256` to `U512` via limbs (independent of inter-`Uint` `From` impls).
#[inline]
fn to_u512(x: U256) -> U512 {
    let l = x.as_limbs();
    U512::from_limbs([l[0], l[1], l[2], l[3], 0, 0, 0, 0])
}

/// (a * b) / denominator, rounded down.
///
/// Widens to `U512` to avoid overflow on the intermediate product, then
/// divides. Errors `MulDivOverflow` if the quotient does not fit in `U256`
/// (mirrors Solidity `FullMath.mulDiv`'s implicit revert).
pub fn mul_div(a: U256, b: U256, denominator: U256) -> crate::Result<U256> {
    if denominator.is_zero() {
        return Err(AmmMathError::DivisionByZero);
    }
    let q = (to_u512(a) * to_u512(b)) / to_u512(denominator);
    u512_to_u256(q).ok_or(AmmMathError::MulDivOverflow)
}

/// (a * b) / denominator, rounded up (ceiling).
pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> crate::Result<U256> {
    if denominator.is_zero() {
        return Err(AmmMathError::DivisionByZero);
    }
    let d = to_u512(denominator);
    let prod = to_u512(a) * to_u512(b);
    let q = prod / d;
    let q = if (prod % d) == U512::ZERO { q } else { q + U512::from(1u64) };
    u512_to_u256(q).ok_or(AmmMathError::MulDivOverflow)
}

// ---------- internal helpers ----------

/// Narrow a `U512` back to `U256`, returning `None` if the high 256 bits are set.
#[inline]
fn u512_to_u256(x: U512) -> Option<U256> {
    let l = x.as_limbs();
    if (l[4] | l[5] | l[6] | l[7]) != 0 {
        return None;
    }
    Some(U256::from_limbs([l[0], l[1], l[2], l[3]]))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_simple_mul_div() {
        let r = mul_div(U256::from(10u64), U256::from(20u64), U256::from(5u64)).unwrap();
        assert_eq!(r, U256::from(40u64));
    }

    #[test]
    fn test_mul_div_large() {
        // (2^128) * (2^128) / (2^128) = 2^128
        let big = U256::from(1u64) << 128;
        let r = mul_div(big, big, big).unwrap();
        assert_eq!(r, big);
    }

    #[test]
    fn test_mul_div_max() {
        // U256::MAX * 1 / 1 = U256::MAX
        let r = mul_div(U256::MAX, U256::from(1u64), U256::from(1u64)).unwrap();
        assert_eq!(r, U256::MAX);
    }

    #[test]
    fn test_mul_div_rounding_up() {
        // 10 * 3 / 7 = 4.28.. -> 5
        let r = mul_div_rounding_up(U256::from(10u64), U256::from(3u64), U256::from(7u64)).unwrap();
        assert_eq!(r, U256::from(5u64));
    }

    #[test]
    fn test_mul_div_rounding_up_exact() {
        // 10 * 2 / 4 = 5.0 -> 5 (no rounding)
        let r = mul_div_rounding_up(U256::from(10u64), U256::from(2u64), U256::from(4u64)).unwrap();
        assert_eq!(r, U256::from(5u64));
    }

    #[test]
    fn test_div_by_zero() {
        let r = mul_div(U256::from(1u64), U256::from(1u64), U256::ZERO);
        assert!(r.is_err());
    }

    #[test]
    fn test_mul_div_q96_style() {
        // Simulate: amount * sqrt_price / Q96
        let q96 = U256::from(1u64) << 96;
        let amount = U256::from(1_000_000u64); // 1M
        let sqrt_price = q96; // price = 1.0
        let r = mul_div(amount, sqrt_price, q96).unwrap();
        assert_eq!(r, amount);
    }

    #[test]
    fn test_mul_div_100_3_7() {
        // 100 * 3 / 7 = 42 (floor)
        let r = mul_div(U256::from(100u64), U256::from(3u64), U256::from(7u64)).unwrap();
        assert_eq!(r, U256::from(42u64));
    }

    #[test]
    fn test_mul_div_rounding_up_100_3_7() {
        // 100 * 3 / 7 = 42.857.. -> 43 (ceiling)
        let r =
            mul_div_rounding_up(U256::from(100u64), U256::from(3u64), U256::from(7u64)).unwrap();
        assert_eq!(r, U256::from(43u64));
    }

    #[test]
    fn test_mul_div_max_max_max() {
        // MAX * MAX / MAX = MAX
        let r = mul_div(U256::MAX, U256::MAX, U256::MAX).unwrap();
        assert_eq!(r, U256::MAX);
    }

    #[test]
    fn test_mul_div_rounding_up_div_by_zero() {
        let r = mul_div_rounding_up(U256::from(1u64), U256::from(1u64), U256::ZERO);
        assert!(r.is_err());
    }

    #[test]
    fn prop_mul_div_matches_u512_definition() {
        // Independent oracle: a/b/d are random; the assertions check the
        // DEFINITION of floored division (q*d <= a*b < (q+1)*d) computed in
        // U512, plus the exact round-up relationship. Not circular: the
        // properties hold for the correct q regardless of how it was computed.
        use rand::{rngs::StdRng, RngCore, SeedableRng};

        fn rand_u256(rng: &mut impl RngCore) -> U256 {
            let mut b = [0u8; 32];
            rng.fill_bytes(&mut b);
            U256::from_be_bytes(b)
        }

        let mut rng = StdRng::seed_from_u64(0xC0FFEE);
        for _ in 0..100_000 {
            let a = rand_u256(&mut rng);
            let b = rand_u256(&mut rng);
            let d = rand_u256(&mut rng);

            if d.is_zero() {
                assert!(mul_div(a, b, d).is_err());
                assert!(mul_div_rounding_up(a, b, d).is_err());
                continue;
            }

            let prod = to_u512(a) * to_u512(b);
            let d512 = to_u512(d);
            let q512 = prod / d512;
            let rem = prod % d512;
            let fits = {
                let l = q512.as_limbs();
                (l[4] | l[5] | l[6] | l[7]) == 0
            };

            match mul_div(a, b, d) {
                Ok(q) => {
                    assert!(fits, "mul_div returned Ok but q overflows U256");
                    let qd = to_u512(q) * d512;
                    assert!(qd <= prod, "q*d > a*b");
                    assert!(prod - qd < d512, "remainder >= d (q too small)");
                }
                Err(_) => assert!(!fits, "mul_div errored but q fits U256"),
            }

            let up512 = if rem == U512::ZERO { q512 } else { q512 + U512::from(1u64) };
            let up_fits = {
                let l = up512.as_limbs();
                (l[4] | l[5] | l[6] | l[7]) == 0
            };
            match mul_div_rounding_up(a, b, d) {
                Ok(up) => {
                    assert!(up_fits, "rounding_up returned Ok but overflows U256");
                    assert_eq!(to_u512(up), up512, "round-up value mismatch");
                }
                Err(_) => assert!(!up_fits, "rounding_up errored but value fits U256"),
            }
        }
    }
}