abstract_core/objects/
deposit_manager.rs

1use cosmwasm_std::{Storage, Uint64};
2use cw_storage_plus::{Item, Map};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5
6use crate::{error::AbstractError, AbstractResult};
7
8#[derive(Default, Serialize, Deserialize, Clone, Debug, PartialEq, Eq, JsonSchema)]
9pub struct Deposit {
10    value: Uint64,
11}
12
13impl Deposit {
14    pub const fn new() -> Deposit {
15        Deposit {
16            value: Uint64::zero(),
17        }
18    }
19
20    pub fn increase(&mut self, amount: Uint64) -> Self {
21        self.value += amount;
22        self.clone()
23    }
24
25    pub fn decrease(&mut self, amount: Uint64) -> AbstractResult<Self> {
26        if amount > self.value {
27            return Err(AbstractError::Deposit(format!(
28                "Cannot decrease {} by {}",
29                self.value, amount
30            )));
31        }
32
33        self.value = self.value.checked_sub(amount)?;
34        Ok(self.clone())
35    }
36
37    pub fn get(&self) -> Uint64 {
38        self.value
39    }
40}
41
42pub struct UserDeposit<'a> {
43    map: Map<'a, &'a [u8], Deposit>,
44}
45
46impl UserDeposit<'_> {
47    pub const fn new(namespace: &'static str) -> UserDeposit<'static> {
48        UserDeposit {
49            map: Map::new(namespace),
50        }
51    }
52
53    pub fn increase(
54        &self,
55        storage: &mut dyn Storage,
56        key: &[u8],
57        amount: Uint64,
58    ) -> AbstractResult<()> {
59        let user_deposit = &mut self.map.may_load(storage, key)?.unwrap_or_default();
60        self.map
61            .save(storage, key, &user_deposit.increase(amount))?;
62        Ok(())
63    }
64
65    pub fn decrease(
66        &self,
67        storage: &mut dyn Storage,
68        key: &[u8],
69        amount: Uint64,
70    ) -> AbstractResult<()> {
71        let mut user_deposit: Deposit = self.map.may_load(storage, key)?.unwrap_or_default();
72        self.map
73            .save(storage, key, &user_deposit.decrease(amount)?)?;
74        let new_deposit = user_deposit.get();
75        if new_deposit == Uint64::zero() {
76            self.map.remove(storage, key);
77        }
78        Ok(())
79    }
80
81    pub fn get(&self, storage: &dyn Storage, key: &[u8]) -> AbstractResult<Uint64> {
82        Ok(self.map.may_load(storage, key)?.unwrap_or_default().get())
83    }
84}
85
86pub struct DepositManager {
87    total_deposits: Item<'static, Deposit>,
88    user_deposits: UserDeposit<'static>, // TODO: Check if lifetime can be improved
89}
90
91impl DepositManager {
92    pub const fn new(total_namespace: &'static str, deposit_namespace: &'static str) -> Self {
93        Self {
94            total_deposits: Item::new(total_namespace),
95            user_deposits: UserDeposit::new(deposit_namespace),
96        }
97    }
98
99    pub fn increase(
100        &self,
101        storage: &mut dyn Storage,
102        key: &[u8],
103        amount: Uint64,
104    ) -> AbstractResult<()> {
105        let deposit = self.total_deposits.load(storage);
106        if deposit.is_err() {
107            self.total_deposits.save(storage, &Deposit::new())?;
108        }
109        let mut total_deposits = self.total_deposits.load(storage)?;
110        self.total_deposits
111            .save(storage, &total_deposits.increase(amount))?;
112        self.user_deposits.increase(storage, key, amount)
113    }
114
115    pub fn decrease(
116        &self,
117        storage: &mut dyn Storage,
118        key: &[u8],
119        amount: Uint64,
120    ) -> AbstractResult<()> {
121        self.user_deposits.decrease(storage, key, amount)?;
122        let mut total_deposits = self.total_deposits.load(storage)?;
123        self.total_deposits
124            .save(storage, &total_deposits.decrease(amount)?)?;
125        Ok(())
126    }
127
128    pub fn get(&self, storage: &dyn Storage, key: &[u8]) -> AbstractResult<Uint64> {
129        self.user_deposits.get(storage, key)
130    }
131
132    pub fn get_total_deposits(&self, storage: &dyn Storage) -> AbstractResult<Uint64> {
133        Ok(self.total_deposits.load(storage)?.get())
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use cosmwasm_std::testing::MockStorage;
140
141    use super::*;
142
143    #[test]
144    fn test_user_deposits() {
145        let mut storage = MockStorage::default();
146        let user_deposits = UserDeposit::new("test");
147
148        let key = "key".as_bytes();
149        let initial_value = user_deposits.get(&storage, key).unwrap();
150        assert_eq!(initial_value, Uint64::from(0u64));
151
152        user_deposits
153            .increase(&mut storage, key, Uint64::from(10u64))
154            .unwrap();
155        let value = user_deposits.get(&storage, key).unwrap();
156        assert_eq!(value, Uint64::from(10u64));
157
158        user_deposits
159            .increase(&mut storage, key, Uint64::from(10u64))
160            .unwrap();
161        let value = user_deposits.get(&storage, key).unwrap();
162        assert_eq!(value, Uint64::from(20u64));
163
164        user_deposits
165            .decrease(&mut storage, key, Uint64::from(5u64))
166            .unwrap();
167        let value = user_deposits.get(&storage, key).unwrap();
168        assert_eq!(value, Uint64::from(15u64));
169
170        user_deposits
171            .decrease(&mut storage, key, Uint64::from(15u64))
172            .unwrap();
173        let value = user_deposits.get(&storage, key).unwrap();
174        assert_eq!(value, Uint64::from(0u64));
175
176        let res = user_deposits.decrease(&mut storage, key, Uint64::from(15u64));
177        assert!(res.is_err());
178    }
179
180    #[test]
181    fn test_deposit_manager() {
182        let mut storage = MockStorage::default();
183        let deposits = DepositManager::new("test", "test2");
184
185        let key = "key".as_bytes();
186        let other_key = "other_key".as_bytes();
187        let initial_value = deposits.get(&storage, key).unwrap();
188        assert_eq!(initial_value, Uint64::from(0u64));
189
190        deposits
191            .increase(&mut storage, key, Uint64::from(10u64))
192            .unwrap();
193        let value = deposits.get(&storage, key).unwrap();
194        assert_eq!(value, Uint64::from(10u64));
195        let value = deposits.get_total_deposits(&storage).unwrap();
196        assert_eq!(value, Uint64::from(10u64));
197
198        deposits
199            .increase(&mut storage, key, Uint64::from(10u64))
200            .unwrap();
201        let value = deposits.get(&storage, key).unwrap();
202        assert_eq!(value, Uint64::from(20u64));
203        assert_eq!(
204            deposits.get_total_deposits(&storage).unwrap(),
205            Uint64::from(20u64)
206        );
207
208        deposits
209            .increase(&mut storage, other_key, Uint64::from(10u64))
210            .unwrap();
211        let value = deposits.get(&storage, key).unwrap();
212        assert_eq!(value, Uint64::from(20u64));
213        let value = deposits.get(&storage, other_key).unwrap();
214        assert_eq!(value, Uint64::from(10u64));
215        assert_eq!(
216            deposits.get_total_deposits(&storage).unwrap(),
217            Uint64::from(30u64)
218        );
219
220        let res = deposits.decrease(&mut storage, other_key, Uint64::from(15u64));
221        assert!(res.is_err());
222
223        deposits
224            .decrease(&mut storage, key, Uint64::from(15u64))
225            .unwrap();
226        assert_eq!(
227            deposits.get_total_deposits(&storage).unwrap(),
228            Uint64::from(15u64)
229        );
230
231        deposits
232            .decrease(&mut storage, key, Uint64::from(5u64))
233            .unwrap();
234        assert_eq!(
235            deposits.get_total_deposits(&storage).unwrap(),
236            Uint64::from(10u64)
237        );
238
239        deposits
240            .decrease(&mut storage, other_key, Uint64::from(10u64))
241            .unwrap();
242        assert_eq!(
243            deposits.get_total_deposits(&storage).unwrap(),
244            Uint64::from(0u64)
245        );
246    }
247}