Skip to main content

evm_dex_pool/lb/
math.rs

1//! TraderJoe Liquidity Book math library.
2//!
3//! Ports 128.128 fixed-point arithmetic from the Solidity LB v2 contracts:
4//! - `Uint128x128Math.pow()` — binary exponentiation
5//! - `PriceHelper.getPriceFromId()` / `getBase()` — bin price calculations
6//! - `FeeHelper` — fee calculations
7//! - `PackedUint128Math.decode()` — bytes32 unpacking
8
9use alloy::primitives::{FixedBytes, Uint, U256};
10
11/// 512-bit unsigned integer for intermediate calculations.
12type U512 = Uint<512, 8>;
13
14// ─── Constants ───────────────────────────────────────────────────────────────
15
16/// 128.128 fixed-point scale offset
17pub const SCALE_OFFSET: u32 = 128;
18
19/// 1 << 128 (the "1.0" in 128.128 fixed-point)
20pub const SCALE: U256 = U256::from_limbs([0, 0, 1, 0]);
21
22/// 1e18 — precision for fee calculations
23pub const PRECISION: u128 = 1_000_000_000_000_000_000;
24
25/// 1e36 — squared precision
26pub const SQUARED_PRECISION: u128 = 0; // not needed directly, use U256 ops
27
28/// 10% max fee (0.1e18)
29pub const MAX_FEE: u128 = 100_000_000_000_000_000;
30
31/// 10_000 basis points = 100%
32pub const BASIS_POINT_MAX: u128 = 10_000;
33
34/// Bin ID where price = 1.0 (2^23)
35pub const REAL_ID_SHIFT: i32 = 1 << 23; // 8_388_608
36
37// ─── Price Math ──────────────────────────────────────────────────────────────
38
39/// Calculate base = 1 + binStep/10000 in 128.128 fixed-point.
40///
41/// Port of `PriceHelper.getBase()`:
42/// `SCALE + (uint256(binStep) << SCALE_OFFSET) / BASIS_POINT_MAX`
43pub fn get_base(bin_step: u16) -> U256 {
44    SCALE + (U256::from(bin_step) << SCALE_OFFSET) / U256::from(BASIS_POINT_MAX)
45}
46
47/// Calculate exponent = id - REAL_ID_SHIFT
48pub fn get_exponent(id: u32) -> i32 {
49    id as i32 - REAL_ID_SHIFT
50}
51
52/// Calculate price from bin ID and bin step in 128.128 fixed-point.
53///
54/// Port of `PriceHelper.getPriceFromId()`:
55/// `getBase(binStep).pow(getExponent(id))`
56pub fn get_price_from_id(id: u32, bin_step: u16) -> U256 {
57    let base = get_base(bin_step);
58    let exponent = get_exponent(id);
59    pow_128x128(base, exponent)
60}
61
62/// 128.128 fixed-point exponentiation via binary exponentiation.
63///
64/// Port of `Uint128x128Math.pow()`. Computes `x^y` where x is a 128.128
65/// fixed-point number and y is a signed integer. Uses 20-bit binary
66/// exponentiation with intermediate 512-bit products.
67///
68/// For negative y, computes `1 / x^|y|`.
69pub fn pow_128x128(x: U256, y: i32) -> U256 {
70    if y == 0 {
71        return SCALE;
72    }
73
74    let abs_y = y.unsigned_abs();
75    let mut invert = y < 0;
76
77    // If x > 2^128, work with 1/x and flip inversion
78    let mut squared = x;
79    if squared > SCALE {
80        squared = U256::MAX / squared;
81        invert = !invert;
82    }
83
84    let mut result = SCALE;
85
86    // Binary exponentiation: iterate over 20 bits
87    for bit in 0..20u32 {
88        if abs_y & (1 << bit) != 0 {
89            result = mul_128x128(result, squared);
90        }
91        squared = mul_128x128(squared, squared);
92    }
93
94    if result.is_zero() {
95        // Underflow
96        return U256::ZERO;
97    }
98
99    if invert {
100        U256::MAX / result
101    } else {
102        result
103    }
104}
105
106/// Multiply two 128.128 fixed-point numbers: (a * b) >> 128.
107///
108/// Uses 512-bit intermediate to avoid overflow.
109fn mul_128x128(a: U256, b: U256) -> U256 {
110    let product: U512 = a.widening_mul(b);
111    let shifted = product >> SCALE_OFFSET;
112    let limbs = shifted.as_limbs();
113    U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]])
114}
115
116// ─── Fee Math ────────────────────────────────────────────────────────────────
117
118/// Calculate base fee in 1e18 precision.
119///
120/// Port of `PairParameterHelper.getBaseFee()`:
121/// `baseFactor * binStep * 1e10`
122pub fn get_base_fee(base_factor: u16, bin_step: u16) -> u128 {
123    base_factor as u128 * bin_step as u128 * 10_000_000_000u128
124}
125
126/// Calculate variable fee in 1e18 precision.
127///
128/// Port of `PairParameterHelper.getVariableFee()`:
129/// `(volatilityAccumulator * binStep)^2 * variableFeeControl / 100`
130pub fn get_variable_fee(
131    volatility_accumulator: u32,
132    bin_step: u16,
133    variable_fee_control: u32,
134) -> u128 {
135    if variable_fee_control == 0 {
136        return 0;
137    }
138    let prod = volatility_accumulator as u128 * bin_step as u128;
139    // prod * prod * vfc / 100
140    // Use u128 arithmetic — max prod ~ 1M * 100 = 100M, prod^2 ~ 1e16, * vfc ~ 1e22
141    // This fits in u128
142    (prod * prod * variable_fee_control as u128 + 99) / 100
143}
144
145/// Calculate total fee (base + variable), capped at MAX_FEE.
146///
147/// Port of `PairParameterHelper.getTotalFee()`
148pub fn get_total_fee(
149    base_factor: u16,
150    bin_step: u16,
151    volatility_accumulator: u32,
152    variable_fee_control: u32,
153) -> u128 {
154    let total = get_base_fee(base_factor, bin_step)
155        + get_variable_fee(volatility_accumulator, bin_step, variable_fee_control);
156    total.min(MAX_FEE)
157}
158
159/// Calculate fee amount from an amount that already includes fees.
160///
161/// Port of `FeeHelper.getFeeAmountFrom()`:
162/// `(amountWithFees * totalFee + PRECISION - 1) / PRECISION`
163///
164/// Rounds up.
165pub fn get_fee_amount_from(amount_with_fees: u128, total_fee: u128) -> u128 {
166    debug_assert!(total_fee <= MAX_FEE, "Fee too large");
167    // Use U256 to avoid u128 overflow: max amount * max fee ~ 2^128 * 1e17
168    let numerator =
169        U256::from(amount_with_fees) * U256::from(total_fee) + U256::from(PRECISION - 1);
170    let result = numerator / U256::from(PRECISION);
171    result.to::<u128>()
172}
173
174/// Calculate fee amount to add to a net amount.
175///
176/// Port of `FeeHelper.getFeeAmount()`:
177/// `(amount * totalFee + (PRECISION - totalFee - 1)) / (PRECISION - totalFee)`
178///
179/// Rounds up.
180pub fn get_fee_amount(amount: u128, total_fee: u128) -> u128 {
181    debug_assert!(total_fee <= MAX_FEE, "Fee too large");
182    let denominator = PRECISION - total_fee;
183    let numerator = U256::from(amount) * U256::from(total_fee) + U256::from(denominator - 1);
184    let result = numerator / U256::from(denominator);
185    result.to::<u128>()
186}
187
188// ─── Packed Amount Helpers ───────────────────────────────────────────────────
189
190/// Decode a packed bytes32 into (amount_x, amount_y).
191///
192/// Port of `PackedUint128Math.decode()`:
193/// Low 128 bits = X, high 128 bits = Y.
194pub fn decode_amounts(packed: FixedBytes<32>) -> (u128, u128) {
195    let bytes = packed.as_slice();
196    // FixedBytes is big-endian. In Solidity, low 128 bits = X.
197    // bytes[0..16] = high 128 bits (Y), bytes[16..32] = low 128 bits (X)
198    let y = u128::from_be_bytes(bytes[0..16].try_into().unwrap());
199    let x = u128::from_be_bytes(bytes[16..32].try_into().unwrap());
200    (x, y)
201}
202
203/// Decode a specific side from packed bytes32.
204/// If `decode_x` is true, returns the low 128 bits (X); otherwise the high 128 bits (Y).
205pub fn decode_amount(packed: FixedBytes<32>, decode_x: bool) -> u128 {
206    let (x, y) = decode_amounts(packed);
207    if decode_x {
208        x
209    } else {
210        y
211    }
212}
213
214// ─── Swap Amount Helpers ─────────────────────────────────────────────────────
215
216/// Multiply then shift right, rounding down: `(x * y) >> offset`
217///
218/// Port of `Uint256x256Math.mulShiftRoundDown()`
219pub fn mul_shift_round_down(x: U256, y: U256, offset: u32) -> U256 {
220    let product: U512 = x.widening_mul(y);
221    let shifted = product >> offset;
222    let limbs = shifted.as_limbs();
223    U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]])
224}
225
226/// Multiply then shift right, rounding up: `ceil((x * y) >> offset)`
227///
228/// Port of `Uint256x256Math.mulShiftRoundUp()`
229pub fn mul_shift_round_up(x: U256, y: U256, offset: u32) -> U256 {
230    let result = mul_shift_round_down(x, y, offset);
231    let product: U512 = x.widening_mul(y);
232    let mask = if offset >= 512 {
233        U512::ZERO
234    } else {
235        (U512::from(1u64) << offset) - U512::from(1u64)
236    };
237    if (product & mask) > U512::ZERO {
238        result + U256::from(1u64)
239    } else {
240        result
241    }
242}
243
244/// Shift left then divide, rounding down: `(x << offset) / y`
245///
246/// Port of `Uint256x256Math.shiftDivRoundDown()`
247pub fn shift_div_round_down(x: U256, offset: u32, denominator: U256) -> U256 {
248    if denominator.is_zero() {
249        return U256::ZERO;
250    }
251    let x_wide = U512::from(x);
252    let shifted = x_wide << offset;
253    let denom_wide = U512::from(denominator);
254    let result = shifted / denom_wide;
255    let limbs = result.as_limbs();
256    U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]])
257}
258
259/// Shift left then divide, rounding up: `ceil((x << offset) / y)`
260///
261/// Port of `Uint256x256Math.shiftDivRoundUp()`
262pub fn shift_div_round_up(x: U256, offset: u32, denominator: U256) -> U256 {
263    if denominator.is_zero() {
264        return U256::ZERO;
265    }
266    let result = shift_div_round_down(x, offset, denominator);
267    let x_wide = U512::from(x);
268    let shifted = x_wide << offset;
269    let denom_wide = U512::from(denominator);
270    if shifted % denom_wide > U512::ZERO {
271        result + U256::from(1u64)
272    } else {
273        result
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_scale_constant() {
283        assert_eq!(SCALE, U256::from(1u64) << 128);
284    }
285
286    #[test]
287    fn test_get_base() {
288        // binStep = 1 (0.01%): base = 1 + 1/10000 in 128.128
289        let base = get_base(1);
290        // Should be slightly above SCALE
291        assert!(base > SCALE);
292        // base - SCALE = SCALE / 10000
293        let diff = base - SCALE;
294        let expected_diff = SCALE / U256::from(10000u64);
295        assert_eq!(diff, expected_diff);
296    }
297
298    #[test]
299    fn test_get_base_25() {
300        // binStep = 25 (0.25%): base = 1 + 25/10000 = 1.0025
301        let base = get_base(25);
302        let diff = base - SCALE;
303        let expected_diff = U256::from(25u64) * SCALE / U256::from(10000u64);
304        assert_eq!(diff, expected_diff);
305    }
306
307    #[test]
308    fn test_pow_identity() {
309        // Any base^0 = SCALE (1.0 in 128.128)
310        assert_eq!(pow_128x128(get_base(25), 0), SCALE);
311    }
312
313    #[test]
314    fn test_pow_one() {
315        // base^1 = base
316        let base = get_base(25);
317        let result = pow_128x128(base, 1);
318        assert_eq!(result, base);
319    }
320
321    #[test]
322    fn test_pow_negative_one() {
323        // base^(-1) = 1/base
324        let base = get_base(25);
325        let result = pow_128x128(base, -1);
326        let expected = U256::MAX / base;
327        // Allow 1 unit of rounding error
328        let diff = if result > expected {
329            result - expected
330        } else {
331            expected - result
332        };
333        assert!(diff <= U256::from(1u64));
334    }
335
336    #[test]
337    fn test_price_at_center_bin() {
338        // At bin_id = 2^23 (REAL_ID_SHIFT), exponent = 0, price = 1.0 = SCALE
339        let price = get_price_from_id(REAL_ID_SHIFT as u32, 25);
340        assert_eq!(price, SCALE);
341    }
342
343    #[test]
344    fn test_price_monotonic() {
345        // Higher bin ID → higher price
346        let center = REAL_ID_SHIFT as u32;
347        let p1 = get_price_from_id(center, 25);
348        let p2 = get_price_from_id(center + 1, 25);
349        let p3 = get_price_from_id(center + 10, 25);
350        assert!(p2 > p1);
351        assert!(p3 > p2);
352
353        let p0 = get_price_from_id(center - 1, 25);
354        assert!(p0 < p1);
355    }
356
357    #[test]
358    fn test_base_fee() {
359        // baseFactor=15, binStep=25: 15 * 25 * 1e10 = 3.75e12
360        let fee = get_base_fee(15, 25);
361        assert_eq!(fee, 3_750_000_000_000u128);
362    }
363
364    #[test]
365    fn test_variable_fee_zero_control() {
366        assert_eq!(get_variable_fee(1000, 25, 0), 0);
367    }
368
369    #[test]
370    fn test_total_fee_capped() {
371        // Very large parameters should be capped at MAX_FEE
372        let fee = get_total_fee(u16::MAX, u16::MAX, u32::MAX, u32::MAX);
373        assert!(fee <= MAX_FEE);
374    }
375
376    #[test]
377    fn test_fee_amount_from() {
378        // If totalFee = 0.003e18 (0.3%), and amount = 1000e18 with fees,
379        // fee should be ~3e18
380        let total_fee = 3_000_000_000_000_000u128; // 0.3%
381        let amount_with_fees = 1_000_000_000_000_000_000_000u128; // 1000e18
382        let fee = get_fee_amount_from(amount_with_fees, total_fee);
383        // fee = (1000e18 * 0.003e18 + 1e18 - 1) / 1e18 = 3e18
384        assert_eq!(fee, 3_000_000_000_000_000_000u128);
385    }
386
387    #[test]
388    fn test_decode_amounts() {
389        // Construct a packed bytes32: X=100, Y=200
390        let mut bytes = [0u8; 32];
391        // Y in high 128 bits (big-endian bytes 0..16)
392        bytes[0..16].copy_from_slice(&200u128.to_be_bytes());
393        // X in low 128 bits (big-endian bytes 16..32)
394        bytes[16..32].copy_from_slice(&100u128.to_be_bytes());
395        let packed = FixedBytes::<32>::from(bytes);
396        let (x, y) = decode_amounts(packed);
397        assert_eq!(x, 100);
398        assert_eq!(y, 200);
399    }
400
401    #[test]
402    fn test_mul_shift_round_down() {
403        // Simple test: (SCALE * SCALE) >> 128 = SCALE (1.0 * 1.0 = 1.0)
404        let result = mul_shift_round_down(SCALE, SCALE, SCALE_OFFSET);
405        assert_eq!(result, SCALE);
406    }
407
408    #[test]
409    fn test_shift_div_round_down() {
410        // (100 << 128) / SCALE = 100
411        let result = shift_div_round_down(U256::from(100u64), SCALE_OFFSET, SCALE);
412        assert_eq!(result, U256::from(100u64));
413    }
414}