use crate::error::ErrorCode;
use crate::math::bn::Downcast;
use crate::math::get_sqrt_price_at_tick;
use std::{ops::Shl, u128};
use super::{
bn::{Shift, U256},
full_math::{DivRoundUpIf, FullMath},
tick_math::{MAX_SQRT_PRICE_X64, MIN_SQRT_PRICE_X64},
};
pub const FEE_RATE_DENOMINATOR: u64 = 1_000_000;
pub struct SwapStepResult {
pub next_sqrt_price: u128,
pub amount_in: u64,
pub amount_out: u64,
pub fee_amount: u64,
}
pub fn get_liquidity_from_amount(
lower_index: i32,
upper_index: i32,
current_sqrt_price: u128,
amount_a: Option<u64>,
amount_b: Option<u64>,
) -> Result<(u128, u64, u64), ErrorCode> {
let lower_price = get_sqrt_price_at_tick(lower_index);
let upper_price = get_sqrt_price_at_tick(upper_index);
if current_sqrt_price < lower_price {
if amount_a.is_none() {
return Err(ErrorCode::InvalidAmountInput);
}
let amount_a = amount_a.unwrap();
let liquidity = get_liquidity_from_a(lower_price, upper_price, amount_a, false)?;
return Ok((liquidity, amount_a, 0));
}
if current_sqrt_price < upper_price {
if (amount_a.is_none() && amount_b.is_none()) || (amount_a.is_some() && amount_b.is_some())
{
return Err(ErrorCode::InvalidAmountInput);
}
if amount_a.is_some() {
if amount_a.is_none() {
return Err(ErrorCode::InvalidAmountInput);
}
let amount_a = amount_a.unwrap();
let liquidity = get_liquidity_from_a(current_sqrt_price, upper_price, amount_a, false)?;
let amount_b = get_delta_b(lower_price, current_sqrt_price, liquidity, true)?;
return Ok((liquidity, amount_a, amount_b));
}
if amount_b.is_some() {
if amount_b.is_none() {
return Err(ErrorCode::InvalidAmountInput);
}
let amount_b = amount_b.unwrap();
let liquidity = get_liquidity_from_b(lower_price, current_sqrt_price, amount_b, false)?;
let amount_a = get_delta_a(current_sqrt_price, upper_price, liquidity, true)?;
return Ok((liquidity, amount_a, amount_b));
}
} else {
if amount_b.is_none() {
return Err(ErrorCode::InvalidAmountInput);
}
let amount_b = amount_b.unwrap();
let liquidity = get_liquidity_from_b(lower_price, upper_price, amount_b, false)?;
return Ok((liquidity, 0, amount_b));
}
Err(ErrorCode::InvalidAmountInput)
}
pub fn get_delta_a(
sqrt_price_0: u128,
sqrt_price_1: u128,
liquidity: u128,
round_up: bool,
) -> Result<u64, ErrorCode> {
let sqrt_price_diff = if sqrt_price_0 > sqrt_price_1 {
sqrt_price_0 - sqrt_price_1
} else {
sqrt_price_1 - sqrt_price_0
};
let numberator = liquidity
.full_mul(sqrt_price_diff)
.checked_shift_word_left()
.ok_or(ErrorCode::MultiplicationOverflow)?;
let denomminator = sqrt_price_0.full_mul(sqrt_price_1);
let (quotient, remainder) = numberator.div_mod(denomminator);
match round_up && !remainder.is_zero() {
true => (quotient + 1)
.checked_as_u64()
.ok_or(ErrorCode::IntegerDowncastOverflow),
false => quotient
.checked_as_u64()
.ok_or(ErrorCode::IntegerDowncastOverflow),
}
}
pub fn get_delta_b(
sqrt_price_0: u128,
sqrt_price_1: u128,
liquidity: u128,
round_up: bool,
) -> Result<u64, ErrorCode> {
let sqrt_price_diff = if sqrt_price_0 > sqrt_price_1 {
sqrt_price_0 - sqrt_price_1
} else {
sqrt_price_1 - sqrt_price_0
};
if liquidity == 0 || sqrt_price_diff == 0 {
return Ok(0);
}
let product = liquidity
.checked_mul(sqrt_price_diff)
.ok_or(ErrorCode::MultiplicationOverflow)?;
let should_round_up = round_up && (product & 0x_FFFF_FFFF_FFFF_FFFF) > 0;
let result = (product >> 64u32) as u64;
match should_round_up {
true => result
.checked_add(1)
.ok_or(ErrorCode::MultiplicationOverflow),
false => Ok(result),
}
}
pub fn get_liquidity_from_a(
sqrt_price_0: u128,
sqrt_price_1: u128,
amount_a: u64,
round_up: bool,
) -> Result<u128, ErrorCode> {
let sqrt_price_diff = if sqrt_price_0 > sqrt_price_1 {
sqrt_price_0 - sqrt_price_1
} else {
sqrt_price_1 - sqrt_price_0
};
let numberator = sqrt_price_0
.full_mul(sqrt_price_1)
.shift_word_right()
.checked_mul(U256::from(amount_a))
.ok_or(ErrorCode::MultiplicationOverflow)?;
let div_res = numberator
.checked_div_round_up_if(U256::from(sqrt_price_diff), round_up)
.ok_or(ErrorCode::DivisorIsZero)?
.as_u128();
Ok(div_res)
}
pub fn get_liquidity_from_b(
sqrt_price_0: u128,
sqrt_price_1: u128,
amount_b: u64,
round_up: bool,
) -> Result<u128, ErrorCode> {
let sqrt_price_diff = if sqrt_price_0 > sqrt_price_1 {
sqrt_price_0 - sqrt_price_1
} else {
sqrt_price_1 - sqrt_price_0
};
let div_res = U256::from(amount_b)
.checked_shift_word_left()
.unwrap()
.checked_div_round_up_if(U256::from(sqrt_price_diff), round_up)
.ok_or(ErrorCode::DivisorIsZero)?
.as_u128();
Ok(div_res)
}
pub fn get_next_sqrt_price_a_up(
sqrt_price: u128,
liquidity: u128,
amount: u64,
by_amount_input: bool,
) -> Result<u128, ErrorCode> {
if amount == 0 {
return Ok(sqrt_price);
}
let numberator = sqrt_price
.full_mul(liquidity)
.checked_shift_word_left()
.ok_or(ErrorCode::MultiplicationOverflow)?;
let liquidity_shl_64 = U256::from(liquidity).shift_word_left();
let product = sqrt_price.full_mul(amount as u128);
let quotient = match by_amount_input {
true => numberator
.checked_div_round_up_if(liquidity_shl_64.checked_add(product).unwrap(), true)
.ok_or(ErrorCode::DivisorIsZero)?,
false => numberator
.checked_div_round_up_if(liquidity_shl_64.checked_sub(product).unwrap(), true)
.ok_or(ErrorCode::DivisorIsZero)?,
};
let new_sqrt_price = quotient
.checked_as_u128()
.ok_or(ErrorCode::IntegerDowncastOverflow)?;
if new_sqrt_price > MAX_SQRT_PRICE_X64 {
return Err(ErrorCode::TokenAmountMaxExceeded);
} else if new_sqrt_price < MIN_SQRT_PRICE_X64 {
return Err(ErrorCode::TokenAmountMinSubceeded);
}
Ok(new_sqrt_price)
}
pub fn get_next_sqrt_price_b_down(
sqrt_price: u128,
liquidity: u128,
amount: u64,
by_amount_input: bool,
) -> Result<u128, ErrorCode> {
let delta_sqrt_price = (amount as u128)
.shl(64u32)
.checked_div_round_up_if(liquidity, !by_amount_input)
.ok_or(ErrorCode::DivisorIsZero)?;
let new_sqrt_price = match by_amount_input {
true => sqrt_price
.checked_add(delta_sqrt_price)
.ok_or(ErrorCode::SqrtPriceOutOfBounds)?,
false => sqrt_price
.checked_sub(delta_sqrt_price)
.ok_or(ErrorCode::SqrtPriceOutOfBounds)?,
};
if new_sqrt_price < MIN_SQRT_PRICE_X64 || new_sqrt_price > MAX_SQRT_PRICE_X64 {
return Err(ErrorCode::SqrtPriceOutOfBounds);
}
Ok(new_sqrt_price)
}
pub fn get_next_sqrt_price_from_input(
sqrt_price: u128,
liquidity: u128,
amount: u64,
a_to_b: bool,
) -> Result<u128, ErrorCode> {
match a_to_b {
true => get_next_sqrt_price_a_up(sqrt_price, liquidity, amount, true),
false => get_next_sqrt_price_b_down(sqrt_price, liquidity, amount, true),
}
}
pub fn get_next_sqrt_price_from_output(
sqrt_price: u128,
liquidity: u128,
amount: u64,
a_to_b: bool,
) -> Result<u128, ErrorCode> {
match a_to_b {
true => get_next_sqrt_price_b_down(sqrt_price, liquidity, amount, false),
false => get_next_sqrt_price_a_up(sqrt_price, liquidity, amount, false),
}
}
pub fn get_delta_up_from_input(
current_sqrt_price: u128,
target_sqrt_price: u128,
liquidity: u128,
a_to_b: bool,
) -> Result<u64, ErrorCode> {
match a_to_b {
true => get_delta_a(target_sqrt_price, current_sqrt_price, liquidity, true),
false => get_delta_b(current_sqrt_price, target_sqrt_price, liquidity, true),
}
}
pub fn get_delta_down_from_output(
current_sqrt_price: u128,
target_sqrt_price: u128,
liquidity: u128,
a_to_b: bool,
) -> Result<u64, ErrorCode> {
match a_to_b {
true => get_delta_b(target_sqrt_price, current_sqrt_price, liquidity, false),
false => get_delta_a(current_sqrt_price, target_sqrt_price, liquidity, false),
}
}
pub fn compute_swap_step(
current_sqrt_price: u128,
target_sqrt_price: u128,
liquidity: u128,
amount: u64,
fee_rate: u16,
by_amount_input: bool,
) -> Result<SwapStepResult, ErrorCode> {
if liquidity == 0 {
return Ok(SwapStepResult {
amount_in: 0u64,
amount_out: 0u64,
next_sqrt_price: target_sqrt_price,
fee_amount: 0u64,
});
}
let a_to_b = current_sqrt_price >= target_sqrt_price;
let next_sqrt_price;
let amount_in: u64;
let amount_out: u64;
let fee_amount: u64;
match by_amount_input {
true => {
let amount_remain = amount.mul_div_floor(
FEE_RATE_DENOMINATOR.checked_sub(fee_rate as u64).unwrap(),
FEE_RATE_DENOMINATOR,
);
let max_amount_in =
get_delta_up_from_input(current_sqrt_price, target_sqrt_price, liquidity, a_to_b)?;
if max_amount_in >= amount_remain {
amount_in = amount_remain;
fee_amount = amount.checked_sub(amount_remain).unwrap();
next_sqrt_price = get_next_sqrt_price_from_input(
current_sqrt_price,
liquidity,
amount_remain,
a_to_b,
)?;
} else {
amount_in = max_amount_in;
fee_amount = amount_in.mul_div_ceil(fee_rate as u64, FEE_RATE_DENOMINATOR);
next_sqrt_price = target_sqrt_price;
}
amount_out =
get_delta_down_from_output(current_sqrt_price, next_sqrt_price, liquidity, a_to_b)?;
}
false => {
let max_amount_out = get_delta_down_from_output(
current_sqrt_price,
target_sqrt_price,
liquidity,
a_to_b,
)?;
if max_amount_out >= amount {
amount_out = amount;
next_sqrt_price =
get_next_sqrt_price_from_output(current_sqrt_price, liquidity, amount, a_to_b)?;
} else {
amount_out = max_amount_out;
next_sqrt_price = target_sqrt_price;
}
amount_in =
get_delta_up_from_input(current_sqrt_price, next_sqrt_price, liquidity, a_to_b)?;
fee_amount = amount_in.mul_div_ceil(fee_rate as u64, FEE_RATE_DENOMINATOR);
}
}
Ok(SwapStepResult {
amount_in,
amount_out,
next_sqrt_price,
fee_amount,
})
}
#[cfg(test)]
mod test_get_next_sqrt_price_from_b_round_down {
use super::*;
#[test]
fn get_next_sqrt_price_from_b_round_down_ok() {
let (sqrt_price, liquidity, amount, add) = (
62058032627749460283664515388u128,
56315830353026631512438212669420532741u128,
10476203047244913035u64,
true,
);
let r = get_next_sqrt_price_b_down(sqrt_price, liquidity, amount, add).unwrap();
println!("{}", r);
}
}
#[cfg(test)]
mod fuzz_tests {
use proptest::prelude::*;
use super::*;
pub fn lower_upper_sqrt_price(sqrt_price_0: u128, sqrt_price_1: u128) -> (u128, u128) {
if sqrt_price_0 < sqrt_price_1 {
(sqrt_price_0, sqrt_price_1)
} else {
(sqrt_price_1, sqrt_price_0)
}
}
proptest! {
#[test]
fn test_get_next_sqrt_price_from_a_round_up (
sqrt_price in MIN_SQRT_PRICE_X64..MAX_SQRT_PRICE_X64,
liquidity in 1..u128::MAX,
amount in 0..u64::MAX,
) {
prop_assume!(sqrt_price != 0);
let case_1_price = get_next_sqrt_price_a_up(sqrt_price, liquidity, amount, true);
if liquidity.leading_zeros() + sqrt_price.leading_zeros() < 64 {
assert!(case_1_price.is_err());
} else {
println!("{} {} {}", sqrt_price, case_1_price.unwrap(), liquidity);
assert!(amount >= get_delta_a(sqrt_price, case_1_price.unwrap(), liquidity, true).unwrap());
let case_2_price = get_next_sqrt_price_a_up(sqrt_price, liquidity, amount, false);
let liquidity_x64 = U256::from(liquidity) << 64;
let product = U256::from(sqrt_price) * U256::from(amount);
if liquidity_x64 <= product {
assert!(case_2_price.is_err());
} else {
assert!(amount <= get_delta_a(sqrt_price, case_2_price.unwrap(), liquidity, false).unwrap());
assert!(case_2_price.unwrap() >= sqrt_price);
}
if amount == 0 {
assert!(case_1_price.unwrap() == case_2_price.unwrap());
}
}
}
#[test]
fn test_get_next_sqrt_price_from_b_round_down (
sqrt_price in MIN_SQRT_PRICE_X64..MAX_SQRT_PRICE_X64,
liquidity in 1..u128::MAX,
amount in 0..u64::MAX,
) {
prop_assume!(sqrt_price != 0);
let case_3_price = get_next_sqrt_price_b_down(sqrt_price, liquidity, amount, true).unwrap();
assert!(case_3_price >= sqrt_price);
assert!(amount >= get_delta_b(sqrt_price, case_3_price, liquidity, true).unwrap());
let case_4_price = get_next_sqrt_price_b_down(sqrt_price, liquidity, amount, false);
let amount_x64 = u128::from(amount) << 64;
let delta = amount_x64.checked_div_round_up_if(liquidity, true).unwrap();
if sqrt_price < delta {
assert!(case_4_price.is_err());
} else {
let calc_delta = get_delta_b(sqrt_price, case_4_price.unwrap(), liquidity, false);
if calc_delta.is_ok() {
assert!(amount <= calc_delta.unwrap());
}
assert!(case_4_price.unwrap() <= sqrt_price);
}
if amount == 0 {
assert!(case_3_price == case_4_price.unwrap());
}
}
#[test]
fn test_get_amount_delta_a(
sqrt_price_0 in MIN_SQRT_PRICE_X64..MAX_SQRT_PRICE_X64,
sqrt_price_1 in MIN_SQRT_PRICE_X64..MAX_SQRT_PRICE_X64,
liquidity in 0..u128::MAX,
) {
let (sqrt_price_lower, sqrt_price_upper) = lower_upper_sqrt_price(sqrt_price_0, sqrt_price_1);
let rounded = get_delta_a(sqrt_price_0, sqrt_price_1, liquidity, true);
if liquidity.leading_zeros() + (sqrt_price_upper - sqrt_price_lower).leading_zeros() < 64 {
assert!(rounded.is_err())
} else {
let unrounded = get_delta_a(sqrt_price_0, sqrt_price_1, liquidity, false).unwrap();
assert_eq!(rounded.unwrap(), get_delta_a(sqrt_price_1, sqrt_price_0, liquidity, true).unwrap());
assert_eq!(unrounded, get_delta_a(sqrt_price_1, sqrt_price_0, liquidity, false).unwrap());
assert!(unrounded <= rounded.unwrap());
assert!(rounded.unwrap() - unrounded <= 1);
}
}
#[test]
fn test_get_amount_delta_b(
sqrt_price_0 in MIN_SQRT_PRICE_X64..MAX_SQRT_PRICE_X64,
sqrt_price_1 in MIN_SQRT_PRICE_X64..MAX_SQRT_PRICE_X64,
liquidity in 0..u128::MAX,
) {
let (price_lower, price_upper) = lower_upper_sqrt_price(sqrt_price_0, sqrt_price_1);
let n_0 = U256::from(liquidity); let n_1 = U256::from(price_upper - price_lower);
let m = n_0 * n_1; let delta = m >> 64; let has_mod = m % (1u128 << 64) > U256::zero();
let round_up_delta = if has_mod { delta + U256::from(1) } else { delta };
let rounded = get_delta_b(sqrt_price_0, sqrt_price_1, liquidity, true);
let unrounded = get_delta_b(sqrt_price_0, sqrt_price_1, liquidity, false);
let u64_max_in_u256 = U256::from(u64::MAX);
if delta > u64_max_in_u256 {
assert!(rounded.is_err());
assert!(unrounded.is_err());
} else if round_up_delta > u64_max_in_u256 {
assert!(rounded.is_err());
assert_eq!(unrounded.unwrap(), get_delta_b(sqrt_price_1, sqrt_price_0, liquidity, false).unwrap());
} else {
assert_eq!(rounded.unwrap(), get_delta_b(sqrt_price_1, sqrt_price_0, liquidity, true).unwrap());
assert_eq!(unrounded.unwrap(), get_delta_b(sqrt_price_1, sqrt_price_0, liquidity, false).unwrap());
assert!(unrounded.unwrap() <= rounded.unwrap() );
assert!(rounded.unwrap() - unrounded.unwrap() <= 1);
}
}
}
}
#[cfg(test)]
mod test_get_amount_delta {
use super::{get_delta_a, get_delta_b};
#[test]
fn test_get_amount_delta_ok() {
assert_eq!(get_delta_a(4 << 64, 2 << 64, 4, true).unwrap(), 1);
assert_eq!(get_delta_a(4 << 64, 2 << 64, 4, false).unwrap(), 1);
assert_eq!(get_delta_b(4 << 64, 2 << 64, 4, true).unwrap(), 8);
assert_eq!(get_delta_b(4 << 64, 2 << 64, 4, false).unwrap(), 8);
}
#[test]
fn test_get_amount_delta_price_diff_zero_ok() {
assert_eq!(get_delta_a(4 << 64, 4 << 64, 4, true).unwrap(), 0);
assert_eq!(get_delta_a(4 << 64, 4 << 64, 4, false).unwrap(), 0);
assert_eq!(get_delta_b(4 << 64, 4 << 64, 4, true).unwrap(), 0);
assert_eq!(get_delta_b(4 << 64, 4 << 64, 4, false).unwrap(), 0);
}
#[test]
fn test_get_amount_delta_a_overflow() {
assert!(get_delta_a(1 << 64, 2 << 64, u128::MAX, true).is_err());
assert!(get_delta_a(1 << 64, 2 << 64, (u64::MAX as u128) << 1 + 1, true).is_err());
assert!(get_delta_a(1 << 64, 2 << 64, (u64::MAX as u128) << 1, true).is_ok());
assert!(get_delta_a(1 << 64, 2 << 64, u64::MAX as u128, true).is_ok());
}
}
#[cfg(test)]
mod test_next_sqrt_price {
use super::get_next_sqrt_price_a_up;
#[test]
fn test_get_next_sqrt_price_from_a_round_up() {
let (sqrt_price, liquidity, amount) = (10u128 << 64, 200u128 << 64, 10000000u64);
let r1 = get_next_sqrt_price_a_up(sqrt_price, liquidity, amount, true);
let r2 = get_next_sqrt_price_a_up(sqrt_price, liquidity, amount, false);
assert_eq!(184467440737090516161u128, r1.unwrap());
assert_eq!(184467440737100516161u128, r2.unwrap());
println!("{}", r1.unwrap());
println!("{}", r2.unwrap());
}
}
#[cfg(test)]
mod test_compute_swap_step {
use super::compute_swap_step;
#[test]
fn test_compute_swap_step() {
let (current_sqrt_price, target_sqrt_price, liquidity, amount, fee_rate, is_input) = (
1u128 << 64,
2u128 << 64,
1000u128 << 32,
20000,
1000u16,
false,
);
let result = compute_swap_step(
current_sqrt_price,
target_sqrt_price,
liquidity,
amount,
fee_rate,
is_input,
)
.unwrap();
println!(
"{} {} {} {}",
result.amount_in, result.amount_out, result.next_sqrt_price, result.fee_amount
);
}
}