#[cfg(feature = "production")]
use std::env;
use {
crate::{
curve::{
base::{CurveType, SwapCurve},
fees::Fees,
},
error::SwapError,
},
solana_program::program_error::ProgramError,
};
pub struct SwapConstraints<'a> {
pub owner_key: &'a str,
pub valid_curve_types: &'a [CurveType],
pub fees: &'a Fees,
}
impl<'a> SwapConstraints<'a> {
pub fn validate_curve(&self, swap_curve: &SwapCurve) -> Result<(), ProgramError> {
if self
.valid_curve_types
.iter()
.any(|x| *x == swap_curve.curve_type)
{
Ok(())
} else {
Err(SwapError::UnsupportedCurveType.into())
}
}
pub fn validate_fees(&self, fees: &Fees) -> Result<(), ProgramError> {
if fees.trade_fee_numerator >= self.fees.trade_fee_numerator
&& fees.trade_fee_denominator == self.fees.trade_fee_denominator
&& fees.owner_trade_fee_numerator >= self.fees.owner_trade_fee_numerator
&& fees.owner_trade_fee_denominator == self.fees.owner_trade_fee_denominator
&& fees.owner_withdraw_fee_numerator >= self.fees.owner_withdraw_fee_numerator
&& fees.owner_withdraw_fee_denominator == self.fees.owner_withdraw_fee_denominator
&& fees.host_fee_numerator == self.fees.host_fee_numerator
&& fees.host_fee_denominator == self.fees.host_fee_denominator
{
Ok(())
} else {
Err(SwapError::InvalidFee.into())
}
}
}
#[cfg(feature = "production")]
const OWNER_KEY: &str = env!("SWAP_PROGRAM_OWNER_FEE_ADDRESS");
#[cfg(feature = "production")]
const FEES: &Fees = &Fees {
trade_fee_numerator: 0,
trade_fee_denominator: 10000,
owner_trade_fee_numerator: 0,
owner_trade_fee_denominator: 10000,
owner_withdraw_fee_numerator: 0,
owner_withdraw_fee_denominator: 10000,
host_fee_numerator: 0,
host_fee_denominator: 10000,
};
#[cfg(feature = "production")]
const VALID_CURVE_TYPES: &[CurveType] = &[CurveType::ConstantPrice, CurveType::ConstantProduct];
pub const SWAP_CONSTRAINTS: Option<SwapConstraints> = {
#[cfg(feature = "production")]
{
Some(SwapConstraints {
owner_key: OWNER_KEY,
valid_curve_types: VALID_CURVE_TYPES,
fees: FEES,
})
}
#[cfg(not(feature = "production"))]
{
None
}
};
#[cfg(test)]
mod tests {
use {
super::*,
crate::curve::{base::CurveType, constant_product::ConstantProductCurve},
std::sync::Arc,
};
#[test]
fn validate_fees() {
let trade_fee_numerator = 1;
let trade_fee_denominator = 4;
let owner_trade_fee_numerator = 2;
let owner_trade_fee_denominator = 5;
let owner_withdraw_fee_numerator = 4;
let owner_withdraw_fee_denominator = 10;
let host_fee_numerator = 10;
let host_fee_denominator = 100;
let owner_key = "";
let curve_type = CurveType::ConstantProduct;
let valid_fees = Fees {
trade_fee_numerator,
trade_fee_denominator,
owner_trade_fee_numerator,
owner_trade_fee_denominator,
owner_withdraw_fee_numerator,
owner_withdraw_fee_denominator,
host_fee_numerator,
host_fee_denominator,
};
let calculator = ConstantProductCurve {};
let swap_curve = SwapCurve {
curve_type,
calculator: Arc::new(calculator.clone()),
};
let constraints = SwapConstraints {
owner_key,
valid_curve_types: &[curve_type],
fees: &valid_fees,
};
constraints.validate_curve(&swap_curve).unwrap();
constraints.validate_fees(&valid_fees).unwrap();
let mut fees = valid_fees.clone();
fees.trade_fee_numerator = trade_fee_numerator - 1;
assert_eq!(
Err(SwapError::InvalidFee.into()),
constraints.validate_fees(&fees),
);
fees.trade_fee_numerator = trade_fee_numerator;
fees.trade_fee_numerator = trade_fee_numerator - 1;
assert_eq!(constraints.validate_fees(&valid_fees), Ok(()));
fees.trade_fee_numerator = trade_fee_numerator;
fees.trade_fee_denominator = trade_fee_denominator - 1;
assert_eq!(
Err(SwapError::InvalidFee.into()),
constraints.validate_fees(&fees),
);
fees.trade_fee_denominator = trade_fee_denominator;
fees.trade_fee_denominator = trade_fee_denominator + 1;
assert_eq!(
Err(SwapError::InvalidFee.into()),
constraints.validate_fees(&fees),
);
fees.trade_fee_denominator = trade_fee_denominator;
fees.owner_trade_fee_numerator = owner_trade_fee_numerator - 1;
assert_eq!(
Err(SwapError::InvalidFee.into()),
constraints.validate_fees(&fees),
);
fees.owner_trade_fee_numerator = owner_trade_fee_numerator;
fees.owner_trade_fee_numerator = owner_trade_fee_numerator - 1;
assert_eq!(constraints.validate_fees(&valid_fees), Ok(()));
fees.owner_trade_fee_numerator = owner_trade_fee_numerator;
fees.owner_trade_fee_denominator = owner_trade_fee_denominator - 1;
assert_eq!(
Err(SwapError::InvalidFee.into()),
constraints.validate_fees(&fees),
);
fees.owner_trade_fee_denominator = owner_trade_fee_denominator;
let swap_curve = SwapCurve {
curve_type: CurveType::ConstantPrice,
calculator: Arc::new(calculator),
};
assert_eq!(
Err(SwapError::UnsupportedCurveType.into()),
constraints.validate_curve(&swap_curve),
);
}
}