use num_traits::ToPrimitive;
use stable_swap_client::fees::Fees;
const MAX: u64 = 1 << 32;
const MAX_BIG: u64 = 1 << 48;
const MAX_SMALL: u64 = 1 << 16;
#[inline(always)]
pub fn mul_div(a: u64, b: u64, c: u64) -> Option<u64> {
if a > MAX || b > MAX {
(a as u128)
.checked_mul(b as u128)?
.checked_div(c as u128)?
.to_u64()
} else {
a.checked_mul(b)?.checked_div(c)
}
}
#[inline(always)]
pub fn mul_div_imbalanced(a: u64, b: u64, c: u64) -> Option<u64> {
if a > MAX_BIG || b > MAX_SMALL {
(a as u128)
.checked_mul(b as u128)?
.checked_div(c as u128)?
.to_u64()
} else {
a.checked_mul(b)?.checked_div(c)
}
}
pub trait FeeCalculator {
fn admin_trade_fee(&self, fee_amount: u64) -> Option<u64>;
fn admin_withdraw_fee(&self, fee_amount: u64) -> Option<u64>;
fn trade_fee(&self, trade_amount: u64) -> Option<u64>;
fn withdraw_fee(&self, withdraw_amount: u64) -> Option<u64>;
fn normalized_trade_fee(&self, n_coins: u8, amount: u64) -> Option<u64>;
}
impl FeeCalculator for Fees {
fn admin_trade_fee(&self, fee_amount: u64) -> Option<u64> {
mul_div_imbalanced(
fee_amount,
self.admin_trade_fee_numerator,
self.admin_trade_fee_denominator,
)
}
fn admin_withdraw_fee(&self, fee_amount: u64) -> Option<u64> {
mul_div_imbalanced(
fee_amount,
self.admin_withdraw_fee_numerator,
self.admin_withdraw_fee_denominator,
)
}
fn trade_fee(&self, trade_amount: u64) -> Option<u64> {
mul_div_imbalanced(
trade_amount,
self.trade_fee_numerator,
self.trade_fee_denominator,
)
}
fn withdraw_fee(&self, withdraw_amount: u64) -> Option<u64> {
mul_div_imbalanced(
withdraw_amount,
self.withdraw_fee_numerator,
self.withdraw_fee_denominator,
)
}
fn normalized_trade_fee(&self, n_coins: u8, amount: u64) -> Option<u64> {
let adjusted_trade_fee_numerator = mul_div(
self.trade_fee_numerator,
n_coins.into(),
(n_coins.checked_sub(1)?).checked_mul(4)?.into(),
)?;
mul_div(
amount,
adjusted_trade_fee_numerator,
self.trade_fee_denominator,
)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn fee_results() {
let admin_trade_fee_numerator = 1;
let admin_trade_fee_denominator = 2;
let admin_withdraw_fee_numerator = 3;
let admin_withdraw_fee_denominator = 4;
let trade_fee_numerator = 5;
let trade_fee_denominator = 6;
let withdraw_fee_numerator = 7;
let withdraw_fee_denominator = 8;
let fees = Fees {
admin_trade_fee_numerator,
admin_trade_fee_denominator,
admin_withdraw_fee_numerator,
admin_withdraw_fee_denominator,
trade_fee_numerator,
trade_fee_denominator,
withdraw_fee_numerator,
withdraw_fee_denominator,
};
let trade_amount = 1_000_000_000;
let expected_trade_fee = trade_amount * trade_fee_numerator / trade_fee_denominator;
let trade_fee = fees.trade_fee(trade_amount).unwrap();
assert_eq!(trade_fee, expected_trade_fee);
let expected_admin_trade_fee =
expected_trade_fee * admin_trade_fee_numerator / admin_trade_fee_denominator;
assert_eq!(
fees.admin_trade_fee(trade_fee).unwrap(),
expected_admin_trade_fee
);
let withdraw_amount = 100_000_000_000;
let expected_withdraw_fee =
withdraw_amount * withdraw_fee_numerator / withdraw_fee_denominator;
let withdraw_fee = fees.withdraw_fee(withdraw_amount).unwrap();
assert_eq!(withdraw_fee, expected_withdraw_fee);
let expected_admin_withdraw_fee =
expected_withdraw_fee * admin_withdraw_fee_numerator / admin_withdraw_fee_denominator;
assert_eq!(
fees.admin_withdraw_fee(expected_withdraw_fee).unwrap(),
expected_admin_withdraw_fee
);
let n_coins: u8 = 2;
let adjusted_trade_fee_numerator: u64 =
trade_fee_numerator * (n_coins as u64) / (4 * ((n_coins as u64) - 1));
let expected_normalized_fee =
trade_amount * adjusted_trade_fee_numerator / trade_fee_denominator;
assert_eq!(
fees.normalized_trade_fee(n_coins, trade_amount).unwrap(),
expected_normalized_fee
);
}
}