Skip to main content

evm_dex_pool/v3/
full_math.rs

1use super::{Q96, THREE, TWO};
2use alloy::primitives::{uint, Uint, U256};
3use anyhow::{anyhow, Result};
4
5const ONE: U256 = uint!(1_U256);
6
7/// Full precision arithmetic operations for [`Uint`] types.
8pub trait FullMath {
9    fn mul_div(self, b: U256, denominator: U256) -> Result<U256>;
10    fn mul_div_rounding_up(self, b: U256, denominator: U256) -> Result<U256>;
11    fn mul_div_q96(self, b: U256) -> Result<U256>;
12}
13
14impl<const BITS: usize, const LIMBS: usize> FullMath for Uint<BITS, LIMBS> {
15    #[inline]
16    fn mul_div(self, b: U256, denominator: U256) -> Result<U256> {
17        mul_div(U256::from(self), b, denominator)
18    }
19
20    #[inline]
21    fn mul_div_rounding_up(self, b: U256, denominator: U256) -> Result<U256> {
22        mul_div_rounding_up(U256::from(self), b, denominator)
23    }
24
25    #[inline]
26    fn mul_div_q96(self, b: U256) -> Result<U256> {
27        mul_div_q96(U256::from(self), b)
28    }
29}
30
31/// Calculates floor(a*b/denominator) with full precision. Throws if result overflows a uint256 or
32/// denominator == 0
33#[inline]
34pub fn mul_div(a: U256, b: U256, mut denominator: U256) -> Result<U256> {
35    let mm = a.mul_mod(b, U256::MAX);
36    let mut prod_0 = a * b;
37    let mut prod_1 = mm - prod_0 - U256::from_limbs([(mm < prod_0) as u64, 0, 0, 0]);
38
39    if denominator <= prod_1 {
40        return Err(anyhow!("MulDiv overflow"));
41    }
42
43    if prod_1.is_zero() {
44        return Ok(prod_0 / denominator);
45    }
46
47    let remainder = a.mul_mod(b, denominator);
48    prod_1 -= U256::from_limbs([(remainder > prod_0) as u64, 0, 0, 0]);
49    prod_0 -= remainder;
50
51    let mut twos = (-denominator) & denominator;
52    denominator /= twos;
53    prod_0 /= twos;
54    twos = (-twos) / twos + ONE;
55    prod_0 |= prod_1 * twos;
56
57    let mut inv = (THREE * denominator) ^ TWO;
58    inv *= TWO - denominator * inv; // inverse mod 2**8
59    inv *= TWO - denominator * inv; // inverse mod 2**16
60    inv *= TWO - denominator * inv; // inverse mod 2**32
61    inv *= TWO - denominator * inv; // inverse mod 2**64
62    inv *= TWO - denominator * inv; // inverse mod 2**128
63    inv *= TWO - denominator * inv; // inverse mod 2**256
64
65    Ok(prod_0 * inv)
66}
67
68/// Calculates ceil(a*b/denominator) with full precision. Throws if result overflows a uint256 or
69/// denominator == 0
70#[inline]
71pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> Result<U256> {
72    let result = mul_div(a, b, denominator)?;
73
74    if a.mul_mod(b, denominator).is_zero() {
75        Ok(result)
76    } else if result == U256::MAX {
77        Err(anyhow!("MulDiv overflow"))
78    } else {
79        Ok(result + ONE)
80    }
81}
82
83/// Calculates a * b / 2^96 with full precision.
84#[inline]
85pub fn mul_div_q96(a: U256, b: U256) -> Result<U256> {
86    let prod0 = a * b;
87    let mm = a.mul_mod(b, U256::MAX);
88    let prod1 = mm - prod0 - U256::from_limbs([(mm < prod0) as u64, 0, 0, 0]);
89    if prod1 >= Q96 {
90        return Err(anyhow!("MulDiv overflow"));
91    }
92    Ok((prod0 >> 96) | (prod1 << 160))
93}