use gemachain_program::{
program_error::ProgramError,
program_pack::{Pack, Sealed},
};
use crate::curve::{
calculator::{CurveCalculator, SwapWithoutFeesResult, TradeDirection},
constant_price::ConstantPriceCurve,
constant_product::ConstantProductCurve,
fees::Fees,
offset::OffsetCurve,
stable::StableCurve,
};
use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs};
use std::convert::{TryFrom, TryInto};
use std::fmt::Debug;
#[cfg(feature = "fuzz")]
use arbitrary::Arbitrary;
#[cfg_attr(feature = "fuzz", derive(Arbitrary))]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CurveType {
ConstantProduct,
ConstantPrice,
Stable,
Offset,
}
#[derive(Debug, PartialEq)]
pub struct SwapResult {
pub new_swap_source_amount: u128,
pub new_swap_destination_amount: u128,
pub source_amount_swapped: u128,
pub destination_amount_swapped: u128,
pub trade_fee: u128,
pub owner_fee: u128,
}
#[repr(C)]
#[derive(Debug)]
pub struct SwapCurve {
pub curve_type: CurveType,
pub calculator: Box<dyn CurveCalculator>,
}
impl SwapCurve {
pub fn swap(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
trade_direction: TradeDirection,
fees: &Fees,
) -> Option<SwapResult> {
let trade_fee = fees.trading_fee(source_amount)?;
let owner_fee = fees.owner_trading_fee(source_amount)?;
let total_fees = trade_fee.checked_add(owner_fee)?;
let source_amount_less_fees = source_amount.checked_sub(total_fees)?;
let SwapWithoutFeesResult {
source_amount_swapped,
destination_amount_swapped,
} = self.calculator.swap_without_fees(
source_amount_less_fees,
swap_source_amount,
swap_destination_amount,
trade_direction,
)?;
let source_amount_swapped = source_amount_swapped.checked_add(total_fees)?;
Some(SwapResult {
new_swap_source_amount: swap_source_amount.checked_add(source_amount_swapped)?,
new_swap_destination_amount: swap_destination_amount
.checked_sub(destination_amount_swapped)?,
source_amount_swapped,
destination_amount_swapped,
trade_fee,
owner_fee,
})
}
pub fn deposit_single_token_type(
&self,
source_amount: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
pool_supply: u128,
trade_direction: TradeDirection,
fees: &Fees,
) -> Option<u128> {
if source_amount == 0 {
return Some(0);
}
let half_source_amount = std::cmp::max(1, source_amount.checked_div(2)?);
let trade_fee = fees.trading_fee(half_source_amount)?;
let source_amount = source_amount.checked_sub(trade_fee)?;
self.calculator.deposit_single_token_type(
source_amount,
swap_token_a_amount,
swap_token_b_amount,
pool_supply,
trade_direction,
)
}
pub fn withdraw_single_token_type_exact_out(
&self,
source_amount: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
pool_supply: u128,
trade_direction: TradeDirection,
fees: &Fees,
) -> Option<u128> {
if source_amount == 0 {
return Some(0);
}
let half_source_amount = std::cmp::max(1, source_amount.checked_div(2)?);
let trade_fee = fees.trading_fee(half_source_amount)?;
let source_amount = source_amount.checked_sub(trade_fee)?;
self.calculator.withdraw_single_token_type_exact_out(
source_amount,
swap_token_a_amount,
swap_token_b_amount,
pool_supply,
trade_direction,
)
}
}
impl Default for SwapCurve {
fn default() -> Self {
let curve_type: CurveType = Default::default();
let calculator: ConstantProductCurve = Default::default();
Self {
curve_type,
calculator: Box::new(calculator),
}
}
}
#[cfg(any(test, feature = "fuzz"))]
impl Clone for SwapCurve {
fn clone(&self) -> Self {
let mut packed_self = [0u8; Self::LEN];
Self::pack_into_slice(self, &mut packed_self);
Self::unpack_from_slice(&packed_self).unwrap()
}
}
impl PartialEq for SwapCurve {
fn eq(&self, other: &Self) -> bool {
let mut packed_self = [0u8; Self::LEN];
Self::pack_into_slice(self, &mut packed_self);
let mut packed_other = [0u8; Self::LEN];
Self::pack_into_slice(other, &mut packed_other);
packed_self[..] == packed_other[..]
}
}
impl Sealed for SwapCurve {}
impl Pack for SwapCurve {
const LEN: usize = 33;
fn unpack_from_slice(input: &[u8]) -> Result<Self, ProgramError> {
let input = array_ref![input, 0, 33];
#[allow(clippy::ptr_offset_with_cast)]
let (curve_type, calculator) = array_refs![input, 1, 32];
let curve_type = curve_type[0].try_into()?;
Ok(Self {
curve_type,
calculator: match curve_type {
CurveType::ConstantProduct => {
Box::new(ConstantProductCurve::unpack_from_slice(calculator)?)
}
CurveType::ConstantPrice => {
Box::new(ConstantPriceCurve::unpack_from_slice(calculator)?)
}
CurveType::Stable => Box::new(StableCurve::unpack_from_slice(calculator)?),
CurveType::Offset => Box::new(OffsetCurve::unpack_from_slice(calculator)?),
},
})
}
fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 33];
let (curve_type, calculator) = mut_array_refs![output, 1, 32];
curve_type[0] = self.curve_type as u8;
self.calculator.pack_into_slice(&mut calculator[..]);
}
}
impl Default for CurveType {
fn default() -> Self {
CurveType::ConstantProduct
}
}
impl TryFrom<u8> for CurveType {
type Error = ProgramError;
fn try_from(curve_type: u8) -> Result<Self, Self::Error> {
match curve_type {
0 => Ok(CurveType::ConstantProduct),
1 => Ok(CurveType::ConstantPrice),
2 => Ok(CurveType::Stable),
3 => Ok(CurveType::Offset),
_ => Err(ProgramError::InvalidAccountData),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_swap_curve() {
let curve = ConstantProductCurve {};
let curve_type = CurveType::ConstantProduct;
let swap_curve = SwapCurve {
curve_type,
calculator: Box::new(curve),
};
let mut packed = [0u8; SwapCurve::LEN];
Pack::pack_into_slice(&swap_curve, &mut packed[..]);
let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap();
assert_eq!(swap_curve, unpacked);
let mut packed = vec![curve_type as u8];
packed.extend_from_slice(&[0u8; 32]); let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap();
assert_eq!(swap_curve, unpacked);
}
#[test]
fn constant_product_trade_fee() {
let swap_source_amount = 1000;
let swap_destination_amount = 50000;
let trade_fee_numerator = 1;
let trade_fee_denominator = 100;
let owner_trade_fee_numerator = 0;
let owner_trade_fee_denominator = 0;
let owner_withdraw_fee_numerator = 0;
let owner_withdraw_fee_denominator = 0;
let host_fee_numerator = 0;
let host_fee_denominator = 0;
let fees = Fees {
trade_fee_numerator,
trade_fee_denominator,
owner_trade_fee_numerator,
owner_trade_fee_denominator,
owner_withdraw_fee_numerator,
owner_withdraw_fee_denominator,
host_fee_numerator,
host_fee_denominator,
};
let source_amount = 100;
let curve = ConstantProductCurve {};
let swap_curve = SwapCurve {
curve_type: CurveType::ConstantProduct,
calculator: Box::new(curve),
};
let result = swap_curve
.swap(
source_amount,
swap_source_amount,
swap_destination_amount,
TradeDirection::AtoB,
&fees,
)
.unwrap();
assert_eq!(result.new_swap_source_amount, 1100);
assert_eq!(result.destination_amount_swapped, 4504);
assert_eq!(result.new_swap_destination_amount, 45496);
assert_eq!(result.trade_fee, 1);
assert_eq!(result.owner_fee, 0);
}
#[test]
fn constant_product_owner_fee() {
let swap_source_amount = 1000;
let swap_destination_amount = 50000;
let trade_fee_numerator = 0;
let trade_fee_denominator = 0;
let owner_trade_fee_numerator = 1;
let owner_trade_fee_denominator = 100;
let owner_withdraw_fee_numerator = 0;
let owner_withdraw_fee_denominator = 0;
let host_fee_numerator = 0;
let host_fee_denominator = 0;
let fees = Fees {
trade_fee_numerator,
trade_fee_denominator,
owner_trade_fee_numerator,
owner_trade_fee_denominator,
owner_withdraw_fee_numerator,
owner_withdraw_fee_denominator,
host_fee_numerator,
host_fee_denominator,
};
let source_amount: u128 = 100;
let curve = ConstantProductCurve {};
let swap_curve = SwapCurve {
curve_type: CurveType::ConstantProduct,
calculator: Box::new(curve),
};
let result = swap_curve
.swap(
source_amount,
swap_source_amount,
swap_destination_amount,
TradeDirection::AtoB,
&fees,
)
.unwrap();
assert_eq!(result.new_swap_source_amount, 1100);
assert_eq!(result.destination_amount_swapped, 4504);
assert_eq!(result.new_swap_destination_amount, 45496);
assert_eq!(result.trade_fee, 0);
assert_eq!(result.owner_fee, 1);
}
#[test]
fn constant_product_no_fee() {
let swap_source_amount: u128 = 1_000;
let swap_destination_amount: u128 = 50_000;
let source_amount: u128 = 100;
let curve = ConstantProductCurve::default();
let fees = Fees::default();
let swap_curve = SwapCurve {
curve_type: CurveType::ConstantProduct,
calculator: Box::new(curve),
};
let result = swap_curve
.swap(
source_amount,
swap_source_amount,
swap_destination_amount,
TradeDirection::AtoB,
&fees,
)
.unwrap();
assert_eq!(result.new_swap_source_amount, 1100);
assert_eq!(result.destination_amount_swapped, 4545);
assert_eq!(result.new_swap_destination_amount, 45455);
}
}