Skip to main content

riptide_amm_math/
transfer_fee.rs

1//! Token-2022 `TransferFee` extension math.
2//!
3//! This module mirrors the rounding semantics of
4//! `spl_token_2022::extension::transfer_fee::TransferFee::calculate_fee` so that
5//! the on-chain swap instruction and any off-chain quoting code can agree on the
6//! exact fee that will be deducted by the SPL Token-2022 program at transfer
7//! time. The structures here only describe the data; parsing the
8//! `TransferFeeConfig` extension out of a mint account lives in
9//! `riptide-amm-program::utility::token`.
10
11use super::error::{CoreError, AMOUNT_EXCEEDS_MAX_U64, ARITHMETIC_OVERFLOW, BPS_EXCEEDS_MAX_U16};
12
13const BPS_DENOMINATOR: u16 = 10_000;
14
15/// Per-epoch fee rate parameters.
16///
17/// Token-2022 stores the older and newer rate side-by-side so that a fee
18/// schedule change scheduled in epoch `N` does not surprise transfers
19/// already submitted in epoch `N - 1`.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct TransferFeeRate {
22    /// Epoch from which this rate is effective.
23    pub epoch: u64,
24    /// Maximum fee in raw token units (cap on the fee). Applied after the
25    /// basis-points calculation.
26    pub maximum_fee: u64,
27    /// Fee rate in basis points (10_000 = 100%). Values above 10_000 are
28    /// rejected by Token-2022 at config time and rejected here as invalid.
29    pub basis_points: u16,
30}
31
32/// Two-rate fee schedule.
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct TransferFeeConfig {
35    pub older: TransferFeeRate,
36    pub newer: TransferFeeRate,
37}
38
39impl TransferFeeConfig {
40    /// Pick the rate that will be applied at `current_epoch`.
41    ///
42    /// Token-2022 activates `newer` once `current_epoch >= newer.epoch`,
43    /// otherwise it keeps using `older`.
44    pub fn rate_for_epoch(&self, current_epoch: u64) -> &TransferFeeRate {
45        if current_epoch >= self.newer.epoch {
46            &self.newer
47        } else {
48            &self.older
49        }
50    }
51
52    /// Fee that the Token-2022 program will withhold when transferring
53    /// `amount` raw units of this mint at `current_epoch`.
54    pub fn calculate_fee(&self, amount: u64, current_epoch: u64) -> Result<u64, CoreError> {
55        let rate = self.rate_for_epoch(current_epoch);
56        calculate_fee_for_rate(amount, rate.basis_points, rate.maximum_fee)
57    }
58
59    /// Inverse of [`calculate_fee`]: given the *net* amount the recipient
60    /// should observe, return the smallest *gross* amount that, after fee
61    /// deduction, yields at least `post_fee_amount`.
62    ///
63    /// Returns `None` only when no finite gross amount can deliver the
64    /// requested net without overflowing `u64`.
65    pub fn calculate_pre_fee_amount(
66        &self,
67        post_fee_amount: u64,
68        current_epoch: u64,
69    ) -> Result<Option<u64>, CoreError> {
70        let rate = self.rate_for_epoch(current_epoch);
71        calculate_pre_fee_amount_for_rate(post_fee_amount, rate.basis_points, rate.maximum_fee)
72    }
73}
74
75/// Fee that the Token-2022 program will withhold for a single rate.
76///
77/// Behaviour is bit-identical to `TransferFee::calculate_fee` in
78/// spl-token-2022:
79/// * 0 fee when `amount == 0` or `basis_points == 0`.
80/// * `min(ceil(amount * basis_points / 10_000), maximum_fee)`.
81///
82/// Token-2022 rejects rates above 10_000 at configuration time. This helper
83/// returns an error for structurally invalid rates.
84pub fn calculate_fee_for_rate(
85    amount: u64,
86    basis_points: u16,
87    maximum_fee: u64,
88) -> Result<u64, CoreError> {
89    Ok(fee_from_pre_fee_amount(amount, basis_points)?.min(maximum_fee))
90}
91
92/// Smallest pre-fee amount whose post-fee value is at least
93/// `post_fee_amount`, for a single rate.
94///
95/// This is needed when the swap caller asks for an *exact-out* result on a
96/// fee-bearing mint: we have to gross up the amount we send out of the vault
97/// so the user actually receives the requested net.
98pub fn calculate_pre_fee_amount_for_rate(
99    post_fee_amount: u64,
100    basis_points: u16,
101    maximum_fee: u64,
102) -> Result<Option<u64>, CoreError> {
103    if post_fee_amount == 0 {
104        return Ok(Some(0));
105    }
106    let fee_amount = fee_from_post_fee_amount(post_fee_amount, basis_points)?;
107    let fee_amount = fee_amount.min(maximum_fee);
108    Ok(post_fee_amount.checked_add(fee_amount))
109}
110
111// The helpers below are intentionally aligned with Wavebreak's
112// `math-lib/src/fee.rs`. TransferFee support layers Token-2022's `maximum_fee`
113// cap on top in `calculate_fee_for_rate` and `calculate_pre_fee_amount_for_rate`.
114
115fn fee_from_pre_fee_amount(pre_fee_amount: u64, fee_bps: u16) -> Result<u64, CoreError> {
116    if fee_bps > BPS_DENOMINATOR {
117        Err(BPS_EXCEEDS_MAX_U16)
118    } else if fee_bps == 0 || pre_fee_amount == 0 {
119        Ok(0)
120    } else {
121        let numerator = <u128>::from(pre_fee_amount)
122            .checked_mul(fee_bps.into())
123            .ok_or(ARITHMETIC_OVERFLOW)?;
124        let fee_amount: u64 = numerator
125            .div_ceil(BPS_DENOMINATOR.into())
126            .try_into()
127            .map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
128        Ok(fee_amount)
129    }
130}
131
132fn fee_from_post_fee_amount(post_fee_amount: u64, fee_bps: u16) -> Result<u64, CoreError> {
133    if fee_bps > BPS_DENOMINATOR {
134        Err(BPS_EXCEEDS_MAX_U16)
135    } else if fee_bps == 0 || post_fee_amount == 0 {
136        Ok(0)
137    } else if fee_bps == BPS_DENOMINATOR {
138        Ok(u64::MAX)
139    } else {
140        let numerator = <u128>::from(post_fee_amount)
141            .checked_mul(BPS_DENOMINATOR.into())
142            .ok_or(ARITHMETIC_OVERFLOW)?;
143        let denominator = <u128>::from(BPS_DENOMINATOR) - <u128>::from(fee_bps);
144        let pre_fee_amount = numerator.div_ceil(denominator);
145        let fee_amount: u64 = pre_fee_amount
146            .checked_sub(post_fee_amount.into())
147            .ok_or(ARITHMETIC_OVERFLOW)?
148            .try_into()
149            .map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
150        Ok(fee_amount)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use rstest::rstest;
158
159    fn rate(epoch: u64, bp: u16, max: u64) -> TransferFeeRate {
160        TransferFeeRate {
161            epoch,
162            maximum_fee: max,
163            basis_points: bp,
164        }
165    }
166
167    #[rstest]
168    // No fee.
169    #[case(0, 0, u64::MAX, 0)]
170    #[case(1_000, 0, u64::MAX, 0)]
171    // Zero amount.
172    #[case(0, 100, u64::MAX, 0)]
173    // 1% fee, ceiling rounding.
174    #[case(100, 100, u64::MAX, 1)]
175    #[case(101, 100, u64::MAX, 2)] // ceil(1.01) = 2
176    #[case(99, 100, u64::MAX, 1)] // ceil(0.99) = 1
177    // Cap binds.
178    #[case(1_000_000, 500, 100, 100)] // 5% would be 50_000, capped to 100
179    // 100% fee.
180    #[case(1_000, 10_000, u64::MAX, 1_000)]
181    #[case(1_000, 10_000, 100, 100)]
182    #[case(1_000, 10_000, 0, 0)]
183    // Cap doesn't bind (rate fee < max).
184    #[case(200, 100, 50, 2)]
185    fn fee_for_rate(#[case] amount: u64, #[case] bp: u16, #[case] max: u64, #[case] expected: u64) {
186        assert_eq!(calculate_fee_for_rate(amount, bp, max).unwrap(), expected);
187    }
188
189    #[test]
190    fn fee_for_rate_overflow_safe_at_u64_max() {
191        // u64::MAX * 10_000 fits in u128, so this must not overflow.
192        let fee = calculate_fee_for_rate(u64::MAX, 10_000, u64::MAX).unwrap();
193        assert_eq!(fee, u64::MAX);
194    }
195
196    #[test]
197    fn invalid_basis_points_are_rejected() {
198        assert_eq!(
199            calculate_fee_for_rate(1_000, 10_001, u64::MAX),
200            Err(BPS_EXCEEDS_MAX_U16)
201        );
202        assert_eq!(
203            calculate_pre_fee_amount_for_rate(1_000, 10_001, u64::MAX),
204            Err(BPS_EXCEEDS_MAX_U16)
205        );
206    }
207
208    #[rstest]
209    // No fee — net == gross.
210    #[case(0, 0, u64::MAX, Some(0))]
211    #[case(1_000, 0, u64::MAX, Some(1_000))]
212    // Zero post-fee, positive rate — gross is 0.
213    #[case(0, 100, u64::MAX, Some(0))]
214    // 100% fee — reachable once the cap binds.
215    #[case(1, 10_000, 5_000, Some(5_001))]
216    #[case(1, 10_000, 0, Some(1))]
217    #[case(u64::MAX, 10_000, 1, None)]
218    // 1% fee. Net 99 -> gross 100 (fee 1, net 99). Round-trip exact.
219    #[case(99, 100, u64::MAX, Some(100))]
220    // 1% fee, cap not binding. Net 200 -> ceil(200*10000/9900)=ceil(202.02..)=203;
221    // alt cap path = 200 + max. Picks min.
222    #[case(200, 100, u64::MAX, Some(203))]
223    // Cap-bound case: bp=500 (5%), max_fee=10. Rate gross = ceil(1000*10000/9500)=1053
224    // Cap gross = 1000 + 10 = 1010. Min wins.
225    #[case(1000, 500, 10, Some(1010))]
226    fn pre_fee_for_rate(
227        #[case] post: u64,
228        #[case] bp: u16,
229        #[case] max: u64,
230        #[case] expected: Option<u64>,
231    ) {
232        assert_eq!(
233            calculate_pre_fee_amount_for_rate(post, bp, max).unwrap(),
234            expected
235        );
236    }
237
238    #[rstest]
239    #[case(99, 100, u64::MAX)]
240    #[case(1, 100, u64::MAX)]
241    #[case(1_000_000, 250, u64::MAX)]
242    #[case(1_000, 500, 10)]
243    #[case(1_000, 500, 1_000_000)]
244    fn pre_fee_round_trip(#[case] post: u64, #[case] bp: u16, #[case] max: u64) {
245        // Round-trip: pre = pre_of(post); fee(pre) yields some f; pre - f >= post
246        // (Token-2022's calculate_fee is monotone non-decreasing in amount, so
247        // the gross we picked must deliver at least the requested net.)
248        let pre = calculate_pre_fee_amount_for_rate(post, bp, max)
249            .unwrap()
250            .unwrap();
251        let fee = calculate_fee_for_rate(pre, bp, max).unwrap();
252        let net = pre.saturating_sub(fee);
253        assert!(
254            net >= post,
255            "pre={pre} fee={fee} net={net} should be >= post={post}"
256        );
257    }
258
259    #[test]
260    fn epoch_routes_to_older_or_newer() {
261        let cfg = TransferFeeConfig {
262            older: rate(0, 100, u64::MAX),  // 1%
263            newer: rate(50, 200, u64::MAX), // 2% from epoch 50
264        };
265        // Before activation epoch -> older.
266        assert_eq!(cfg.rate_for_epoch(0).basis_points, 100);
267        assert_eq!(cfg.rate_for_epoch(49).basis_points, 100);
268        // At and after activation epoch -> newer.
269        assert_eq!(cfg.rate_for_epoch(50).basis_points, 200);
270        assert_eq!(cfg.rate_for_epoch(u64::MAX).basis_points, 200);
271    }
272
273    #[test]
274    fn epoch_aware_calculate_fee() {
275        let cfg = TransferFeeConfig {
276            older: rate(0, 100, u64::MAX),  // 1%
277            newer: rate(50, 200, u64::MAX), // 2% from epoch 50
278        };
279        // 1% of 10_000 = 100
280        assert_eq!(cfg.calculate_fee(10_000, 0).unwrap(), 100);
281        assert_eq!(cfg.calculate_fee(10_000, 49).unwrap(), 100);
282        // 2% of 10_000 = 200
283        assert_eq!(cfg.calculate_fee(10_000, 50).unwrap(), 200);
284    }
285}