riptide-amm-math 2.0.1

The Riptide program math library
Documentation
//! Token-2022 `TransferFee` extension math.
//!
//! This module mirrors the rounding semantics of
//! `spl_token_2022::extension::transfer_fee::TransferFee::calculate_fee` so that
//! the on-chain swap instruction and any off-chain quoting code can agree on the
//! exact fee that will be deducted by the SPL Token-2022 program at transfer
//! time. The structures here only describe the data; parsing the
//! `TransferFeeConfig` extension out of a mint account lives in
//! `riptide-amm-program::utility::token`.

use super::error::{CoreError, AMOUNT_EXCEEDS_MAX_U64, ARITHMETIC_OVERFLOW, BPS_EXCEEDS_MAX_U16};

const BPS_DENOMINATOR: u16 = 10_000;

/// Per-epoch fee rate parameters.
///
/// Token-2022 stores the older and newer rate side-by-side so that a fee
/// schedule change scheduled in epoch `N` does not surprise transfers
/// already submitted in epoch `N - 1`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TransferFeeRate {
    /// Epoch from which this rate is effective.
    pub epoch: u64,
    /// Maximum fee in raw token units (cap on the fee). Applied after the
    /// basis-points calculation.
    pub maximum_fee: u64,
    /// Fee rate in basis points (10_000 = 100%). Values above 10_000 are
    /// rejected by Token-2022 at config time and rejected here as invalid.
    pub basis_points: u16,
}

/// Two-rate fee schedule.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TransferFeeConfig {
    pub older: TransferFeeRate,
    pub newer: TransferFeeRate,
}

impl TransferFeeConfig {
    /// Pick the rate that will be applied at `current_epoch`.
    ///
    /// Token-2022 activates `newer` once `current_epoch >= newer.epoch`,
    /// otherwise it keeps using `older`.
    pub fn rate_for_epoch(&self, current_epoch: u64) -> &TransferFeeRate {
        if current_epoch >= self.newer.epoch {
            &self.newer
        } else {
            &self.older
        }
    }

    /// Fee that the Token-2022 program will withhold when transferring
    /// `amount` raw units of this mint at `current_epoch`.
    pub fn calculate_fee(&self, amount: u64, current_epoch: u64) -> Result<u64, CoreError> {
        let rate = self.rate_for_epoch(current_epoch);
        calculate_fee_for_rate(amount, rate.basis_points, rate.maximum_fee)
    }

    /// Inverse of [`calculate_fee`]: given the *net* amount the recipient
    /// should observe, return the smallest *gross* amount that, after fee
    /// deduction, yields at least `post_fee_amount`.
    ///
    /// Returns `None` only when no finite gross amount can deliver the
    /// requested net without overflowing `u64`.
    pub fn calculate_pre_fee_amount(
        &self,
        post_fee_amount: u64,
        current_epoch: u64,
    ) -> Result<Option<u64>, CoreError> {
        let rate = self.rate_for_epoch(current_epoch);
        calculate_pre_fee_amount_for_rate(post_fee_amount, rate.basis_points, rate.maximum_fee)
    }
}

/// Fee that the Token-2022 program will withhold for a single rate.
///
/// Behaviour is bit-identical to `TransferFee::calculate_fee` in
/// spl-token-2022:
/// * 0 fee when `amount == 0` or `basis_points == 0`.
/// * `min(ceil(amount * basis_points / 10_000), maximum_fee)`.
///
/// Token-2022 rejects rates above 10_000 at configuration time. This helper
/// returns an error for structurally invalid rates.
pub fn calculate_fee_for_rate(
    amount: u64,
    basis_points: u16,
    maximum_fee: u64,
) -> Result<u64, CoreError> {
    Ok(fee_from_pre_fee_amount(amount, basis_points)?.min(maximum_fee))
}

/// Smallest pre-fee amount whose post-fee value is at least
/// `post_fee_amount`, for a single rate.
///
/// This is needed when the swap caller asks for an *exact-out* result on a
/// fee-bearing mint: we have to gross up the amount we send out of the vault
/// so the user actually receives the requested net.
pub fn calculate_pre_fee_amount_for_rate(
    post_fee_amount: u64,
    basis_points: u16,
    maximum_fee: u64,
) -> Result<Option<u64>, CoreError> {
    if post_fee_amount == 0 {
        return Ok(Some(0));
    }
    let fee_amount = fee_from_post_fee_amount(post_fee_amount, basis_points)?;
    let fee_amount = fee_amount.min(maximum_fee);
    Ok(post_fee_amount.checked_add(fee_amount))
}

// The helpers below are intentionally aligned with Wavebreak's
// `math-lib/src/fee.rs`. TransferFee support layers Token-2022's `maximum_fee`
// cap on top in `calculate_fee_for_rate` and `calculate_pre_fee_amount_for_rate`.

fn fee_from_pre_fee_amount(pre_fee_amount: u64, fee_bps: u16) -> Result<u64, CoreError> {
    if fee_bps > BPS_DENOMINATOR {
        Err(BPS_EXCEEDS_MAX_U16)
    } else if fee_bps == 0 || pre_fee_amount == 0 {
        Ok(0)
    } else {
        let numerator = <u128>::from(pre_fee_amount)
            .checked_mul(fee_bps.into())
            .ok_or(ARITHMETIC_OVERFLOW)?;
        let fee_amount: u64 = numerator
            .div_ceil(BPS_DENOMINATOR.into())
            .try_into()
            .map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
        Ok(fee_amount)
    }
}

fn fee_from_post_fee_amount(post_fee_amount: u64, fee_bps: u16) -> Result<u64, CoreError> {
    if fee_bps > BPS_DENOMINATOR {
        Err(BPS_EXCEEDS_MAX_U16)
    } else if fee_bps == 0 || post_fee_amount == 0 {
        Ok(0)
    } else if fee_bps == BPS_DENOMINATOR {
        Ok(u64::MAX)
    } else {
        let numerator = <u128>::from(post_fee_amount)
            .checked_mul(BPS_DENOMINATOR.into())
            .ok_or(ARITHMETIC_OVERFLOW)?;
        let denominator = <u128>::from(BPS_DENOMINATOR) - <u128>::from(fee_bps);
        let pre_fee_amount = numerator.div_ceil(denominator);
        let fee_amount: u64 = pre_fee_amount
            .checked_sub(post_fee_amount.into())
            .ok_or(ARITHMETIC_OVERFLOW)?
            .try_into()
            .map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
        Ok(fee_amount)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rstest::rstest;

    fn rate(epoch: u64, bp: u16, max: u64) -> TransferFeeRate {
        TransferFeeRate {
            epoch,
            maximum_fee: max,
            basis_points: bp,
        }
    }

    #[rstest]
    // No fee.
    #[case(0, 0, u64::MAX, 0)]
    #[case(1_000, 0, u64::MAX, 0)]
    // Zero amount.
    #[case(0, 100, u64::MAX, 0)]
    // 1% fee, ceiling rounding.
    #[case(100, 100, u64::MAX, 1)]
    #[case(101, 100, u64::MAX, 2)] // ceil(1.01) = 2
    #[case(99, 100, u64::MAX, 1)] // ceil(0.99) = 1
    // Cap binds.
    #[case(1_000_000, 500, 100, 100)] // 5% would be 50_000, capped to 100
    // 100% fee.
    #[case(1_000, 10_000, u64::MAX, 1_000)]
    #[case(1_000, 10_000, 100, 100)]
    #[case(1_000, 10_000, 0, 0)]
    // Cap doesn't bind (rate fee < max).
    #[case(200, 100, 50, 2)]
    fn fee_for_rate(#[case] amount: u64, #[case] bp: u16, #[case] max: u64, #[case] expected: u64) {
        assert_eq!(calculate_fee_for_rate(amount, bp, max).unwrap(), expected);
    }

    #[test]
    fn fee_for_rate_overflow_safe_at_u64_max() {
        // u64::MAX * 10_000 fits in u128, so this must not overflow.
        let fee = calculate_fee_for_rate(u64::MAX, 10_000, u64::MAX).unwrap();
        assert_eq!(fee, u64::MAX);
    }

    #[test]
    fn invalid_basis_points_are_rejected() {
        assert_eq!(
            calculate_fee_for_rate(1_000, 10_001, u64::MAX),
            Err(BPS_EXCEEDS_MAX_U16)
        );
        assert_eq!(
            calculate_pre_fee_amount_for_rate(1_000, 10_001, u64::MAX),
            Err(BPS_EXCEEDS_MAX_U16)
        );
    }

    #[rstest]
    // No fee — net == gross.
    #[case(0, 0, u64::MAX, Some(0))]
    #[case(1_000, 0, u64::MAX, Some(1_000))]
    // Zero post-fee, positive rate — gross is 0.
    #[case(0, 100, u64::MAX, Some(0))]
    // 100% fee — reachable once the cap binds.
    #[case(1, 10_000, 5_000, Some(5_001))]
    #[case(1, 10_000, 0, Some(1))]
    #[case(u64::MAX, 10_000, 1, None)]
    // 1% fee. Net 99 -> gross 100 (fee 1, net 99). Round-trip exact.
    #[case(99, 100, u64::MAX, Some(100))]
    // 1% fee, cap not binding. Net 200 -> ceil(200*10000/9900)=ceil(202.02..)=203;
    // alt cap path = 200 + max. Picks min.
    #[case(200, 100, u64::MAX, Some(203))]
    // Cap-bound case: bp=500 (5%), max_fee=10. Rate gross = ceil(1000*10000/9500)=1053
    // Cap gross = 1000 + 10 = 1010. Min wins.
    #[case(1000, 500, 10, Some(1010))]
    fn pre_fee_for_rate(
        #[case] post: u64,
        #[case] bp: u16,
        #[case] max: u64,
        #[case] expected: Option<u64>,
    ) {
        assert_eq!(
            calculate_pre_fee_amount_for_rate(post, bp, max).unwrap(),
            expected
        );
    }

    #[rstest]
    #[case(99, 100, u64::MAX)]
    #[case(1, 100, u64::MAX)]
    #[case(1_000_000, 250, u64::MAX)]
    #[case(1_000, 500, 10)]
    #[case(1_000, 500, 1_000_000)]
    fn pre_fee_round_trip(#[case] post: u64, #[case] bp: u16, #[case] max: u64) {
        // Round-trip: pre = pre_of(post); fee(pre) yields some f; pre - f >= post
        // (Token-2022's calculate_fee is monotone non-decreasing in amount, so
        // the gross we picked must deliver at least the requested net.)
        let pre = calculate_pre_fee_amount_for_rate(post, bp, max)
            .unwrap()
            .unwrap();
        let fee = calculate_fee_for_rate(pre, bp, max).unwrap();
        let net = pre.saturating_sub(fee);
        assert!(
            net >= post,
            "pre={pre} fee={fee} net={net} should be >= post={post}"
        );
    }

    #[test]
    fn epoch_routes_to_older_or_newer() {
        let cfg = TransferFeeConfig {
            older: rate(0, 100, u64::MAX),  // 1%
            newer: rate(50, 200, u64::MAX), // 2% from epoch 50
        };
        // Before activation epoch -> older.
        assert_eq!(cfg.rate_for_epoch(0).basis_points, 100);
        assert_eq!(cfg.rate_for_epoch(49).basis_points, 100);
        // At and after activation epoch -> newer.
        assert_eq!(cfg.rate_for_epoch(50).basis_points, 200);
        assert_eq!(cfg.rate_for_epoch(u64::MAX).basis_points, 200);
    }

    #[test]
    fn epoch_aware_calculate_fee() {
        let cfg = TransferFeeConfig {
            older: rate(0, 100, u64::MAX),  // 1%
            newer: rate(50, 200, u64::MAX), // 2% from epoch 50
        };
        // 1% of 10_000 = 100
        assert_eq!(cfg.calculate_fee(10_000, 0).unwrap(), 100);
        assert_eq!(cfg.calculate_fee(10_000, 49).unwrap(), 100);
        // 2% of 10_000 = 200
        assert_eq!(cfg.calculate_fee(10_000, 50).unwrap(), 200);
    }
}