wp-evm-amm-math 0.1.1

Native Rust CLMM/AMM math (Uniswap V3 compatible, zero SDK deps)
Documentation
//! Encode a price ratio as sqrt_price_x96.
//!
//! `encode_sqrt_ratio_x96(amount1, amount0)` computes:
//!   `sqrt(amount1 / amount0) * 2^96`

use alloy_primitives::U256;

use crate::AmmMathError;

/// Compute `sqrt(amount1 / amount0) * 2^96`.
///
/// Uses integer Newton's method to compute the square root.
///
/// # Arguments
/// * `amount1` - numerator of the price ratio
/// * `amount0` - denominator of the price ratio
///
/// Both must be non-zero (at least one must be > 0, but the
/// denominator `amount0` must be > 0).
pub fn encode_sqrt_ratio_x96(amount1: U256, amount0: U256) -> crate::Result<U256> {
    if amount0.is_zero() {
        return Err(AmmMathError::ZeroAmounts);
    }

    // Compute (amount1 * 2^192) / amount0, then take sqrt.
    // This avoids precision loss from dividing first.
    //
    // If amount1 is small enough (< 2^64), we can shift by 192
    // without overflow in U256.  Otherwise shift less and
    // compensate.

    // ratio_x192 = amount1 << 192 / amount0
    // But amount1 << 192 may overflow U256 (256 bits).
    // Split: first compute amount1 << 96 / amount0, then
    // shift result << 96 (but that also may overflow).
    //
    // Better approach:
    // sqrt(a1/a0) * 2^96
    // = sqrt(a1 * 2^192 / a0)
    // = sqrt(a1 / a0) * 2^96
    //
    // Use full_math to compute a1 * 2^192 / a0 safely.

    // Try: if amount1 fits in 64 bits, direct computation.
    let ratio_x192 = if amount1 <= U256::from(u64::MAX) {
        // amount1 << 192 fits in 256 bits only if amount1 < 2^64.
        (amount1 << 192) / amount0
    } else {
        // Use mul_div: (amount1 * 2^192) / amount0
        // = mul_div(amount1, 2^192, amount0)
        // But 2^192 is a U256 that fits.
        let q192 = U256::from(1u64) << 192;
        crate::full_math::mul_div(amount1, q192, amount0)?
    };

    // Now compute sqrt(ratio_x192).
    let result = sqrt_u256(ratio_x192);

    // Adjust: if result^2 < ratio_x192, we might need +1
    // (rounding).  The Uniswap SDK just uses sqrt without
    // rounding adjustment.
    Ok(result)
}

/// Integer square root of a U256 via Newton's method.
fn sqrt_u256(x: U256) -> U256 {
    if x.is_zero() {
        return U256::ZERO;
    }
    if x <= U256::from(3u64) {
        return U256::from(1u64);
    }

    // Initial guess: 2^((bit_len+1)/2)
    let bit_len = 256 - x.leading_zeros() as u32;
    let mut guess = U256::from(1u64) << bit_len.div_ceil(2);

    // Newton iterations.
    loop {
        let next = (guess + x / guess) >> 1;
        if next >= guess {
            break;
        }
        guess = next;
    }

    guess
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_encode_1_to_1() {
        // sqrt(1/1) * 2^96 = 2^96
        let result = encode_sqrt_ratio_x96(U256::from(1u64), U256::from(1u64)).unwrap();
        let q96 = U256::from(1u64) << 96;
        assert_eq!(result, q96);
    }

    #[test]
    fn test_encode_4_to_1() {
        // sqrt(4/1) * 2^96 = 2 * 2^96 = 2^97
        let result = encode_sqrt_ratio_x96(U256::from(4u64), U256::from(1u64)).unwrap();
        let expected = U256::from(1u64) << 97;
        assert_eq!(result, expected);
    }

    #[test]
    fn test_encode_1_to_4() {
        // sqrt(1/4) * 2^96 = 0.5 * 2^96 = 2^95
        let result = encode_sqrt_ratio_x96(U256::from(1u64), U256::from(4u64)).unwrap();
        let expected = U256::from(1u64) << 95;
        assert_eq!(result, expected);
    }

    #[test]
    fn test_encode_zero_denominator() {
        let result = encode_sqrt_ratio_x96(U256::from(1u64), U256::ZERO);
        assert!(result.is_err());
    }

    #[test]
    fn test_encode_large_ratio() {
        // 10^18 : 10^18 = 1:1 -> 2^96
        let big = U256::from(1_000_000_000_000_000_000u64);
        let result = encode_sqrt_ratio_x96(big, big).unwrap();
        let q96 = U256::from(1u64) << 96;
        assert_eq!(result, q96);
    }

    #[test]
    fn test_sqrt_u256() {
        assert_eq!(sqrt_u256(U256::ZERO), U256::ZERO);
        assert_eq!(sqrt_u256(U256::from(1u64)), U256::from(1u64));
        assert_eq!(sqrt_u256(U256::from(4u64)), U256::from(2u64));
        assert_eq!(sqrt_u256(U256::from(9u64)), U256::from(3u64));
        assert_eq!(sqrt_u256(U256::from(100u64)), U256::from(10u64));
        // Non-perfect square: floor(sqrt(8)) = 2
        assert_eq!(sqrt_u256(U256::from(8u64)), U256::from(2u64));
    }
}