use alloy_primitives::{I256, U256};
use thiserror::Error;
use wp_evm_amm_math::{
swap_math::compute_swap_step,
tick_math::{
get_sqrt_ratio_at_tick, get_tick_at_sqrt_ratio, MAX_SQRT_RATIO, MAX_TICK, MIN_SQRT_RATIO,
MIN_TICK,
},
AmmMathError,
};
use crate::data::{PoolState, TickInfo};
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct SwapResult {
pub amount_in: U256,
pub amount_out: U256,
pub sqrt_price_x96_after: U256,
}
#[derive(Debug, Error)]
pub(crate) enum SwapError {
#[error("amount_specified must be non-zero")]
ZeroAmount,
#[error("sqrt_price_limit_x96 invalid (must be on the correct side of current price and within [MIN_SQRT_RATIO, MAX_SQRT_RATIO])")]
InvalidPriceLimit,
#[error("ran out of liquidity before satisfying amount_specified")]
InsufficientLiquidity,
#[error("amm-math: {0}")]
Math(#[from] AmmMathError),
#[error("swap loop internal invariant violated: {0}")]
Internal(&'static str),
}
pub(crate) fn min_sqrt_ratio_plus_one() -> U256 {
MIN_SQRT_RATIO + U256::from(1u64)
}
pub(crate) fn max_sqrt_ratio_minus_one() -> U256 {
MAX_SQRT_RATIO - U256::from(1u64)
}
fn add_delta(x: u128, y: i128) -> core::result::Result<u128, SwapError> {
if y < 0 {
let neg = y.unsigned_abs();
x.checked_sub(neg).ok_or(SwapError::InsufficientLiquidity)
} else {
x.checked_add(y as u128).ok_or(SwapError::Math(AmmMathError::LiquidityOverflow))
}
}
fn next_initialized_tick(
ticks: &[TickInfo],
current_sqrt_x96: U256,
zero_for_one: bool,
) -> core::result::Result<Option<&TickInfo>, SwapError> {
if ticks.is_empty() {
return Ok(None);
}
let cur_tick = get_tick_at_sqrt_ratio(current_sqrt_x96)?;
let cur_at_tick = get_sqrt_ratio_at_tick(cur_tick)?;
let on_tick_boundary = cur_at_tick == current_sqrt_x96;
if zero_for_one {
let idx = if on_tick_boundary {
ticks.partition_point(|t| t.tick < cur_tick)
} else {
ticks.partition_point(|t| t.tick <= cur_tick)
};
Ok(if idx == 0 { None } else { Some(&ticks[idx - 1]) })
} else {
let idx = ticks.partition_point(|t| t.tick <= cur_tick);
Ok(ticks.get(idx))
}
}
type Result<T> = core::result::Result<T, SwapError>;
pub(crate) fn swap(
state: &PoolState,
zero_for_one: bool,
amount_specified: I256,
sqrt_price_limit_x96: U256,
fee_pips: u32,
) -> Result<SwapResult> {
if amount_specified.is_zero() {
return Err(SwapError::ZeroAmount);
}
if zero_for_one {
if sqrt_price_limit_x96 >= state.sqrt_price_x96 || sqrt_price_limit_x96 <= MIN_SQRT_RATIO {
return Err(SwapError::InvalidPriceLimit);
}
} else if sqrt_price_limit_x96 <= state.sqrt_price_x96 || sqrt_price_limit_x96 >= MAX_SQRT_RATIO
{
return Err(SwapError::InvalidPriceLimit);
}
let exact_in = !amount_specified.is_negative();
let mut amount_specified_remaining: I256 = amount_specified;
let mut amount_calculated: I256 = I256::ZERO;
let mut sqrt_price_x96: U256 = state.sqrt_price_x96;
let mut liquidity: u128 = state.liquidity;
let boundary_sqrt = if zero_for_one {
get_sqrt_ratio_at_tick(MIN_TICK)?
} else {
get_sqrt_ratio_at_tick(MAX_TICK)?
};
while !amount_specified_remaining.is_zero() && sqrt_price_x96 != sqrt_price_limit_x96 {
let next_tick_opt = next_initialized_tick(&state.ticks, sqrt_price_x96, zero_for_one)?;
let next_tick_sqrt = match next_tick_opt {
Some(t) => get_sqrt_ratio_at_tick(t.tick)?,
None => boundary_sqrt,
};
let target_sqrt = if zero_for_one {
core::cmp::max(next_tick_sqrt, sqrt_price_limit_x96)
} else {
core::cmp::min(next_tick_sqrt, sqrt_price_limit_x96)
};
if liquidity == 0 {
return Err(SwapError::InsufficientLiquidity);
}
let step = compute_swap_step(
sqrt_price_x96,
target_sqrt,
liquidity,
amount_specified_remaining,
fee_pips,
)?;
let step_in_with_fee_u: U256 = step
.amount_in
.checked_add(step.fee_amount)
.ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
let step_in_with_fee: I256 = I256::try_from(step_in_with_fee_u)
.map_err(|_| SwapError::Math(AmmMathError::MulDivOverflow))?;
let step_out: I256 = I256::try_from(step.amount_out)
.map_err(|_| SwapError::Math(AmmMathError::MulDivOverflow))?;
if exact_in {
amount_specified_remaining = amount_specified_remaining
.checked_sub(step_in_with_fee)
.ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
amount_calculated = amount_calculated
.checked_sub(step_out)
.ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
} else {
amount_specified_remaining = amount_specified_remaining
.checked_add(step_out)
.ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
amount_calculated = amount_calculated
.checked_add(step_in_with_fee)
.ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
}
sqrt_price_x96 = step.sqrt_ratio_next_x96;
if sqrt_price_x96 == next_tick_sqrt {
if let Some(next_tick) = next_tick_opt {
let delta = if zero_for_one {
next_tick
.liquidity_net
.checked_neg()
.ok_or(SwapError::Internal("liquidity_net negation overflow (i128::MIN)"))?
} else {
next_tick.liquidity_net
};
liquidity = add_delta(liquidity, delta)?;
}
}
}
if !amount_specified_remaining.is_zero() && sqrt_price_x96 != sqrt_price_limit_x96 {
return Err(SwapError::InsufficientLiquidity);
}
let (amount_in_i, amount_out_i) = if exact_in {
let consumed = amount_specified
.checked_sub(amount_specified_remaining)
.ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
let produced =
amount_calculated.checked_neg().ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
(consumed, produced)
} else {
let produced_neg = amount_specified
.checked_sub(amount_specified_remaining)
.ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
let produced =
produced_neg.checked_neg().ok_or(SwapError::Math(AmmMathError::MulDivOverflow))?;
(amount_calculated, produced)
};
let amount_in: U256 = if amount_in_i.is_negative() {
return Err(SwapError::Math(AmmMathError::MulDivOverflow));
} else {
amount_in_i.into_raw()
};
let amount_out: U256 = if amount_out_i.is_negative() {
return Err(SwapError::Math(AmmMathError::MulDivOverflow));
} else {
amount_out_i.into_raw()
};
Ok(SwapResult { amount_in, amount_out, sqrt_price_x96_after: sqrt_price_x96 })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::{PoolState, TickInfo};
use alloy_primitives::address;
fn fixture_usdc_weth_03() -> PoolState {
let sqrt_price_x96: U256 =
U256::from_str_radix("3543191142285914205922034323214", 10).unwrap();
PoolState {
token0: address!("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"),
token1: address!("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"),
fee: 3000,
tick_spacing: 60,
sqrt_price_x96,
liquidity: 2_000_000_000_000_000_000_000u128,
tick: 76012,
ticks: vec![
TickInfo {
tick: 74940,
liquidity_net: 1_000_000_000_000_000_000_000i128,
liquidity_gross: 1_000_000_000_000_000_000_000u128,
},
TickInfo {
tick: 75960,
liquidity_net: 1_000_000_000_000_000_000_000i128,
liquidity_gross: 1_000_000_000_000_000_000_000u128,
},
TickInfo {
tick: 76020,
liquidity_net: -2_000_000_000_000_000_000_000i128,
liquidity_gross: 2_000_000_000_000_000_000_000u128,
},
],
}
}
#[test]
fn swap_no_tick_cross_partial_fill() {
let s = fixture_usdc_weth_03();
let amt = I256::try_from(U256::from(1_000u64)).unwrap();
let r = swap(&s, true, amt, min_sqrt_ratio_plus_one(), 3000).unwrap();
assert!(r.amount_in > U256::ZERO);
assert!(r.amount_out > U256::ZERO);
let sqrt_at_75960 = U256::from_str_radix("3533845506420911390540068078527", 10).unwrap();
assert!(r.sqrt_price_x96_after > sqrt_at_75960);
assert!(r.sqrt_price_x96_after < s.sqrt_price_x96);
}
#[test]
fn swap_crosses_one_tick() {
let s = fixture_usdc_weth_03();
let amt = I256::try_from(U256::from(1_000_000_000_000_000_000u64)).unwrap();
let r = swap(&s, true, amt, min_sqrt_ratio_plus_one(), 3000).unwrap();
let sqrt_at_75960 = U256::from_str_radix("3533845506420911390540068078527", 10).unwrap();
assert!(
r.sqrt_price_x96_after < sqrt_at_75960,
"should cross tick 75960, ended at {}",
r.sqrt_price_x96_after
);
assert_eq!(r.amount_in, U256::from(1_000_000_000_000_000_000u64));
}
#[test]
fn swap_clamps_at_price_limit() {
let s = fixture_usdc_weth_03();
let amt = I256::try_from(U256::from(1_000_000_000_000_000_000u64)).unwrap();
let custom_limit = U256::from_str_radix("3540000000000000000000000000000", 10).unwrap();
assert!(custom_limit < s.sqrt_price_x96);
let r = swap(&s, true, amt, custom_limit, 3000).unwrap();
assert_eq!(r.sqrt_price_x96_after, custom_limit, "should clamp at caller-provided limit");
assert!(
r.amount_in < U256::from(1_000_000_000_000_000_000u64),
"should leave input unconsumed: got {}",
r.amount_in
);
}
#[test]
fn swap_empty_ticks_zero_for_one_errors() {
let mut s = fixture_usdc_weth_03();
s.ticks.clear();
s.liquidity = 0;
let amt = I256::try_from(U256::from(1_000u64)).unwrap();
let err = swap(&s, true, amt, min_sqrt_ratio_plus_one(), 3000).unwrap_err();
assert!(matches!(err, SwapError::InsufficientLiquidity));
}
#[test]
fn swap_direction_zero_for_one_decreases_price() {
let s = fixture_usdc_weth_03();
let amt = I256::try_from(U256::from(1_000_000u64)).unwrap();
let r = swap(&s, true, amt, min_sqrt_ratio_plus_one(), 3000).unwrap();
assert!(r.sqrt_price_x96_after < s.sqrt_price_x96);
}
#[test]
fn swap_direction_one_for_zero_increases_price() {
let s = fixture_usdc_weth_03();
let amt = I256::try_from(U256::from(1_000_000_000_000_000u64)).unwrap(); let r = swap(&s, false, amt, max_sqrt_ratio_minus_one(), 3000).unwrap();
assert!(r.sqrt_price_x96_after > s.sqrt_price_x96);
}
#[test]
fn swap_exact_in_conserves_input_and_fee() {
let s = fixture_usdc_weth_03();
let amt_u = U256::from(1_000_000u64);
let amt = I256::try_from(amt_u).unwrap();
let r = swap(&s, true, amt, min_sqrt_ratio_plus_one(), 3000).unwrap();
assert_eq!(r.amount_in, amt_u, "exact-in should fully consume input");
}
#[test]
fn swap_rejects_zero_amount() {
let s = fixture_usdc_weth_03();
let err = swap(&s, true, I256::ZERO, min_sqrt_ratio_plus_one(), 3000).unwrap_err();
assert!(matches!(err, SwapError::ZeroAmount));
}
#[test]
fn swap_exact_out_no_tick_cross() {
let s = fixture_usdc_weth_03();
let req_out = U256::from(1_000_000_000u64); let amt = I256::try_from(req_out).unwrap().checked_neg().unwrap();
let r = swap(&s, true, amt, min_sqrt_ratio_plus_one(), 3000).unwrap();
assert_eq!(r.amount_out, req_out);
assert!(r.amount_in > U256::ZERO);
assert!(r.sqrt_price_x96_after < s.sqrt_price_x96);
let sqrt_at_75960 = U256::from_str_radix("3533845506420911390540068078527", 10).unwrap();
assert!(r.sqrt_price_x96_after > sqrt_at_75960, "should not cross tick 75960");
}
#[test]
fn swap_exact_out_crosses_one_tick() {
let s = fixture_usdc_weth_03();
let req_out = U256::from_str_radix("300000000000000000000", 10).unwrap(); let amt = I256::try_from(req_out).unwrap().checked_neg().unwrap();
let r = swap(&s, true, amt, min_sqrt_ratio_plus_one(), 3000).unwrap();
assert_eq!(r.amount_out, req_out, "exact-out must supply full request");
let sqrt_at_75960 = U256::from_str_radix("3533845506420911390540068078527", 10).unwrap();
assert!(
r.sqrt_price_x96_after < sqrt_at_75960,
"should cross tick 75960, ended at {}",
r.sqrt_price_x96_after
);
}
#[test]
fn swap_rejects_invalid_price_limit() {
let s = fixture_usdc_weth_03();
let amt = I256::try_from(U256::from(1_000u64)).unwrap();
let bad_limit = s.sqrt_price_x96 + U256::from(1u64);
let err = swap(&s, true, amt, bad_limit, 3000).unwrap_err();
assert!(matches!(err, SwapError::InvalidPriceLimit));
}
}