use super::error::{CoreError, AMOUNT_EXCEEDS_MAX_U64, ARITHMETIC_OVERFLOW, BPS_EXCEEDS_MAX_U16};
const BPS_DENOMINATOR: u16 = 10_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TransferFeeRate {
pub epoch: u64,
pub maximum_fee: u64,
pub basis_points: u16,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TransferFeeConfig {
pub older: TransferFeeRate,
pub newer: TransferFeeRate,
}
impl TransferFeeConfig {
pub fn rate_for_epoch(&self, current_epoch: u64) -> &TransferFeeRate {
if current_epoch >= self.newer.epoch {
&self.newer
} else {
&self.older
}
}
pub fn calculate_fee(&self, amount: u64, current_epoch: u64) -> Result<u64, CoreError> {
let rate = self.rate_for_epoch(current_epoch);
calculate_fee_for_rate(amount, rate.basis_points, rate.maximum_fee)
}
pub fn calculate_pre_fee_amount(
&self,
post_fee_amount: u64,
current_epoch: u64,
) -> Result<Option<u64>, CoreError> {
let rate = self.rate_for_epoch(current_epoch);
calculate_pre_fee_amount_for_rate(post_fee_amount, rate.basis_points, rate.maximum_fee)
}
}
pub fn calculate_fee_for_rate(
amount: u64,
basis_points: u16,
maximum_fee: u64,
) -> Result<u64, CoreError> {
Ok(fee_from_pre_fee_amount(amount, basis_points)?.min(maximum_fee))
}
pub fn calculate_pre_fee_amount_for_rate(
post_fee_amount: u64,
basis_points: u16,
maximum_fee: u64,
) -> Result<Option<u64>, CoreError> {
if post_fee_amount == 0 {
return Ok(Some(0));
}
let fee_amount = fee_from_post_fee_amount(post_fee_amount, basis_points)?;
let fee_amount = fee_amount.min(maximum_fee);
Ok(post_fee_amount.checked_add(fee_amount))
}
fn fee_from_pre_fee_amount(pre_fee_amount: u64, fee_bps: u16) -> Result<u64, CoreError> {
if fee_bps > BPS_DENOMINATOR {
Err(BPS_EXCEEDS_MAX_U16)
} else if fee_bps == 0 || pre_fee_amount == 0 {
Ok(0)
} else {
let numerator = <u128>::from(pre_fee_amount)
.checked_mul(fee_bps.into())
.ok_or(ARITHMETIC_OVERFLOW)?;
let fee_amount: u64 = numerator
.div_ceil(BPS_DENOMINATOR.into())
.try_into()
.map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
Ok(fee_amount)
}
}
fn fee_from_post_fee_amount(post_fee_amount: u64, fee_bps: u16) -> Result<u64, CoreError> {
if fee_bps > BPS_DENOMINATOR {
Err(BPS_EXCEEDS_MAX_U16)
} else if fee_bps == 0 || post_fee_amount == 0 {
Ok(0)
} else if fee_bps == BPS_DENOMINATOR {
Ok(u64::MAX)
} else {
let numerator = <u128>::from(post_fee_amount)
.checked_mul(BPS_DENOMINATOR.into())
.ok_or(ARITHMETIC_OVERFLOW)?;
let denominator = <u128>::from(BPS_DENOMINATOR) - <u128>::from(fee_bps);
let pre_fee_amount = numerator.div_ceil(denominator);
let fee_amount: u64 = pre_fee_amount
.checked_sub(post_fee_amount.into())
.ok_or(ARITHMETIC_OVERFLOW)?
.try_into()
.map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
Ok(fee_amount)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
fn rate(epoch: u64, bp: u16, max: u64) -> TransferFeeRate {
TransferFeeRate {
epoch,
maximum_fee: max,
basis_points: bp,
}
}
#[rstest]
#[case(0, 0, u64::MAX, 0)]
#[case(1_000, 0, u64::MAX, 0)]
#[case(0, 100, u64::MAX, 0)]
#[case(100, 100, u64::MAX, 1)]
#[case(101, 100, u64::MAX, 2)] #[case(99, 100, u64::MAX, 1)] #[case(1_000_000, 500, 100, 100)] #[case(1_000, 10_000, u64::MAX, 1_000)]
#[case(1_000, 10_000, 100, 100)]
#[case(1_000, 10_000, 0, 0)]
#[case(200, 100, 50, 2)]
fn fee_for_rate(#[case] amount: u64, #[case] bp: u16, #[case] max: u64, #[case] expected: u64) {
assert_eq!(calculate_fee_for_rate(amount, bp, max).unwrap(), expected);
}
#[test]
fn fee_for_rate_overflow_safe_at_u64_max() {
let fee = calculate_fee_for_rate(u64::MAX, 10_000, u64::MAX).unwrap();
assert_eq!(fee, u64::MAX);
}
#[test]
fn invalid_basis_points_are_rejected() {
assert_eq!(
calculate_fee_for_rate(1_000, 10_001, u64::MAX),
Err(BPS_EXCEEDS_MAX_U16)
);
assert_eq!(
calculate_pre_fee_amount_for_rate(1_000, 10_001, u64::MAX),
Err(BPS_EXCEEDS_MAX_U16)
);
}
#[rstest]
#[case(0, 0, u64::MAX, Some(0))]
#[case(1_000, 0, u64::MAX, Some(1_000))]
#[case(0, 100, u64::MAX, Some(0))]
#[case(1, 10_000, 5_000, Some(5_001))]
#[case(1, 10_000, 0, Some(1))]
#[case(u64::MAX, 10_000, 1, None)]
#[case(99, 100, u64::MAX, Some(100))]
#[case(200, 100, u64::MAX, Some(203))]
#[case(1000, 500, 10, Some(1010))]
fn pre_fee_for_rate(
#[case] post: u64,
#[case] bp: u16,
#[case] max: u64,
#[case] expected: Option<u64>,
) {
assert_eq!(
calculate_pre_fee_amount_for_rate(post, bp, max).unwrap(),
expected
);
}
#[rstest]
#[case(99, 100, u64::MAX)]
#[case(1, 100, u64::MAX)]
#[case(1_000_000, 250, u64::MAX)]
#[case(1_000, 500, 10)]
#[case(1_000, 500, 1_000_000)]
fn pre_fee_round_trip(#[case] post: u64, #[case] bp: u16, #[case] max: u64) {
let pre = calculate_pre_fee_amount_for_rate(post, bp, max)
.unwrap()
.unwrap();
let fee = calculate_fee_for_rate(pre, bp, max).unwrap();
let net = pre.saturating_sub(fee);
assert!(
net >= post,
"pre={pre} fee={fee} net={net} should be >= post={post}"
);
}
#[test]
fn epoch_routes_to_older_or_newer() {
let cfg = TransferFeeConfig {
older: rate(0, 100, u64::MAX), newer: rate(50, 200, u64::MAX), };
assert_eq!(cfg.rate_for_epoch(0).basis_points, 100);
assert_eq!(cfg.rate_for_epoch(49).basis_points, 100);
assert_eq!(cfg.rate_for_epoch(50).basis_points, 200);
assert_eq!(cfg.rate_for_epoch(u64::MAX).basis_points, 200);
}
#[test]
fn epoch_aware_calculate_fee() {
let cfg = TransferFeeConfig {
older: rate(0, 100, u64::MAX), newer: rate(50, 200, u64::MAX), };
assert_eq!(cfg.calculate_fee(10_000, 0).unwrap(), 100);
assert_eq!(cfg.calculate_fee(10_000, 49).unwrap(), 100);
assert_eq!(cfg.calculate_fee(10_000, 50).unwrap(), 200);
}
}