1use cosmwasm_schema::cw_serde;
2use cosmwasm_std::Uint256;
3
4use crate::traits::Zeroed;
5
6#[cw_serde]
8#[derive(Copy, Default)]
9pub struct Pool {
10 pub balance: Uint256,
11 pub shares: Uint256,
12}
13
14impl GetPoolMut for Pool {
15 fn get_pool_mut(&mut self) -> PoolMut {
16 PoolMut { balance: &mut self.balance, shares: &mut self.shares }
17 }
18}
19
20impl GetPoolRef for Pool {
21 fn get_pool_ref(&self) -> PoolRef {
22 PoolRef { balance: &self.balance, shares: &self.shares }
23 }
24}
25
26pub struct PoolRef<'a> {
28 pub balance: &'a Uint256,
29 pub shares: &'a Uint256,
30}
31
32pub struct PoolMut<'a> {
34 pub balance: &'a mut Uint256,
35 pub shares: &'a mut Uint256,
36}
37
38impl GetPoolMut for PoolMut<'_> {
39 fn get_pool_mut(&mut self) -> PoolMut {
40 PoolMut { balance: self.balance, shares: self.shares }
41 }
42}
43
44impl GetPoolRef for PoolMut<'_> {
45 fn get_pool_ref(&self) -> PoolRef {
46 PoolRef { balance: self.balance, shares: self.shares }
47 }
48}
49
50pub trait GetPoolMut {
51 fn get_pool_mut(&mut self) -> PoolMut;
52}
53
54pub trait GetPoolRef {
55 fn get_pool_ref(&self) -> PoolRef;
56}
57
58pub fn add_shares(pool: &mut dyn GetPoolMut, shares: Uint256, account: &mut PoolAccount) -> AddSharesResponse {
60 let pool_mut = pool.get_pool_mut();
61 let pool_balance = pool_mut.balance;
62 let pool_shares = pool_mut.shares;
63 let account_principal = &mut account.principal;
64 let account_shares = &mut account.shares;
65
66 let shares_to_issue = shares;
67 let balance_to_issue = shares_to_issue.multiply_ratio(*pool_balance, *pool_shares);
68
69 *account_shares += shares_to_issue;
70 *account_principal += balance_to_issue;
71
72 *pool_shares += shares_to_issue;
73 *pool_balance += balance_to_issue;
74
75 AddSharesResponse { balance_added: balance_to_issue }
76}
77
78pub fn add_amount(pool: &mut dyn GetPoolMut, amount: Uint256, account: &mut PoolAccount) -> AddAmountResponse {
80 let balance_to_issue = amount;
81
82 let pool_mut = pool.get_pool_mut();
83 let pool_balance = pool_mut.balance;
84 let pool_shares = pool_mut.shares;
85 let account_principal = &mut account.principal;
86 let account_shares = &mut account.shares;
87
88 let shares_to_issue = if pool_balance.is_zero() {
89 amount
90 } else {
91 amount.multiply_ratio(*pool_shares, *pool_balance)
92 };
93
94 *account_shares += shares_to_issue;
95 *account_principal += balance_to_issue;
96
97 *pool_shares += shares_to_issue;
98 *pool_balance += balance_to_issue;
99
100 AddAmountResponse { shares_added: shares_to_issue }
101}
102
103pub fn remove_shares(pool: &mut dyn GetPoolMut, shares: Uint256, account: &mut PoolAccount) -> RemoveSharesResponse {
105 let pool_mut = pool.get_pool_mut();
106 let pool_balance = pool_mut.balance;
107 let pool_shares = pool_mut.shares;
108 let account_principal = &mut account.principal;
109 let account_shares = &mut account.shares;
110
111 let shares_to_remove = if shares > *account_shares {
112 *account_shares
113 } else {
114 shares
115 };
116
117 let amount_to_remove = shares_to_remove.multiply_ratio(*pool_balance, *pool_shares);
118
119 *account_shares -= shares_to_remove;
120 *account_principal = account_principal.saturating_sub(amount_to_remove);
121
122 *pool_shares -= shares_to_remove;
123 *pool_balance -= amount_to_remove;
124
125 RemoveSharesResponse { balance_removed: amount_to_remove }
126}
127
128pub fn remove_amount(pool: &mut dyn GetPoolMut, amount: Uint256, account: &mut PoolAccount) -> RemoveAmountResponse {
130 let pool_mut = pool.get_pool_mut();
131 let pool_balance = pool_mut.balance;
132 let pool_shares = pool_mut.shares;
133 let account_principal = &mut account.principal;
134 let account_shares = &mut account.shares;
135
136 if pool_balance.is_zero() || pool_shares.is_zero() || account_shares.is_zero() {
137 return RemoveAmountResponse { amount_removed: Uint256::zero(), shares_removed: Uint256::zero() };
138 }
139
140 let amount_to_remove;
141 let shares_to_remove;
142 let account_amount = account_shares.multiply_ratio(*pool_balance, *pool_shares);
143 if amount > account_amount {
144 amount_to_remove = account_amount;
145 shares_to_remove = *account_shares;
146 } else {
147 amount_to_remove = amount;
148 shares_to_remove = account_shares.multiply_ratio(amount, account_amount);
149 }
150
151 *account_shares -= shares_to_remove;
152 *account_principal = account_principal.saturating_sub(amount_to_remove);
153
154 *pool_shares -= shares_to_remove;
155 *pool_balance -= amount_to_remove;
156
157 RemoveAmountResponse { amount_removed: amount_to_remove, shares_removed: shares_to_remove }
158}
159
160pub fn increase_balance(pool: &mut dyn GetPoolMut, amount: Uint256) {
162 let pool_mut = pool.get_pool_mut();
163 let pool_balance = pool_mut.balance;
164 *pool_balance += amount;
165}
166
167pub fn decrease_balance(pool: &mut dyn GetPoolMut, amount: Uint256) {
169 let pool_mut = pool.get_pool_mut();
170 let pool_balance = pool_mut.balance;
171 *pool_balance = pool_balance.saturating_sub(amount);
172}
173
174pub fn get_account_balance(pool: &dyn GetPoolRef, account: PoolAccount) -> Uint256 {
176 let pool_ref = pool.get_pool_ref();
177 let pool_balance = pool_ref.balance;
178 let pool_shares = pool_ref.shares;
179 account.shares.checked_multiply_ratio(*pool_balance, *pool_shares).unwrap_or_default()
180}
181
182#[cw_serde]
183#[derive(Copy, Default)]
184pub struct PoolAccount {
185 pub principal: Uint256,
186 pub shares: Uint256,
187}
188
189pub struct AddSharesResponse {
190 pub balance_added: Uint256,
191}
192
193pub struct AddAmountResponse {
194 pub shares_added: Uint256,
195}
196
197pub struct RemoveSharesResponse {
198 pub balance_removed: Uint256,
199}
200
201pub struct RemoveAmountResponse {
202 pub amount_removed: Uint256,
203 pub shares_removed: Uint256,
204}
205
206impl Zeroed for PoolAccount {
207 fn is_zeroed(&self) -> bool {
208 self.shares.is_zero()
209 }
210
211 fn remove_zeroed(&mut self) {}
212}
213
214#[cfg(test)]
215mod test {
216 use cosmwasm_std::Uint256;
217 use rand::random;
218
219 use super::*;
220
221 #[test]
222 fn test_add_and_remove() {
223 for _ in 0..1000 {
224 let start_pool_balance = Uint256::from(random::<u64>());
225 let start_pool_shares = Uint256::from(random::<u64>());
226 let amount = Uint256::from(random::<u64>());
227
228 let mut account = PoolAccount::default();
229 let mut pool = Pool { balance: start_pool_balance, shares: start_pool_shares };
230 add_amount(&mut pool, amount, &mut account);
233 pool.balance += Uint256::from(random::<u64>());
235 let balance = get_account_balance(&pool, account);
236 let amount_removed = remove_amount(&mut pool, balance, &mut account);
237
238 assert_eq!(amount_removed.amount_removed, balance);
239 assert_eq!(
240 get_account_balance(&pool, account),
241 Uint256::zero(),
242 "start_pool_balance: {start_pool_balance}, start_pool_shares:
243{start_pool_shares}, amount: {amount}, account {account:#?}"
244 );
245 }
246 }
247
248 #[test]
249 fn pool_test() {
250 let mut pool: Pool = Pool::default();
251 let mut account1: PoolAccount = PoolAccount::default();
252 let mut account2: PoolAccount = PoolAccount::default();
253
254 add_amount(&mut pool, Uint256::from(100u64), &mut account1);
255 assert_eq!(pool.balance, Uint256::from(100u64));
256 assert_eq!(pool.shares, Uint256::from(100u64));
257 assert_eq!(account1.principal, Uint256::from(100u64));
258 assert_eq!(account1.shares, Uint256::from(100u64));
259
260 increase_balance(&mut pool, Uint256::from(100u64));
261 assert_eq!(pool.balance, Uint256::from(200u64));
262 assert_eq!(pool.shares, Uint256::from(100u64));
263 assert_eq!(account1.principal, Uint256::from(100u64));
264 assert_eq!(account1.shares, Uint256::from(100u64));
265 assert_eq!(get_account_balance(&pool, account1), Uint256::from(200u64));
266
267 add_shares(&mut pool, Uint256::from(50u64), &mut account2);
268 assert_eq!(pool.balance, Uint256::from(300u64));
269 assert_eq!(pool.shares, Uint256::from(150u64));
270 assert_eq!(account2.principal, Uint256::from(100u64));
271 assert_eq!(account2.shares, Uint256::from(50u64));
272
273 remove_amount(&mut pool, Uint256::from(100u64), &mut account2);
274 assert_eq!(pool.balance, Uint256::from(200u64));
275 assert_eq!(pool.shares, Uint256::from(100u64));
276 assert_eq!(account2.principal, Uint256::from(0u64));
277 assert_eq!(account2.shares, Uint256::from(0u64));
278
279 remove_shares(&mut pool, Uint256::from(100u64), &mut account2);
280 assert_eq!(pool.balance, Uint256::from(200u64));
281 assert_eq!(pool.shares, Uint256::from(100u64));
282 assert_eq!(account2.principal, Uint256::from(0u64));
283 assert_eq!(account2.shares, Uint256::from(0u64));
284 }
285}