use std::fmt::{Display, Formatter};
use cosmwasm_schema::cw_serde;
use cosmwasm_std::{Decimal, Decimal256, StdError, StdResult, Uint128, Uint256};
#[cw_serde]
pub struct Fee {
pub share: Decimal,
}
impl Fee {
pub fn compute(&self, amount: Uint256) -> StdResult<Uint256> {
Ok(Decimal256::from_ratio(amount, Uint256::one())
.checked_mul(self.to_decimal_256())
.map_err(|e| StdError::generic_err(e.to_string()))?
.to_uint_floor())
}
pub fn to_decimal_256(&self) -> Decimal256 {
Decimal256::from(self.share)
}
pub fn is_valid(&self) -> StdResult<()> {
if self.share >= Decimal::percent(100) {
return Err(StdError::generic_err("Invalid fee"));
}
Ok(())
}
}
impl Display for Fee {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}%", self.share * Decimal::percent(100))
}
}
#[cw_serde]
pub struct PoolFee {
pub protocol_fee: Fee,
pub swap_fee: Fee,
pub burn_fee: Fee,
pub extra_fees: Vec<Fee>,
}
impl PoolFee {
pub fn is_valid(&self) -> StdResult<()> {
let mut total_share = Decimal::zero();
let predefined_fees = [&self.protocol_fee, &self.swap_fee, &self.burn_fee];
for fee in predefined_fees.iter().copied() {
fee.is_valid()?; total_share += fee.share;
}
for fee in &self.extra_fees {
fee.is_valid()?; total_share += fee.share;
}
if total_share > Decimal::percent(20) {
return Err(StdError::generic_err("Total fees cannot exceed 20%"));
}
Ok(())
}
pub fn compute_and_apply_fees(&self, amount: Uint256) -> StdResult<Uint128> {
let mut total_fee_amount = Uint256::zero();
let protocol_fee_amount = self.protocol_fee.compute(amount)?;
total_fee_amount = total_fee_amount.checked_add(protocol_fee_amount)?;
let swap_fee_amount = self.swap_fee.compute(amount)?;
total_fee_amount = total_fee_amount.checked_add(swap_fee_amount)?;
let burn_fee_amount = self.burn_fee.compute(amount)?;
total_fee_amount = total_fee_amount.checked_add(burn_fee_amount)?;
for extra_fee in &self.extra_fees {
let extra_fee_amount = extra_fee.compute(amount)?;
total_fee_amount = total_fee_amount.checked_add(extra_fee_amount)?;
}
Uint128::try_from(total_fee_amount)
.map_err(|_| StdError::generic_err("Fee conversion error"))
}
}
#[cfg(test)]
mod tests {
use cosmwasm_std::{Decimal, StdError, Uint128, Uint256};
use test_case::test_case;
use crate::fee::{Fee, PoolFee};
#[test]
fn valid_fee() {
let fee = Fee {
share: Decimal::from_ratio(9u128, 10u128),
};
let res = fee.is_valid();
match res {
Ok(_) => (),
Err(_) => panic!("this fee shouldn't fail"),
}
let fee = Fee {
share: Decimal::from_ratio(Uint128::new(2u128), Uint128::new(100u128)),
};
let res = fee.is_valid();
match res {
Ok(_) => (),
Err(_) => panic!("this fee shouldn't fail"),
}
let fee = Fee {
share: Decimal::zero(),
};
let res = fee.is_valid();
match res {
Ok(_) => (),
Err(_) => panic!("this fee shouldn't fail"),
}
}
#[test]
fn invalid_fee() {
let fee = Fee {
share: Decimal::one(),
};
assert_eq!(fee.is_valid(), Err(StdError::generic_err("Invalid fee")));
let fee = Fee {
share: Decimal::from_ratio(Uint128::new(2u128), Uint128::new(1u128)),
};
assert_eq!(fee.is_valid(), Err(StdError::generic_err("Invalid fee")));
}
#[test_case(
Decimal::permille(1), Decimal::permille(2), Decimal::permille(1), Uint256::from(1000u128), Uint128::from(4u128); "low fee scenario"
)]
#[test_case(
Decimal::percent(1), Decimal::percent(2), Decimal::zero(), Uint256::from(1000u128), Uint128::from(30u128); "higher fee scenario"
)]
fn pool_fee_application(
protocol_fee_share: Decimal,
swap_fee_share: Decimal,
burn_fee_share: Decimal,
amount: Uint256,
expected_fee_deducted: Uint128,
) {
let protocol_fee = Fee {
share: protocol_fee_share,
};
let swap_fee = Fee {
share: swap_fee_share,
};
let burn_fee = Fee {
share: burn_fee_share,
};
let extra_fees = vec![];
let pool_fee = PoolFee {
protocol_fee,
swap_fee,
burn_fee,
extra_fees,
};
let total_fee_deducted = pool_fee.compute_and_apply_fees(amount).unwrap();
assert_eq!(
total_fee_deducted, expected_fee_deducted,
"The total deducted fees did not match the expected value."
);
}
#[test]
fn pool_fee_exceeds_limit() {
let protocol_fee = Fee {
share: Decimal::percent(10),
};
let swap_fee = Fee {
share: Decimal::percent(5),
};
let burn_fee = Fee {
share: Decimal::percent(5),
};
let extra_fees = vec![Fee {
share: Decimal::percent(1),
}];
let pool_fee = PoolFee {
protocol_fee,
swap_fee,
burn_fee,
extra_fees,
};
assert_eq!(
pool_fee.is_valid(),
Err(StdError::generic_err("Total fees cannot exceed 20%"))
);
}
}