1use std::fmt::{Display, Formatter};
2
3use cosmwasm_schema::cw_serde;
4use cosmwasm_std::{Decimal, Decimal256, StdError, StdResult, Uint128, Uint256};
5
6#[cw_serde]
7pub struct Fee {
8 pub share: Decimal,
9}
10
11impl Fee {
12 pub fn compute(&self, amount: Uint256) -> StdResult<Uint256> {
14 Ok(Decimal256::from_ratio(amount, Uint256::one())
15 .checked_mul(self.to_decimal_256())
16 .map_err(|e| StdError::generic_err(e.to_string()))?
17 .to_uint_floor())
18 }
19
20 pub fn to_decimal_256(&self) -> Decimal256 {
22 Decimal256::from(self.share)
23 }
24
25 pub fn is_valid(&self) -> StdResult<()> {
27 if self.share >= Decimal::percent(100) {
28 return Err(StdError::generic_err("Invalid fee"));
29 }
30 Ok(())
31 }
32}
33
34impl Display for Fee {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 write!(f, "{}%", self.share * Decimal::percent(100))
37 }
38}
39
40#[cw_serde]
54pub struct PoolFee {
55 pub protocol_fee: Fee,
57
58 pub swap_fee: Fee,
60
61 pub burn_fee: Fee,
64
65 pub extra_fees: Vec<Fee>,
70}
71
72impl PoolFee {
73 pub fn is_valid(&self) -> StdResult<()> {
75 let mut total_share = Decimal::zero();
76
77 let predefined_fees = [&self.protocol_fee, &self.swap_fee, &self.burn_fee];
79
80 for fee in predefined_fees.iter().copied() {
81 fee.is_valid()?; total_share += fee.share;
83 }
84
85 for fee in &self.extra_fees {
87 fee.is_valid()?; total_share += fee.share;
89 }
90
91 if total_share > Decimal::percent(20) {
93 return Err(StdError::generic_err("Total fees cannot exceed 20%"));
94 }
95
96 Ok(())
97 }
98
99 pub fn compute_and_apply_fees(&self, amount: Uint256) -> StdResult<Uint128> {
102 let mut total_fee_amount = Uint256::zero();
103
104 let protocol_fee_amount = self.protocol_fee.compute(amount)?;
106 total_fee_amount = total_fee_amount.checked_add(protocol_fee_amount)?;
107
108 let swap_fee_amount = self.swap_fee.compute(amount)?;
110 total_fee_amount = total_fee_amount.checked_add(swap_fee_amount)?;
111
112 let burn_fee_amount = self.burn_fee.compute(amount)?;
114 total_fee_amount = total_fee_amount.checked_add(burn_fee_amount)?;
115
116 for extra_fee in &self.extra_fees {
118 let extra_fee_amount = extra_fee.compute(amount)?;
119 total_fee_amount = total_fee_amount.checked_add(extra_fee_amount)?;
120 }
121
122 Uint128::try_from(total_fee_amount)
124 .map_err(|_| StdError::generic_err("Fee conversion error"))
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use cosmwasm_std::{Decimal, StdError, Uint128, Uint256};
131 use test_case::test_case;
132
133 use crate::fee::{Fee, PoolFee};
134
135 #[test]
136 fn valid_fee() {
137 let fee = Fee {
138 share: Decimal::from_ratio(9u128, 10u128),
139 };
140 let res = fee.is_valid();
141 match res {
142 Ok(_) => (),
143 Err(_) => panic!("this fee shouldn't fail"),
144 }
145
146 let fee = Fee {
147 share: Decimal::from_ratio(Uint128::new(2u128), Uint128::new(100u128)),
148 };
149 let res = fee.is_valid();
150 match res {
151 Ok(_) => (),
152 Err(_) => panic!("this fee shouldn't fail"),
153 }
154
155 let fee = Fee {
156 share: Decimal::zero(),
157 };
158 let res = fee.is_valid();
159 match res {
160 Ok(_) => (),
161 Err(_) => panic!("this fee shouldn't fail"),
162 }
163 }
164
165 #[test]
166 fn invalid_fee() {
167 let fee = Fee {
168 share: Decimal::one(),
169 };
170 assert_eq!(fee.is_valid(), Err(StdError::generic_err("Invalid fee")));
171
172 let fee = Fee {
173 share: Decimal::from_ratio(Uint128::new(2u128), Uint128::new(1u128)),
174 };
175 assert_eq!(fee.is_valid(), Err(StdError::generic_err("Invalid fee")));
176 }
177
178 #[test_case(
179 Decimal::permille(1), Decimal::permille(2), Decimal::permille(1), Uint256::from(1000u128), Uint128::from(4u128); "low fee scenario"
180 )]
181 #[test_case(
182 Decimal::percent(1), Decimal::percent(2), Decimal::zero(), Uint256::from(1000u128), Uint128::from(30u128); "higher fee scenario"
183 )]
184 fn pool_fee_application(
185 protocol_fee_share: Decimal,
186 swap_fee_share: Decimal,
187 burn_fee_share: Decimal,
188 amount: Uint256,
189 expected_fee_deducted: Uint128,
190 ) {
191 let protocol_fee = Fee {
192 share: protocol_fee_share,
193 };
194 let swap_fee = Fee {
195 share: swap_fee_share,
196 };
197 let burn_fee = Fee {
198 share: burn_fee_share,
199 };
200 let extra_fees = vec![]; let pool_fee = PoolFee {
203 protocol_fee,
204 swap_fee,
205 burn_fee,
206 extra_fees,
207 };
208
209 let total_fee_deducted = pool_fee.compute_and_apply_fees(amount).unwrap();
210 assert_eq!(
211 total_fee_deducted, expected_fee_deducted,
212 "The total deducted fees did not match the expected value."
213 );
214 }
215
216 #[test]
217 fn pool_fee_exceeds_limit() {
218 let protocol_fee = Fee {
219 share: Decimal::percent(10),
220 };
221 let swap_fee = Fee {
222 share: Decimal::percent(5),
223 };
224 let burn_fee = Fee {
225 share: Decimal::percent(5),
226 };
227 let extra_fees = vec![Fee {
228 share: Decimal::percent(1),
229 }]; let pool_fee = PoolFee {
232 protocol_fee,
233 swap_fee,
234 burn_fee,
235 extra_fees,
236 };
237
238 assert_eq!(
239 pool_fee.is_valid(),
240 Err(StdError::generic_err("Total fees cannot exceed 20%"))
241 );
242 }
243}