1use crate::VaultError;
2use cosmwasm_std::{Deps, StdError, StdResult, Storage, Uint128};
3use cw_storage_plus::Item;
4
5const OFFSET: Uint128 = Uint128::new(1);
15
16const TOTAL_SHARES: Item<Uint128> = Item::new("total_shares");
19
20pub fn get_total_shares(storage: &dyn Storage) -> StdResult<Uint128> {
22 TOTAL_SHARES
23 .may_load(storage)
24 .map(|shares| shares.unwrap_or(Uint128::zero()))
25}
26
27#[derive(Debug)]
33pub struct VirtualOffset {
34 total_shares: Uint128,
35 total_assets: Uint128,
36 virtual_total_shares: Uint128,
37 virtual_total_assets: Uint128,
38}
39
40impl VirtualOffset {
41 pub fn new(total_shares: Uint128, total_assets: Uint128) -> StdResult<Self> {
43 let virtual_total_shares = total_shares.checked_add(OFFSET).map_err(StdError::from)?;
44 let virtual_total_assets = total_assets.checked_add(OFFSET).map_err(StdError::from)?;
45
46 Ok(Self {
47 total_shares,
48 total_assets,
49 virtual_total_shares,
50 virtual_total_assets,
51 })
52 }
53
54 pub fn shares_to_assets(&self, shares: Uint128) -> StdResult<Uint128> {
56 shares
58 .checked_mul(self.virtual_total_assets)
59 .map_err(StdError::from)?
60 .checked_div(self.virtual_total_shares)
61 .map_err(StdError::from)
62 }
63
64 pub fn assets_to_shares(&self, assets: Uint128) -> StdResult<Uint128> {
66 assets
68 .checked_mul(self.virtual_total_shares)
69 .map_err(StdError::from)?
70 .checked_div(self.virtual_total_assets)
71 .map_err(StdError::from)
72 }
73
74 pub fn total_shares(&self) -> Uint128 {
76 self.total_shares
77 }
78
79 pub fn total_assets(&self) -> Uint128 {
81 self.total_assets
82 }
83}
84
85#[derive(Debug)]
92pub struct TotalShares(VirtualOffset);
93
94impl TotalShares {
95 pub fn load(deps: &Deps, total_assets: Uint128) -> StdResult<Self> {
100 let total_shares = get_total_shares(deps.storage)?;
101 let offset = VirtualOffset::new(total_shares, total_assets)?;
102 Ok(Self(offset))
103 }
104
105 pub fn shares_to_assets(&self, shares: Uint128) -> StdResult<Uint128> {
107 self.0.shares_to_assets(shares)
108 }
109
110 pub fn assets_to_shares(&self, assets: Uint128) -> StdResult<Uint128> {
112 self.0.assets_to_shares(assets)
113 }
114
115 pub fn total_shares(&self) -> Uint128 {
117 self.0.total_shares
118 }
119
120 pub fn total_assets(&self) -> Uint128 {
122 self.0.total_assets
123 }
124
125 pub fn checked_add_shares(
131 &mut self,
132 storage: &mut dyn Storage,
133 shares: Uint128,
134 ) -> Result<(), VaultError> {
135 if shares.is_zero() {
136 return Err(VaultError::zero("Add shares cannot be zero"));
137 }
138
139 self.0.total_shares = self
140 .0
141 .total_shares
142 .checked_add(shares)
143 .map_err(StdError::from)?;
144 self.0.virtual_total_shares = self
145 .0
146 .total_shares
147 .checked_add(OFFSET)
148 .map_err(StdError::from)?;
149 TOTAL_SHARES.save(storage, &self.0.total_shares)?;
150 Ok(())
151 }
152
153 pub fn checked_sub_shares(
155 &mut self,
156 storage: &mut dyn Storage,
157 shares: Uint128,
158 ) -> Result<(), VaultError> {
159 if shares.is_zero() {
160 return Err(VaultError::zero("Sub shares cannot be zero"));
161 }
162
163 self.0.total_shares = self
164 .0
165 .total_shares
166 .checked_sub(shares)
167 .map_err(StdError::from)?;
168 self.0.virtual_total_shares = self
169 .0
170 .total_shares
171 .checked_add(OFFSET)
172 .map_err(StdError::from)?;
173 TOTAL_SHARES.save(storage, &self.0.total_shares)?;
174 Ok(())
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn one_to_one() {
184 let total_assets = Uint128::new(1000);
185 let total_shares = Uint128::new(1000);
186 let vault = VirtualOffset::new(total_shares, total_assets).unwrap();
187
188 {
189 let assets = vault.shares_to_assets(Uint128::new(1000)).unwrap();
190 assert_eq!(assets, Uint128::new(1000));
191
192 let shares = vault.assets_to_shares(Uint128::new(1000)).unwrap();
193 assert_eq!(shares, Uint128::new(1000));
194 }
195
196 {
197 let assets = vault.shares_to_assets(Uint128::new(100)).unwrap();
198 assert_eq!(assets, Uint128::new(100));
199
200 let shares = vault.assets_to_shares(Uint128::new(100)).unwrap();
201 assert_eq!(shares, Uint128::new(100));
202 }
203
204 {
205 let assets = vault.shares_to_assets(Uint128::new(10000)).unwrap();
206 assert_eq!(assets, Uint128::new(10000));
207
208 let shares = vault.assets_to_shares(Uint128::new(10000)).unwrap();
209 assert_eq!(shares, Uint128::new(10000));
210 }
211 }
212
213 #[test]
214 fn inflation_attack_over_1() {
215 let attacker_donation = Uint128::new(99_999);
218
219 let balance = Uint128::new(1) + attacker_donation;
220 let total_shares = Uint128::new(1);
221 let vault = VirtualOffset::new(total_shares, balance).unwrap();
222
223 let amount = vault.shares_to_assets(Uint128::new(1)).unwrap();
225 assert_eq!(amount, Uint128::new(50_000));
226
227 let amount = Uint128::new(10_000);
229 let shares = vault.assets_to_shares(amount).unwrap();
230 assert_eq!(shares, Uint128::new(0));
231
232 {
233 let amount = Uint128::new(50_001);
236 let shares = vault.assets_to_shares(amount).unwrap();
237 assert_eq!(shares, Uint128::new(1));
238
239 let balance = Uint128::new(150_001);
241 let total_shares = Uint128::new(1 + 1);
242 let vault = VirtualOffset::new(total_shares, balance).unwrap();
243
244 let amount = vault.shares_to_assets(Uint128::new(1)).unwrap();
246 assert_eq!(amount, Uint128::new(50_000));
247
248 let amount = vault.shares_to_assets(shares).unwrap();
250 assert_eq!(amount, Uint128::new(50_000));
251 }
252 {
253 let amount = Uint128::new(100_000);
255 let shares = vault.assets_to_shares(amount).unwrap();
256 assert_eq!(shares, Uint128::new(1));
257
258 let balance = Uint128::new(150_000);
260 let total_shares = Uint128::new(1 + 1);
261 let vault = VirtualOffset::new(total_shares, balance).unwrap();
262
263 let amount = vault.shares_to_assets(Uint128::new(1)).unwrap();
265 assert_eq!(amount, Uint128::new(50_000));
266
267 let amount = vault.shares_to_assets(shares).unwrap();
269 assert_eq!(amount, Uint128::new(50_000));
270 }
271 {
272 let amount = Uint128::new(100_001);
274 let shares = vault.assets_to_shares(amount).unwrap();
275 assert_eq!(shares, Uint128::new(2));
276
277 let balance = Uint128::new(150_001);
279 let total_shares = Uint128::new(1 + 2);
280 let vault = VirtualOffset::new(total_shares, balance).unwrap();
281
282 let amount = vault.shares_to_assets(Uint128::new(1)).unwrap();
284 assert_eq!(amount, Uint128::new(37_500));
285
286 let amount = vault.shares_to_assets(shares).unwrap();
288 assert_eq!(amount, Uint128::new(75_001));
289 }
290 }
291
292 #[test]
293 fn imbalance_1000_to_1() {
294 let balance = Uint128::new(1000);
295 let total_shares = Uint128::new(1);
296
297 let vault = VirtualOffset::new(total_shares, balance).unwrap();
300
301 {
303 let shares = Uint128::new(500);
304 let amount = vault.shares_to_assets(shares).unwrap();
305 assert_eq!(amount, Uint128::new(250_250));
307
308 let amount = Uint128::new(250);
309 let shares = vault.assets_to_shares(amount).unwrap();
310 assert_eq!(shares, Uint128::new(0));
312 }
313
314 {
316 let shares = Uint128::new(10_000);
317 let amount = vault.shares_to_assets(shares).unwrap();
318 assert_eq!(amount, Uint128::new(5_005_000));
320
321 let amount = Uint128::new(10_000_000);
322 let shares = vault.assets_to_shares(amount).unwrap();
323 assert_eq!(shares, Uint128::new(19_980));
325 }
326 }
327
328 #[test]
329 fn imbalance_1000_to_2() {
330 let balance = Uint128::new(1000);
331 let total_shares = Uint128::new(2);
332
333 let vault = VirtualOffset::new(total_shares, balance).unwrap();
336
337 {
339 let shares = Uint128::new(1000);
340 let amount = vault.shares_to_assets(shares).unwrap();
341 assert_eq!(amount, Uint128::new(333_666));
343
344 let amount = Uint128::new(1);
345 let shares = vault.assets_to_shares(amount).unwrap();
346 assert_eq!(shares, Uint128::new(0));
348
349 let amount = Uint128::new(10);
350 let shares = vault.assets_to_shares(amount).unwrap();
351 assert_eq!(shares, Uint128::new(0));
353 }
354
355 {
357 let shares = Uint128::new(100_444);
358 let amount = vault.shares_to_assets(shares).unwrap();
359 assert_eq!(amount, Uint128::new(33_514_814));
361
362 let amount = Uint128::new(10_000_000);
363 let shares = vault.assets_to_shares(amount).unwrap();
364 assert_eq!(shares, Uint128::new(29_970));
366 }
367 }
368
369 #[test]
371 fn shares_imbalance_100_000_to_1() {
372 let balance = Uint128::new(100_000);
373 let total_shares = Uint128::new(1);
374
375 let vault = VirtualOffset::new(total_shares, balance).unwrap();
378
379 let shares = Uint128::new(500);
382 let amount = vault.shares_to_assets(shares).unwrap();
383 assert_eq!(amount, Uint128::new(25_000_250));
384
385 let shares = Uint128::new(1);
388 let amount = vault.shares_to_assets(shares).unwrap();
389 assert_eq!(amount, Uint128::new(50_000));
390
391 let shares = Uint128::new(10_000);
394 let amount = vault.shares_to_assets(shares).unwrap();
395 assert_eq!(amount, Uint128::new(500_005_000));
396 }
397
398 #[test]
400 fn amount_imbalance_100_000_to_1() {
401 let balance = Uint128::new(100_000);
402 let total_shares = Uint128::new(1);
403
404 let vault = VirtualOffset::new(total_shares, balance).unwrap();
407
408 let amount = Uint128::new(1);
411 let shares = vault.assets_to_shares(amount).unwrap();
412 assert_eq!(shares, Uint128::new(0));
413
414 let amount = Uint128::new(100);
416 let shares = vault.assets_to_shares(amount).unwrap();
417 assert_eq!(shares, Uint128::new(0));
418
419 let amount = Uint128::new(50_001);
422 let shares = vault.assets_to_shares(amount).unwrap();
423 assert_eq!(shares, Uint128::new(1));
424 }
425
426 #[test]
427 fn extreme_inflation_1e20_to_1() {
428 let balance = Uint128::new(1e20 as u128);
429 let total_shares = Uint128::new(1);
430
431 let vault = VirtualOffset::new(total_shares, balance).unwrap();
434
435 let amount = Uint128::new(999);
438 let shares = vault.assets_to_shares(amount).unwrap();
439 assert_eq!(shares, Uint128::new(0));
440
441 let amount = Uint128::new(1_000_000);
443 let shares = vault.assets_to_shares(amount).unwrap();
444 assert_eq!(shares, Uint128::new(0));
445
446 let amount = Uint128::new(1e20 as u128);
448 let shares = vault.assets_to_shares(amount).unwrap();
449 assert_eq!(shares, Uint128::new(1));
450
451 {
454 let new_share = Uint128::new(1) + Uint128::new(1);
456 let new_balance = Uint128::new(1e20 as u128) + Uint128::new(1e20 as u128);
457 let vault = VirtualOffset::new(new_share, new_balance).unwrap();
458
459 let shares = Uint128::new(1);
461 let amount = vault.shares_to_assets(shares).unwrap();
462 assert!(amount < Uint128::new(1e20 as u128));
463 }
464 }
465
466 #[test]
467 fn overflow() {
468 let almost_max = Uint128::new(u128::MAX);
469
470 {
471 let error = VirtualOffset::new(almost_max, almost_max).unwrap_err();
472 assert_eq!(
473 error.to_string(),
474 "Overflow: Cannot Add with given operands"
475 )
476 }
477
478 {
479 let max_div_1e10 = Uint128::new(u128::MAX / 1e10 as u128);
480 let vault = VirtualOffset::new(max_div_1e10, max_div_1e10).unwrap();
481
482 vault.shares_to_assets(Uint128::new(1)).unwrap();
483 vault.assets_to_shares(Uint128::new(1)).unwrap();
484
485 vault.shares_to_assets(Uint128::new(1e9 as u128)).unwrap();
486 vault.assets_to_shares(Uint128::new(1e9 as u128)).unwrap();
487
488 vault
489 .shares_to_assets(Uint128::new((1e10 as u128) - 1))
490 .unwrap();
491 vault
492 .assets_to_shares(Uint128::new((1e10 as u128) - 1))
493 .unwrap();
494
495 let error = vault
496 .shares_to_assets(Uint128::new(1e10 as u128))
497 .unwrap_err();
498 assert_eq!(
499 error.to_string(),
500 "Overflow: Cannot Mul with given operands"
501 );
502
503 let error = vault
504 .assets_to_shares(Uint128::new(1e10 as u128))
505 .unwrap_err();
506 assert_eq!(
507 error.to_string(),
508 "Overflow: Cannot Mul with given operands"
509 );
510 }
511 }
512}