use ethnum::U256;
use crate::{
fee_math::fee_amount_from_input,
liquidity_math::{get_amount_0_delta, get_amount_1_delta},
AmmMathError,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SwapStepResult {
pub sqrt_price_next: u128,
pub amount_in: u64,
pub amount_out: u64,
pub fee_amount: u64,
}
pub fn get_next_sqrt_price_from_input(
sqrt_price: u128,
liquidity: u128,
amount_in: u64,
a_to_b: bool,
) -> Result<u128, AmmMathError> {
if sqrt_price == 0 {
return Err(AmmMathError::SqrtPriceOutOfRange(0));
}
if liquidity == 0 {
return Err(AmmMathError::DivisionByZero);
}
if amount_in == 0 {
return Ok(sqrt_price);
}
if a_to_b {
let l_shifted = U256::from(liquidity) << 64;
let amount = U256::from(amount_in);
let price = U256::from(sqrt_price);
let numerator = l_shifted * price;
let denominator = l_shifted + amount * price;
if denominator == U256::ZERO {
return Err(AmmMathError::DivisionByZero);
}
let result: U256 = (numerator + denominator - U256::from(1u8)) / denominator;
if result > U256::from(u128::MAX) {
return Err(AmmMathError::Overflow);
}
Ok(result.as_u128())
} else {
let delta = (U256::from(amount_in) << 64) / U256::from(liquidity);
let result: U256 = U256::from(sqrt_price) + delta;
if result > U256::from(u128::MAX) {
return Err(AmmMathError::Overflow);
}
Ok(result.as_u128())
}
}
pub fn get_next_sqrt_price_from_output(
sqrt_price: u128,
liquidity: u128,
amount_out: u64,
a_to_b: bool,
) -> Result<u128, AmmMathError> {
if sqrt_price == 0 {
return Err(AmmMathError::SqrtPriceOutOfRange(0));
}
if liquidity == 0 {
return Err(AmmMathError::DivisionByZero);
}
if amount_out == 0 {
return Ok(sqrt_price);
}
if a_to_b {
let numerator = U256::from(amount_out) << 64;
let denom = U256::from(liquidity);
let delta = (numerator + denom - U256::ONE) / denom;
let price: U256 = U256::from(sqrt_price);
if delta >= price {
return Err(AmmMathError::Overflow);
}
let next: U256 = price - delta;
Ok(next.as_u128())
} else {
let l = U256::from(liquidity);
let price = U256::from(sqrt_price);
let amount = U256::from(amount_out);
let product = amount * price;
let l_shifted = l << 64;
if product >= l_shifted {
return Err(AmmMathError::Overflow);
}
let denominator = l_shifted - product;
let numerator = l_shifted * price;
let result: U256 = (numerator + denominator - U256::ONE) / denominator;
if result > U256::from(u128::MAX) {
return Err(AmmMathError::Overflow);
}
Ok(result.as_u128())
}
}
pub fn compute_swap_step(
sqrt_price_current: u128,
sqrt_price_target: u128,
liquidity: u128,
amount_remaining: u64,
fee_rate_bps: u16,
by_amount_in: bool,
) -> Result<SwapStepResult, AmmMathError> {
let a_to_b = sqrt_price_current >= sqrt_price_target;
if amount_remaining == 0 || liquidity == 0 {
return Ok(SwapStepResult {
sqrt_price_next: sqrt_price_current,
amount_in: 0,
amount_out: 0,
fee_amount: 0,
});
}
let sqrt_price_next;
let amount_in;
let amount_out;
let fee_amount;
if by_amount_in {
let fee_on_remaining = fee_amount_from_input(amount_remaining, fee_rate_bps)?;
let amount_remaining_less_fee = amount_remaining.saturating_sub(fee_on_remaining);
let max_amount_in = if a_to_b {
get_amount_0_delta(sqrt_price_target, sqrt_price_current, liquidity, true)?
} else {
get_amount_1_delta(sqrt_price_current, sqrt_price_target, liquidity, true)?
};
if amount_remaining_less_fee >= max_amount_in {
sqrt_price_next = sqrt_price_target;
amount_in = max_amount_in;
} else {
sqrt_price_next = get_next_sqrt_price_from_input(
sqrt_price_current,
liquidity,
amount_remaining_less_fee,
a_to_b,
)?;
let delta_in = if a_to_b {
get_amount_0_delta(sqrt_price_next, sqrt_price_current, liquidity, true)?
} else {
get_amount_1_delta(sqrt_price_current, sqrt_price_next, liquidity, true)?
};
amount_in = delta_in.min(amount_remaining_less_fee);
}
amount_out = if a_to_b {
get_amount_1_delta(sqrt_price_next, sqrt_price_current, liquidity, false)?
} else {
get_amount_0_delta(sqrt_price_current, sqrt_price_next, liquidity, false)?
};
if sqrt_price_next != sqrt_price_target {
fee_amount = amount_remaining.saturating_sub(amount_in);
} else {
fee_amount = fee_amount_from_input(amount_in, fee_rate_bps)?;
}
} else {
let max_amount_out = if a_to_b {
get_amount_1_delta(sqrt_price_target, sqrt_price_current, liquidity, false)?
} else {
get_amount_0_delta(sqrt_price_current, sqrt_price_target, liquidity, false)?
};
let capped_output = amount_remaining.min(max_amount_out);
if capped_output == 0 {
return Ok(SwapStepResult {
sqrt_price_next: sqrt_price_current,
amount_in: 0,
amount_out: 0,
fee_amount: 0,
});
}
sqrt_price_next = if amount_remaining >= max_amount_out {
sqrt_price_target
} else {
get_next_sqrt_price_from_output(sqrt_price_current, liquidity, capped_output, a_to_b)?
};
amount_out = capped_output;
amount_in = if a_to_b {
get_amount_0_delta(sqrt_price_next, sqrt_price_current, liquidity, true)?
} else {
get_amount_1_delta(sqrt_price_current, sqrt_price_next, liquidity, true)?
};
fee_amount = fee_amount_from_input(amount_in, fee_rate_bps)?;
}
Ok(SwapStepResult { sqrt_price_next, amount_in, amount_out, fee_amount })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tick_math::tick_to_sqrt_price_x64;
const Q64: u128 = 1u128 << 64;
#[test]
fn test_next_price_from_input_zero_amount() {
let price = Q64;
let result = get_next_sqrt_price_from_input(price, 1_000_000, 0, true).unwrap();
assert_eq!(result, price);
}
#[test]
fn test_next_price_from_input_a_to_b() {
let price = Q64; let liquidity = 1_000_000u128;
let result = get_next_sqrt_price_from_input(price, liquidity, 100, true).unwrap();
assert!(result < price, "a_to_b should decrease price");
}
#[test]
fn test_next_price_from_input_b_to_a() {
let price = Q64;
let liquidity = 1_000_000u128;
let result = get_next_sqrt_price_from_input(price, liquidity, 100, false).unwrap();
assert!(result > price, "b_to_a should increase price");
}
#[test]
fn test_next_price_from_input_zero_liquidity() {
assert!(get_next_sqrt_price_from_input(Q64, 0, 100, true).is_err());
}
#[test]
fn test_next_price_from_input_zero_price() {
assert!(get_next_sqrt_price_from_input(0, 1000, 100, true).is_err());
}
#[test]
fn test_next_price_from_output_zero_amount() {
let result = get_next_sqrt_price_from_output(Q64, 1_000_000, 0, true).unwrap();
assert_eq!(result, Q64);
}
#[test]
fn test_next_price_from_output_a_to_b() {
let price = Q64;
let liquidity = 1_000_000_000u128;
let result = get_next_sqrt_price_from_output(price, liquidity, 100, true).unwrap();
assert!(result < price, "a_to_b output should decrease price");
}
#[test]
fn test_next_price_from_output_b_to_a() {
let price = Q64;
let liquidity = 1_000_000_000u128;
let result = get_next_sqrt_price_from_output(price, liquidity, 100, false).unwrap();
assert!(result > price, "b_to_a output should increase price");
}
#[test]
fn test_next_price_from_output_excessive_amount_a_to_b() {
let result = get_next_sqrt_price_from_output(Q64, 1, u64::MAX, true);
assert!(result.is_err());
}
#[test]
fn test_swap_step_zero_amount() {
let result = compute_swap_step(Q64, Q64 / 2, 1_000_000, 0, 30, true).unwrap();
assert_eq!(result.amount_in, 0);
assert_eq!(result.amount_out, 0);
assert_eq!(result.fee_amount, 0);
assert_eq!(result.sqrt_price_next, Q64);
}
#[test]
fn test_swap_step_zero_liquidity() {
let result = compute_swap_step(Q64, Q64 / 2, 0, 1000, 30, true).unwrap();
assert_eq!(result.amount_in, 0);
assert_eq!(result.amount_out, 0);
}
#[test]
fn test_swap_step_exact_in_a_to_b_reaches_target() {
let price_current = tick_to_sqrt_price_x64(100).unwrap();
let price_target = tick_to_sqrt_price_x64(0).unwrap();
let liquidity = 10_000_000_000u128;
let result =
compute_swap_step(price_current, price_target, liquidity, u64::MAX, 30, true).unwrap();
assert_eq!(result.sqrt_price_next, price_target, "should reach target price");
assert!(result.amount_in > 0);
assert!(result.amount_out > 0);
assert!(result.fee_amount > 0);
}
#[test]
fn test_swap_step_exact_in_a_to_b_partial() {
let price_current = tick_to_sqrt_price_x64(1000).unwrap();
let price_target = tick_to_sqrt_price_x64(0).unwrap();
let liquidity = 10_000_000_000_000u128;
let result =
compute_swap_step(price_current, price_target, liquidity, 100, 30, true).unwrap();
assert!(result.sqrt_price_next > price_target, "should not reach target");
assert!(result.sqrt_price_next < price_current);
assert_eq!(
result.fee_amount + result.amount_in,
100,
"fee + amount_in should equal amount_remaining"
);
}
#[test]
fn test_swap_step_exact_in_b_to_a() {
let price_current = tick_to_sqrt_price_x64(0).unwrap();
let price_target = tick_to_sqrt_price_x64(100).unwrap();
let liquidity = 10_000_000_000u128;
let result =
compute_swap_step(price_current, price_target, liquidity, u64::MAX, 30, true).unwrap();
assert_eq!(result.sqrt_price_next, price_target);
}
#[test]
fn test_swap_step_exact_out_a_to_b() {
let price_current = tick_to_sqrt_price_x64(100).unwrap();
let price_target = tick_to_sqrt_price_x64(0).unwrap();
let liquidity = 10_000_000_000u128;
let result =
compute_swap_step(price_current, price_target, liquidity, 1_000_000, 30, false)
.unwrap();
assert!(result.amount_out <= 1_000_000);
assert!(result.amount_in > 0);
assert!(result.fee_amount > 0);
}
#[test]
fn test_swap_step_exact_out_b_to_a() {
let price_current = tick_to_sqrt_price_x64(0).unwrap();
let price_target = tick_to_sqrt_price_x64(100).unwrap();
let liquidity = 10_000_000_000u128;
let result =
compute_swap_step(price_current, price_target, liquidity, 100, 30, false).unwrap();
assert!(result.amount_out <= 100);
}
#[test]
fn test_swap_step_price_never_exceeds_target_a_to_b() {
let price_current = tick_to_sqrt_price_x64(500).unwrap();
let price_target = tick_to_sqrt_price_x64(-500).unwrap();
let liquidity = 1_000_000_000u128;
for amount in [1, 100, 10_000, 1_000_000, u64::MAX] {
let result =
compute_swap_step(price_current, price_target, liquidity, amount, 30, true)
.unwrap();
assert!(
result.sqrt_price_next >= price_target,
"price went below target for amount={amount}"
);
}
}
#[test]
fn test_swap_step_price_never_exceeds_target_b_to_a() {
let price_current = tick_to_sqrt_price_x64(-500).unwrap();
let price_target = tick_to_sqrt_price_x64(500).unwrap();
let liquidity = 1_000_000_000u128;
for amount in [1, 100, 10_000, 1_000_000, u64::MAX] {
let result =
compute_swap_step(price_current, price_target, liquidity, amount, 30, true)
.unwrap();
assert!(
result.sqrt_price_next <= price_target,
"price went above target for amount={amount}"
);
}
}
#[test]
fn test_swap_step_fee_zero() {
let price_current = tick_to_sqrt_price_x64(100).unwrap();
let price_target = tick_to_sqrt_price_x64(0).unwrap();
let liquidity = 10_000_000_000u128;
let result =
compute_swap_step(price_current, price_target, liquidity, 1_000_000, 0, true).unwrap();
assert_eq!(result.fee_amount, 0);
}
#[test]
fn test_swap_step_exact_in_full_consume_fee_identity() {
let price_current = tick_to_sqrt_price_x64(10000).unwrap();
let price_target = tick_to_sqrt_price_x64(0).unwrap();
let liquidity = 1_000_000_000_000_000u128;
let amount = 50_000u64;
let result =
compute_swap_step(price_current, price_target, liquidity, amount, 30, true).unwrap();
if result.sqrt_price_next != price_target {
assert_eq!(
result.fee_amount + result.amount_in,
amount,
"partial fill: fee + in should equal remaining"
);
}
}
}