use alloy_primitives::U256;
use crate::{
full_math::{mul_div, mul_div_rounding_up},
AmmMathError,
};
const Q96: U256 = U256::from_limbs([0, 1 << 32, 0, 0]);
pub fn get_amount_0_delta(
sqrt_ratio_a_x96: U256,
sqrt_ratio_b_x96: U256,
liquidity: u128,
round_up: bool,
) -> crate::Result<U256> {
let (sqrt_lower, sqrt_upper) = if sqrt_ratio_a_x96 <= sqrt_ratio_b_x96 {
(sqrt_ratio_a_x96, sqrt_ratio_b_x96)
} else {
(sqrt_ratio_b_x96, sqrt_ratio_a_x96)
};
if sqrt_lower.is_zero() {
return Err(AmmMathError::SqrtPriceDiffZero);
}
let numerator = U256::from(liquidity) << 96;
let diff = sqrt_upper - sqrt_lower;
if round_up {
let amount = mul_div_rounding_up(numerator, diff, sqrt_upper)?;
div_rounding_up(amount, sqrt_lower)
} else {
let amount = mul_div(numerator, diff, sqrt_upper)?;
Ok(amount / sqrt_lower)
}
}
pub fn get_amount_1_delta(
sqrt_ratio_a_x96: U256,
sqrt_ratio_b_x96: U256,
liquidity: u128,
round_up: bool,
) -> crate::Result<U256> {
let (sqrt_lower, sqrt_upper) = if sqrt_ratio_a_x96 <= sqrt_ratio_b_x96 {
(sqrt_ratio_a_x96, sqrt_ratio_b_x96)
} else {
(sqrt_ratio_b_x96, sqrt_ratio_a_x96)
};
let diff = sqrt_upper - sqrt_lower;
if round_up {
mul_div_rounding_up(U256::from(liquidity), diff, Q96)
} else {
mul_div(U256::from(liquidity), diff, Q96)
}
}
pub(crate) fn div_rounding_up(a: U256, b: U256) -> crate::Result<U256> {
if b.is_zero() {
return Err(AmmMathError::DivisionByZero);
}
let q = a / b;
let r = a % b;
if r.is_zero() {
Ok(q)
} else {
Ok(q + U256::from(1u64))
}
}
fn u160_max() -> U256 {
(U256::from(1u64) << 160) - U256::from(1u64)
}
pub(crate) fn get_next_sqrt_price_from_amount_0_rounding_up(
sqrt_p_x96: U256,
liquidity: u128,
amount: U256,
add: bool,
) -> crate::Result<U256> {
if amount.is_zero() {
return Ok(sqrt_p_x96);
}
let numerator_1: U256 = U256::from(liquidity) << 96;
if add {
if let Some(product) = amount.checked_mul(sqrt_p_x96) {
if let Some(denominator) = numerator_1.checked_add(product) {
return mul_div_rounding_up(numerator_1, sqrt_p_x96, denominator);
}
}
let term =
(numerator_1 / sqrt_p_x96).checked_add(amount).ok_or(AmmMathError::MulDivOverflow)?;
div_rounding_up(numerator_1, term)
} else {
let product = amount.checked_mul(sqrt_p_x96).ok_or(AmmMathError::PriceUnderflow)?;
if numerator_1 <= product {
return Err(AmmMathError::PriceUnderflow);
}
let denominator = numerator_1 - product;
let next = mul_div_rounding_up(numerator_1, sqrt_p_x96, denominator)?;
if next > u160_max() {
return Err(AmmMathError::SqrtPriceOutOfRange);
}
Ok(next)
}
}
pub(crate) fn get_next_sqrt_price_from_amount_1_rounding_down(
sqrt_p_x96: U256,
liquidity: u128,
amount: U256,
add: bool,
) -> crate::Result<U256> {
let liquidity = U256::from(liquidity);
if add {
let quotient = if amount <= u160_max() {
if liquidity.is_zero() {
return Err(AmmMathError::LiquidityZero);
}
(amount << 96) / liquidity
} else {
mul_div(amount, Q96, liquidity)?
};
let next = sqrt_p_x96.checked_add(quotient).ok_or(AmmMathError::SqrtPriceOutOfRange)?;
if next > u160_max() {
return Err(AmmMathError::SqrtPriceOutOfRange);
}
Ok(next)
} else {
let quotient = if amount <= u160_max() {
div_rounding_up(amount << 96, liquidity)?
} else {
mul_div_rounding_up(amount, Q96, liquidity)?
};
if sqrt_p_x96 <= quotient {
return Err(AmmMathError::PriceUnderflow);
}
Ok(sqrt_p_x96 - quotient)
}
}
pub fn get_next_sqrt_price_from_input(
sqrt_p_x96: U256,
liquidity: u128,
amount_in: U256,
zero_for_one: bool,
) -> crate::Result<U256> {
if sqrt_p_x96.is_zero() {
return Err(AmmMathError::SqrtPriceZero);
}
if liquidity == 0 {
return Err(AmmMathError::LiquidityZero);
}
if zero_for_one {
get_next_sqrt_price_from_amount_0_rounding_up(sqrt_p_x96, liquidity, amount_in, true)
} else {
get_next_sqrt_price_from_amount_1_rounding_down(sqrt_p_x96, liquidity, amount_in, true)
}
}
pub fn get_next_sqrt_price_from_output(
sqrt_p_x96: U256,
liquidity: u128,
amount_out: U256,
zero_for_one: bool,
) -> crate::Result<U256> {
if sqrt_p_x96.is_zero() {
return Err(AmmMathError::SqrtPriceZero);
}
if liquidity == 0 {
return Err(AmmMathError::LiquidityZero);
}
if zero_for_one {
get_next_sqrt_price_from_amount_1_rounding_down(sqrt_p_x96, liquidity, amount_out, false)
} else {
get_next_sqrt_price_from_amount_0_rounding_up(sqrt_p_x96, liquidity, amount_out, false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tick_math::get_sqrt_ratio_at_tick;
#[test]
fn test_amount_0_simple() {
let sqrt_a = get_sqrt_ratio_at_tick(0).unwrap();
let sqrt_b = get_sqrt_ratio_at_tick(100).unwrap();
let liq: u128 = 1_000_000_000_000_000_000;
let a0 = get_amount_0_delta(sqrt_a, sqrt_b, liq, false).unwrap();
assert!(a0 > U256::ZERO);
}
#[test]
fn test_amount_1_simple() {
let sqrt_a = get_sqrt_ratio_at_tick(0).unwrap();
let sqrt_b = get_sqrt_ratio_at_tick(100).unwrap();
let liq: u128 = 1_000_000_000_000_000_000;
let a1 = get_amount_1_delta(sqrt_a, sqrt_b, liq, false).unwrap();
assert!(a1 > U256::ZERO);
}
#[test]
fn test_amount_0_round_up_geq_round_down() {
let sqrt_a = get_sqrt_ratio_at_tick(-100).unwrap();
let sqrt_b = get_sqrt_ratio_at_tick(100).unwrap();
let liq: u128 = 999_999_999;
let down = get_amount_0_delta(sqrt_a, sqrt_b, liq, false).unwrap();
let up = get_amount_0_delta(sqrt_a, sqrt_b, liq, true).unwrap();
assert!(up >= down);
}
#[test]
fn test_amount_1_round_up_geq_round_down() {
let sqrt_a = get_sqrt_ratio_at_tick(-100).unwrap();
let sqrt_b = get_sqrt_ratio_at_tick(100).unwrap();
let liq: u128 = 999_999_999;
let down = get_amount_1_delta(sqrt_a, sqrt_b, liq, false).unwrap();
let up = get_amount_1_delta(sqrt_a, sqrt_b, liq, true).unwrap();
assert!(up >= down);
}
#[test]
fn test_zero_liquidity() {
let sqrt_a = get_sqrt_ratio_at_tick(0).unwrap();
let sqrt_b = get_sqrt_ratio_at_tick(100).unwrap();
let a0 = get_amount_0_delta(sqrt_a, sqrt_b, 0, false).unwrap();
let a1 = get_amount_1_delta(sqrt_a, sqrt_b, 0, false).unwrap();
assert_eq!(a0, U256::ZERO);
assert_eq!(a1, U256::ZERO);
}
#[test]
fn test_same_sqrt_price() {
let sqrt = get_sqrt_ratio_at_tick(42).unwrap();
let a0 = get_amount_0_delta(sqrt, sqrt, 1_000_000, false).unwrap();
let a1 = get_amount_1_delta(sqrt, sqrt, 1_000_000, false).unwrap();
assert_eq!(a0, U256::ZERO);
assert_eq!(a1, U256::ZERO);
}
#[test]
fn test_reversed_args() {
let sqrt_a = get_sqrt_ratio_at_tick(0).unwrap();
let sqrt_b = get_sqrt_ratio_at_tick(100).unwrap();
let liq: u128 = 10_000_000;
let normal = get_amount_0_delta(sqrt_a, sqrt_b, liq, false).unwrap();
let reversed = get_amount_0_delta(sqrt_b, sqrt_a, liq, false).unwrap();
assert_eq!(normal, reversed);
}
fn price_1() -> U256 {
U256::from_str_radix("79228162514264337593543950336", 10).unwrap()
}
#[test]
fn from_input_zero_amount_returns_input_price() {
let p = price_1();
let liq: u128 = 1_000_000_000_000_000_000; let next = get_next_sqrt_price_from_input(p, liq, U256::ZERO, true).unwrap();
assert_eq!(next, p);
let next2 = get_next_sqrt_price_from_input(p, liq, U256::ZERO, false).unwrap();
assert_eq!(next2, p);
}
#[test]
fn from_input_rejects_zero_price() {
let liq: u128 = 1;
let err = get_next_sqrt_price_from_input(U256::ZERO, liq, U256::from(1u64), true)
.expect_err("should reject zero price");
assert!(matches!(err, AmmMathError::SqrtPriceZero));
}
#[test]
fn from_input_rejects_zero_liquidity() {
let p = price_1();
let err = get_next_sqrt_price_from_input(p, 0, U256::from(1u64), true)
.expect_err("should reject zero liquidity");
assert!(matches!(err, AmmMathError::LiquidityZero));
}
#[test]
fn from_input_zero_for_one_decreases_price() {
let p = price_1();
let liq: u128 = 1_000_000_000_000_000_000;
let amount_in = U256::from(100_000_000_000_000_000u64); let next = get_next_sqrt_price_from_input(p, liq, amount_in, true).unwrap();
assert!(next < p, "zeroForOne should decrease price; before={p}, after={next}");
}
#[test]
fn from_input_one_for_zero_increases_price() {
let p = price_1();
let liq: u128 = 1_000_000_000_000_000_000;
let amount_in = U256::from(100_000_000_000_000_000u64);
let next = get_next_sqrt_price_from_input(p, liq, amount_in, false).unwrap();
assert!(next > p, "oneForZero should increase price; before={p}, after={next}");
}
#[test]
fn from_output_one_for_zero_decreases_price_for_token0_out() {
let p = price_1();
let liq: u128 = 1_000_000_000_000_000_000_000_000u128; let amount_out = U256::from(1_000_000u64);
let next = get_next_sqrt_price_from_output(p, liq, amount_out, false).unwrap();
assert!(next > p, "removing token0 should increase price; before={p}, after={next}");
}
#[test]
fn from_output_zero_for_one_decreases_price_for_token1_out() {
let p = price_1();
let liq: u128 = 1_000_000_000_000_000_000_000_000u128;
let amount_out = U256::from(1_000_000u64);
let next = get_next_sqrt_price_from_output(p, liq, amount_out, true).unwrap();
assert!(next < p, "removing token1 should decrease price; before={p}, after={next}");
}
#[test]
fn from_output_underflow_rejected() {
let p = price_1();
let liq: u128 = 1; let huge = U256::from(1u64) << 100; let result = get_next_sqrt_price_from_output(p, liq, huge, false);
assert!(result.is_err(), "huge output withdrawal should error, got {result:?}");
}
#[test]
fn from_input_round_trip_within_one_unit() {
let p = price_1();
let liq: u128 = 1_000_000_000_000_000_000_000_000u128;
let amount_in = U256::from(1_000_000_000_000_000u64); let next = get_next_sqrt_price_from_input(p, liq, amount_in, true).unwrap();
assert!(next < p);
let recovered = get_next_sqrt_price_from_input(next, liq, amount_in, false).unwrap();
assert!(recovered > next);
}
}