1use super::error::{CoreError, AMOUNT_EXCEEDS_MAX_U64, ARITHMETIC_OVERFLOW, BPS_EXCEEDS_MAX_U16};
12
13const BPS_DENOMINATOR: u16 = 10_000;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct TransferFeeRate {
22 pub epoch: u64,
24 pub maximum_fee: u64,
27 pub basis_points: u16,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct TransferFeeConfig {
35 pub older: TransferFeeRate,
36 pub newer: TransferFeeRate,
37}
38
39impl TransferFeeConfig {
40 pub fn rate_for_epoch(&self, current_epoch: u64) -> &TransferFeeRate {
45 if current_epoch >= self.newer.epoch {
46 &self.newer
47 } else {
48 &self.older
49 }
50 }
51
52 pub fn calculate_fee(&self, amount: u64, current_epoch: u64) -> Result<u64, CoreError> {
55 let rate = self.rate_for_epoch(current_epoch);
56 calculate_fee_for_rate(amount, rate.basis_points, rate.maximum_fee)
57 }
58
59 pub fn calculate_pre_fee_amount(
66 &self,
67 post_fee_amount: u64,
68 current_epoch: u64,
69 ) -> Result<Option<u64>, CoreError> {
70 let rate = self.rate_for_epoch(current_epoch);
71 calculate_pre_fee_amount_for_rate(post_fee_amount, rate.basis_points, rate.maximum_fee)
72 }
73}
74
75pub fn calculate_fee_for_rate(
85 amount: u64,
86 basis_points: u16,
87 maximum_fee: u64,
88) -> Result<u64, CoreError> {
89 Ok(fee_from_pre_fee_amount(amount, basis_points)?.min(maximum_fee))
90}
91
92pub fn calculate_pre_fee_amount_for_rate(
99 post_fee_amount: u64,
100 basis_points: u16,
101 maximum_fee: u64,
102) -> Result<Option<u64>, CoreError> {
103 if post_fee_amount == 0 {
104 return Ok(Some(0));
105 }
106 let fee_amount = fee_from_post_fee_amount(post_fee_amount, basis_points)?;
107 let fee_amount = fee_amount.min(maximum_fee);
108 Ok(post_fee_amount.checked_add(fee_amount))
109}
110
111fn fee_from_pre_fee_amount(pre_fee_amount: u64, fee_bps: u16) -> Result<u64, CoreError> {
116 if fee_bps > BPS_DENOMINATOR {
117 Err(BPS_EXCEEDS_MAX_U16)
118 } else if fee_bps == 0 || pre_fee_amount == 0 {
119 Ok(0)
120 } else {
121 let numerator = <u128>::from(pre_fee_amount)
122 .checked_mul(fee_bps.into())
123 .ok_or(ARITHMETIC_OVERFLOW)?;
124 let fee_amount: u64 = numerator
125 .div_ceil(BPS_DENOMINATOR.into())
126 .try_into()
127 .map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
128 Ok(fee_amount)
129 }
130}
131
132fn fee_from_post_fee_amount(post_fee_amount: u64, fee_bps: u16) -> Result<u64, CoreError> {
133 if fee_bps > BPS_DENOMINATOR {
134 Err(BPS_EXCEEDS_MAX_U16)
135 } else if fee_bps == 0 || post_fee_amount == 0 {
136 Ok(0)
137 } else if fee_bps == BPS_DENOMINATOR {
138 Ok(u64::MAX)
139 } else {
140 let numerator = <u128>::from(post_fee_amount)
141 .checked_mul(BPS_DENOMINATOR.into())
142 .ok_or(ARITHMETIC_OVERFLOW)?;
143 let denominator = <u128>::from(BPS_DENOMINATOR) - <u128>::from(fee_bps);
144 let pre_fee_amount = numerator.div_ceil(denominator);
145 let fee_amount: u64 = pre_fee_amount
146 .checked_sub(post_fee_amount.into())
147 .ok_or(ARITHMETIC_OVERFLOW)?
148 .try_into()
149 .map_err(|_| AMOUNT_EXCEEDS_MAX_U64)?;
150 Ok(fee_amount)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use rstest::rstest;
158
159 fn rate(epoch: u64, bp: u16, max: u64) -> TransferFeeRate {
160 TransferFeeRate {
161 epoch,
162 maximum_fee: max,
163 basis_points: bp,
164 }
165 }
166
167 #[rstest]
168 #[case(0, 0, u64::MAX, 0)]
170 #[case(1_000, 0, u64::MAX, 0)]
171 #[case(0, 100, u64::MAX, 0)]
173 #[case(100, 100, u64::MAX, 1)]
175 #[case(101, 100, u64::MAX, 2)] #[case(99, 100, u64::MAX, 1)] #[case(1_000_000, 500, 100, 100)] #[case(1_000, 10_000, u64::MAX, 1_000)]
181 #[case(1_000, 10_000, 100, 100)]
182 #[case(1_000, 10_000, 0, 0)]
183 #[case(200, 100, 50, 2)]
185 fn fee_for_rate(#[case] amount: u64, #[case] bp: u16, #[case] max: u64, #[case] expected: u64) {
186 assert_eq!(calculate_fee_for_rate(amount, bp, max).unwrap(), expected);
187 }
188
189 #[test]
190 fn fee_for_rate_overflow_safe_at_u64_max() {
191 let fee = calculate_fee_for_rate(u64::MAX, 10_000, u64::MAX).unwrap();
193 assert_eq!(fee, u64::MAX);
194 }
195
196 #[test]
197 fn invalid_basis_points_are_rejected() {
198 assert_eq!(
199 calculate_fee_for_rate(1_000, 10_001, u64::MAX),
200 Err(BPS_EXCEEDS_MAX_U16)
201 );
202 assert_eq!(
203 calculate_pre_fee_amount_for_rate(1_000, 10_001, u64::MAX),
204 Err(BPS_EXCEEDS_MAX_U16)
205 );
206 }
207
208 #[rstest]
209 #[case(0, 0, u64::MAX, Some(0))]
211 #[case(1_000, 0, u64::MAX, Some(1_000))]
212 #[case(0, 100, u64::MAX, Some(0))]
214 #[case(1, 10_000, 5_000, Some(5_001))]
216 #[case(1, 10_000, 0, Some(1))]
217 #[case(u64::MAX, 10_000, 1, None)]
218 #[case(99, 100, u64::MAX, Some(100))]
220 #[case(200, 100, u64::MAX, Some(203))]
223 #[case(1000, 500, 10, Some(1010))]
226 fn pre_fee_for_rate(
227 #[case] post: u64,
228 #[case] bp: u16,
229 #[case] max: u64,
230 #[case] expected: Option<u64>,
231 ) {
232 assert_eq!(
233 calculate_pre_fee_amount_for_rate(post, bp, max).unwrap(),
234 expected
235 );
236 }
237
238 #[rstest]
239 #[case(99, 100, u64::MAX)]
240 #[case(1, 100, u64::MAX)]
241 #[case(1_000_000, 250, u64::MAX)]
242 #[case(1_000, 500, 10)]
243 #[case(1_000, 500, 1_000_000)]
244 fn pre_fee_round_trip(#[case] post: u64, #[case] bp: u16, #[case] max: u64) {
245 let pre = calculate_pre_fee_amount_for_rate(post, bp, max)
249 .unwrap()
250 .unwrap();
251 let fee = calculate_fee_for_rate(pre, bp, max).unwrap();
252 let net = pre.saturating_sub(fee);
253 assert!(
254 net >= post,
255 "pre={pre} fee={fee} net={net} should be >= post={post}"
256 );
257 }
258
259 #[test]
260 fn epoch_routes_to_older_or_newer() {
261 let cfg = TransferFeeConfig {
262 older: rate(0, 100, u64::MAX), newer: rate(50, 200, u64::MAX), };
265 assert_eq!(cfg.rate_for_epoch(0).basis_points, 100);
267 assert_eq!(cfg.rate_for_epoch(49).basis_points, 100);
268 assert_eq!(cfg.rate_for_epoch(50).basis_points, 200);
270 assert_eq!(cfg.rate_for_epoch(u64::MAX).basis_points, 200);
271 }
272
273 #[test]
274 fn epoch_aware_calculate_fee() {
275 let cfg = TransferFeeConfig {
276 older: rate(0, 100, u64::MAX), newer: rate(50, 200, u64::MAX), };
279 assert_eq!(cfg.calculate_fee(10_000, 0).unwrap(), 100);
281 assert_eq!(cfg.calculate_fee(10_000, 49).unwrap(), 100);
282 assert_eq!(cfg.calculate_fee(10_000, 50).unwrap(), 200);
284 }
285}