1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use crate::error::ContractError;
use cosmwasm_schema::cw_serde;
use cosmwasm_std::{ensure, Decimal, Uint128};
use std::cmp;

#[cw_serde]
pub struct Withdrawal {
    pub token: String,
    pub withdrawal_type: Option<WithdrawalType>,
}

#[cw_serde]
pub enum WithdrawalType {
    Amount(Uint128),
    Percentage(Decimal),
}

impl Withdrawal {
    /// Calculates the amount to withdraw given the withdrawal type and passed in `balance`.
    pub fn get_amount(&self, balance: Uint128) -> Result<Uint128, ContractError> {
        match self.withdrawal_type.clone() {
            None => Ok(balance),
            Some(withdrawal_type) => withdrawal_type.get_amount(balance),
        }
    }
}

impl WithdrawalType {
    /// Calculates the amount to withdraw given the withdrawal type and passed in `balance`.
    pub fn get_amount(&self, balance: Uint128) -> Result<Uint128, ContractError> {
        match self {
            WithdrawalType::Percentage(percent) => {
                ensure!(*percent <= Decimal::one(), ContractError::InvalidRate {});
                Ok(balance * *percent)
            }
            WithdrawalType::Amount(amount) => Ok(cmp::min(*amount, balance)),
        }
    }

    /// Checks if the underlying value is zero or not.
    pub fn is_zero(&self) -> bool {
        match self {
            WithdrawalType::Percentage(percent) => percent.is_zero(),
            WithdrawalType::Amount(amount) => amount.is_zero(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_get_amount_no_withdrawal_type() {
        let withdrawal = Withdrawal {
            token: "token".to_string(),
            withdrawal_type: None,
        };
        let balance = Uint128::from(100u128);
        assert_eq!(balance, withdrawal.get_amount(balance).unwrap());
    }

    #[test]
    fn test_get_amount_percentage() {
        let withdrawal = Withdrawal {
            token: "token".to_string(),
            withdrawal_type: Some(WithdrawalType::Percentage(Decimal::percent(10))),
        };
        let balance = Uint128::from(100u128);
        assert_eq!(10u128, withdrawal.get_amount(balance).unwrap().u128());
    }

    #[test]
    fn test_get_amount_invalid_percentage() {
        let withdrawal = Withdrawal {
            token: "token".to_string(),
            withdrawal_type: Some(WithdrawalType::Percentage(Decimal::percent(101))),
        };
        let balance = Uint128::from(100u128);
        assert_eq!(
            ContractError::InvalidRate {},
            withdrawal.get_amount(balance).unwrap_err()
        );
    }

    #[test]
    fn test_get_amount_amount() {
        let withdrawal = Withdrawal {
            token: "token".to_string(),
            withdrawal_type: Some(WithdrawalType::Amount(5u128.into())),
        };
        let balance = Uint128::from(10u128);
        assert_eq!(5u128, withdrawal.get_amount(balance).unwrap().u128());
    }

    #[test]
    fn test_get_too_high_amount() {
        let balance = Uint128::from(10u128);
        let withdrawal = Withdrawal {
            token: "token".to_string(),
            withdrawal_type: Some(WithdrawalType::Amount(balance + Uint128::from(1u128))),
        };
        assert_eq!(10, withdrawal.get_amount(balance).unwrap().u128());
    }
}