use crate::{
CoreError, TransferFee, AMOUNT_EXCEEDS_MAX_U64, ARITHMETIC_OVERFLOW, BPS_DENOMINATOR, FEE_RATE_MUL_VALUE, INVALID_SLIPPAGE_TOLERANCE,
INVALID_TRANSFER_FEE, MAX_SQRT_PRICE, MIN_SQRT_PRICE, SQRT_PRICE_OUT_OF_BOUNDS,
};
use crate::math::liquidity::{get_amount_a_from_liquidity, get_amount_b_from_liquidity};
use crate::math::u256_math::{mul_u256, U256Muldiv};
use ethnum::U256;
#[cfg(feature = "wasm")]
use fusionamm_macros::wasm_expose;
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_get_amount_delta_a(sqrt_price_1: u128, sqrt_price_2: u128, liquidity: u128, round_up: bool) -> Result<u64, CoreError> {
let (sqrt_price_lower, sqrt_price_upper) = order_prices(sqrt_price_1, sqrt_price_2);
get_amount_a_from_liquidity(liquidity, sqrt_price_lower, sqrt_price_upper, round_up)
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_get_amount_delta_b(sqrt_price_1: u128, sqrt_price_2: u128, liquidity: u128, round_up: bool) -> Result<u64, CoreError> {
let (sqrt_price_lower, sqrt_price_upper) = order_prices(sqrt_price_1, sqrt_price_2);
get_amount_b_from_liquidity(liquidity, sqrt_price_lower, sqrt_price_upper, round_up)
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_get_next_sqrt_price_from_a(
current_sqrt_price: u128,
current_liquidity: u128,
amount: u64,
specified_input: bool,
) -> Result<u128, CoreError> {
if amount == 0 {
return Ok(current_sqrt_price);
}
let p = <U256>::from(current_sqrt_price).checked_mul(amount.into()).ok_or(ARITHMETIC_OVERFLOW)?;
let numerator = <U256>::from(current_liquidity)
.checked_mul(current_sqrt_price.into())
.ok_or(ARITHMETIC_OVERFLOW)?
.checked_shl(64)
.ok_or(ARITHMETIC_OVERFLOW)?;
let current_liquidity_shifted = <U256>::from(current_liquidity).checked_shl(64).ok_or(ARITHMETIC_OVERFLOW)?;
let denominator = if specified_input {
current_liquidity_shifted + p
} else {
current_liquidity_shifted - p
};
let quotient: U256 = numerator / denominator;
let remainder: U256 = numerator % denominator;
let result = if remainder != 0 { quotient + 1 } else { quotient };
if !(MIN_SQRT_PRICE..=MAX_SQRT_PRICE).contains(&result) {
return Err(SQRT_PRICE_OUT_OF_BOUNDS);
}
Ok(result.as_u128())
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_get_next_sqrt_price_from_b(
current_sqrt_price: u128,
current_liquidity: u128,
amount: u64,
specified_input: bool,
) -> Result<u128, CoreError> {
if amount == 0 {
return Ok(current_sqrt_price);
}
let current_sqrt_price = <U256>::from(current_sqrt_price);
let current_liquidity = <U256>::from(current_liquidity);
let amount_shifted = <U256>::from(amount).checked_shl(64).ok_or(ARITHMETIC_OVERFLOW)?;
let quotient: U256 = amount_shifted / current_liquidity;
let remainder: U256 = amount_shifted % current_liquidity;
let delta = if !specified_input && remainder != 0 { quotient + 1 } else { quotient };
let result = if specified_input {
current_sqrt_price + delta
} else {
current_sqrt_price - delta
};
if !(MIN_SQRT_PRICE..=MAX_SQRT_PRICE).contains(&result) {
return Err(SQRT_PRICE_OUT_OF_BOUNDS);
}
Ok(result.as_u128())
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_apply_transfer_fee(amount: u64, transfer_fee: TransferFee) -> Result<u64, CoreError> {
if transfer_fee.fee_bps > BPS_DENOMINATOR {
return Err(INVALID_TRANSFER_FEE);
}
if transfer_fee.fee_bps == 0 || amount == 0 {
return Ok(amount);
}
let numerator = <u128>::from(amount).checked_mul(transfer_fee.fee_bps.into()).ok_or(ARITHMETIC_OVERFLOW)?;
let raw_fee: u64 = numerator
.div_ceil(BPS_DENOMINATOR.into())
.try_into()
.map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
let fee_amount = raw_fee.min(transfer_fee.max_fee);
Ok(amount - fee_amount)
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_reverse_apply_transfer_fee(amount: u64, transfer_fee: TransferFee) -> Result<u64, CoreError> {
if transfer_fee.fee_bps > BPS_DENOMINATOR {
Err(INVALID_TRANSFER_FEE)
} else if transfer_fee.fee_bps == 0 {
Ok(amount)
} else if amount == 0 {
Ok(0)
} else if transfer_fee.fee_bps == BPS_DENOMINATOR {
amount.checked_add(transfer_fee.max_fee).ok_or(AMOUNT_EXCEEDS_MAX_U64)
} else {
let numerator = <u128>::from(amount).checked_mul(BPS_DENOMINATOR.into()).ok_or(ARITHMETIC_OVERFLOW)?;
let denominator = <u128>::from(BPS_DENOMINATOR) - <u128>::from(transfer_fee.fee_bps);
let raw_pre_fee_amount = numerator.div_ceil(denominator);
let fee_amount = raw_pre_fee_amount.checked_sub(amount.into()).ok_or(AMOUNT_EXCEEDS_MAX_U64)?;
if fee_amount >= transfer_fee.max_fee as u128 {
amount.checked_add(transfer_fee.max_fee).ok_or(AMOUNT_EXCEEDS_MAX_U64)
} else {
raw_pre_fee_amount.try_into().map_err(|_| AMOUNT_EXCEEDS_MAX_U64)
}
}
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_get_max_amount_with_slippage_tolerance(amount: u64, slippage_tolerance_bps: u16) -> Result<u64, CoreError> {
if slippage_tolerance_bps > BPS_DENOMINATOR {
return Err(INVALID_SLIPPAGE_TOLERANCE);
}
let product = <u128>::from(BPS_DENOMINATOR) + <u128>::from(slippage_tolerance_bps);
let result = try_mul_div(amount, product, BPS_DENOMINATOR.into(), true)?;
Ok(result)
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_get_min_amount_with_slippage_tolerance(amount: u64, slippage_tolerance_bps: u16) -> Result<u64, CoreError> {
if slippage_tolerance_bps > BPS_DENOMINATOR {
return Err(INVALID_SLIPPAGE_TOLERANCE);
}
let product = <u128>::from(BPS_DENOMINATOR) - <u128>::from(slippage_tolerance_bps);
let result = try_mul_div(amount, product, BPS_DENOMINATOR.into(), false)?;
Ok(result)
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_apply_swap_fee(amount: u64, fee_rate: u16) -> Result<u64, CoreError> {
let product = <u128>::from(FEE_RATE_MUL_VALUE) - <u128>::from(fee_rate);
let result = try_mul_div(amount, product, FEE_RATE_MUL_VALUE.into(), false)?;
Ok(result)
}
#[cfg_attr(feature = "wasm", wasm_expose)]
pub fn try_reverse_apply_swap_fee(amount: u64, fee_rate: u16) -> Result<u64, CoreError> {
let denominator = <u128>::from(FEE_RATE_MUL_VALUE) - <u128>::from(fee_rate);
let result = try_mul_div(amount, FEE_RATE_MUL_VALUE.into(), denominator, true)?;
Ok(result)
}
pub fn try_mul_div(amount: u64, product: u128, denominator: u128, round_up: bool) -> Result<u64, CoreError> {
if amount == 0 || product == 0 {
return Ok(0);
}
let amount: u128 = amount.into();
let numerator = amount.checked_mul(product).ok_or(ARITHMETIC_OVERFLOW)?;
let quotient = numerator / denominator;
let remainder = numerator % denominator;
let result = if round_up && remainder != 0 { quotient + 1 } else { quotient };
result.try_into().map_err(|_| AMOUNT_EXCEEDS_MAX_U64)
}
pub fn mul_by_sqrt_price_squared(amount: u64, sqrt_price: u128, round_up: bool) -> Result<u64, CoreError> {
let price = mul_u256(sqrt_price, sqrt_price).shift_word_right();
let value_u256 = U256Muldiv::new(0, amount as u128).mul(price);
let value_u128 = if round_up && value_u256.get_word(0) > 0 {
value_u256.shift_word_right().try_into_u128()? + 1
} else {
value_u256.shift_word_right().try_into_u128()?
};
u64::try_from(value_u128).or(Err(ARITHMETIC_OVERFLOW))
}
pub fn div_by_sqrt_price_squared(amount: u64, sqrt_price: u128, round_up: bool) -> Result<u64, CoreError> {
let price = mul_u256(sqrt_price, sqrt_price);
let (value_u256, reminder_u256) = U256Muldiv::new(amount as u128, 0).div(price, true);
let value_u128 = if round_up && reminder_u256.gt(U256Muldiv::new(0, 0)) {
value_u256.try_into_u128()? + 1
} else {
value_u256.try_into_u128()?
};
u64::try_from(value_u128).or(Err(ARITHMETIC_OVERFLOW))
}
fn order_prices(a: u128, b: u128) -> (u128, u128) {
if a < b {
(a, b)
} else {
(b, a)
}
}
#[cfg(all(test, not(feature = "wasm")))]
mod tests {
use super::*;
#[test]
fn test_get_amount_delta_a() {
assert_eq!(try_get_amount_delta_a(4 << 64, 2 << 64, 4, true), Ok(1));
assert_eq!(try_get_amount_delta_a(4 << 64, 2 << 64, 4, false), Ok(1));
assert_eq!(try_get_amount_delta_a(4 << 64, 4 << 64, 4, true), Ok(0));
assert_eq!(try_get_amount_delta_a(4 << 64, 4 << 64, 4, false), Ok(0));
}
#[test]
fn test_get_amount_delta_b() {
assert_eq!(try_get_amount_delta_b(4 << 64, 2 << 64, 4, true), Ok(8));
assert_eq!(try_get_amount_delta_b(4 << 64, 2 << 64, 4, false), Ok(8));
assert_eq!(try_get_amount_delta_b(4 << 64, 4 << 64, 4, true), Ok(0));
assert_eq!(try_get_amount_delta_b(4 << 64, 4 << 64, 4, false), Ok(0));
}
#[test]
fn test_get_next_sqrt_price_from_a() {
assert_eq!(try_get_next_sqrt_price_from_a(4 << 64, 4, 1, true), Ok(2 << 64));
assert_eq!(try_get_next_sqrt_price_from_a(2 << 64, 4, 1, false), Ok(4 << 64));
assert_eq!(try_get_next_sqrt_price_from_a(4 << 64, 4, 0, true), Ok(4 << 64));
assert_eq!(try_get_next_sqrt_price_from_a(4 << 64, 4, 0, false), Ok(4 << 64));
}
#[test]
fn test_get_next_sqrt_price_from_b() {
assert_eq!(try_get_next_sqrt_price_from_b(2 << 64, 4, 8, true), Ok(4 << 64));
assert_eq!(try_get_next_sqrt_price_from_b(4 << 64, 4, 8, false), Ok(2 << 64));
assert_eq!(try_get_next_sqrt_price_from_b(4 << 64, 4, 0, true), Ok(4 << 64));
assert_eq!(try_get_next_sqrt_price_from_b(4 << 64, 4, 0, false), Ok(4 << 64));
}
#[test]
fn test_apply_transfer_fee() {
assert_eq!(try_apply_transfer_fee(0, TransferFee::new(100)), Ok(0));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new(0)), Ok(10000));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new(100)), Ok(9900));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new(1000)), Ok(9000));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new(10000)), Ok(0));
assert_eq!(try_apply_transfer_fee(u64::MAX, TransferFee::new(1000)), Ok(16602069666338596453));
assert_eq!(try_apply_transfer_fee(u64::MAX, TransferFee::new(10000)), Ok(0));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new(10001)), Err(INVALID_TRANSFER_FEE));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new(u16::MAX)), Err(INVALID_TRANSFER_FEE));
}
#[test]
fn test_apply_transfer_fee_with_max() {
assert_eq!(try_apply_transfer_fee(0, TransferFee::new_with_max(100, 500)), Ok(0));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new_with_max(0, 500)), Ok(10000));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new_with_max(100, 500)), Ok(9900));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new_with_max(1000, 500)), Ok(9500));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new_with_max(10000, 500)), Ok(9500));
assert_eq!(try_apply_transfer_fee(u64::MAX, TransferFee::new_with_max(1000, 500)), Ok(18446744073709551115));
assert_eq!(try_apply_transfer_fee(u64::MAX, TransferFee::new_with_max(10000, 500)), Ok(18446744073709551115));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new_with_max(10001, 500)), Err(INVALID_TRANSFER_FEE));
assert_eq!(try_apply_transfer_fee(10000, TransferFee::new_with_max(u16::MAX, 500)), Err(INVALID_TRANSFER_FEE));
}
#[test]
fn test_reverse_apply_transfer_fee() {
assert_eq!(try_reverse_apply_transfer_fee(0, TransferFee::new(100)), Ok(0));
assert_eq!(try_reverse_apply_transfer_fee(10000, TransferFee::new(0)), Ok(10000));
assert_eq!(try_reverse_apply_transfer_fee(9900, TransferFee::new(100)), Ok(10000));
assert_eq!(try_reverse_apply_transfer_fee(9000, TransferFee::new(1000)), Ok(10000));
assert_eq!(try_reverse_apply_transfer_fee(5000, TransferFee::new(10000)), Err(AMOUNT_EXCEEDS_MAX_U64));
assert_eq!(try_reverse_apply_transfer_fee(0, TransferFee::new(10000)), Ok(0));
assert_eq!(try_reverse_apply_transfer_fee(u64::MAX, TransferFee::new(10000)), Err(AMOUNT_EXCEEDS_MAX_U64));
assert_eq!(try_reverse_apply_transfer_fee(10000, TransferFee::new(10001)), Err(INVALID_TRANSFER_FEE));
assert_eq!(try_reverse_apply_transfer_fee(10000, TransferFee::new(u16::MAX)), Err(INVALID_TRANSFER_FEE));
}
#[test]
fn test_reverse_apply_transfer_fee_with_max() {
assert_eq!(try_reverse_apply_transfer_fee(0, TransferFee::new_with_max(100, 500)), Ok(0));
assert_eq!(try_reverse_apply_transfer_fee(10000, TransferFee::new_with_max(0, 500)), Ok(10000));
assert_eq!(try_reverse_apply_transfer_fee(9900, TransferFee::new_with_max(100, 500)), Ok(10000));
assert_eq!(try_reverse_apply_transfer_fee(9500, TransferFee::new_with_max(1000, 500)), Ok(10000));
assert_eq!(try_reverse_apply_transfer_fee(9500, TransferFee::new_with_max(10000, 500)), Ok(10000));
assert_eq!(try_reverse_apply_transfer_fee(0, TransferFee::new_with_max(10000, 500)), Ok(0));
assert_eq!(try_reverse_apply_transfer_fee(u64::MAX - 500, TransferFee::new_with_max(10000, 500)), Ok(u64::MAX));
assert_eq!(try_reverse_apply_transfer_fee(u64::MAX, TransferFee::new_with_max(10000, 500)), Err(AMOUNT_EXCEEDS_MAX_U64));
assert_eq!(try_reverse_apply_transfer_fee(10000, TransferFee::new_with_max(10001, 500)), Err(INVALID_TRANSFER_FEE));
assert_eq!(try_reverse_apply_transfer_fee(10000, TransferFee::new_with_max(u16::MAX, 500)), Err(INVALID_TRANSFER_FEE));
}
#[test]
fn test_get_max_amount_with_slippage_tolerance() {
assert_eq!(try_get_max_amount_with_slippage_tolerance(0, 100), Ok(0));
assert_eq!(try_get_max_amount_with_slippage_tolerance(10000, 0), Ok(10000));
assert_eq!(try_get_max_amount_with_slippage_tolerance(10000, 100), Ok(10100));
assert_eq!(try_get_max_amount_with_slippage_tolerance(10000, 1000), Ok(11000));
assert_eq!(try_get_max_amount_with_slippage_tolerance(10000, 10000), Ok(20000));
assert_eq!(try_get_max_amount_with_slippage_tolerance(u64::MAX, 10000), Err(AMOUNT_EXCEEDS_MAX_U64));
assert_eq!(try_get_max_amount_with_slippage_tolerance(10000, 10001), Err(INVALID_SLIPPAGE_TOLERANCE));
assert_eq!(try_get_max_amount_with_slippage_tolerance(10000, u16::MAX), Err(INVALID_SLIPPAGE_TOLERANCE));
}
#[test]
fn test_get_min_amount_with_slippage_tolerance() {
assert_eq!(try_get_min_amount_with_slippage_tolerance(0, 100), Ok(0));
assert_eq!(try_get_min_amount_with_slippage_tolerance(10000, 0), Ok(10000));
assert_eq!(try_get_min_amount_with_slippage_tolerance(10000, 100), Ok(9900));
assert_eq!(try_get_min_amount_with_slippage_tolerance(10000, 1000), Ok(9000));
assert_eq!(try_get_min_amount_with_slippage_tolerance(10000, 10000), Ok(0));
assert_eq!(try_get_min_amount_with_slippage_tolerance(u64::MAX, 10000), Ok(0));
assert_eq!(try_get_min_amount_with_slippage_tolerance(u64::MAX, 1000), Ok(16602069666338596453));
assert_eq!(try_get_min_amount_with_slippage_tolerance(10000, 10001), Err(INVALID_SLIPPAGE_TOLERANCE));
assert_eq!(try_get_min_amount_with_slippage_tolerance(10000, u16::MAX), Err(INVALID_SLIPPAGE_TOLERANCE));
}
#[test]
fn test_apply_swap_fee() {
assert_eq!(try_apply_swap_fee(0, 1000), Ok(0));
assert_eq!(try_apply_swap_fee(10000, 0), Ok(10000));
assert_eq!(try_apply_swap_fee(10000, 1000), Ok(9990));
assert_eq!(try_apply_swap_fee(10000, 10000), Ok(9900));
assert_eq!(try_apply_swap_fee(10000, u16::MAX), Ok(9344));
assert_eq!(try_apply_swap_fee(u64::MAX, 1000), Ok(18428297329635842063));
assert_eq!(try_apply_swap_fee(u64::MAX, 10000), Ok(18262276632972456098));
}
#[test]
fn test_reverse_apply_swap_fee() {
assert_eq!(try_reverse_apply_swap_fee(0, 1000), Ok(0));
assert_eq!(try_reverse_apply_swap_fee(10000, 0), Ok(10000));
assert_eq!(try_reverse_apply_swap_fee(9990, 1000), Ok(10000));
assert_eq!(try_reverse_apply_swap_fee(9900, 10000), Ok(10000));
assert_eq!(try_reverse_apply_swap_fee(9344, u16::MAX), Ok(10000));
assert_eq!(try_reverse_apply_swap_fee(u64::MAX, 1000), Err(AMOUNT_EXCEEDS_MAX_U64));
assert_eq!(try_reverse_apply_swap_fee(u64::MAX, 10000), Err(AMOUNT_EXCEEDS_MAX_U64));
}
}