#[cfg(feature = "fuzz")]
use arbitrary::Arbitrary;
use {crate::error::SwapError, spl_math::precise_number::PreciseNumber, std::fmt::Debug};
pub const INITIAL_SWAP_POOL_AMOUNT: u128 = 1_000_000_000;
pub const TOKENS_IN_POOL: u128 = 2;
pub fn map_zero_to_none(x: u128) -> Option<u128> {
if x == 0 {
None
} else {
Some(x)
}
}
#[cfg_attr(feature = "fuzz", derive(Arbitrary))]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum TradeDirection {
AtoB,
BtoA,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RoundDirection {
Floor,
Ceiling,
}
impl TradeDirection {
pub fn opposite(&self) -> TradeDirection {
match self {
TradeDirection::AtoB => TradeDirection::BtoA,
TradeDirection::BtoA => TradeDirection::AtoB,
}
}
}
#[derive(Debug, PartialEq)]
pub struct SwapWithoutFeesResult {
pub source_amount_swapped: u128,
pub destination_amount_swapped: u128,
}
#[derive(Debug, PartialEq)]
pub struct TradingTokenResult {
pub token_a_amount: u128,
pub token_b_amount: u128,
}
pub trait DynPack {
fn pack_into_slice(&self, dst: &mut [u8]);
}
pub trait CurveCalculator: Debug + DynPack {
fn swap_without_fees(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
trade_direction: TradeDirection,
) -> Option<SwapWithoutFeesResult>;
fn new_pool_supply(&self) -> u128 {
INITIAL_SWAP_POOL_AMOUNT
}
fn pool_tokens_to_trading_tokens(
&self,
pool_tokens: u128,
pool_token_supply: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
round_direction: RoundDirection,
) -> Option<TradingTokenResult>;
fn deposit_single_token_type(
&self,
source_amount: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
pool_supply: u128,
trade_direction: TradeDirection,
) -> Option<u128>;
fn withdraw_single_token_type_exact_out(
&self,
source_amount: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
pool_supply: u128,
trade_direction: TradeDirection,
round_direction: RoundDirection,
) -> Option<u128>;
fn validate(&self) -> Result<(), SwapError>;
fn validate_supply(&self, token_a_amount: u64, token_b_amount: u64) -> Result<(), SwapError> {
if token_a_amount == 0 {
return Err(SwapError::EmptySupply);
}
if token_b_amount == 0 {
return Err(SwapError::EmptySupply);
}
Ok(())
}
fn allows_deposits(&self) -> bool {
true
}
fn normalized_value(
&self,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
) -> Option<PreciseNumber>;
}
#[cfg(test)]
pub mod test {
use {super::*, proptest::prelude::*, spl_math::uint::U256};
pub const CONVERSION_BASIS_POINTS_GUARANTEE: u128 = 50;
pub fn check_deposit_token_conversion(
curve: &dyn CurveCalculator,
source_token_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
trade_direction: TradeDirection,
pool_supply: u128,
epsilon_in_basis_points: u128,
) {
let amount_to_swap = source_token_amount / 2;
let results = curve
.swap_without_fees(
amount_to_swap,
swap_source_amount,
swap_destination_amount,
trade_direction,
)
.unwrap();
let opposite_direction = trade_direction.opposite();
let (swap_token_a_amount, swap_token_b_amount) = match trade_direction {
TradeDirection::AtoB => (swap_source_amount, swap_destination_amount),
TradeDirection::BtoA => (swap_destination_amount, swap_source_amount),
};
let pool_tokens_from_one_side = curve
.deposit_single_token_type(
source_token_amount,
swap_token_a_amount,
swap_token_b_amount,
pool_supply,
trade_direction,
)
.unwrap();
let (swap_token_a_amount, swap_token_b_amount) = match trade_direction {
TradeDirection::AtoB => (
swap_source_amount + results.source_amount_swapped,
swap_destination_amount - results.destination_amount_swapped,
),
TradeDirection::BtoA => (
swap_destination_amount - results.destination_amount_swapped,
swap_source_amount + results.source_amount_swapped,
),
};
let pool_tokens_from_source = curve
.deposit_single_token_type(
source_token_amount - results.source_amount_swapped,
swap_token_a_amount,
swap_token_b_amount,
pool_supply,
trade_direction,
)
.unwrap();
let pool_tokens_from_destination = curve
.deposit_single_token_type(
results.destination_amount_swapped,
swap_token_a_amount,
swap_token_b_amount,
pool_supply + pool_tokens_from_source,
opposite_direction,
)
.unwrap();
let pool_tokens_total_separate = pool_tokens_from_source + pool_tokens_from_destination;
let epsilon = std::cmp::max(
1,
pool_tokens_total_separate * epsilon_in_basis_points / 10000,
);
let difference = if pool_tokens_from_one_side >= pool_tokens_total_separate {
pool_tokens_from_one_side - pool_tokens_total_separate
} else {
pool_tokens_total_separate - pool_tokens_from_one_side
};
assert!(
difference <= epsilon,
"difference expected to be less than {}, actually {}",
epsilon,
difference
);
}
pub fn check_withdraw_token_conversion(
curve: &dyn CurveCalculator,
pool_token_amount: u128,
pool_token_supply: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
trade_direction: TradeDirection,
epsilon_in_basis_points: u128,
) {
let withdraw_result = curve
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_token_supply,
swap_token_a_amount,
swap_token_b_amount,
RoundDirection::Floor,
)
.unwrap();
let new_swap_token_a_amount = swap_token_a_amount - withdraw_result.token_a_amount;
let new_swap_token_b_amount = swap_token_b_amount - withdraw_result.token_b_amount;
let source_token_amount = match trade_direction {
TradeDirection::AtoB => {
let results = curve
.swap_without_fees(
withdraw_result.token_a_amount,
new_swap_token_a_amount,
new_swap_token_b_amount,
trade_direction,
)
.unwrap();
withdraw_result.token_b_amount + results.destination_amount_swapped
}
TradeDirection::BtoA => {
let results = curve
.swap_without_fees(
withdraw_result.token_b_amount,
new_swap_token_b_amount,
new_swap_token_a_amount,
trade_direction,
)
.unwrap();
withdraw_result.token_a_amount + results.destination_amount_swapped
}
};
let opposite_direction = trade_direction.opposite();
let pool_token_amount_from_single_side_withdraw = curve
.withdraw_single_token_type_exact_out(
source_token_amount,
swap_token_a_amount,
swap_token_b_amount,
pool_token_supply,
opposite_direction,
RoundDirection::Ceiling,
)
.unwrap();
let epsilon = std::cmp::max(1, pool_token_amount * epsilon_in_basis_points / 10000);
let difference = if pool_token_amount >= pool_token_amount_from_single_side_withdraw {
pool_token_amount - pool_token_amount_from_single_side_withdraw
} else {
pool_token_amount_from_single_side_withdraw - pool_token_amount
};
assert!(
difference <= epsilon,
"difference expected to be less than {}, actually {}",
epsilon,
difference
);
}
pub fn check_curve_value_from_swap(
curve: &dyn CurveCalculator,
source_token_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
trade_direction: TradeDirection,
) {
let results = curve
.swap_without_fees(
source_token_amount,
swap_source_amount,
swap_destination_amount,
trade_direction,
)
.unwrap();
let (swap_token_a_amount, swap_token_b_amount) = match trade_direction {
TradeDirection::AtoB => (swap_source_amount, swap_destination_amount),
TradeDirection::BtoA => (swap_destination_amount, swap_source_amount),
};
let previous_value = curve
.normalized_value(swap_token_a_amount, swap_token_b_amount)
.unwrap();
let new_swap_source_amount = swap_source_amount
.checked_add(results.source_amount_swapped)
.unwrap();
let new_swap_destination_amount = swap_destination_amount
.checked_sub(results.destination_amount_swapped)
.unwrap();
let (swap_token_a_amount, swap_token_b_amount) = match trade_direction {
TradeDirection::AtoB => (new_swap_source_amount, new_swap_destination_amount),
TradeDirection::BtoA => (new_swap_destination_amount, new_swap_source_amount),
};
let new_value = curve
.normalized_value(swap_token_a_amount, swap_token_b_amount)
.unwrap();
assert!(new_value.greater_than_or_equal(&previous_value));
let epsilon = 1; let difference = new_value
.checked_sub(&previous_value)
.unwrap()
.to_imprecise()
.unwrap();
assert!(difference <= epsilon);
}
pub fn check_pool_value_from_deposit(
curve: &dyn CurveCalculator,
pool_token_amount: u128,
pool_token_supply: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
) {
let deposit_result = curve
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_token_supply,
swap_token_a_amount,
swap_token_b_amount,
RoundDirection::Ceiling,
)
.unwrap();
let new_swap_token_a_amount = swap_token_a_amount + deposit_result.token_a_amount;
let new_swap_token_b_amount = swap_token_b_amount + deposit_result.token_b_amount;
let new_pool_token_supply = pool_token_supply + pool_token_amount;
let pool_token_supply = U256::from(pool_token_supply);
let new_pool_token_supply = U256::from(new_pool_token_supply);
let swap_token_a_amount = U256::from(swap_token_a_amount);
let new_swap_token_a_amount = U256::from(new_swap_token_a_amount);
let swap_token_b_amount = U256::from(swap_token_b_amount);
let new_swap_token_b_amount = U256::from(new_swap_token_b_amount);
assert!(
new_swap_token_a_amount * pool_token_supply
>= swap_token_a_amount * new_pool_token_supply
);
assert!(
new_swap_token_b_amount * pool_token_supply
>= swap_token_b_amount * new_pool_token_supply
);
}
pub fn check_pool_value_from_withdraw(
curve: &dyn CurveCalculator,
pool_token_amount: u128,
pool_token_supply: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
) {
let withdraw_result = curve
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_token_supply,
swap_token_a_amount,
swap_token_b_amount,
RoundDirection::Floor,
)
.unwrap();
let new_swap_token_a_amount = swap_token_a_amount - withdraw_result.token_a_amount;
let new_swap_token_b_amount = swap_token_b_amount - withdraw_result.token_b_amount;
let new_pool_token_supply = pool_token_supply - pool_token_amount;
let value = curve
.normalized_value(swap_token_a_amount, swap_token_b_amount)
.unwrap();
let new_value = curve
.normalized_value(new_swap_token_a_amount, new_swap_token_b_amount)
.unwrap();
let pool_token_supply = PreciseNumber::new(pool_token_supply).unwrap();
let new_pool_token_supply = PreciseNumber::new(new_pool_token_supply).unwrap();
assert!(new_value
.checked_mul(&pool_token_supply)
.unwrap()
.greater_than_or_equal(&value.checked_mul(&new_pool_token_supply).unwrap()));
}
prop_compose! {
pub fn total_and_intermediate(max_value: u64)(total in 1..max_value)
(intermediate in 1..total, total in Just(total))
-> (u64, u64) {
(total, intermediate)
}
}
}