evm_dex_pool/v3/
full_math.rs1use super::{Q96, THREE, TWO};
2use alloy::primitives::{uint, Uint, U256};
3use anyhow::{anyhow, Result};
4
5const ONE: U256 = uint!(1_U256);
6
7pub trait FullMath {
9 fn mul_div(self, b: U256, denominator: U256) -> Result<U256>;
10 fn mul_div_rounding_up(self, b: U256, denominator: U256) -> Result<U256>;
11 fn mul_div_q96(self, b: U256) -> Result<U256>;
12}
13
14impl<const BITS: usize, const LIMBS: usize> FullMath for Uint<BITS, LIMBS> {
15 #[inline]
16 fn mul_div(self, b: U256, denominator: U256) -> Result<U256> {
17 mul_div(U256::from(self), b, denominator)
18 }
19
20 #[inline]
21 fn mul_div_rounding_up(self, b: U256, denominator: U256) -> Result<U256> {
22 mul_div_rounding_up(U256::from(self), b, denominator)
23 }
24
25 #[inline]
26 fn mul_div_q96(self, b: U256) -> Result<U256> {
27 mul_div_q96(U256::from(self), b)
28 }
29}
30
31#[inline]
34pub fn mul_div(a: U256, b: U256, mut denominator: U256) -> Result<U256> {
35 let mm = a.mul_mod(b, U256::MAX);
36 let mut prod_0 = a * b;
37 let mut prod_1 = mm - prod_0 - U256::from_limbs([(mm < prod_0) as u64, 0, 0, 0]);
38
39 if denominator <= prod_1 {
40 return Err(anyhow!("MulDiv overflow"));
41 }
42
43 if prod_1.is_zero() {
44 return Ok(prod_0 / denominator);
45 }
46
47 let remainder = a.mul_mod(b, denominator);
48 prod_1 -= U256::from_limbs([(remainder > prod_0) as u64, 0, 0, 0]);
49 prod_0 -= remainder;
50
51 let mut twos = (-denominator) & denominator;
52 denominator /= twos;
53 prod_0 /= twos;
54 twos = (-twos) / twos + ONE;
55 prod_0 |= prod_1 * twos;
56
57 let mut inv = (THREE * denominator) ^ TWO;
58 inv *= TWO - denominator * inv; inv *= TWO - denominator * inv; inv *= TWO - denominator * inv; inv *= TWO - denominator * inv; inv *= TWO - denominator * inv; inv *= TWO - denominator * inv; Ok(prod_0 * inv)
66}
67
68#[inline]
71pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> Result<U256> {
72 let result = mul_div(a, b, denominator)?;
73
74 if a.mul_mod(b, denominator).is_zero() {
75 Ok(result)
76 } else if result == U256::MAX {
77 Err(anyhow!("MulDiv overflow"))
78 } else {
79 Ok(result + ONE)
80 }
81}
82
83#[inline]
85pub fn mul_div_q96(a: U256, b: U256) -> Result<U256> {
86 let prod0 = a * b;
87 let mm = a.mul_mod(b, U256::MAX);
88 let prod1 = mm - prod0 - U256::from_limbs([(mm < prod0) as u64, 0, 0, 0]);
89 if prod1 >= Q96 {
90 return Err(anyhow!("MulDiv overflow"));
91 }
92 Ok((prod0 >> 96) | (prod1 << 160))
93}