1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
use cosmwasm_std::{Addr, CosmosMsg, Decimal, StdError, StdResult, Uint128};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use cw_asset::Asset;

/// A wrapper around Decimal to help handle fractional fees.
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq, JsonSchema)]
pub struct Fee {
    /// fraction of asset to take as fee.
    share: Decimal,
}

impl Fee {
    pub fn new(share: Decimal) -> StdResult<Self> {
        if share >= Decimal::percent(100) {
            return Err(StdError::generic_err("fee share must be lesser than 100%"));
        }
        Ok(Fee { share })
    }
    pub fn compute(&self, amount: Uint128) -> Uint128 {
        amount * self.share
    }

    pub fn msg(&self, asset: Asset, recipient: Addr) -> StdResult<CosmosMsg> {
        asset.transfer_msg(recipient)
    }
    pub fn share(&self) -> Decimal {
        self.share
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_fee_manual_construction() {
        let fee = Fee {
            share: Decimal::percent(20u64),
        };
        let deposit = Uint128::from(1000000u64);
        let deposit_fee = fee.compute(deposit);
        assert_eq!(deposit_fee, Uint128::from(200000u64));
    }

    #[test]
    fn test_fee_new() {
        let fee = Fee::new(Decimal::percent(20u64)).unwrap();
        let deposit = Uint128::from(1000000u64);
        let deposit_fee = fee.compute(deposit);
        assert_eq!(deposit_fee, Uint128::from(200000u64));
    }

    #[test]
    fn test_fee_new_gte_100() {
        let fee = Fee::new(Decimal::percent(100u64));
        assert!(fee.is_err());
        let fee = Fee::new(Decimal::percent(101u64));
        assert!(fee.is_err());
    }

    #[test]
    fn test_fee_share() {
        let expected_percent = 20u64;
        let fee = Fee::new(Decimal::percent(expected_percent)).unwrap();
        assert_eq!(fee.share(), Decimal::percent(expected_percent));
    }

    #[test]
    fn test_fee_msg() {
        let fee = Fee::new(Decimal::percent(20u64)).unwrap();
        let asset = Asset::native("uusd", Uint128::from(1000000u64));

        let recipient = Addr::unchecked("recipient");
        let msg = fee.msg(asset.clone(), recipient.clone()).unwrap();
        assert_eq!(msg, asset.transfer_msg(recipient).unwrap(),);
    }
}