andromeda_std/common/
withdraw.rs

1use crate::error::ContractError;
2use cosmwasm_schema::cw_serde;
3use cosmwasm_std::{ensure, Decimal, Uint128};
4use std::cmp;
5
6#[cw_serde]
7pub struct Withdrawal {
8    pub token: String,
9    pub withdrawal_type: Option<WithdrawalType>,
10}
11
12#[cw_serde]
13pub enum WithdrawalType {
14    Amount(Uint128),
15    Percentage(Decimal),
16}
17
18impl Withdrawal {
19    /// Calculates the amount to withdraw given the withdrawal type and passed in `balance`.
20    pub fn get_amount(&self, balance: Uint128) -> Result<Uint128, ContractError> {
21        match self.withdrawal_type.clone() {
22            None => Ok(balance),
23            Some(withdrawal_type) => withdrawal_type.get_amount(balance),
24        }
25    }
26}
27
28impl WithdrawalType {
29    /// Calculates the amount to withdraw given the withdrawal type and passed in `balance`.
30    pub fn get_amount(&self, balance: Uint128) -> Result<Uint128, ContractError> {
31        match self {
32            WithdrawalType::Percentage(percent) => {
33                ensure!(*percent <= Decimal::one(), ContractError::InvalidRate {});
34                Ok(balance * *percent)
35            }
36            WithdrawalType::Amount(amount) => Ok(cmp::min(*amount, balance)),
37        }
38    }
39
40    /// Checks if the underlying value is zero or not.
41    pub fn is_zero(&self) -> bool {
42        match self {
43            WithdrawalType::Percentage(percent) => percent.is_zero(),
44            WithdrawalType::Amount(amount) => amount.is_zero(),
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    #[test]
54    fn test_get_amount_no_withdrawal_type() {
55        let withdrawal = Withdrawal {
56            token: "token".to_string(),
57            withdrawal_type: None,
58        };
59        let balance = Uint128::from(100u128);
60        assert_eq!(balance, withdrawal.get_amount(balance).unwrap());
61    }
62
63    #[test]
64    fn test_get_amount_percentage() {
65        let withdrawal = Withdrawal {
66            token: "token".to_string(),
67            withdrawal_type: Some(WithdrawalType::Percentage(Decimal::percent(10))),
68        };
69        let balance = Uint128::from(100u128);
70        assert_eq!(10u128, withdrawal.get_amount(balance).unwrap().u128());
71    }
72
73    #[test]
74    fn test_get_amount_invalid_percentage() {
75        let withdrawal = Withdrawal {
76            token: "token".to_string(),
77            withdrawal_type: Some(WithdrawalType::Percentage(Decimal::percent(101))),
78        };
79        let balance = Uint128::from(100u128);
80        assert_eq!(
81            ContractError::InvalidRate {},
82            withdrawal.get_amount(balance).unwrap_err()
83        );
84    }
85
86    #[test]
87    fn test_get_amount_amount() {
88        let withdrawal = Withdrawal {
89            token: "token".to_string(),
90            withdrawal_type: Some(WithdrawalType::Amount(5u128.into())),
91        };
92        let balance = Uint128::from(10u128);
93        assert_eq!(5u128, withdrawal.get_amount(balance).unwrap().u128());
94    }
95
96    #[test]
97    fn test_get_too_high_amount() {
98        let balance = Uint128::from(10u128);
99        let withdrawal = Withdrawal {
100            token: "token".to_string(),
101            withdrawal_type: Some(WithdrawalType::Amount(balance + Uint128::from(1u128))),
102        };
103        assert_eq!(10, withdrawal.get_amount(balance).unwrap().u128());
104    }
105}