#[cfg(feature = "fuzz")]
use arbitrary::Arbitrary;
use {
crate::curve::{
calculator::{CurveCalculator, RoundDirection, SwapWithoutFeesResult, TradeDirection},
constant_price::ConstantPriceCurve,
constant_product::ConstantProductCurve,
fees::Fees,
offset::OffsetCurve,
},
arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs},
solana_program::{
program_error::ProgramError,
program_pack::{Pack, Sealed},
},
std::{
convert::{TryFrom, TryInto},
fmt::Debug,
sync::Arc,
},
};
#[cfg_attr(feature = "fuzz", derive(Arbitrary))]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CurveType {
ConstantProduct,
ConstantPrice,
Offset,
}
#[derive(Debug, PartialEq)]
pub struct SwapResult {
pub new_swap_source_amount: u128,
pub new_swap_destination_amount: u128,
pub source_amount_swapped: u128,
pub destination_amount_swapped: u128,
pub trade_fee: u128,
pub owner_fee: u128,
}
#[repr(C)]
#[derive(Debug)]
pub struct SwapCurve {
pub curve_type: CurveType,
pub calculator: Arc<dyn CurveCalculator + Sync + Send>,
}
impl SwapCurve {
pub fn swap(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
trade_direction: TradeDirection,
fees: &Fees,
) -> Option<SwapResult> {
let trade_fee = fees.trading_fee(source_amount)?;
let owner_fee = fees.owner_trading_fee(source_amount)?;
let total_fees = trade_fee.checked_add(owner_fee)?;
let source_amount_less_fees = source_amount.checked_sub(total_fees)?;
let SwapWithoutFeesResult {
source_amount_swapped,
destination_amount_swapped,
} = self.calculator.swap_without_fees(
source_amount_less_fees,
swap_source_amount,
swap_destination_amount,
trade_direction,
)?;
let source_amount_swapped = source_amount_swapped.checked_add(total_fees)?;
Some(SwapResult {
new_swap_source_amount: swap_source_amount.checked_add(source_amount_swapped)?,
new_swap_destination_amount: swap_destination_amount
.checked_sub(destination_amount_swapped)?,
source_amount_swapped,
destination_amount_swapped,
trade_fee,
owner_fee,
})
}
pub fn deposit_single_token_type(
&self,
source_amount: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
pool_supply: u128,
trade_direction: TradeDirection,
fees: &Fees,
) -> Option<u128> {
if source_amount == 0 {
return Some(0);
}
let half_source_amount = std::cmp::max(1, source_amount.checked_div(2)?);
let trade_fee = fees.trading_fee(half_source_amount)?;
let owner_fee = fees.owner_trading_fee(half_source_amount)?;
let total_fees = trade_fee.checked_add(owner_fee)?;
let source_amount = source_amount.checked_sub(total_fees)?;
self.calculator.deposit_single_token_type(
source_amount,
swap_token_a_amount,
swap_token_b_amount,
pool_supply,
trade_direction,
)
}
pub fn withdraw_single_token_type_exact_out(
&self,
source_amount: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
pool_supply: u128,
trade_direction: TradeDirection,
fees: &Fees,
) -> Option<u128> {
if source_amount == 0 {
return Some(0);
}
let half_source_amount = source_amount.checked_add(1)?.checked_div(2)?; let pre_fee_source_amount = fees.pre_trading_fee_amount(half_source_amount)?;
let source_amount = source_amount
.checked_sub(half_source_amount)?
.checked_add(pre_fee_source_amount)?;
self.calculator.withdraw_single_token_type_exact_out(
source_amount,
swap_token_a_amount,
swap_token_b_amount,
pool_supply,
trade_direction,
RoundDirection::Ceiling,
)
}
}
impl Default for SwapCurve {
fn default() -> Self {
let curve_type: CurveType = Default::default();
let calculator: ConstantProductCurve = Default::default();
Self {
curve_type,
calculator: Arc::new(calculator),
}
}
}
#[cfg(any(test, feature = "fuzz"))]
impl Clone for SwapCurve {
fn clone(&self) -> Self {
let mut packed_self = [0u8; Self::LEN];
Self::pack_into_slice(self, &mut packed_self);
Self::unpack_from_slice(&packed_self).unwrap()
}
}
impl PartialEq for SwapCurve {
fn eq(&self, other: &Self) -> bool {
let mut packed_self = [0u8; Self::LEN];
Self::pack_into_slice(self, &mut packed_self);
let mut packed_other = [0u8; Self::LEN];
Self::pack_into_slice(other, &mut packed_other);
packed_self[..] == packed_other[..]
}
}
impl Sealed for SwapCurve {}
impl Pack for SwapCurve {
const LEN: usize = 33;
fn unpack_from_slice(input: &[u8]) -> Result<Self, ProgramError> {
let input = array_ref![input, 0, 33];
#[allow(clippy::ptr_offset_with_cast)]
let (curve_type, calculator) = array_refs![input, 1, 32];
let curve_type = curve_type[0].try_into()?;
Ok(Self {
curve_type,
calculator: match curve_type {
CurveType::ConstantProduct => {
Arc::new(ConstantProductCurve::unpack_from_slice(calculator)?)
}
CurveType::ConstantPrice => {
Arc::new(ConstantPriceCurve::unpack_from_slice(calculator)?)
}
CurveType::Offset => Arc::new(OffsetCurve::unpack_from_slice(calculator)?),
},
})
}
fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 33];
let (curve_type, calculator) = mut_array_refs![output, 1, 32];
curve_type[0] = self.curve_type as u8;
self.calculator.pack_into_slice(&mut calculator[..]);
}
}
impl Default for CurveType {
fn default() -> Self {
CurveType::ConstantProduct
}
}
impl TryFrom<u8> for CurveType {
type Error = ProgramError;
fn try_from(curve_type: u8) -> Result<Self, Self::Error> {
match curve_type {
0 => Ok(CurveType::ConstantProduct),
1 => Ok(CurveType::ConstantPrice),
2 => Ok(CurveType::Offset),
_ => Err(ProgramError::InvalidAccountData),
}
}
}
#[cfg(test)]
mod test {
use {super::*, crate::curve::calculator::test::total_and_intermediate, proptest::prelude::*};
#[test]
fn pack_swap_curve() {
let curve = ConstantProductCurve {};
let curve_type = CurveType::ConstantProduct;
let swap_curve = SwapCurve {
curve_type,
calculator: Arc::new(curve),
};
let mut packed = [0u8; SwapCurve::LEN];
Pack::pack_into_slice(&swap_curve, &mut packed[..]);
let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap();
assert_eq!(swap_curve, unpacked);
let mut packed = vec![curve_type as u8];
packed.extend_from_slice(&[0u8; 32]); let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap();
assert_eq!(swap_curve, unpacked);
}
#[test]
fn constant_product_trade_fee() {
let swap_source_amount = 1000;
let swap_destination_amount = 50000;
let trade_fee_numerator = 1;
let trade_fee_denominator = 100;
let owner_trade_fee_numerator = 0;
let owner_trade_fee_denominator = 0;
let owner_withdraw_fee_numerator = 0;
let owner_withdraw_fee_denominator = 0;
let host_fee_numerator = 0;
let host_fee_denominator = 0;
let fees = Fees {
trade_fee_numerator,
trade_fee_denominator,
owner_trade_fee_numerator,
owner_trade_fee_denominator,
owner_withdraw_fee_numerator,
owner_withdraw_fee_denominator,
host_fee_numerator,
host_fee_denominator,
};
let source_amount = 100;
let curve = ConstantProductCurve {};
let swap_curve = SwapCurve {
curve_type: CurveType::ConstantProduct,
calculator: Arc::new(curve),
};
let result = swap_curve
.swap(
source_amount,
swap_source_amount,
swap_destination_amount,
TradeDirection::AtoB,
&fees,
)
.unwrap();
assert_eq!(result.new_swap_source_amount, 1100);
assert_eq!(result.destination_amount_swapped, 4504);
assert_eq!(result.new_swap_destination_amount, 45496);
assert_eq!(result.trade_fee, 1);
assert_eq!(result.owner_fee, 0);
}
#[test]
fn constant_product_owner_fee() {
let swap_source_amount = 1000;
let swap_destination_amount = 50000;
let trade_fee_numerator = 0;
let trade_fee_denominator = 0;
let owner_trade_fee_numerator = 1;
let owner_trade_fee_denominator = 100;
let owner_withdraw_fee_numerator = 0;
let owner_withdraw_fee_denominator = 0;
let host_fee_numerator = 0;
let host_fee_denominator = 0;
let fees = Fees {
trade_fee_numerator,
trade_fee_denominator,
owner_trade_fee_numerator,
owner_trade_fee_denominator,
owner_withdraw_fee_numerator,
owner_withdraw_fee_denominator,
host_fee_numerator,
host_fee_denominator,
};
let source_amount: u128 = 100;
let curve = ConstantProductCurve {};
let swap_curve = SwapCurve {
curve_type: CurveType::ConstantProduct,
calculator: Arc::new(curve),
};
let result = swap_curve
.swap(
source_amount,
swap_source_amount,
swap_destination_amount,
TradeDirection::AtoB,
&fees,
)
.unwrap();
assert_eq!(result.new_swap_source_amount, 1100);
assert_eq!(result.destination_amount_swapped, 4504);
assert_eq!(result.new_swap_destination_amount, 45496);
assert_eq!(result.trade_fee, 0);
assert_eq!(result.owner_fee, 1);
}
#[test]
fn constant_product_no_fee() {
let swap_source_amount: u128 = 1_000;
let swap_destination_amount: u128 = 50_000;
let source_amount: u128 = 100;
let curve = ConstantProductCurve;
let fees = Fees::default();
let swap_curve = SwapCurve {
curve_type: CurveType::ConstantProduct,
calculator: Arc::new(curve),
};
let result = swap_curve
.swap(
source_amount,
swap_source_amount,
swap_destination_amount,
TradeDirection::AtoB,
&fees,
)
.unwrap();
assert_eq!(result.new_swap_source_amount, 1100);
assert_eq!(result.destination_amount_swapped, 4545);
assert_eq!(result.new_swap_destination_amount, 45455);
}
fn one_sided_deposit_vs_swap(
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
pool_supply: u128,
fees: Fees,
) -> (u128, u128) {
let curve = ConstantProductCurve;
let swap_curve = SwapCurve {
curve_type: CurveType::ConstantProduct,
calculator: Arc::new(curve),
};
let results = swap_curve
.swap(
source_amount,
swap_source_amount,
swap_destination_amount,
TradeDirection::AtoB,
&fees,
)
.unwrap();
let deposit_pool_tokens = swap_curve
.deposit_single_token_type(
results.source_amount_swapped,
swap_source_amount,
swap_destination_amount,
pool_supply,
TradeDirection::AtoB,
&fees,
)
.unwrap();
let withdraw_pool_tokens = swap_curve
.withdraw_single_token_type_exact_out(
results.destination_amount_swapped,
swap_source_amount + results.source_amount_swapped,
swap_destination_amount,
pool_supply + deposit_pool_tokens,
TradeDirection::BtoA,
&fees,
)
.unwrap();
(withdraw_pool_tokens, deposit_pool_tokens)
}
#[test]
fn one_sided_equals_swap_with_fee_specific() {
let pool_supply: u128 = 1_000_000;
let swap_source_amount: u128 = 1_000_000;
let swap_destination_amount: u128 = 50_000_000;
let source_amount: u128 = 10_000;
let fees = Fees {
trade_fee_numerator: 25,
trade_fee_denominator: 1_000,
owner_trade_fee_numerator: 5,
owner_trade_fee_denominator: 1_000,
..Fees::default()
};
let (withdraw_pool_tokens, deposit_pool_tokens) = one_sided_deposit_vs_swap(
source_amount,
swap_source_amount,
swap_destination_amount,
pool_supply,
fees,
);
assert!(withdraw_pool_tokens >= deposit_pool_tokens);
let epsilon = 2;
assert!(withdraw_pool_tokens - deposit_pool_tokens <= epsilon);
assert_eq!(withdraw_pool_tokens, 4914);
assert_eq!(deposit_pool_tokens, 4912);
}
proptest! {
#[test]
fn one_sided_equals_swap_with_fee(
(swap_source_amount, source_amount) in total_and_intermediate(u64::MAX),
swap_destination_amount in 1..u64::MAX,
pool_supply in 1..u64::MAX,
) {
let fees = Fees {
trade_fee_numerator: 25,
trade_fee_denominator: 1_000,
owner_trade_fee_numerator: 5,
owner_trade_fee_denominator: 1_000,
..Fees::default()
};
let (withdraw_pool_tokens, deposit_pool_tokens) = one_sided_deposit_vs_swap(
pool_supply.into(),
swap_source_amount.into(),
swap_destination_amount.into(),
source_amount.into(),
fees
);
assert!(withdraw_pool_tokens >= deposit_pool_tokens);
}
#[test]
fn one_sided_equals_swap_with_withdrawal_fee(
(swap_source_amount, source_amount) in total_and_intermediate(u64::MAX),
swap_destination_amount in 1..u64::MAX,
pool_supply in 1..u64::MAX,
) {
let fees = Fees {
trade_fee_numerator: 25,
trade_fee_denominator: 1_000,
owner_trade_fee_numerator: 5,
owner_trade_fee_denominator: 1_000,
owner_withdraw_fee_numerator: 1,
owner_withdraw_fee_denominator: 1_000,
..Fees::default()
};
let (withdraw_pool_tokens, deposit_pool_tokens) = one_sided_deposit_vs_swap(
pool_supply.into(),
swap_source_amount.into(),
swap_destination_amount.into(),
source_amount.into(),
fees
);
assert!(withdraw_pool_tokens >= deposit_pool_tokens);
}
#[test]
fn one_sided_equals_swap_without_fee(
(swap_source_amount, source_amount) in total_and_intermediate(u64::MAX),
swap_destination_amount in 1..u64::MAX,
pool_supply in 1..u64::MAX,
) {
let fees = Fees::default();
let (withdraw_pool_tokens, deposit_pool_tokens) = one_sided_deposit_vs_swap(
pool_supply.into(),
swap_source_amount.into(),
swap_destination_amount.into(),
source_amount.into(),
fees
);
let difference = if withdraw_pool_tokens >= deposit_pool_tokens {
withdraw_pool_tokens - deposit_pool_tokens
} else {
deposit_pool_tokens - withdraw_pool_tokens
};
let epsilon = std::cmp::max(1, withdraw_pool_tokens / 1_000_000);
assert!(
difference <= epsilon,
"difference between {} and {} expected to be less than {}, actually {}",
withdraw_pool_tokens,
deposit_pool_tokens,
epsilon,
difference
);
}
}
}