bvs_vault_base/
offset.rs

1use crate::VaultError;
2use cosmwasm_std::{Deps, StdError, StdResult, Storage, Uint128};
3use cw_storage_plus::Item;
4
5/// The offset is used to mitigate the common 'share inflation' attack vector.
6///
7/// See [https://docs.openzeppelin.com/contracts/5.x/erc4626#inflation-attack]
8///
9/// This 1 offset will be used in exchange rate computation to reduce the impact of the attack.
10/// When the vault is empty, the virtual shares and virtual assets enforce the conversion rate 1/1.
11///
12/// Share inflation attack will not be as prevalent because of async withdrawal,
13/// hence offset of 1 will be enough to mitigate the attack.
14const OFFSET: Uint128 = Uint128::new(1);
15
16/// The total shares of the contract held by all stakers.
17/// [`OFFSET`] value is not included in the total shares, only the real shares are counted.
18const TOTAL_SHARES: Item<Uint128> = Item::new("total_shares");
19
20/// Get the total shares in circulation
21pub fn get_total_shares(storage: &dyn Storage) -> StdResult<Uint128> {
22    TOTAL_SHARES
23        .may_load(storage)
24        .map(|shares| shares.unwrap_or(Uint128::zero()))
25}
26
27/// Follows the OpenZeppelin's ERC4626 mitigation strategy for inflation attack.
28/// Using a "virtual" offset to +1 to both total shares and assets representing the virtual total shares and virtual total assets.
29/// A donation of 1 and under will be completely captured by the vault—without affecting the user.
30/// A donation greater than 1, the attacker will suffer loss greater than the user.
31/// [https://github.com/OpenZeppelin/openzeppelin-contracts/blob/fa995ef1fe66e1447783cb6038470aba23a6343f/contracts/token/ERC20/extensions/ERC4626.sol#L30-L37]
32#[derive(Debug)]
33pub struct VirtualOffset {
34    total_shares: Uint128,
35    total_assets: Uint128,
36    virtual_total_shares: Uint128,
37    virtual_total_assets: Uint128,
38}
39
40impl VirtualOffset {
41    /// Create a new [VirtualOffset] with the given total shares and total assets.
42    pub fn new(total_shares: Uint128, total_assets: Uint128) -> StdResult<Self> {
43        let virtual_total_shares = total_shares.checked_add(OFFSET).map_err(StdError::from)?;
44        let virtual_total_assets = total_assets.checked_add(OFFSET).map_err(StdError::from)?;
45
46        Ok(Self {
47            total_shares,
48            total_assets,
49            virtual_total_shares,
50            virtual_total_assets,
51        })
52    }
53
54    /// Shares to underlying assets
55    pub fn shares_to_assets(&self, shares: Uint128) -> StdResult<Uint128> {
56        // (shares * self.virtual_total_assets) / self.virtual_total_shares
57        shares
58            .checked_mul(self.virtual_total_assets)
59            .map_err(StdError::from)?
60            .checked_div(self.virtual_total_shares)
61            .map_err(StdError::from)
62    }
63
64    /// Underlying assets to shares
65    pub fn assets_to_shares(&self, assets: Uint128) -> StdResult<Uint128> {
66        // (assets * self.virtual_total_shares) / self.virtual_total_assets
67        assets
68            .checked_mul(self.virtual_total_shares)
69            .map_err(StdError::from)?
70            .checked_div(self.virtual_total_assets)
71            .map_err(StdError::from)
72    }
73
74    /// Get the total shares in circulation
75    pub fn total_shares(&self) -> Uint128 {
76        self.total_shares
77    }
78
79    /// Get the total assets under management
80    pub fn total_assets(&self) -> Uint128 {
81        self.total_assets
82    }
83}
84
85/// This struct wraps the [VirtualOffset] struct with [TOTAL_SHARES] storage features
86/// `checked_add_shares` and `checked_sub_shares` implemented.
87/// Other methods are mapped to the underlying [VirtualOffset] instance.
88///
89/// [TotalShares] is only used to account for the total shares (and total assets).
90/// Individual staker shares are stored here to allow for different staking strategies (e.g., Tokenized Vault).
91#[derive(Debug)]
92pub struct TotalShares(VirtualOffset);
93
94impl TotalShares {
95    /// Load the virtual total shares from storage (supports rebasing, by default).
96    /// A fixed [`OFFSET`] of 1 will be added to both total shares and total assets
97    /// to mitigate against inflation attack.
98    /// Use [shares_to_assets] and [assets_to_shares] to convert between shares and assets.
99    pub fn load(deps: &Deps, total_assets: Uint128) -> StdResult<Self> {
100        let total_shares = get_total_shares(deps.storage)?;
101        let offset = VirtualOffset::new(total_shares, total_assets)?;
102        Ok(Self(offset))
103    }
104
105    /// Shares to underlying assets
106    pub fn shares_to_assets(&self, shares: Uint128) -> StdResult<Uint128> {
107        self.0.shares_to_assets(shares)
108    }
109
110    /// Underlying assets to shares
111    pub fn assets_to_shares(&self, assets: Uint128) -> StdResult<Uint128> {
112        self.0.assets_to_shares(assets)
113    }
114
115    /// Get the total shares in circulation
116    pub fn total_shares(&self) -> Uint128 {
117        self.0.total_shares
118    }
119
120    /// Get the total assets under management
121    pub fn total_assets(&self) -> Uint128 {
122        self.0.total_assets
123    }
124
125    /// Add the new shares to the total shares and refresh the virtual shares and virtual assets.
126    /// This method is checked:
127    ///  - New shares cannot be zero.
128    ///  - Total shares cannot overflow.
129    ///  - Virtual shares cannot overflow.
130    pub fn checked_add_shares(
131        &mut self,
132        storage: &mut dyn Storage,
133        shares: Uint128,
134    ) -> Result<(), VaultError> {
135        if shares.is_zero() {
136            return Err(VaultError::zero("Add shares cannot be zero"));
137        }
138
139        self.0.total_shares = self
140            .0
141            .total_shares
142            .checked_add(shares)
143            .map_err(StdError::from)?;
144        self.0.virtual_total_shares = self
145            .0
146            .total_shares
147            .checked_add(OFFSET)
148            .map_err(StdError::from)?;
149        TOTAL_SHARES.save(storage, &self.0.total_shares)?;
150        Ok(())
151    }
152
153    /// Subtract the shares from the total shares and refresh the virtual shares and virtual assets.
154    pub fn checked_sub_shares(
155        &mut self,
156        storage: &mut dyn Storage,
157        shares: Uint128,
158    ) -> Result<(), VaultError> {
159        if shares.is_zero() {
160            return Err(VaultError::zero("Sub shares cannot be zero"));
161        }
162
163        self.0.total_shares = self
164            .0
165            .total_shares
166            .checked_sub(shares)
167            .map_err(StdError::from)?;
168        self.0.virtual_total_shares = self
169            .0
170            .total_shares
171            .checked_add(OFFSET)
172            .map_err(StdError::from)?;
173        TOTAL_SHARES.save(storage, &self.0.total_shares)?;
174        Ok(())
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn one_to_one() {
184        let total_assets = Uint128::new(1000);
185        let total_shares = Uint128::new(1000);
186        let vault = VirtualOffset::new(total_shares, total_assets).unwrap();
187
188        {
189            let assets = vault.shares_to_assets(Uint128::new(1000)).unwrap();
190            assert_eq!(assets, Uint128::new(1000));
191
192            let shares = vault.assets_to_shares(Uint128::new(1000)).unwrap();
193            assert_eq!(shares, Uint128::new(1000));
194        }
195
196        {
197            let assets = vault.shares_to_assets(Uint128::new(100)).unwrap();
198            assert_eq!(assets, Uint128::new(100));
199
200            let shares = vault.assets_to_shares(Uint128::new(100)).unwrap();
201            assert_eq!(shares, Uint128::new(100));
202        }
203
204        {
205            let assets = vault.shares_to_assets(Uint128::new(10000)).unwrap();
206            assert_eq!(assets, Uint128::new(10000));
207
208            let shares = vault.assets_to_shares(Uint128::new(10000)).unwrap();
209            assert_eq!(shares, Uint128::new(10000));
210        }
211    }
212
213    #[test]
214    fn inflation_attack_over_1() {
215        // Attacker deposits 1 to get 1 share
216        // Attacker donates 99,999 moving the balance to 100,000
217        let attacker_donation = Uint128::new(99_999);
218
219        let balance = Uint128::new(1) + attacker_donation;
220        let total_shares = Uint128::new(1);
221        let vault = VirtualOffset::new(total_shares, balance).unwrap();
222
223        // Attacker 1 share is worth amount 50_000 (captured by the vault)
224        let amount = vault.shares_to_assets(Uint128::new(1)).unwrap();
225        assert_eq!(amount, Uint128::new(50_000));
226
227        // Normal user deposits 10,000 to get 0 shares (not executed)
228        let amount = Uint128::new(10_000);
229        let shares = vault.assets_to_shares(amount).unwrap();
230        assert_eq!(shares, Uint128::new(0));
231
232        {
233            // Normal user deposits 50_001 to get 1 share
234            // ( anything below 50_001 will receive 0 shares)
235            let amount = Uint128::new(50_001);
236            let shares = vault.assets_to_shares(amount).unwrap();
237            assert_eq!(shares, Uint128::new(1));
238
239            // Moves the vault.
240            let balance = Uint128::new(150_001);
241            let total_shares = Uint128::new(1 + 1);
242            let vault = VirtualOffset::new(total_shares, balance).unwrap();
243
244            // Attacker 1 share is worth 50,000 (captured by the vault)
245            let amount = vault.shares_to_assets(Uint128::new(1)).unwrap();
246            assert_eq!(amount, Uint128::new(50_000));
247
248            // User 1 share is worth 50,000 (captured by the vault)
249            let amount = vault.shares_to_assets(shares).unwrap();
250            assert_eq!(amount, Uint128::new(50_000));
251        }
252        {
253            // Normal user deposits 100,000 to get 1 share
254            let amount = Uint128::new(100_000);
255            let shares = vault.assets_to_shares(amount).unwrap();
256            assert_eq!(shares, Uint128::new(1));
257
258            // Moves the vault.
259            let balance = Uint128::new(150_000);
260            let total_shares = Uint128::new(1 + 1);
261            let vault = VirtualOffset::new(total_shares, balance).unwrap();
262
263            // Attacker 1 share is worth 50,000 (captured by the vault) - attacker lost 50%
264            let amount = vault.shares_to_assets(Uint128::new(1)).unwrap();
265            assert_eq!(amount, Uint128::new(50_000));
266
267            // User 1 share is worth 50,000 (captured by the vault) - user lost 50%
268            let amount = vault.shares_to_assets(shares).unwrap();
269            assert_eq!(amount, Uint128::new(50_000));
270        }
271        {
272            // Normal user deposits 100,001 to get 2 shares
273            let amount = Uint128::new(100_001);
274            let shares = vault.assets_to_shares(amount).unwrap();
275            assert_eq!(shares, Uint128::new(2));
276
277            // Moves the vault.
278            let balance = Uint128::new(150_001);
279            let total_shares = Uint128::new(1 + 2);
280            let vault = VirtualOffset::new(total_shares, balance).unwrap();
281
282            // Attacker 1 share is worth 37,500 (captured by the vault) - attacker lost 62.5%
283            let amount = vault.shares_to_assets(Uint128::new(1)).unwrap();
284            assert_eq!(amount, Uint128::new(37_500));
285
286            // User 2 share is worth 75,001 (captured by the vault) - user lost 25%
287            let amount = vault.shares_to_assets(shares).unwrap();
288            assert_eq!(amount, Uint128::new(75_001));
289        }
290    }
291
292    #[test]
293    fn imbalance_1000_to_1() {
294        let balance = Uint128::new(1000);
295        let total_shares = Uint128::new(1);
296
297        // Virtual balance: (1000) + 1 = 1001
298        // Virtual shares: (1) + 1 = 2
299        let vault = VirtualOffset::new(total_shares, balance).unwrap();
300
301        // Low amounts
302        {
303            let shares = Uint128::new(500);
304            let amount = vault.shares_to_assets(shares).unwrap();
305            // Amount: (500) * 1001 / 2 = 250,250
306            assert_eq!(amount, Uint128::new(250_250));
307
308            let amount = Uint128::new(250);
309            let shares = vault.assets_to_shares(amount).unwrap();
310            // Shares: (250) * 2 / 1001 = 0.499
311            assert_eq!(shares, Uint128::new(0));
312        }
313
314        // High amounts
315        {
316            let shares = Uint128::new(10_000);
317            let amount = vault.shares_to_assets(shares).unwrap();
318            // Amount: (10,000) * 1001 / 2 = 5,005,000
319            assert_eq!(amount, Uint128::new(5_005_000));
320
321            let amount = Uint128::new(10_000_000);
322            let shares = vault.assets_to_shares(amount).unwrap();
323            // Shares: (10,000,000) * 2 / 1001 = 19,980.01998002
324            assert_eq!(shares, Uint128::new(19_980));
325        }
326    }
327
328    #[test]
329    fn imbalance_1000_to_2() {
330        let balance = Uint128::new(1000);
331        let total_shares = Uint128::new(2);
332
333        // Virtual balance: (1000) + 1 = 1001
334        // Virtual shares: (2) + 1 = 3
335        let vault = VirtualOffset::new(total_shares, balance).unwrap();
336
337        // Low amounts
338        {
339            let shares = Uint128::new(1000);
340            let amount = vault.shares_to_assets(shares).unwrap();
341            // Amount: (1000) * 1001 / 3 = 333,666.67
342            assert_eq!(amount, Uint128::new(333_666));
343
344            let amount = Uint128::new(1);
345            let shares = vault.assets_to_shares(amount).unwrap();
346            // Shares: (1) * 3 / 1001 = 0.003
347            assert_eq!(shares, Uint128::new(0));
348
349            let amount = Uint128::new(10);
350            let shares = vault.assets_to_shares(amount).unwrap();
351            // Shares: (10) * 3 / 1001 = 0.03
352            assert_eq!(shares, Uint128::new(0));
353        }
354
355        // High amounts
356        {
357            let shares = Uint128::new(100_444);
358            let amount = vault.shares_to_assets(shares).unwrap();
359            // Amount: (100,444) * 1001 / 3 = 33,514,814.67
360            assert_eq!(amount, Uint128::new(33_514_814));
361
362            let amount = Uint128::new(10_000_000);
363            let shares = vault.assets_to_shares(amount).unwrap();
364            // Shares: (10,000,000) * 3 / 1001 = 29,970.03
365            assert_eq!(shares, Uint128::new(29_970));
366        }
367    }
368
369    /// This is 100_000x over the offset amount
370    #[test]
371    fn shares_imbalance_100_000_to_1() {
372        let balance = Uint128::new(100_000);
373        let total_shares = Uint128::new(1);
374
375        // Virtual balance: (100,000) + 1 = 100,001
376        // Virtual shares: (1) + 1 = 2
377        let vault = VirtualOffset::new(total_shares, balance).unwrap();
378
379        // With 500 shares, they get 25_000_250
380        // Amount: (500) * 100,001 / 2 = 25_000_250
381        let shares = Uint128::new(500);
382        let amount = vault.shares_to_assets(shares).unwrap();
383        assert_eq!(amount, Uint128::new(25_000_250));
384
385        // With 1 share, they get 50,000
386        // Amount: (1) * 100,001 / 2 = 50,000.5
387        let shares = Uint128::new(1);
388        let amount = vault.shares_to_assets(shares).unwrap();
389        assert_eq!(amount, Uint128::new(50_000));
390
391        // With 10,000 shares, they get 500,005,000
392        // Amount: (10,000) * 100,001 / 2 = 500,005,000
393        let shares = Uint128::new(10_000);
394        let amount = vault.shares_to_assets(shares).unwrap();
395        assert_eq!(amount, Uint128::new(500_005_000));
396    }
397
398    /// This is 100_000x over the offset amount
399    #[test]
400    fn amount_imbalance_100_000_to_1() {
401        let balance = Uint128::new(100_000);
402        let total_shares = Uint128::new(1);
403
404        // Virtual balance: (100000) + 1 = 100001
405        // Virtual shares: (1) + 1 = 2
406        let vault = VirtualOffset::new(total_shares, balance).unwrap();
407
408        // With 1 amount, they get 0 share
409        // (1) * 2 / 100,001 = 0.0000199998
410        let amount = Uint128::new(1);
411        let shares = vault.assets_to_shares(amount).unwrap();
412        assert_eq!(shares, Uint128::new(0));
413
414        // (100) * 2 / 100,001 = 0.00199998
415        let amount = Uint128::new(100);
416        let shares = vault.assets_to_shares(amount).unwrap();
417        assert_eq!(shares, Uint128::new(0));
418
419        // With 50,001 amount (will at least get 1 no matter what)
420        // (50,001) * 2 / 100,001 = 1.0000099999
421        let amount = Uint128::new(50_001);
422        let shares = vault.assets_to_shares(amount).unwrap();
423        assert_eq!(shares, Uint128::new(1));
424    }
425
426    #[test]
427    fn extreme_inflation_1e20_to_1() {
428        let balance = Uint128::new(1e20 as u128);
429        let total_shares = Uint128::new(1);
430
431        // Virtual balance: (1e20) + 1 = 1e20
432        // Virtual shares: (1) + 1 = 2
433        let vault = VirtualOffset::new(total_shares, balance).unwrap();
434
435        // With 999, they get 0 shares
436        // Amount: (999) * (1 + 1)/ (1e20 + 1) = 1.998E-17
437        let amount = Uint128::new(999);
438        let shares = vault.assets_to_shares(amount).unwrap();
439        assert_eq!(shares, Uint128::new(0));
440
441        // Same for 1,000,000
442        let amount = Uint128::new(1_000_000);
443        let shares = vault.assets_to_shares(amount).unwrap();
444        assert_eq!(shares, Uint128::new(0));
445
446        // You will need at least 1e20 / 1 = 1e20 amount to get 1 share
447        let amount = Uint128::new(1e20 as u128);
448        let shares = vault.assets_to_shares(amount).unwrap();
449        assert_eq!(shares, Uint128::new(1));
450
451        // But the cost of attack is crazy.
452        // Using 1e20, you get 1 share
453        {
454            // New vault with +1 share and +1e20 balance
455            let new_share = Uint128::new(1) + Uint128::new(1);
456            let new_balance = Uint128::new(1e20 as u128) + Uint128::new(1e20 as u128);
457            let vault = VirtualOffset::new(new_share, new_balance).unwrap();
458
459            // That one share is only worth less than 1e20
460            let shares = Uint128::new(1);
461            let amount = vault.shares_to_assets(shares).unwrap();
462            assert!(amount < Uint128::new(1e20 as u128));
463        }
464    }
465
466    #[test]
467    fn overflow() {
468        let almost_max = Uint128::new(u128::MAX);
469
470        {
471            let error = VirtualOffset::new(almost_max, almost_max).unwrap_err();
472            assert_eq!(
473                error.to_string(),
474                "Overflow: Cannot Add with given operands"
475            )
476        }
477
478        {
479            let max_div_1e10 = Uint128::new(u128::MAX / 1e10 as u128);
480            let vault = VirtualOffset::new(max_div_1e10, max_div_1e10).unwrap();
481
482            vault.shares_to_assets(Uint128::new(1)).unwrap();
483            vault.assets_to_shares(Uint128::new(1)).unwrap();
484
485            vault.shares_to_assets(Uint128::new(1e9 as u128)).unwrap();
486            vault.assets_to_shares(Uint128::new(1e9 as u128)).unwrap();
487
488            vault
489                .shares_to_assets(Uint128::new((1e10 as u128) - 1))
490                .unwrap();
491            vault
492                .assets_to_shares(Uint128::new((1e10 as u128) - 1))
493                .unwrap();
494
495            let error = vault
496                .shares_to_assets(Uint128::new(1e10 as u128))
497                .unwrap_err();
498            assert_eq!(
499                error.to_string(),
500                "Overflow: Cannot Mul with given operands"
501            );
502
503            let error = vault
504                .assets_to_shares(Uint128::new(1e10 as u128))
505                .unwrap_err();
506            assert_eq!(
507                error.to_string(),
508                "Overflow: Cannot Mul with given operands"
509            );
510        }
511    }
512}