use alloy_primitives::{I256, U160, U256};
pub const Q128: U256 = U256::from_limbs([0, 0, 1, 0]);
pub const Q96_U160: U160 = U160::from_limbs([0, 1 << 32, 0]);
#[derive(Debug)]
pub struct FullMath;
impl FullMath {
pub fn mul_div(a: U256, b: U256, mut denominator: U256) -> anyhow::Result<U256> {
let mm = a.mul_mod(b, U256::MAX);
let mut prod_0 = a * b;
let mut prod_1 = mm - prod_0 - U256::from_limbs([(mm < prod_0) as u64, 0, 0, 0]);
if denominator <= prod_1 {
anyhow::bail!("Result would overflow 256 bits");
}
let remainder = a.mul_mod(b, denominator);
prod_1 -= U256::from_limbs([(remainder > prod_0) as u64, 0, 0, 0]);
prod_0 -= remainder;
let mut twos = (-denominator) & denominator;
denominator /= twos;
prod_0 /= twos;
twos = (-twos) / twos + U256::from(1);
prod_0 |= prod_1 * twos;
let mut inv = (U256::from(3) * denominator) ^ U256::from(2);
inv *= U256::from(2) - denominator * inv;
inv *= U256::from(2) - denominator * inv;
inv *= U256::from(2) - denominator * inv;
inv *= U256::from(2) - denominator * inv;
inv *= U256::from(2) - denominator * inv;
inv *= U256::from(2) - denominator * inv;
let result = prod_0 * inv;
Ok(result)
}
pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> anyhow::Result<U256> {
let result = Self::mul_div(a, b, denominator)?;
if a.mul_mod(b, denominator).is_zero() {
Ok(result)
} else if result == U256::MAX {
anyhow::bail!("Result would overflow 256 bits")
} else {
Ok(result + U256::from(1))
}
}
pub fn div_rounding_up(a: U256, b: U256) -> anyhow::Result<U256> {
if b.is_zero() {
anyhow::bail!("Cannot divide by zero");
}
let quotient = a / b;
let remainder = a % b;
if remainder > U256::ZERO {
if quotient == U256::MAX {
anyhow::bail!("Result would overflow 256 bits");
}
Ok(quotient + U256::from(1))
} else {
Ok(quotient)
}
}
pub fn sqrt(x: U256) -> U256 {
if x.is_zero() {
return U256::ZERO;
}
if x == U256::from(1u128) {
return U256::from(1u128);
}
let mut z = x;
let mut y = (x + U256::from(1u128)) >> 1;
while y < z {
z = y;
y = (x / z + z) >> 1;
}
z
}
#[must_use]
pub fn truncate_to_u128(value: U256) -> u128 {
(value & U256::from(u128::MAX)).to::<u128>()
}
#[must_use]
pub fn truncate_to_u256(value: I256) -> U256 {
value.into_raw()
}
#[must_use]
pub fn truncate_to_i256(value: U256) -> I256 {
I256::from_raw(value)
}
}
#[cfg(test)]
mod tests {
use rstest::*;
use super::*;
#[rstest]
fn test_mul_div_reverts_denominator_zero() {
assert!(FullMath::mul_div(Q128, U256::from(5), U256::ZERO).is_err());
assert!(FullMath::mul_div(Q128, Q128, U256::ZERO).is_err());
}
#[rstest]
fn test_mul_div_reverts_output_overflow() {
assert!(FullMath::mul_div(Q128, Q128, U256::from(1)).is_err());
assert!(FullMath::mul_div(U256::MAX, U256::MAX, U256::from(1)).is_err());
assert!(FullMath::mul_div(U256::MAX, U256::MAX, U256::from(2)).is_err());
assert!(FullMath::mul_div(U256::MAX, U256::MAX, U256::MAX - U256::from(1)).is_err());
}
#[rstest]
fn test_mul_div_all_max_inputs() {
let result = FullMath::mul_div(U256::MAX, U256::MAX, U256::MAX).unwrap();
assert_eq!(result, U256::MAX);
}
#[rstest]
fn test_mul_div_accurate_without_phantom_overflow() {
let numerator_b = Q128 * U256::from(50) / U256::from(100); let denominator = Q128 * U256::from(150) / U256::from(100); let expected = Q128 / U256::from(3);
let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
assert_eq!(result, expected);
}
#[rstest]
fn test_mul_div_accurate_with_phantom_overflow() {
let numerator_b = U256::from(35) * Q128;
let denominator = U256::from(8) * Q128;
let expected = U256::from(4375) * Q128 / U256::from(1000);
let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
assert_eq!(result, expected);
}
#[rstest]
fn test_mul_div_accurate_with_phantom_overflow_repeating_decimal() {
let numerator_b = U256::from(1000) * Q128;
let denominator = U256::from(3000) * Q128;
let expected = Q128 / U256::from(3);
let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
assert_eq!(result, expected);
}
#[rstest]
fn test_mul_div_basic_cases() {
assert_eq!(
FullMath::mul_div(U256::from(100), U256::from(200), U256::from(50)).unwrap(),
U256::from(400)
);
assert_eq!(
FullMath::mul_div(U256::from(1000), U256::from(1), U256::from(4)).unwrap(),
U256::from(250)
);
assert_eq!(
FullMath::mul_div(U256::from(1), U256::from(1), U256::from(3)).unwrap(),
U256::ZERO
);
}
#[rstest]
fn test_mul_div_rounding_up_reverts_denominator_zero() {
assert!(FullMath::mul_div_rounding_up(Q128, U256::from(5), U256::ZERO).is_err());
assert!(FullMath::mul_div_rounding_up(Q128, Q128, U256::ZERO).is_err());
}
#[rstest]
fn test_mul_div_rounding_up_reverts_output_overflow() {
assert!(FullMath::mul_div_rounding_up(Q128, Q128, U256::from(1)).is_err());
assert!(FullMath::mul_div_rounding_up(U256::MAX, U256::MAX, U256::from(2)).is_err());
assert!(
FullMath::mul_div_rounding_up(U256::MAX, U256::MAX, U256::MAX - U256::from(1)).is_err()
);
}
#[rstest]
fn test_mul_div_rounding_up_reverts_overflow_after_rounding_case_1() {
let a = U256::from_str_radix("535006138814359", 10).unwrap();
let b = U256::from_str_radix(
"432862656469423142931042426214547535783388063929571229938474969",
10,
)
.unwrap();
let denominator = U256::from(2);
assert!(FullMath::mul_div_rounding_up(a, b, denominator).is_err());
}
#[rstest]
fn test_mul_div_rounding_up_reverts_overflow_after_rounding_case_2() {
let a = U256::from_str_radix(
"115792089237316195423570985008687907853269984659341747863450311749907997002549",
10,
)
.unwrap();
let b = U256::from_str_radix(
"115792089237316195423570985008687907853269984659341747863450311749907997002550",
10,
)
.unwrap();
let denominator = U256::from_str_radix(
"115792089237316195423570985008687907853269984653042931687443039491902864365164",
10,
)
.unwrap();
assert!(FullMath::mul_div_rounding_up(a, b, denominator).is_err());
}
#[rstest]
fn test_mul_div_rounding_up_all_max_inputs() {
let result = FullMath::mul_div_rounding_up(U256::MAX, U256::MAX, U256::MAX).unwrap();
assert_eq!(result, U256::MAX);
}
#[rstest]
fn test_mul_div_rounding_up_accurate_without_phantom_overflow() {
let numerator_b = Q128 * U256::from(50) / U256::from(100); let denominator = Q128 * U256::from(150) / U256::from(100); let expected = Q128 / U256::from(3) + U256::from(1);
let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
assert_eq!(result, expected);
}
#[rstest]
fn test_mul_div_rounding_up_accurate_with_phantom_overflow() {
let numerator_b = U256::from(35) * Q128;
let denominator = U256::from(8) * Q128;
let expected = U256::from(4375) * Q128 / U256::from(1000);
let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
assert_eq!(result, expected);
}
#[rstest]
fn test_mul_div_rounding_up_accurate_with_phantom_overflow_repeating_decimal() {
let numerator_b = U256::from(1000) * Q128;
let denominator = U256::from(3000) * Q128;
let expected = Q128 / U256::from(3) + U256::from(1);
let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
assert_eq!(result, expected);
}
#[rstest]
fn test_mul_div_rounding_up_basic_cases() {
assert_eq!(
FullMath::mul_div_rounding_up(U256::from(100), U256::from(200), U256::from(50))
.unwrap(),
U256::from(400)
);
assert_eq!(
FullMath::mul_div_rounding_up(U256::from(1), U256::from(1), U256::from(3)).unwrap(),
U256::from(1) );
assert_eq!(
FullMath::mul_div_rounding_up(U256::from(7), U256::from(3), U256::from(4)).unwrap(),
U256::from(6)
);
assert_eq!(
FullMath::mul_div_rounding_up(U256::ZERO, U256::from(100), U256::from(3)).unwrap(),
U256::ZERO
);
}
#[rstest]
fn test_mul_div_rounding_up_overflow_at_max() {
assert!(FullMath::mul_div_rounding_up(U256::MAX, U256::from(2), U256::from(2)).is_ok());
assert_eq!(
FullMath::mul_div_rounding_up(U256::MAX, U256::from(1), U256::from(1)).unwrap(),
U256::MAX
);
}
#[rstest]
fn test_truncate_to_u128_preserves_small_values() {
let value = U256::from(12345u128);
assert_eq!(FullMath::truncate_to_u128(value), 12345u128);
let max_value = U256::from(u128::MAX);
assert_eq!(FullMath::truncate_to_u128(max_value), u128::MAX);
}
#[rstest]
fn test_truncate_to_u128_discards_upper_bits() {
let value = U256::from(u128::MAX) + U256::from(1);
assert_eq!(FullMath::truncate_to_u128(value), 0);
let value = (U256::from(u128::MAX) << 128) | U256::from(0x1234u128);
assert_eq!(FullMath::truncate_to_u128(value), 0x1234u128);
}
}