andromeda_std/common/
withdraw.rs1use crate::error::ContractError;
2use cosmwasm_schema::cw_serde;
3use cosmwasm_std::{ensure, Decimal, Uint128};
4use std::cmp;
5
6#[cw_serde]
7pub struct Withdrawal {
8 pub token: String,
9 pub withdrawal_type: Option<WithdrawalType>,
10}
11
12#[cw_serde]
13pub enum WithdrawalType {
14 Amount(Uint128),
15 Percentage(Decimal),
16}
17
18impl Withdrawal {
19 pub fn get_amount(&self, balance: Uint128) -> Result<Uint128, ContractError> {
21 match self.withdrawal_type.clone() {
22 None => Ok(balance),
23 Some(withdrawal_type) => withdrawal_type.get_amount(balance),
24 }
25 }
26}
27
28impl WithdrawalType {
29 pub fn get_amount(&self, balance: Uint128) -> Result<Uint128, ContractError> {
31 match self {
32 WithdrawalType::Percentage(percent) => {
33 ensure!(*percent <= Decimal::one(), ContractError::InvalidRate {});
34 Ok(balance * *percent)
35 }
36 WithdrawalType::Amount(amount) => Ok(cmp::min(*amount, balance)),
37 }
38 }
39
40 pub fn is_zero(&self) -> bool {
42 match self {
43 WithdrawalType::Percentage(percent) => percent.is_zero(),
44 WithdrawalType::Amount(amount) => amount.is_zero(),
45 }
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52
53 #[test]
54 fn test_get_amount_no_withdrawal_type() {
55 let withdrawal = Withdrawal {
56 token: "token".to_string(),
57 withdrawal_type: None,
58 };
59 let balance = Uint128::from(100u128);
60 assert_eq!(balance, withdrawal.get_amount(balance).unwrap());
61 }
62
63 #[test]
64 fn test_get_amount_percentage() {
65 let withdrawal = Withdrawal {
66 token: "token".to_string(),
67 withdrawal_type: Some(WithdrawalType::Percentage(Decimal::percent(10))),
68 };
69 let balance = Uint128::from(100u128);
70 assert_eq!(10u128, withdrawal.get_amount(balance).unwrap().u128());
71 }
72
73 #[test]
74 fn test_get_amount_invalid_percentage() {
75 let withdrawal = Withdrawal {
76 token: "token".to_string(),
77 withdrawal_type: Some(WithdrawalType::Percentage(Decimal::percent(101))),
78 };
79 let balance = Uint128::from(100u128);
80 assert_eq!(
81 ContractError::InvalidRate {},
82 withdrawal.get_amount(balance).unwrap_err()
83 );
84 }
85
86 #[test]
87 fn test_get_amount_amount() {
88 let withdrawal = Withdrawal {
89 token: "token".to_string(),
90 withdrawal_type: Some(WithdrawalType::Amount(5u128.into())),
91 };
92 let balance = Uint128::from(10u128);
93 assert_eq!(5u128, withdrawal.get_amount(balance).unwrap().u128());
94 }
95
96 #[test]
97 fn test_get_too_high_amount() {
98 let balance = Uint128::from(10u128);
99 let withdrawal = Withdrawal {
100 token: "token".to_string(),
101 withdrawal_type: Some(WithdrawalType::Amount(balance + Uint128::from(1u128))),
102 };
103 assert_eq!(10, withdrawal.get_amount(balance).unwrap().u128());
104 }
105}