use crate::instruction::utils::raydium_amm_v4::accounts::{
SWAP_FEE_DENOMINATOR, SWAP_FEE_NUMERATOR, TRADE_FEE_DENOMINATOR, TRADE_FEE_NUMERATOR,
};
fn compute_trading_fee(amount: u64, fee_rate: u64, fee_denominator: u64) -> u64 {
let numerator = (amount as u128) * (fee_rate as u128);
((numerator + fee_denominator as u128 - 1) / fee_denominator as u128) as u64
}
fn compute_protocol_fund_fee(amount: u64, fee_rate: u64, fee_denominator: u64) -> u64 {
let numerator = (amount as u128) * (fee_rate as u128);
(numerator / fee_denominator as u128) as u64
}
#[derive(Debug, Clone)]
pub struct ComputeSwapParams {
pub all_trade: bool,
pub amount_in: u64,
pub amount_out: u64,
pub min_amount_out: u64,
pub fee: u64,
}
#[derive(Debug, Clone)]
pub struct SwapResult {
pub new_input_vault_amount: u64,
pub new_output_vault_amount: u64,
pub input_amount: u64,
pub output_amount: u64,
pub trade_fee: u64,
pub swap_fee: u64,
}
fn swap_base_input(
input_amount: u64,
input_vault_amount: u64,
output_vault_amount: u64,
trade_fee_rate: u64,
swap_fee_rate: u64,
) -> SwapResult {
let trade_fee = compute_trading_fee(input_amount, trade_fee_rate, TRADE_FEE_DENOMINATOR);
let input_amount_less_fees = input_amount.saturating_sub(trade_fee);
let swap_fee = compute_protocol_fund_fee(trade_fee, swap_fee_rate, SWAP_FEE_DENOMINATOR);
let output_amount_swapped = ((output_vault_amount as u128)
.saturating_mul(input_amount_less_fees as u128)
/ (input_vault_amount as u128).saturating_add(input_amount_less_fees as u128))
as u64;
let output_amount = output_amount_swapped.saturating_sub(swap_fee);
SwapResult {
new_input_vault_amount: input_vault_amount.saturating_add(input_amount_less_fees),
new_output_vault_amount: output_vault_amount.saturating_sub(output_amount_swapped),
input_amount,
output_amount,
trade_fee,
swap_fee,
}
}
pub fn compute_swap_amount(
base_reserve: u64,
quote_reserve: u64,
is_base_in: bool,
amount_in: u64,
slippage_basis_points: u64,
) -> ComputeSwapParams {
let (input_reserve, output_reserve) =
if is_base_in { (base_reserve, quote_reserve) } else { (quote_reserve, base_reserve) };
let swap_result = swap_base_input(
amount_in,
input_reserve,
output_reserve,
TRADE_FEE_NUMERATOR,
SWAP_FEE_NUMERATOR,
);
let min_amount_out = ((swap_result.output_amount as f64)
* (1.0 - (slippage_basis_points as f64) / 10000.0)) as u64;
let all_trade = swap_result.input_amount == amount_in;
ComputeSwapParams {
all_trade,
amount_in,
amount_out: swap_result.output_amount,
min_amount_out,
fee: swap_result.trade_fee,
}
}