neptune_common/
pool.rs

1use cosmwasm_schema::cw_serde;
2use cosmwasm_std::Uint256;
3
4use crate::traits::Zeroed;
5
6/// This data type helps to keep track of pooling together assets between multiple accounts.
7#[cw_serde]
8#[derive(Copy, Default)]
9pub struct Pool {
10    pub balance: Uint256,
11    pub shares: Uint256,
12}
13
14impl GetPoolMut for Pool {
15    fn get_pool_mut(&mut self) -> PoolMut {
16        PoolMut { balance: &mut self.balance, shares: &mut self.shares }
17    }
18}
19
20impl GetPoolRef for Pool {
21    fn get_pool_ref(&self) -> PoolRef {
22        PoolRef { balance: &self.balance, shares: &self.shares }
23    }
24}
25
26/// This serves the same purpose as Pool, but can be constructed directly from immutable references.
27pub struct PoolRef<'a> {
28    pub balance: &'a Uint256,
29    pub shares: &'a Uint256,
30}
31
32// This serves the same purpose as Pool, but can be constructed directly from mutable references.
33pub struct PoolMut<'a> {
34    pub balance: &'a mut Uint256,
35    pub shares: &'a mut Uint256,
36}
37
38impl GetPoolMut for PoolMut<'_> {
39    fn get_pool_mut(&mut self) -> PoolMut {
40        PoolMut { balance: self.balance, shares: self.shares }
41    }
42}
43
44impl GetPoolRef for PoolMut<'_> {
45    fn get_pool_ref(&self) -> PoolRef {
46        PoolRef { balance: self.balance, shares: self.shares }
47    }
48}
49
50pub trait GetPoolMut {
51    fn get_pool_mut(&mut self) -> PoolMut;
52}
53
54pub trait GetPoolRef {
55    fn get_pool_ref(&self) -> PoolRef;
56}
57
58/// Adds shares to an account and calculates the corresponding balance.
59pub fn add_shares(pool: &mut dyn GetPoolMut, shares: Uint256, account: &mut PoolAccount) -> AddSharesResponse {
60    let pool_mut = pool.get_pool_mut();
61    let pool_balance = pool_mut.balance;
62    let pool_shares = pool_mut.shares;
63    let account_principal = &mut account.principal;
64    let account_shares = &mut account.shares;
65
66    let shares_to_issue = shares;
67    let balance_to_issue = shares_to_issue.multiply_ratio(*pool_balance, *pool_shares);
68
69    *account_shares += shares_to_issue;
70    *account_principal += balance_to_issue;
71
72    *pool_shares += shares_to_issue;
73    *pool_balance += balance_to_issue;
74
75    AddSharesResponse { balance_added: balance_to_issue }
76}
77
78/// Adds a balance to an account and calculates the corresponding shares to issue.
79pub fn add_amount(pool: &mut dyn GetPoolMut, amount: Uint256, account: &mut PoolAccount) -> AddAmountResponse {
80    let balance_to_issue = amount;
81
82    let pool_mut = pool.get_pool_mut();
83    let pool_balance = pool_mut.balance;
84    let pool_shares = pool_mut.shares;
85    let account_principal = &mut account.principal;
86    let account_shares = &mut account.shares;
87
88    let shares_to_issue = if pool_balance.is_zero() {
89        amount
90    } else {
91        amount.multiply_ratio(*pool_shares, *pool_balance)
92    };
93
94    *account_shares += shares_to_issue;
95    *account_principal += balance_to_issue;
96
97    *pool_shares += shares_to_issue;
98    *pool_balance += balance_to_issue;
99
100    AddAmountResponse { shares_added: shares_to_issue }
101}
102
103/// Removes shares from an account and calculates the corresponding balance to return.
104pub fn remove_shares(pool: &mut dyn GetPoolMut, shares: Uint256, account: &mut PoolAccount) -> RemoveSharesResponse {
105    let pool_mut = pool.get_pool_mut();
106    let pool_balance = pool_mut.balance;
107    let pool_shares = pool_mut.shares;
108    let account_principal = &mut account.principal;
109    let account_shares = &mut account.shares;
110
111    let shares_to_remove = if shares > *account_shares {
112        *account_shares
113    } else {
114        shares
115    };
116
117    let amount_to_remove = shares_to_remove.multiply_ratio(*pool_balance, *pool_shares);
118
119    *account_shares -= shares_to_remove;
120    *account_principal = account_principal.saturating_sub(amount_to_remove);
121
122    *pool_shares -= shares_to_remove;
123    *pool_balance -= amount_to_remove;
124
125    RemoveSharesResponse { balance_removed: amount_to_remove }
126}
127
128/// Removes a balance from an account and calculates the corresponding shares to return.
129pub fn remove_amount(pool: &mut dyn GetPoolMut, amount: Uint256, account: &mut PoolAccount) -> RemoveAmountResponse {
130    let pool_mut = pool.get_pool_mut();
131    let pool_balance = pool_mut.balance;
132    let pool_shares = pool_mut.shares;
133    let account_principal = &mut account.principal;
134    let account_shares = &mut account.shares;
135
136    if pool_balance.is_zero() || pool_shares.is_zero() || account_shares.is_zero() {
137        return RemoveAmountResponse { amount_removed: Uint256::zero(), shares_removed: Uint256::zero() };
138    }
139
140    let amount_to_remove;
141    let shares_to_remove;
142    let account_amount = account_shares.multiply_ratio(*pool_balance, *pool_shares);
143    if amount > account_amount {
144        amount_to_remove = account_amount;
145        shares_to_remove = *account_shares;
146    } else {
147        amount_to_remove = amount;
148        shares_to_remove = account_shares.multiply_ratio(amount, account_amount);
149    }
150
151    *account_shares -= shares_to_remove;
152    *account_principal = account_principal.saturating_sub(amount_to_remove);
153
154    *pool_shares -= shares_to_remove;
155    *pool_balance -= amount_to_remove;
156
157    RemoveAmountResponse { amount_removed: amount_to_remove, shares_removed: shares_to_remove }
158}
159
160/// Increases the balance of the pool by the amount specified.
161pub fn increase_balance(pool: &mut dyn GetPoolMut, amount: Uint256) {
162    let pool_mut = pool.get_pool_mut();
163    let pool_balance = pool_mut.balance;
164    *pool_balance += amount;
165}
166
167/// Decreases the balance of the pool by the amount specified.
168pub fn decrease_balance(pool: &mut dyn GetPoolMut, amount: Uint256) {
169    let pool_mut = pool.get_pool_mut();
170    let pool_balance = pool_mut.balance;
171    *pool_balance = pool_balance.saturating_sub(amount);
172}
173
174/// Returns the balance of an account
175pub fn get_account_balance(pool: &dyn GetPoolRef, account: PoolAccount) -> Uint256 {
176    let pool_ref = pool.get_pool_ref();
177    let pool_balance = pool_ref.balance;
178    let pool_shares = pool_ref.shares;
179    account.shares.checked_multiply_ratio(*pool_balance, *pool_shares).unwrap_or_default()
180}
181
182#[cw_serde]
183#[derive(Copy, Default)]
184pub struct PoolAccount {
185    pub principal: Uint256,
186    pub shares: Uint256,
187}
188
189pub struct AddSharesResponse {
190    pub balance_added: Uint256,
191}
192
193pub struct AddAmountResponse {
194    pub shares_added: Uint256,
195}
196
197pub struct RemoveSharesResponse {
198    pub balance_removed: Uint256,
199}
200
201pub struct RemoveAmountResponse {
202    pub amount_removed: Uint256,
203    pub shares_removed: Uint256,
204}
205
206impl Zeroed for PoolAccount {
207    fn is_zeroed(&self) -> bool {
208        self.shares.is_zero()
209    }
210
211    fn remove_zeroed(&mut self) {}
212}
213
214#[cfg(test)]
215mod test {
216    use cosmwasm_std::Uint256;
217    use rand::random;
218
219    use super::*;
220
221    #[test]
222    fn test_add_and_remove() {
223        for _ in 0..1000 {
224            let start_pool_balance = Uint256::from(random::<u64>());
225            let start_pool_shares = Uint256::from(random::<u64>());
226            let amount = Uint256::from(random::<u64>());
227
228            let mut account = PoolAccount::default();
229            let mut pool = Pool { balance: start_pool_balance, shares: start_pool_shares };
230            // add_amount(&mut pool, amount, &mut account);
231            //add_amount(&mut pool, amount, &mut account);
232            add_amount(&mut pool, amount, &mut account);
233            // pool_mut.add_amount(amount, &mut account);
234            pool.balance += Uint256::from(random::<u64>());
235            let balance = get_account_balance(&pool, account);
236            let amount_removed = remove_amount(&mut pool, balance, &mut account);
237
238            assert_eq!(amount_removed.amount_removed, balance);
239            assert_eq!(
240                get_account_balance(&pool, account),
241                Uint256::zero(),
242                "start_pool_balance: {start_pool_balance}, start_pool_shares:
243{start_pool_shares}, amount: {amount}, account {account:#?}"
244            );
245        }
246    }
247
248    #[test]
249    fn pool_test() {
250        let mut pool: Pool = Pool::default();
251        let mut account1: PoolAccount = PoolAccount::default();
252        let mut account2: PoolAccount = PoolAccount::default();
253
254        add_amount(&mut pool, Uint256::from(100u64), &mut account1);
255        assert_eq!(pool.balance, Uint256::from(100u64));
256        assert_eq!(pool.shares, Uint256::from(100u64));
257        assert_eq!(account1.principal, Uint256::from(100u64));
258        assert_eq!(account1.shares, Uint256::from(100u64));
259
260        increase_balance(&mut pool, Uint256::from(100u64));
261        assert_eq!(pool.balance, Uint256::from(200u64));
262        assert_eq!(pool.shares, Uint256::from(100u64));
263        assert_eq!(account1.principal, Uint256::from(100u64));
264        assert_eq!(account1.shares, Uint256::from(100u64));
265        assert_eq!(get_account_balance(&pool, account1), Uint256::from(200u64));
266
267        add_shares(&mut pool, Uint256::from(50u64), &mut account2);
268        assert_eq!(pool.balance, Uint256::from(300u64));
269        assert_eq!(pool.shares, Uint256::from(150u64));
270        assert_eq!(account2.principal, Uint256::from(100u64));
271        assert_eq!(account2.shares, Uint256::from(50u64));
272
273        remove_amount(&mut pool, Uint256::from(100u64), &mut account2);
274        assert_eq!(pool.balance, Uint256::from(200u64));
275        assert_eq!(pool.shares, Uint256::from(100u64));
276        assert_eq!(account2.principal, Uint256::from(0u64));
277        assert_eq!(account2.shares, Uint256::from(0u64));
278
279        remove_shares(&mut pool, Uint256::from(100u64), &mut account2);
280        assert_eq!(pool.balance, Uint256::from(200u64));
281        assert_eq!(pool.shares, Uint256::from(100u64));
282        assert_eq!(account2.principal, Uint256::from(0u64));
283        assert_eq!(account2.shares, Uint256::from(0u64));
284    }
285}