use alloy_primitives::{U256, U512};
use crate::AmmMathError;
#[inline]
fn to_u512(x: U256) -> U512 {
let l = x.as_limbs();
U512::from_limbs([l[0], l[1], l[2], l[3], 0, 0, 0, 0])
}
pub fn mul_div(a: U256, b: U256, denominator: U256) -> crate::Result<U256> {
if denominator.is_zero() {
return Err(AmmMathError::DivisionByZero);
}
let q = (to_u512(a) * to_u512(b)) / to_u512(denominator);
u512_to_u256(q).ok_or(AmmMathError::MulDivOverflow)
}
pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> crate::Result<U256> {
if denominator.is_zero() {
return Err(AmmMathError::DivisionByZero);
}
let d = to_u512(denominator);
let prod = to_u512(a) * to_u512(b);
let q = prod / d;
let q = if (prod % d) == U512::ZERO { q } else { q + U512::from(1u64) };
u512_to_u256(q).ok_or(AmmMathError::MulDivOverflow)
}
#[inline]
fn u512_to_u256(x: U512) -> Option<U256> {
let l = x.as_limbs();
if (l[4] | l[5] | l[6] | l[7]) != 0 {
return None;
}
Some(U256::from_limbs([l[0], l[1], l[2], l[3]]))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_mul_div() {
let r = mul_div(U256::from(10u64), U256::from(20u64), U256::from(5u64)).unwrap();
assert_eq!(r, U256::from(40u64));
}
#[test]
fn test_mul_div_large() {
let big = U256::from(1u64) << 128;
let r = mul_div(big, big, big).unwrap();
assert_eq!(r, big);
}
#[test]
fn test_mul_div_max() {
let r = mul_div(U256::MAX, U256::from(1u64), U256::from(1u64)).unwrap();
assert_eq!(r, U256::MAX);
}
#[test]
fn test_mul_div_rounding_up() {
let r = mul_div_rounding_up(U256::from(10u64), U256::from(3u64), U256::from(7u64)).unwrap();
assert_eq!(r, U256::from(5u64));
}
#[test]
fn test_mul_div_rounding_up_exact() {
let r = mul_div_rounding_up(U256::from(10u64), U256::from(2u64), U256::from(4u64)).unwrap();
assert_eq!(r, U256::from(5u64));
}
#[test]
fn test_div_by_zero() {
let r = mul_div(U256::from(1u64), U256::from(1u64), U256::ZERO);
assert!(r.is_err());
}
#[test]
fn test_mul_div_q96_style() {
let q96 = U256::from(1u64) << 96;
let amount = U256::from(1_000_000u64); let sqrt_price = q96; let r = mul_div(amount, sqrt_price, q96).unwrap();
assert_eq!(r, amount);
}
#[test]
fn test_mul_div_100_3_7() {
let r = mul_div(U256::from(100u64), U256::from(3u64), U256::from(7u64)).unwrap();
assert_eq!(r, U256::from(42u64));
}
#[test]
fn test_mul_div_rounding_up_100_3_7() {
let r =
mul_div_rounding_up(U256::from(100u64), U256::from(3u64), U256::from(7u64)).unwrap();
assert_eq!(r, U256::from(43u64));
}
#[test]
fn test_mul_div_max_max_max() {
let r = mul_div(U256::MAX, U256::MAX, U256::MAX).unwrap();
assert_eq!(r, U256::MAX);
}
#[test]
fn test_mul_div_rounding_up_div_by_zero() {
let r = mul_div_rounding_up(U256::from(1u64), U256::from(1u64), U256::ZERO);
assert!(r.is_err());
}
#[test]
fn prop_mul_div_matches_u512_definition() {
use rand::{rngs::StdRng, RngCore, SeedableRng};
fn rand_u256(rng: &mut impl RngCore) -> U256 {
let mut b = [0u8; 32];
rng.fill_bytes(&mut b);
U256::from_be_bytes(b)
}
let mut rng = StdRng::seed_from_u64(0xC0FFEE);
for _ in 0..100_000 {
let a = rand_u256(&mut rng);
let b = rand_u256(&mut rng);
let d = rand_u256(&mut rng);
if d.is_zero() {
assert!(mul_div(a, b, d).is_err());
assert!(mul_div_rounding_up(a, b, d).is_err());
continue;
}
let prod = to_u512(a) * to_u512(b);
let d512 = to_u512(d);
let q512 = prod / d512;
let rem = prod % d512;
let fits = {
let l = q512.as_limbs();
(l[4] | l[5] | l[6] | l[7]) == 0
};
match mul_div(a, b, d) {
Ok(q) => {
assert!(fits, "mul_div returned Ok but q overflows U256");
let qd = to_u512(q) * d512;
assert!(qd <= prod, "q*d > a*b");
assert!(prod - qd < d512, "remainder >= d (q too small)");
}
Err(_) => assert!(!fits, "mul_div errored but q fits U256"),
}
let up512 = if rem == U512::ZERO { q512 } else { q512 + U512::from(1u64) };
let up_fits = {
let l = up512.as_limbs();
(l[4] | l[5] | l[6] | l[7]) == 0
};
match mul_div_rounding_up(a, b, d) {
Ok(up) => {
assert!(up_fits, "rounding_up returned Ok but overflows U256");
assert_eq!(to_u512(up), up512, "round-up value mismatch");
}
Err(_) => assert!(!up_fits, "rounding_up errored but value fits U256"),
}
}
}
}