use alloy_primitives::{I256, U256};
use crate::{
full_math::{mul_div, mul_div_rounding_up},
sqrt_price_math::{
get_amount_0_delta, get_amount_1_delta, get_next_sqrt_price_from_input,
get_next_sqrt_price_from_output,
},
AmmMathError,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SwapStep {
pub sqrt_ratio_next_x96: U256,
pub amount_in: U256,
pub amount_out: U256,
pub fee_amount: U256,
}
const FEE_DENOMINATOR: u32 = 1_000_000;
pub fn compute_swap_step(
sqrt_ratio_current_x96: U256,
sqrt_ratio_target_x96: U256,
liquidity: u128,
amount_remaining: I256,
fee_pips: u32,
) -> crate::Result<SwapStep> {
if fee_pips >= FEE_DENOMINATOR {
return Err(AmmMathError::InvalidFeePips);
}
let zero_for_one = sqrt_ratio_current_x96 >= sqrt_ratio_target_x96;
let exact_in = !amount_remaining.is_negative();
let mut amount_in = U256::ZERO;
let mut amount_out = U256::ZERO;
let sqrt_ratio_next_x96: U256;
if exact_in {
let amount_remaining_u: U256 = amount_remaining.into_raw();
let amount_remaining_less_fee = mul_div(
amount_remaining_u,
U256::from(FEE_DENOMINATOR - fee_pips),
U256::from(FEE_DENOMINATOR),
)?;
amount_in = if zero_for_one {
get_amount_0_delta(sqrt_ratio_target_x96, sqrt_ratio_current_x96, liquidity, true)?
} else {
get_amount_1_delta(sqrt_ratio_current_x96, sqrt_ratio_target_x96, liquidity, true)?
};
sqrt_ratio_next_x96 = if amount_remaining_less_fee >= amount_in {
sqrt_ratio_target_x96
} else {
get_next_sqrt_price_from_input(
sqrt_ratio_current_x96,
liquidity,
amount_remaining_less_fee,
zero_for_one,
)?
};
} else {
let amount_remaining_abs: U256 = amount_remaining.unsigned_abs();
amount_out = if zero_for_one {
get_amount_1_delta(sqrt_ratio_target_x96, sqrt_ratio_current_x96, liquidity, false)?
} else {
get_amount_0_delta(sqrt_ratio_current_x96, sqrt_ratio_target_x96, liquidity, false)?
};
sqrt_ratio_next_x96 = if amount_remaining_abs >= amount_out {
sqrt_ratio_target_x96
} else {
get_next_sqrt_price_from_output(
sqrt_ratio_current_x96,
liquidity,
amount_remaining_abs,
zero_for_one,
)?
};
}
let max_reached = sqrt_ratio_target_x96 == sqrt_ratio_next_x96;
if zero_for_one {
amount_in = if max_reached && exact_in {
amount_in
} else {
get_amount_0_delta(sqrt_ratio_next_x96, sqrt_ratio_current_x96, liquidity, true)?
};
amount_out = if max_reached && !exact_in {
amount_out
} else {
get_amount_1_delta(sqrt_ratio_next_x96, sqrt_ratio_current_x96, liquidity, false)?
};
} else {
amount_in = if max_reached && exact_in {
amount_in
} else {
get_amount_1_delta(sqrt_ratio_current_x96, sqrt_ratio_next_x96, liquidity, true)?
};
amount_out = if max_reached && !exact_in {
amount_out
} else {
get_amount_0_delta(sqrt_ratio_current_x96, sqrt_ratio_next_x96, liquidity, false)?
};
}
if !exact_in {
let amount_remaining_abs: U256 = amount_remaining.unsigned_abs();
if amount_out > amount_remaining_abs {
amount_out = amount_remaining_abs;
}
}
let fee_amount = if exact_in && sqrt_ratio_next_x96 != sqrt_ratio_target_x96 {
let amount_remaining_u: U256 = amount_remaining.into_raw();
amount_remaining_u.checked_sub(amount_in).ok_or(AmmMathError::PriceUnderflow)?
} else {
mul_div_rounding_up(
amount_in,
U256::from(fee_pips),
U256::from(FEE_DENOMINATOR - fee_pips),
)?
};
Ok(SwapStep { sqrt_ratio_next_x96, amount_in, amount_out, fee_amount })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tick_math::get_sqrt_ratio_at_tick;
fn price_1() -> U256 {
U256::from_str_radix("79228162514264337593543950336", 10).unwrap()
}
#[test]
fn rejects_invalid_fee_pips() {
let p = price_1();
let err = compute_swap_step(p, p, 1, I256::ZERO, FEE_DENOMINATOR)
.expect_err("fee_pips == 1e6 must be rejected");
assert!(matches!(err, AmmMathError::InvalidFeePips));
}
#[test]
fn exact_in_capped_at_target_when_input_exceeds() {
let p_current = price_1();
let p_target = get_sqrt_ratio_at_tick(-100).unwrap(); let liq: u128 = 2_000_000_000_000_000_000;
let huge_in = I256::try_from(U256::from(10u128.pow(20))).unwrap();
let step = compute_swap_step(p_current, p_target, liq, huge_in, 600).unwrap();
assert_eq!(step.sqrt_ratio_next_x96, p_target, "should snap to target");
assert!(step.amount_in > U256::ZERO);
assert!(step.amount_out > U256::ZERO);
assert!(step.fee_amount > U256::ZERO);
}
#[test]
fn exact_in_partial_when_input_insufficient() {
let p_current = price_1();
let p_target = get_sqrt_ratio_at_tick(-1000).unwrap();
let liq: u128 = 2_000_000_000_000_000_000_000_000u128;
let small_in = I256::try_from(U256::from(1_000u64)).unwrap();
let step = compute_swap_step(p_current, p_target, liq, small_in, 3000).unwrap();
assert_ne!(step.sqrt_ratio_next_x96, p_target, "should not reach target");
assert!(step.sqrt_ratio_next_x96 < p_current);
assert_eq!(step.amount_in + step.fee_amount, U256::from(1_000u64));
}
#[test]
fn exact_out_capped_at_target_when_output_exceeds() {
let p_current = price_1();
let p_target = get_sqrt_ratio_at_tick(-100).unwrap();
let liq: u128 = 2_000_000_000_000_000_000;
let huge_out_neg =
I256::try_from(U256::from(10u128.pow(20))).unwrap().checked_neg().unwrap();
let step = compute_swap_step(p_current, p_target, liq, huge_out_neg, 600).unwrap();
assert_eq!(step.sqrt_ratio_next_x96, p_target, "should snap to target");
}
#[test]
fn exact_out_output_capped_at_request_magnitude() {
let p_current = price_1();
let p_target = get_sqrt_ratio_at_tick(-1000).unwrap();
let liq: u128 = 2_000_000_000_000_000_000_000_000u128;
let req_out = U256::from(1_000u64);
let neg = I256::try_from(req_out).unwrap().checked_neg().unwrap();
let step = compute_swap_step(p_current, p_target, liq, neg, 3000).unwrap();
assert!(step.amount_out <= req_out, "output must not exceed request");
}
#[test]
fn zero_for_one_decreases_price() {
let p_current = price_1();
let p_target = get_sqrt_ratio_at_tick(-50).unwrap();
let liq: u128 = 1_000_000_000_000_000_000_000;
let in_amt = I256::try_from(U256::from(1_000_000u64)).unwrap();
let step = compute_swap_step(p_current, p_target, liq, in_amt, 3000).unwrap();
assert!(step.sqrt_ratio_next_x96 < p_current);
assert!(step.sqrt_ratio_next_x96 >= p_target);
}
#[test]
fn one_for_zero_increases_price() {
let p_current = price_1();
let p_target = get_sqrt_ratio_at_tick(50).unwrap();
let liq: u128 = 1_000_000_000_000_000_000_000;
let in_amt = I256::try_from(U256::from(1_000_000u64)).unwrap();
let step = compute_swap_step(p_current, p_target, liq, in_amt, 3000).unwrap();
assert!(step.sqrt_ratio_next_x96 > p_current);
assert!(step.sqrt_ratio_next_x96 <= p_target);
}
#[test]
fn fee_proportional_when_capped_at_target() {
let p_current = price_1();
let p_target = get_sqrt_ratio_at_tick(-100).unwrap();
let liq: u128 = 2_000_000_000_000_000_000;
let huge_in = I256::try_from(U256::from(10u128.pow(20))).unwrap();
let step = compute_swap_step(p_current, p_target, liq, huge_in, 10_000).unwrap(); let expected =
mul_div_rounding_up(step.amount_in, U256::from(10_000u64), U256::from(990_000u64))
.unwrap();
assert_eq!(step.fee_amount, expected);
}
#[test]
fn zero_amount_remaining_exact_in_no_op_at_same_price() {
let p = price_1();
let liq: u128 = 1_000_000_000_000_000_000;
let step = compute_swap_step(p, p, liq, I256::ZERO, 3000).unwrap();
assert_eq!(step.sqrt_ratio_next_x96, p);
assert_eq!(step.amount_in, U256::ZERO);
assert_eq!(step.amount_out, U256::ZERO);
assert_eq!(step.fee_amount, U256::ZERO);
}
}