wp-evm-amm-math 0.1.2

Native Rust CLMM/AMM math (Uniswap V3 compatible, zero SDK deps)
Documentation
//! Single-step swap math.
//!
//! Direct port of Uniswap V3 `SwapMath.sol::computeSwapStep`. Used by the
//! swap loop to compute, for one segment of a swap (between two adjacent
//! initialized ticks or between current price and target price), how much
//! of the input is consumed, how much output is produced, the resulting
//! sqrt price, and the fee charged.
//!
//! Native replacement for the path that previously went through
//! `uniswap-v3-sdk`'s `Pool::get_output_amount_mut` /
//! `get_input_amount_mut`.

use alloy_primitives::{I256, U256};

use crate::{
    full_math::{mul_div, mul_div_rounding_up},
    sqrt_price_math::{
        get_amount_0_delta, get_amount_1_delta, get_next_sqrt_price_from_input,
        get_next_sqrt_price_from_output,
    },
    AmmMathError,
};

/// Result of a single swap step.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SwapStep {
    /// New sqrt price after this step. Either equals `sqrt_ratio_target_x96`
    /// (if the step crossed the target) or some price strictly between the
    /// starting price and the target.
    pub sqrt_ratio_next_x96: U256,
    /// Token amount in (excluding fee) consumed by this step.
    pub amount_in: U256,
    /// Token amount out produced by this step.
    pub amount_out: U256,
    /// Fee charged for this step, in the input token.
    pub fee_amount: U256,
}

const FEE_DENOMINATOR: u32 = 1_000_000;

/// Solidity: `SwapMath.computeSwapStep`.
///
/// Compute one step of a swap. Consumes at most `amount_remaining` of the
/// input token (or produces at most `|amount_remaining|` of the output
/// token if `amount_remaining` is negative — the V3 sign convention for
/// exact-out swaps).
///
/// `fee_pips` is the fee in hundredths of a basis point (1e-6), e.g. `3000`
/// for 0.30%. Must be strictly less than `1_000_000` (100%).
pub fn compute_swap_step(
    sqrt_ratio_current_x96: U256,
    sqrt_ratio_target_x96: U256,
    liquidity: u128,
    amount_remaining: I256,
    fee_pips: u32,
) -> crate::Result<SwapStep> {
    if fee_pips >= FEE_DENOMINATOR {
        return Err(AmmMathError::InvalidFeePips);
    }

    let zero_for_one = sqrt_ratio_current_x96 >= sqrt_ratio_target_x96;
    let exact_in = !amount_remaining.is_negative();

    let mut amount_in = U256::ZERO;
    let mut amount_out = U256::ZERO;
    let sqrt_ratio_next_x96: U256;

    if exact_in {
        // amount_remaining is non-negative; reinterpret as U256 safely via raw().
        let amount_remaining_u: U256 = amount_remaining.into_raw();
        // amountRemainingLessFee = amountRemaining * (1e6 - feePips) / 1e6
        let amount_remaining_less_fee = mul_div(
            amount_remaining_u,
            U256::from(FEE_DENOMINATOR - fee_pips),
            U256::from(FEE_DENOMINATOR),
        )?;

        // amountIn = full delta from current to target (rounded UP — we may
        // overpay slightly so we never overshoot the target).
        amount_in = if zero_for_one {
            get_amount_0_delta(sqrt_ratio_target_x96, sqrt_ratio_current_x96, liquidity, true)?
        } else {
            get_amount_1_delta(sqrt_ratio_current_x96, sqrt_ratio_target_x96, liquidity, true)?
        };

        sqrt_ratio_next_x96 = if amount_remaining_less_fee >= amount_in {
            // We have enough input to cross the entire target — snap to it.
            sqrt_ratio_target_x96
        } else {
            // Partial step: solve for the next price.
            get_next_sqrt_price_from_input(
                sqrt_ratio_current_x96,
                liquidity,
                amount_remaining_less_fee,
                zero_for_one,
            )?
        };
    } else {
        // amount_remaining is negative — the magnitude is the desired output.
        // I256::unsigned_abs is the right primitive here.
        let amount_remaining_abs: U256 = amount_remaining.unsigned_abs();

        amount_out = if zero_for_one {
            get_amount_1_delta(sqrt_ratio_target_x96, sqrt_ratio_current_x96, liquidity, false)?
        } else {
            get_amount_0_delta(sqrt_ratio_current_x96, sqrt_ratio_target_x96, liquidity, false)?
        };

        sqrt_ratio_next_x96 = if amount_remaining_abs >= amount_out {
            sqrt_ratio_target_x96
        } else {
            get_next_sqrt_price_from_output(
                sqrt_ratio_current_x96,
                liquidity,
                amount_remaining_abs,
                zero_for_one,
            )?
        };
    }

    let max_reached = sqrt_ratio_target_x96 == sqrt_ratio_next_x96;

    // Recompute amount_in / amount_out using the *actual* sqrt ratio reached.
    // The Solidity code does: if we hit max and we know one side already
    // (because we set it above), keep that side; otherwise recompute from the
    // (possibly partial) price step.
    if zero_for_one {
        amount_in = if max_reached && exact_in {
            amount_in
        } else {
            get_amount_0_delta(sqrt_ratio_next_x96, sqrt_ratio_current_x96, liquidity, true)?
        };
        amount_out = if max_reached && !exact_in {
            amount_out
        } else {
            get_amount_1_delta(sqrt_ratio_next_x96, sqrt_ratio_current_x96, liquidity, false)?
        };
    } else {
        amount_in = if max_reached && exact_in {
            amount_in
        } else {
            get_amount_1_delta(sqrt_ratio_current_x96, sqrt_ratio_next_x96, liquidity, true)?
        };
        amount_out = if max_reached && !exact_in {
            amount_out
        } else {
            get_amount_0_delta(sqrt_ratio_current_x96, sqrt_ratio_next_x96, liquidity, false)?
        };
    }

    // Cap output at the requested amount in exact-out mode. This prevents the
    // rounding-up on `get_amount_*_delta(round_up=false)` (yes, false) from
    // ever producing more than the user asked for. Mirrors Solidity exactly.
    if !exact_in {
        let amount_remaining_abs: U256 = amount_remaining.unsigned_abs();
        if amount_out > amount_remaining_abs {
            amount_out = amount_remaining_abs;
        }
    }

    let fee_amount = if exact_in && sqrt_ratio_next_x96 != sqrt_ratio_target_x96 {
        // Partial step on exact-in: the leftover input becomes fee.
        let amount_remaining_u: U256 = amount_remaining.into_raw();
        // amount_remaining_u >= amount_in is invariant from the branch above
        // (we only entered the "partial" path when amount_remaining_less_fee
        // < amount_in, and amount_remaining_less_fee <= amount_remaining_u).
        // Use checked_sub to surface any future invariant break.
        amount_remaining_u.checked_sub(amount_in).ok_or(AmmMathError::PriceUnderflow)?
    } else {
        // Reached the target (or exact-out): fee is computed from amount_in.
        // Note feePips < 1e6 by the guard at function entry, so denominator > 0.
        mul_div_rounding_up(
            amount_in,
            U256::from(fee_pips),
            U256::from(FEE_DENOMINATOR - fee_pips),
        )?
    };

    Ok(SwapStep { sqrt_ratio_next_x96, amount_in, amount_out, fee_amount })
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tick_math::get_sqrt_ratio_at_tick;

    fn price_1() -> U256 {
        U256::from_str_radix("79228162514264337593543950336", 10).unwrap()
    }

    #[test]
    fn ref_vector_compute_swap_step_exact_in_partial() {
        // Oracle: sc=Q96, st=Q96-1e27, L=1e24, amount_remaining=+1e15, fee=3000.
        // Partial step (does not reach target); fee = amount_remaining - amount_in.
        let sc = price_1();
        let st = U256::from_str_radix("78228162514264337593543950336", 10).unwrap();
        let liq: u128 = 1_000_000_000_000_000_000_000_000;
        let amt = I256::try_from(U256::from(1_000_000_000_000_000_u64)).unwrap();
        let step = compute_swap_step(sc, st, liq, amt, 3000).unwrap();
        assert_eq!(
            step.sqrt_ratio_next_x96,
            U256::from_str_radix("79228162435273859645575912270", 10).unwrap()
        );
        assert_eq!(step.amount_in, U256::from(997_000_000_000_000_u64));
        assert_eq!(step.amount_out, U256::from(996_999_999_005_991_u64));
        assert_eq!(step.fee_amount, U256::from(3_000_000_000_000_u64));
        assert_eq!(step.amount_in + step.fee_amount, U256::from(1_000_000_000_000_000_u64));
    }

    #[test]
    fn ref_vector_compute_swap_step_exact_in_capped() {
        // Oracle: sc=Q96, st=Q96+5e26, L=2e18, amount_remaining=+1e20, fee=600.
        // Input is huge → step snaps to the target.
        let sc = price_1();
        let st = U256::from_str_radix("79728162514264337593543950336", 10).unwrap();
        let liq: u128 = 2_000_000_000_000_000_000;
        let amt = I256::try_from(U256::from(10u128.pow(20))).unwrap();
        let step = compute_swap_step(sc, st, liq, amt, 600).unwrap();
        assert_eq!(step.sqrt_ratio_next_x96, st, "should snap to target");
        assert_eq!(step.amount_in, U256::from(12_621_774_483_536_189_u64));
        assert_eq!(step.amount_out, U256::from(12_542_619_426_618_390_u64));
        assert_eq!(step.fee_amount, U256::from(7_577_611_256_876_u64));
    }

    #[test]
    fn rejects_invalid_fee_pips() {
        let p = price_1();
        let err = compute_swap_step(p, p, 1, I256::ZERO, FEE_DENOMINATOR)
            .expect_err("fee_pips == 1e6 must be rejected");
        assert!(matches!(err, AmmMathError::InvalidFeePips));
    }

    #[test]
    fn exact_in_capped_at_target_when_input_exceeds() {
        // SwapMath.spec.ts: "exact amount in that gets capped at price
        // target in one for zero" (we test zero for one; symmetric).
        let p_current = price_1();
        let p_target = get_sqrt_ratio_at_tick(-100).unwrap(); // lower tick → lower sqrt
        let liq: u128 = 2_000_000_000_000_000_000;
        let huge_in = I256::try_from(U256::from(10u128.pow(20))).unwrap();
        let step = compute_swap_step(p_current, p_target, liq, huge_in, 600).unwrap();
        assert_eq!(step.sqrt_ratio_next_x96, p_target, "should snap to target");
        assert!(step.amount_in > U256::ZERO);
        assert!(step.amount_out > U256::ZERO);
        assert!(step.fee_amount > U256::ZERO);
    }

    #[test]
    fn exact_in_partial_when_input_insufficient() {
        let p_current = price_1();
        let p_target = get_sqrt_ratio_at_tick(-1000).unwrap();
        let liq: u128 = 2_000_000_000_000_000_000_000_000u128;
        // Tiny input — definitely won't reach the target.
        let small_in = I256::try_from(U256::from(1_000u64)).unwrap();
        let step = compute_swap_step(p_current, p_target, liq, small_in, 3000).unwrap();
        assert_ne!(step.sqrt_ratio_next_x96, p_target, "should not reach target");
        assert!(step.sqrt_ratio_next_x96 < p_current);
        // Partial step: amount_in + fee_amount must equal the requested input.
        // Solidity: feeAmount = amountRemaining - amountIn  (the partial branch).
        assert_eq!(step.amount_in + step.fee_amount, U256::from(1_000u64));
    }

    #[test]
    fn exact_out_capped_at_target_when_output_exceeds() {
        let p_current = price_1();
        let p_target = get_sqrt_ratio_at_tick(-100).unwrap();
        let liq: u128 = 2_000_000_000_000_000_000;
        let huge_out_neg =
            I256::try_from(U256::from(10u128.pow(20))).unwrap().checked_neg().unwrap();
        let step = compute_swap_step(p_current, p_target, liq, huge_out_neg, 600).unwrap();
        assert_eq!(step.sqrt_ratio_next_x96, p_target, "should snap to target");
    }

    #[test]
    fn exact_out_output_capped_at_request_magnitude() {
        // The "cap output at request magnitude" branch: when we DON'T reach
        // the target, output must equal the requested |amount_remaining|.
        let p_current = price_1();
        let p_target = get_sqrt_ratio_at_tick(-1000).unwrap();
        let liq: u128 = 2_000_000_000_000_000_000_000_000u128;
        let req_out = U256::from(1_000u64);
        let neg = I256::try_from(req_out).unwrap().checked_neg().unwrap();
        let step = compute_swap_step(p_current, p_target, liq, neg, 3000).unwrap();
        assert!(step.amount_out <= req_out, "output must not exceed request");
    }

    #[test]
    fn zero_for_one_decreases_price() {
        // sqrt_target < sqrt_current means zero_for_one (selling token0).
        let p_current = price_1();
        let p_target = get_sqrt_ratio_at_tick(-50).unwrap();
        let liq: u128 = 1_000_000_000_000_000_000_000;
        let in_amt = I256::try_from(U256::from(1_000_000u64)).unwrap();
        let step = compute_swap_step(p_current, p_target, liq, in_amt, 3000).unwrap();
        assert!(step.sqrt_ratio_next_x96 < p_current);
        assert!(step.sqrt_ratio_next_x96 >= p_target);
    }

    #[test]
    fn one_for_zero_increases_price() {
        // sqrt_target > sqrt_current means one_for_zero (selling token1).
        let p_current = price_1();
        let p_target = get_sqrt_ratio_at_tick(50).unwrap();
        let liq: u128 = 1_000_000_000_000_000_000_000;
        let in_amt = I256::try_from(U256::from(1_000_000u64)).unwrap();
        let step = compute_swap_step(p_current, p_target, liq, in_amt, 3000).unwrap();
        assert!(step.sqrt_ratio_next_x96 > p_current);
        assert!(step.sqrt_ratio_next_x96 <= p_target);
    }

    #[test]
    fn fee_proportional_when_capped_at_target() {
        // When the step reaches the target, fee = amount_in * fee_pips / (1e6 - fee_pips)
        // (rounded up). For a 100bp fee that's roughly amount_in/99 * 1.
        let p_current = price_1();
        let p_target = get_sqrt_ratio_at_tick(-100).unwrap();
        let liq: u128 = 2_000_000_000_000_000_000;
        let huge_in = I256::try_from(U256::from(10u128.pow(20))).unwrap();
        let step = compute_swap_step(p_current, p_target, liq, huge_in, 10_000).unwrap(); // 1%
                                                                                          // Expected fee = ceil(amount_in * 10_000 / 990_000)
        let expected =
            mul_div_rounding_up(step.amount_in, U256::from(10_000u64), U256::from(990_000u64))
                .unwrap();
        assert_eq!(step.fee_amount, expected);
    }

    #[test]
    fn zero_amount_remaining_exact_in_no_op_at_same_price() {
        // amount_remaining = 0, exact-in (non-negative). Target == current.
        // Should produce zero amounts and no movement. Mirrors SDK's no-op.
        let p = price_1();
        let liq: u128 = 1_000_000_000_000_000_000;
        let step = compute_swap_step(p, p, liq, I256::ZERO, 3000).unwrap();
        assert_eq!(step.sqrt_ratio_next_x96, p);
        assert_eq!(step.amount_in, U256::ZERO);
        assert_eq!(step.amount_out, U256::ZERO);
        assert_eq!(step.fee_amount, U256::ZERO);
    }
}