spl_stake_pool/
state.rs

1//! State transition types
2
3use {
4    crate::{
5        big_vec::BigVec, error::StakePoolError, MAX_WITHDRAWAL_FEE_INCREASE,
6        WITHDRAWAL_BASELINE_FEE,
7    },
8    borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
9    num_derive::FromPrimitive,
10    num_traits::FromPrimitive,
11    solana_program::{
12        account_info::AccountInfo,
13        borsh::get_instance_packed_len,
14        msg,
15        program_error::ProgramError,
16        program_memory::sol_memcmp,
17        program_pack::{Pack, Sealed},
18        pubkey::{Pubkey, PUBKEY_BYTES},
19        stake::state::Lockup,
20    },
21    spl_token_2022::{
22        extension::{BaseStateWithExtensions, ExtensionType, StateWithExtensions},
23        state::{Account, AccountState, Mint},
24    },
25    std::{borrow::Borrow, convert::TryFrom, fmt, matches},
26};
27
28/// Enum representing the account type managed by the program
29#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
30pub enum AccountType {
31    /// If the account has not been initialized, the enum will be 0
32    Uninitialized,
33    /// Stake pool
34    StakePool,
35    /// Validator stake list
36    ValidatorList,
37}
38
39impl Default for AccountType {
40    fn default() -> Self {
41        AccountType::Uninitialized
42    }
43}
44
45/// Initialized program details.
46#[repr(C)]
47#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
48pub struct StakePool {
49    /// Account type, must be StakePool currently
50    pub account_type: AccountType,
51
52    /// Manager authority, allows for updating the staker, manager, and fee account
53    pub manager: Pubkey,
54
55    /// Staker authority, allows for adding and removing validators, and managing stake
56    /// distribution
57    pub staker: Pubkey,
58
59    /// Stake deposit authority
60    ///
61    /// If a depositor pubkey is specified on initialization, then deposits must be
62    /// signed by this authority. If no deposit authority is specified,
63    /// then the stake pool will default to the result of:
64    /// `Pubkey::find_program_address(
65    ///     &[&stake_pool_address.as_ref(), b"deposit"],
66    ///     program_id,
67    /// )`
68    pub stake_deposit_authority: Pubkey,
69
70    /// Stake withdrawal authority bump seed
71    /// for `create_program_address(&[state::StakePool account, "withdrawal"])`
72    pub stake_withdraw_bump_seed: u8,
73
74    /// Validator stake list storage account
75    pub validator_list: Pubkey,
76
77    /// Reserve stake account, holds deactivated stake
78    pub reserve_stake: Pubkey,
79
80    /// Pool Mint
81    pub pool_mint: Pubkey,
82
83    /// Manager fee account
84    pub manager_fee_account: Pubkey,
85
86    /// Pool token program id
87    pub token_program_id: Pubkey,
88
89    /// Total stake under management.
90    /// Note that if `last_update_epoch` does not match the current epoch then
91    /// this field may not be accurate
92    pub total_lamports: u64,
93
94    /// Total supply of pool tokens (should always match the supply in the Pool Mint)
95    pub pool_token_supply: u64,
96
97    /// Last epoch the `total_lamports` field was updated
98    pub last_update_epoch: u64,
99
100    /// Lockup that all stakes in the pool must have
101    pub lockup: Lockup,
102
103    /// Fee taken as a proportion of rewards each epoch
104    pub epoch_fee: Fee,
105
106    /// Fee for next epoch
107    pub next_epoch_fee: FutureEpoch<Fee>,
108
109    /// Preferred deposit validator vote account pubkey
110    pub preferred_deposit_validator_vote_address: Option<Pubkey>,
111
112    /// Preferred withdraw validator vote account pubkey
113    pub preferred_withdraw_validator_vote_address: Option<Pubkey>,
114
115    /// Fee assessed on stake deposits
116    pub stake_deposit_fee: Fee,
117
118    /// Fee assessed on withdrawals
119    pub stake_withdrawal_fee: Fee,
120
121    /// Future stake withdrawal fee, to be set for the following epoch
122    pub next_stake_withdrawal_fee: FutureEpoch<Fee>,
123
124    /// Fees paid out to referrers on referred stake deposits.
125    /// Expressed as a percentage (0 - 100) of deposit fees.
126    /// i.e. `stake_deposit_fee`% of stake deposited is collected as deposit fees for every deposit
127    /// and `stake_referral_fee`% of the collected stake deposit fees is paid out to the referrer
128    pub stake_referral_fee: u8,
129
130    /// Toggles whether the `DepositSol` instruction requires a signature from
131    /// this `sol_deposit_authority`
132    pub sol_deposit_authority: Option<Pubkey>,
133
134    /// Fee assessed on SOL deposits
135    pub sol_deposit_fee: Fee,
136
137    /// Fees paid out to referrers on referred SOL deposits.
138    /// Expressed as a percentage (0 - 100) of SOL deposit fees.
139    /// i.e. `sol_deposit_fee`% of SOL deposited is collected as deposit fees for every deposit
140    /// and `sol_referral_fee`% of the collected SOL deposit fees is paid out to the referrer
141    pub sol_referral_fee: u8,
142
143    /// Toggles whether the `WithdrawSol` instruction requires a signature from
144    /// the `deposit_authority`
145    pub sol_withdraw_authority: Option<Pubkey>,
146
147    /// Fee assessed on SOL withdrawals
148    pub sol_withdrawal_fee: Fee,
149
150    /// Future SOL withdrawal fee, to be set for the following epoch
151    pub next_sol_withdrawal_fee: FutureEpoch<Fee>,
152
153    /// Last epoch's total pool tokens, used only for APR estimation
154    pub last_epoch_pool_token_supply: u64,
155
156    /// Last epoch's total lamports, used only for APR estimation
157    pub last_epoch_total_lamports: u64,
158}
159impl StakePool {
160    /// calculate the pool tokens that should be minted for a deposit of `stake_lamports`
161    #[inline]
162    pub fn calc_pool_tokens_for_deposit(&self, stake_lamports: u64) -> Option<u64> {
163        if self.total_lamports == 0 || self.pool_token_supply == 0 {
164            return Some(stake_lamports);
165        }
166        u64::try_from(
167            (stake_lamports as u128)
168                .checked_mul(self.pool_token_supply as u128)?
169                .checked_div(self.total_lamports as u128)?,
170        )
171        .ok()
172    }
173
174    /// calculate lamports amount on withdrawal
175    #[inline]
176    pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
177        // `checked_div` returns `None` for a 0 quotient result, but in this
178        // case, a return of 0 is valid for small amounts of pool tokens. So
179        // we check for that separately
180        let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
181        let denominator = self.pool_token_supply as u128;
182        if numerator < denominator || denominator == 0 {
183            Some(0)
184        } else {
185            u64::try_from(numerator.checked_div(denominator)?).ok()
186        }
187    }
188
189    /// calculate pool tokens to be deducted as withdrawal fees
190    #[inline]
191    pub fn calc_pool_tokens_stake_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
192        u64::try_from(self.stake_withdrawal_fee.apply(pool_tokens)?).ok()
193    }
194
195    /// calculate pool tokens to be deducted as withdrawal fees
196    #[inline]
197    pub fn calc_pool_tokens_sol_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
198        u64::try_from(self.sol_withdrawal_fee.apply(pool_tokens)?).ok()
199    }
200
201    /// calculate pool tokens to be deducted as stake deposit fees
202    #[inline]
203    pub fn calc_pool_tokens_stake_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
204        u64::try_from(self.stake_deposit_fee.apply(pool_tokens_minted)?).ok()
205    }
206
207    /// calculate pool tokens to be deducted from deposit fees as referral fees
208    #[inline]
209    pub fn calc_pool_tokens_stake_referral_fee(&self, stake_deposit_fee: u64) -> Option<u64> {
210        u64::try_from(
211            (stake_deposit_fee as u128)
212                .checked_mul(self.stake_referral_fee as u128)?
213                .checked_div(100u128)?,
214        )
215        .ok()
216    }
217
218    /// calculate pool tokens to be deducted as SOL deposit fees
219    #[inline]
220    pub fn calc_pool_tokens_sol_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
221        u64::try_from(self.sol_deposit_fee.apply(pool_tokens_minted)?).ok()
222    }
223
224    /// calculate pool tokens to be deducted from SOL deposit fees as referral fees
225    #[inline]
226    pub fn calc_pool_tokens_sol_referral_fee(&self, sol_deposit_fee: u64) -> Option<u64> {
227        u64::try_from(
228            (sol_deposit_fee as u128)
229                .checked_mul(self.sol_referral_fee as u128)?
230                .checked_div(100u128)?,
231        )
232        .ok()
233    }
234
235    /// Calculate the fee in pool tokens that goes to the manager
236    ///
237    /// This function assumes that `reward_lamports` has not already been added
238    /// to the stake pool's `total_lamports`
239    #[inline]
240    pub fn calc_epoch_fee_amount(&self, reward_lamports: u64) -> Option<u64> {
241        if reward_lamports == 0 {
242            return Some(0);
243        }
244        let total_lamports = (self.total_lamports as u128).checked_add(reward_lamports as u128)?;
245        let fee_lamports = self.epoch_fee.apply(reward_lamports)?;
246        if total_lamports == fee_lamports || self.pool_token_supply == 0 {
247            Some(reward_lamports)
248        } else {
249            u64::try_from(
250                (self.pool_token_supply as u128)
251                    .checked_mul(fee_lamports)?
252                    .checked_div(total_lamports.checked_sub(fee_lamports)?)?,
253            )
254            .ok()
255        }
256    }
257
258    /// Get the current value of pool tokens, rounded up
259    #[inline]
260    pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
261        self.total_lamports
262            .checked_add(self.pool_token_supply)?
263            .checked_sub(1)?
264            .checked_div(self.pool_token_supply)
265    }
266
267    /// Checks that the withdraw or deposit authority is valid
268    fn check_program_derived_authority(
269        authority_address: &Pubkey,
270        program_id: &Pubkey,
271        stake_pool_address: &Pubkey,
272        authority_seed: &[u8],
273        bump_seed: u8,
274    ) -> Result<(), ProgramError> {
275        let expected_address = Pubkey::create_program_address(
276            &[stake_pool_address.as_ref(), authority_seed, &[bump_seed]],
277            program_id,
278        )?;
279
280        if *authority_address == expected_address {
281            Ok(())
282        } else {
283            msg!(
284                "Incorrect authority provided, expected {}, received {}",
285                expected_address,
286                authority_address
287            );
288            Err(StakePoolError::InvalidProgramAddress.into())
289        }
290    }
291
292    /// Check if the manager fee info is a valid token program account
293    /// capable of receiving tokens from the mint.
294    pub(crate) fn check_manager_fee_info(
295        &self,
296        manager_fee_info: &AccountInfo,
297    ) -> Result<(), ProgramError> {
298        let account_data = manager_fee_info.try_borrow_data()?;
299        let token_account = StateWithExtensions::<Account>::unpack(&account_data)?;
300        if manager_fee_info.owner != &self.token_program_id
301            || token_account.base.state != AccountState::Initialized
302            || token_account.base.mint != self.pool_mint
303        {
304            msg!("Manager fee account is not owned by token program, is not initialized, or does not match stake pool's mint");
305            return Err(StakePoolError::InvalidFeeAccount.into());
306        }
307        let extensions = token_account.get_extension_types()?;
308        if extensions
309            .iter()
310            .any(|x| !is_extension_supported_for_fee_account(x))
311        {
312            return Err(StakePoolError::UnsupportedFeeAccountExtension.into());
313        }
314        Ok(())
315    }
316
317    /// Checks that the withdraw authority is valid
318    #[inline]
319    pub(crate) fn check_authority_withdraw(
320        &self,
321        withdraw_authority: &Pubkey,
322        program_id: &Pubkey,
323        stake_pool_address: &Pubkey,
324    ) -> Result<(), ProgramError> {
325        Self::check_program_derived_authority(
326            withdraw_authority,
327            program_id,
328            stake_pool_address,
329            crate::AUTHORITY_WITHDRAW,
330            self.stake_withdraw_bump_seed,
331        )
332    }
333    /// Checks that the deposit authority is valid
334    #[inline]
335    pub(crate) fn check_stake_deposit_authority(
336        &self,
337        stake_deposit_authority: &Pubkey,
338    ) -> Result<(), ProgramError> {
339        if self.stake_deposit_authority == *stake_deposit_authority {
340            Ok(())
341        } else {
342            Err(StakePoolError::InvalidStakeDepositAuthority.into())
343        }
344    }
345
346    /// Checks that the deposit authority is valid
347    /// Does nothing if `sol_deposit_authority` is currently not set
348    #[inline]
349    pub(crate) fn check_sol_deposit_authority(
350        &self,
351        maybe_sol_deposit_authority: Result<&AccountInfo, ProgramError>,
352    ) -> Result<(), ProgramError> {
353        if let Some(auth) = self.sol_deposit_authority {
354            let sol_deposit_authority = maybe_sol_deposit_authority?;
355            if auth != *sol_deposit_authority.key {
356                msg!("Expected {}, received {}", auth, sol_deposit_authority.key);
357                return Err(StakePoolError::InvalidSolDepositAuthority.into());
358            }
359            if !sol_deposit_authority.is_signer {
360                msg!("SOL Deposit authority signature missing");
361                return Err(StakePoolError::SignatureMissing.into());
362            }
363        }
364        Ok(())
365    }
366
367    /// Checks that the sol withdraw authority is valid
368    /// Does nothing if `sol_withdraw_authority` is currently not set
369    #[inline]
370    pub(crate) fn check_sol_withdraw_authority(
371        &self,
372        maybe_sol_withdraw_authority: Result<&AccountInfo, ProgramError>,
373    ) -> Result<(), ProgramError> {
374        if let Some(auth) = self.sol_withdraw_authority {
375            let sol_withdraw_authority = maybe_sol_withdraw_authority?;
376            if auth != *sol_withdraw_authority.key {
377                return Err(StakePoolError::InvalidSolWithdrawAuthority.into());
378            }
379            if !sol_withdraw_authority.is_signer {
380                msg!("SOL withdraw authority signature missing");
381                return Err(StakePoolError::SignatureMissing.into());
382            }
383        }
384        Ok(())
385    }
386
387    /// Check mint is correct
388    #[inline]
389    pub(crate) fn check_mint(&self, mint_info: &AccountInfo) -> Result<u8, ProgramError> {
390        if *mint_info.key != self.pool_mint {
391            Err(StakePoolError::WrongPoolMint.into())
392        } else {
393            let mint_data = mint_info.try_borrow_data()?;
394            let mint = StateWithExtensions::<Mint>::unpack(&mint_data)?;
395            Ok(mint.base.decimals)
396        }
397    }
398
399    /// Check manager validity and signature
400    pub(crate) fn check_manager(&self, manager_info: &AccountInfo) -> Result<(), ProgramError> {
401        if *manager_info.key != self.manager {
402            msg!(
403                "Incorrect manager provided, expected {}, received {}",
404                self.manager,
405                manager_info.key
406            );
407            return Err(StakePoolError::WrongManager.into());
408        }
409        if !manager_info.is_signer {
410            msg!("Manager signature missing");
411            return Err(StakePoolError::SignatureMissing.into());
412        }
413        Ok(())
414    }
415
416    /// Check staker validity and signature
417    pub(crate) fn check_staker(&self, staker_info: &AccountInfo) -> Result<(), ProgramError> {
418        if *staker_info.key != self.staker {
419            msg!(
420                "Incorrect staker provided, expected {}, received {}",
421                self.staker,
422                staker_info.key
423            );
424            return Err(StakePoolError::WrongStaker.into());
425        }
426        if !staker_info.is_signer {
427            msg!("Staker signature missing");
428            return Err(StakePoolError::SignatureMissing.into());
429        }
430        Ok(())
431    }
432
433    /// Check the validator list is valid
434    pub fn check_validator_list(
435        &self,
436        validator_list_info: &AccountInfo,
437    ) -> Result<(), ProgramError> {
438        if *validator_list_info.key != self.validator_list {
439            msg!(
440                "Invalid validator list provided, expected {}, received {}",
441                self.validator_list,
442                validator_list_info.key
443            );
444            Err(StakePoolError::InvalidValidatorStakeList.into())
445        } else {
446            Ok(())
447        }
448    }
449
450    /// Check the reserve stake is valid
451    pub fn check_reserve_stake(
452        &self,
453        reserve_stake_info: &AccountInfo,
454    ) -> Result<(), ProgramError> {
455        if *reserve_stake_info.key != self.reserve_stake {
456            msg!(
457                "Invalid reserve stake provided, expected {}, received {}",
458                self.reserve_stake,
459                reserve_stake_info.key
460            );
461            Err(StakePoolError::InvalidProgramAddress.into())
462        } else {
463            Ok(())
464        }
465    }
466
467    /// Check if StakePool is actually initialized as a stake pool
468    pub fn is_valid(&self) -> bool {
469        self.account_type == AccountType::StakePool
470    }
471
472    /// Check if StakePool is currently uninitialized
473    pub fn is_uninitialized(&self) -> bool {
474        self.account_type == AccountType::Uninitialized
475    }
476
477    /// Updates one of the StakePool's fees.
478    pub fn update_fee(&mut self, fee: &FeeType) -> Result<(), StakePoolError> {
479        match fee {
480            FeeType::SolReferral(new_fee) => self.sol_referral_fee = *new_fee,
481            FeeType::StakeReferral(new_fee) => self.stake_referral_fee = *new_fee,
482            FeeType::Epoch(new_fee) => self.next_epoch_fee = FutureEpoch::new(*new_fee),
483            FeeType::StakeWithdrawal(new_fee) => {
484                new_fee.check_withdrawal(&self.stake_withdrawal_fee)?;
485                self.next_stake_withdrawal_fee = FutureEpoch::new(*new_fee)
486            }
487            FeeType::SolWithdrawal(new_fee) => {
488                new_fee.check_withdrawal(&self.sol_withdrawal_fee)?;
489                self.next_sol_withdrawal_fee = FutureEpoch::new(*new_fee)
490            }
491            FeeType::SolDeposit(new_fee) => self.sol_deposit_fee = *new_fee,
492            FeeType::StakeDeposit(new_fee) => self.stake_deposit_fee = *new_fee,
493        };
494        Ok(())
495    }
496}
497
498/// Checks if the given extension is supported for the stake pool mint
499pub fn is_extension_supported_for_mint(extension_type: &ExtensionType) -> bool {
500    const SUPPORTED_EXTENSIONS: [ExtensionType; 5] = [
501        ExtensionType::Uninitialized,
502        ExtensionType::TransferFeeConfig,
503        ExtensionType::ConfidentialTransferMint,
504        ExtensionType::DefaultAccountState, // ok, but a freeze authority is not
505        ExtensionType::InterestBearingConfig,
506    ];
507    if !SUPPORTED_EXTENSIONS.contains(extension_type) {
508        msg!(
509            "Stake pool mint account cannot have the {:?} extension",
510            extension_type
511        );
512        false
513    } else {
514        true
515    }
516}
517
518/// Checks if the given extension is supported for the stake pool's fee account
519pub fn is_extension_supported_for_fee_account(extension_type: &ExtensionType) -> bool {
520    // Note: this does not include the `ConfidentialTransferAccount` extension
521    // because it is possible to block non-confidential transfers with the
522    // extension enabled.
523    const SUPPORTED_EXTENSIONS: [ExtensionType; 4] = [
524        ExtensionType::Uninitialized,
525        ExtensionType::TransferFeeAmount,
526        ExtensionType::ImmutableOwner,
527        ExtensionType::CpiGuard,
528    ];
529    if !SUPPORTED_EXTENSIONS.contains(extension_type) {
530        msg!("Fee account cannot have the {:?} extension", extension_type);
531        false
532    } else {
533        true
534    }
535}
536
537/// Storage list for all validator stake accounts in the pool.
538#[repr(C)]
539#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
540pub struct ValidatorList {
541    /// Data outside of the validator list, separated out for cheaper deserializations
542    pub header: ValidatorListHeader,
543
544    /// List of stake info for each validator in the pool
545    pub validators: Vec<ValidatorStakeInfo>,
546}
547
548/// Helper type to deserialize just the start of a ValidatorList
549#[repr(C)]
550#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
551pub struct ValidatorListHeader {
552    /// Account type, must be ValidatorList currently
553    pub account_type: AccountType,
554
555    /// Maximum allowable number of validators
556    pub max_validators: u32,
557}
558
559/// Status of the stake account in the validator list, for accounting
560#[derive(
561    FromPrimitive, Copy, Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema,
562)]
563pub enum StakeStatus {
564    /// Stake account is active, there may be a transient stake as well
565    Active,
566    /// Only transient stake account exists, when a transient stake is
567    /// deactivating during validator removal
568    DeactivatingTransient,
569    /// No more validator stake accounts exist, entry ready for removal during
570    /// `UpdateStakePoolBalance`
571    ReadyForRemoval,
572    /// Only the validator stake account is deactivating, no transient stake
573    /// account exists
574    DeactivatingValidator,
575    /// Both the transient and validator stake account are deactivating, when
576    /// a validator is removed with a transient stake active
577    DeactivatingAll,
578}
579impl StakeStatus {
580    /// Downgrade the status towards ready for removal by removing the validator stake
581    pub fn remove_validator_stake(&mut self) {
582        let new_self = match self {
583            Self::Active | Self::DeactivatingTransient | Self::ReadyForRemoval => *self,
584            Self::DeactivatingAll => Self::DeactivatingTransient,
585            Self::DeactivatingValidator => Self::ReadyForRemoval,
586        };
587        *self = new_self;
588    }
589    /// Downgrade the status towards ready for removal by removing the transient stake
590    pub fn remove_transient_stake(&mut self) {
591        let new_self = match self {
592            Self::Active | Self::DeactivatingValidator | Self::ReadyForRemoval => *self,
593            Self::DeactivatingAll => Self::DeactivatingValidator,
594            Self::DeactivatingTransient => Self::ReadyForRemoval,
595        };
596        *self = new_self;
597    }
598}
599impl Default for StakeStatus {
600    fn default() -> Self {
601        Self::Active
602    }
603}
604
605/// Withdrawal type, figured out during process_withdraw_stake
606#[derive(Debug, PartialEq)]
607pub(crate) enum StakeWithdrawSource {
608    /// Some of an active stake account, but not all
609    Active,
610    /// Some of a transient stake account
611    Transient,
612    /// Take a whole validator stake account
613    ValidatorRemoval,
614}
615
616/// Information about a validator in the pool
617///
618/// NOTE: ORDER IS VERY IMPORTANT HERE, PLEASE DO NOT RE-ORDER THE FIELDS UNLESS
619/// THERE'S AN EXTREMELY GOOD REASON.
620///
621/// To save on BPF instructions, the serialized bytes are reinterpreted with an
622/// unsafe pointer cast, which means that this structure cannot have any
623/// undeclared alignment-padding in its representation.
624#[repr(C)]
625#[derive(Clone, Copy, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
626pub struct ValidatorStakeInfo {
627    /// Amount of lamports on the validator stake account, including rent
628    ///
629    /// Note that if `last_update_epoch` does not match the current epoch then
630    /// this field may not be accurate
631    pub active_stake_lamports: u64,
632
633    /// Amount of transient stake delegated to this validator
634    ///
635    /// Note that if `last_update_epoch` does not match the current epoch then
636    /// this field may not be accurate
637    pub transient_stake_lamports: u64,
638
639    /// Last epoch the active and transient stake lamports fields were updated
640    pub last_update_epoch: u64,
641
642    /// Transient account seed suffix, used to derive the transient stake account address
643    pub transient_seed_suffix: u64,
644
645    /// Unused space, initially meant to specify the end of seed suffixes
646    pub unused: u32,
647
648    /// Validator account seed suffix
649    pub validator_seed_suffix: u32, // really `Option<NonZeroU32>` so 0 is `None`
650
651    /// Status of the validator stake account
652    pub status: StakeStatus,
653
654    /// Validator vote account address
655    pub vote_account_address: Pubkey,
656}
657
658impl ValidatorStakeInfo {
659    /// Get the total lamports on this validator (active and transient)
660    pub fn stake_lamports(&self) -> Result<u64, StakePoolError> {
661        self.active_stake_lamports
662            .checked_add(self.transient_stake_lamports)
663            .ok_or(StakePoolError::CalculationFailure)
664    }
665
666    /// Performs a very cheap comparison, for checking if this validator stake
667    /// info matches the vote account address
668    pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
669        sol_memcmp(
670            &data[41..41_usize.saturating_add(PUBKEY_BYTES)],
671            vote_address.as_ref(),
672            PUBKEY_BYTES,
673        ) == 0
674    }
675
676    /// Performs a comparison, used to check if this validator stake
677    /// info has more active lamports than some limit
678    pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
679        // without this unwrap, compute usage goes up significantly
680        u64::try_from_slice(&data[0..8]).unwrap() > *lamports
681    }
682
683    /// Performs a comparison, used to check if this validator stake
684    /// info has more transient lamports than some limit
685    pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
686        // without this unwrap, compute usage goes up significantly
687        u64::try_from_slice(&data[8..16]).unwrap() > *lamports
688    }
689
690    /// Check that the validator stake info is valid
691    pub fn is_not_removed(data: &[u8]) -> bool {
692        FromPrimitive::from_u8(data[40]) != Some(StakeStatus::ReadyForRemoval)
693    }
694}
695
696impl Sealed for ValidatorStakeInfo {}
697
698impl Pack for ValidatorStakeInfo {
699    const LEN: usize = 73;
700    fn pack_into_slice(&self, data: &mut [u8]) {
701        let mut data = data;
702        // Removing this unwrap would require changing from `Pack` to some other
703        // trait or `bytemuck`, so it stays in for now
704        self.serialize(&mut data).unwrap();
705    }
706    fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
707        let unpacked = Self::try_from_slice(src)?;
708        Ok(unpacked)
709    }
710}
711
712impl ValidatorList {
713    /// Create an empty instance containing space for `max_validators` and preferred validator keys
714    pub fn new(max_validators: u32) -> Self {
715        Self {
716            header: ValidatorListHeader {
717                account_type: AccountType::ValidatorList,
718                max_validators,
719            },
720            validators: vec![ValidatorStakeInfo::default(); max_validators as usize],
721        }
722    }
723
724    /// Calculate the number of validator entries that fit in the provided length
725    pub fn calculate_max_validators(buffer_length: usize) -> usize {
726        let header_size = ValidatorListHeader::LEN.saturating_add(4);
727        buffer_length
728            .saturating_sub(header_size)
729            .saturating_div(ValidatorStakeInfo::LEN)
730    }
731
732    /// Check if contains validator with particular pubkey
733    pub fn contains(&self, vote_account_address: &Pubkey) -> bool {
734        self.validators
735            .iter()
736            .any(|x| x.vote_account_address == *vote_account_address)
737    }
738
739    /// Check if contains validator with particular pubkey
740    pub fn find_mut(&mut self, vote_account_address: &Pubkey) -> Option<&mut ValidatorStakeInfo> {
741        self.validators
742            .iter_mut()
743            .find(|x| x.vote_account_address == *vote_account_address)
744    }
745    /// Check if contains validator with particular pubkey
746    pub fn find(&self, vote_account_address: &Pubkey) -> Option<&ValidatorStakeInfo> {
747        self.validators
748            .iter()
749            .find(|x| x.vote_account_address == *vote_account_address)
750    }
751
752    /// Check if the list has any active stake
753    pub fn has_active_stake(&self) -> bool {
754        self.validators.iter().any(|x| x.active_stake_lamports > 0)
755    }
756}
757
758impl ValidatorListHeader {
759    const LEN: usize = 1 + 4;
760
761    /// Check if validator stake list is actually initialized as a validator stake list
762    pub fn is_valid(&self) -> bool {
763        self.account_type == AccountType::ValidatorList
764    }
765
766    /// Check if the validator stake list is uninitialized
767    pub fn is_uninitialized(&self) -> bool {
768        self.account_type == AccountType::Uninitialized
769    }
770
771    /// Extracts a slice of ValidatorStakeInfo types from the vec part
772    /// of the ValidatorList
773    pub fn deserialize_mut_slice(
774        data: &mut [u8],
775        skip: usize,
776        len: usize,
777    ) -> Result<(Self, Vec<&mut ValidatorStakeInfo>), ProgramError> {
778        let (header, mut big_vec) = Self::deserialize_vec(data)?;
779        let validator_list = big_vec.deserialize_mut_slice::<ValidatorStakeInfo>(skip, len)?;
780        Ok((header, validator_list))
781    }
782
783    /// Extracts the validator list into its header and internal BigVec
784    pub fn deserialize_vec(data: &mut [u8]) -> Result<(Self, BigVec), ProgramError> {
785        let mut data_mut = data.borrow();
786        let header = ValidatorListHeader::deserialize(&mut data_mut)?;
787        let length = get_instance_packed_len(&header)?;
788
789        let big_vec = BigVec {
790            data: &mut data[length..],
791        };
792        Ok((header, big_vec))
793    }
794}
795
796/// Wrapper type that "counts down" epochs, which is Borsh-compatible with the
797/// native `Option`
798#[repr(C)]
799#[derive(Clone, Copy, Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
800pub enum FutureEpoch<T> {
801    /// Nothing is set
802    None,
803    /// Value is ready after the next epoch boundary
804    One(T),
805    /// Value is ready after two epoch boundaries
806    Two(T),
807}
808impl<T> Default for FutureEpoch<T> {
809    fn default() -> Self {
810        Self::None
811    }
812}
813impl<T> FutureEpoch<T> {
814    /// Create a new value to be unlocked in a two epochs
815    pub fn new(value: T) -> Self {
816        Self::Two(value)
817    }
818}
819impl<T: Clone> FutureEpoch<T> {
820    /// Update the epoch, to be done after `get`ting the underlying value
821    pub fn update_epoch(&mut self) {
822        match self {
823            Self::None => {}
824            Self::One(_) => {
825                // The value has waited its last epoch
826                *self = Self::None;
827            }
828            // The value still has to wait one more epoch after this
829            Self::Two(v) => {
830                *self = Self::One(v.clone());
831            }
832        }
833    }
834
835    /// Get the value if it's ready, which is only at `One` epoch remaining
836    pub fn get(&self) -> Option<&T> {
837        match self {
838            Self::None | Self::Two(_) => None,
839            Self::One(v) => Some(v),
840        }
841    }
842}
843impl<T> From<FutureEpoch<T>> for Option<T> {
844    fn from(v: FutureEpoch<T>) -> Option<T> {
845        match v {
846            FutureEpoch::None => None,
847            FutureEpoch::One(inner) | FutureEpoch::Two(inner) => Some(inner),
848        }
849    }
850}
851
852/// Fee rate as a ratio, minted on `UpdateStakePoolBalance` as a proportion of
853/// the rewards
854/// If either the numerator or the denominator is 0, the fee is considered to be 0
855#[repr(C)]
856#[derive(Clone, Copy, Debug, Default, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
857pub struct Fee {
858    /// denominator of the fee ratio
859    pub denominator: u64,
860    /// numerator of the fee ratio
861    pub numerator: u64,
862}
863
864impl Fee {
865    /// Applies the Fee's rates to a given amount, `amt`
866    /// returning the amount to be subtracted from it as fees
867    /// (0 if denominator is 0 or amt is 0),
868    /// or None if overflow occurs
869    #[inline]
870    pub fn apply(&self, amt: u64) -> Option<u128> {
871        if self.denominator == 0 {
872            return Some(0);
873        }
874        (amt as u128)
875            .checked_mul(self.numerator as u128)?
876            .checked_div(self.denominator as u128)
877    }
878
879    /// Withdrawal fees have some additional restrictions,
880    /// this fn checks if those are met, returning an error if not.
881    /// Does nothing and returns Ok if fee type is not withdrawal
882    pub fn check_withdrawal(&self, old_withdrawal_fee: &Fee) -> Result<(), StakePoolError> {
883        // If the previous withdrawal fee was 0, we allow the fee to be set to a
884        // maximum of (WITHDRAWAL_BASELINE_FEE * MAX_WITHDRAWAL_FEE_INCREASE)
885        let (old_num, old_denom) =
886            if old_withdrawal_fee.denominator == 0 || old_withdrawal_fee.numerator == 0 {
887                (
888                    WITHDRAWAL_BASELINE_FEE.numerator,
889                    WITHDRAWAL_BASELINE_FEE.denominator,
890                )
891            } else {
892                (old_withdrawal_fee.numerator, old_withdrawal_fee.denominator)
893            };
894
895        // Check that new_fee / old_fee <= MAX_WITHDRAWAL_FEE_INCREASE
896        // Program fails if provided numerator or denominator is too large, resulting in overflow
897        if (old_num as u128)
898            .checked_mul(self.denominator as u128)
899            .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.numerator as u128))
900            .ok_or(StakePoolError::CalculationFailure)?
901            < (self.numerator as u128)
902                .checked_mul(old_denom as u128)
903                .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.denominator as u128))
904                .ok_or(StakePoolError::CalculationFailure)?
905        {
906            msg!(
907                "Fee increase exceeds maximum allowed, proposed increase factor ({} / {})",
908                self.numerator.saturating_mul(old_denom),
909                old_num.saturating_mul(self.denominator),
910            );
911            return Err(StakePoolError::FeeIncreaseTooHigh);
912        }
913        Ok(())
914    }
915}
916
917impl fmt::Display for Fee {
918    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
919        if self.numerator > 0 && self.denominator > 0 {
920            write!(f, "{}/{}", self.numerator, self.denominator)
921        } else {
922            write!(f, "none")
923        }
924    }
925}
926
927/// The type of fees that can be set on the stake pool
928#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
929pub enum FeeType {
930    /// Referral fees for SOL deposits
931    SolReferral(u8),
932    /// Referral fees for stake deposits
933    StakeReferral(u8),
934    /// Management fee paid per epoch
935    Epoch(Fee),
936    /// Stake withdrawal fee
937    StakeWithdrawal(Fee),
938    /// Deposit fee for SOL deposits
939    SolDeposit(Fee),
940    /// Deposit fee for stake deposits
941    StakeDeposit(Fee),
942    /// SOL withdrawal fee
943    SolWithdrawal(Fee),
944}
945
946impl FeeType {
947    /// Checks if the provided fee is too high, returning an error if so
948    pub fn check_too_high(&self) -> Result<(), StakePoolError> {
949        let too_high = match self {
950            Self::SolReferral(pct) => *pct > 100u8,
951            Self::StakeReferral(pct) => *pct > 100u8,
952            Self::Epoch(fee) => fee.numerator > fee.denominator,
953            Self::StakeWithdrawal(fee) => fee.numerator > fee.denominator,
954            Self::SolWithdrawal(fee) => fee.numerator > fee.denominator,
955            Self::SolDeposit(fee) => fee.numerator > fee.denominator,
956            Self::StakeDeposit(fee) => fee.numerator > fee.denominator,
957        };
958        if too_high {
959            msg!("Fee greater than 100%: {:?}", self);
960            return Err(StakePoolError::FeeTooHigh);
961        }
962        Ok(())
963    }
964
965    /// Returns if the contained fee can only be updated earliest on the next epoch
966    #[inline]
967    pub fn can_only_change_next_epoch(&self) -> bool {
968        matches!(
969            self,
970            Self::StakeWithdrawal(_) | Self::SolWithdrawal(_) | Self::Epoch(_)
971        )
972    }
973}
974
975#[cfg(test)]
976mod test {
977    #![allow(clippy::integer_arithmetic)]
978    use {
979        super::*,
980        proptest::prelude::*,
981        solana_program::{
982            borsh::{get_instance_packed_len, get_packed_len, try_from_slice_unchecked},
983            clock::{DEFAULT_SLOTS_PER_EPOCH, DEFAULT_S_PER_SLOT, SECONDS_PER_DAY},
984            native_token::LAMPORTS_PER_SOL,
985        },
986    };
987
988    fn uninitialized_validator_list() -> ValidatorList {
989        ValidatorList {
990            header: ValidatorListHeader {
991                account_type: AccountType::Uninitialized,
992                max_validators: 0,
993            },
994            validators: vec![],
995        }
996    }
997
998    fn test_validator_list(max_validators: u32) -> ValidatorList {
999        ValidatorList {
1000            header: ValidatorListHeader {
1001                account_type: AccountType::ValidatorList,
1002                max_validators,
1003            },
1004            validators: vec![
1005                ValidatorStakeInfo {
1006                    status: StakeStatus::Active,
1007                    vote_account_address: Pubkey::new_from_array([1; 32]),
1008                    active_stake_lamports: u64::from_le_bytes([255; 8]),
1009                    transient_stake_lamports: u64::from_le_bytes([128; 8]),
1010                    last_update_epoch: u64::from_le_bytes([64; 8]),
1011                    transient_seed_suffix: 0,
1012                    unused: 0,
1013                    validator_seed_suffix: 0,
1014                },
1015                ValidatorStakeInfo {
1016                    status: StakeStatus::DeactivatingTransient,
1017                    vote_account_address: Pubkey::new_from_array([2; 32]),
1018                    active_stake_lamports: 998877665544,
1019                    transient_stake_lamports: 222222222,
1020                    last_update_epoch: 11223445566,
1021                    transient_seed_suffix: 0,
1022                    unused: 0,
1023                    validator_seed_suffix: 0,
1024                },
1025                ValidatorStakeInfo {
1026                    status: StakeStatus::ReadyForRemoval,
1027                    vote_account_address: Pubkey::new_from_array([3; 32]),
1028                    active_stake_lamports: 0,
1029                    transient_stake_lamports: 0,
1030                    last_update_epoch: 999999999999999,
1031                    transient_seed_suffix: 0,
1032                    unused: 0,
1033                    validator_seed_suffix: 0,
1034                },
1035            ],
1036        }
1037    }
1038
1039    #[test]
1040    fn state_packing() {
1041        let max_validators = 10_000;
1042        let size = get_instance_packed_len(&ValidatorList::new(max_validators)).unwrap();
1043        let stake_list = uninitialized_validator_list();
1044        let mut byte_vec = vec![0u8; size];
1045        let mut bytes = byte_vec.as_mut_slice();
1046        stake_list.serialize(&mut bytes).unwrap();
1047        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1048        assert_eq!(stake_list_unpacked, stake_list);
1049
1050        // Empty, one preferred key
1051        let stake_list = ValidatorList {
1052            header: ValidatorListHeader {
1053                account_type: AccountType::ValidatorList,
1054                max_validators: 0,
1055            },
1056            validators: vec![],
1057        };
1058        let mut byte_vec = vec![0u8; size];
1059        let mut bytes = byte_vec.as_mut_slice();
1060        stake_list.serialize(&mut bytes).unwrap();
1061        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1062        assert_eq!(stake_list_unpacked, stake_list);
1063
1064        // With several accounts
1065        let stake_list = test_validator_list(max_validators);
1066        let mut byte_vec = vec![0u8; size];
1067        let mut bytes = byte_vec.as_mut_slice();
1068        stake_list.serialize(&mut bytes).unwrap();
1069        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1070        assert_eq!(stake_list_unpacked, stake_list);
1071    }
1072
1073    #[test]
1074    fn validator_list_active_stake() {
1075        let max_validators = 10_000;
1076        let mut validator_list = test_validator_list(max_validators);
1077        assert!(validator_list.has_active_stake());
1078        for validator in validator_list.validators.iter_mut() {
1079            validator.active_stake_lamports = 0;
1080        }
1081        assert!(!validator_list.has_active_stake());
1082    }
1083
1084    #[test]
1085    fn validator_list_deserialize_mut_slice() {
1086        let max_validators = 10;
1087        let stake_list = test_validator_list(max_validators);
1088        let mut serialized = stake_list.try_to_vec().unwrap();
1089        let (header, list) = ValidatorListHeader::deserialize_mut_slice(
1090            &mut serialized,
1091            0,
1092            stake_list.validators.len(),
1093        )
1094        .unwrap();
1095        assert_eq!(header.account_type, AccountType::ValidatorList);
1096        assert_eq!(header.max_validators, max_validators);
1097        assert!(list
1098            .iter()
1099            .zip(stake_list.validators.iter())
1100            .all(|(a, b)| *a == b));
1101
1102        let (_, list) = ValidatorListHeader::deserialize_mut_slice(&mut serialized, 1, 2).unwrap();
1103        assert!(list
1104            .iter()
1105            .zip(stake_list.validators[1..].iter())
1106            .all(|(a, b)| *a == b));
1107        let (_, list) = ValidatorListHeader::deserialize_mut_slice(&mut serialized, 2, 1).unwrap();
1108        assert!(list
1109            .iter()
1110            .zip(stake_list.validators[2..].iter())
1111            .all(|(a, b)| *a == b));
1112        let (_, list) = ValidatorListHeader::deserialize_mut_slice(&mut serialized, 0, 2).unwrap();
1113        assert!(list
1114            .iter()
1115            .zip(stake_list.validators[..2].iter())
1116            .all(|(a, b)| *a == b));
1117
1118        assert_eq!(
1119            ValidatorListHeader::deserialize_mut_slice(&mut serialized, 0, 4).unwrap_err(),
1120            ProgramError::AccountDataTooSmall
1121        );
1122        assert_eq!(
1123            ValidatorListHeader::deserialize_mut_slice(&mut serialized, 1, 3).unwrap_err(),
1124            ProgramError::AccountDataTooSmall
1125        );
1126    }
1127
1128    #[test]
1129    fn validator_list_iter() {
1130        let max_validators = 10;
1131        let stake_list = test_validator_list(max_validators);
1132        let mut serialized = stake_list.try_to_vec().unwrap();
1133        let (_, big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1134        for (a, b) in big_vec
1135            .iter::<ValidatorStakeInfo>()
1136            .zip(stake_list.validators.iter())
1137        {
1138            assert_eq!(a, b);
1139        }
1140    }
1141
1142    proptest! {
1143        #[test]
1144        fn stake_list_size_calculation(test_amount in 0..=100_000_u32) {
1145            let validators = ValidatorList::new(test_amount);
1146            let size = get_instance_packed_len(&validators).unwrap();
1147            assert_eq!(ValidatorList::calculate_max_validators(size), test_amount as usize);
1148            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(1)), test_amount as usize);
1149            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(get_packed_len::<ValidatorStakeInfo>())), (test_amount + 1)as usize);
1150            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_sub(1)), (test_amount.saturating_sub(1)) as usize);
1151        }
1152    }
1153
1154    prop_compose! {
1155        fn fee()(denominator in 1..=u16::MAX)(
1156            denominator in Just(denominator),
1157            numerator in 0..=denominator,
1158        ) -> (u64, u64) {
1159            (numerator as u64, denominator as u64)
1160        }
1161    }
1162
1163    prop_compose! {
1164        fn total_stake_and_rewards()(total_lamports in 1..u64::MAX)(
1165            total_lamports in Just(total_lamports),
1166            rewards in 0..=total_lamports,
1167        ) -> (u64, u64) {
1168            (total_lamports - rewards, rewards)
1169        }
1170    }
1171
1172    #[test]
1173    fn specific_fee_calculation() {
1174        // 10% of 10 SOL in rewards should be 1 SOL in fees
1175        let epoch_fee = Fee {
1176            numerator: 1,
1177            denominator: 10,
1178        };
1179        let mut stake_pool = StakePool {
1180            total_lamports: 100 * LAMPORTS_PER_SOL,
1181            pool_token_supply: 100 * LAMPORTS_PER_SOL,
1182            epoch_fee,
1183            ..StakePool::default()
1184        };
1185        let reward_lamports = 10 * LAMPORTS_PER_SOL;
1186        let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1187
1188        stake_pool.total_lamports += reward_lamports;
1189        stake_pool.pool_token_supply += pool_token_fee;
1190
1191        let fee_lamports = stake_pool
1192            .calc_lamports_withdraw_amount(pool_token_fee)
1193            .unwrap();
1194        assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); // off-by-one due to truncation
1195    }
1196
1197    #[test]
1198    fn zero_withdraw_calculation() {
1199        let epoch_fee = Fee {
1200            numerator: 0,
1201            denominator: 1,
1202        };
1203        let stake_pool = StakePool {
1204            epoch_fee,
1205            ..StakePool::default()
1206        };
1207        let fee_lamports = stake_pool.calc_lamports_withdraw_amount(0).unwrap();
1208        assert_eq!(fee_lamports, 0);
1209    }
1210
1211    #[test]
1212    fn divide_by_zero_fee() {
1213        let stake_pool = StakePool {
1214            total_lamports: 0,
1215            epoch_fee: Fee {
1216                numerator: 1,
1217                denominator: 10,
1218            },
1219            ..StakePool::default()
1220        };
1221        let rewards = 10;
1222        let fee = stake_pool.calc_epoch_fee_amount(rewards).unwrap();
1223        assert_eq!(fee, rewards);
1224    }
1225
1226    #[test]
1227    fn approximate_apr_calculation() {
1228        // 8% / year means roughly .044% / epoch
1229        let stake_pool = StakePool {
1230            last_epoch_total_lamports: 100_000,
1231            last_epoch_pool_token_supply: 100_000,
1232            total_lamports: 100_044,
1233            pool_token_supply: 100_000,
1234            ..StakePool::default()
1235        };
1236        let pool_token_value =
1237            stake_pool.total_lamports as f64 / stake_pool.pool_token_supply as f64;
1238        let last_epoch_pool_token_value = stake_pool.last_epoch_total_lamports as f64
1239            / stake_pool.last_epoch_pool_token_supply as f64;
1240        let epoch_rate = pool_token_value / last_epoch_pool_token_value - 1.0;
1241        const SECONDS_PER_EPOCH: f64 = DEFAULT_SLOTS_PER_EPOCH as f64 * DEFAULT_S_PER_SLOT;
1242        const EPOCHS_PER_YEAR: f64 = SECONDS_PER_DAY as f64 * 365.25 / SECONDS_PER_EPOCH;
1243        const EPSILON: f64 = 0.00001;
1244        let yearly_rate = epoch_rate * EPOCHS_PER_YEAR;
1245        assert!((yearly_rate - 0.080355).abs() < EPSILON);
1246    }
1247
1248    proptest! {
1249        #[test]
1250        fn fee_calculation(
1251            (numerator, denominator) in fee(),
1252            (total_lamports, reward_lamports) in total_stake_and_rewards(),
1253        ) {
1254            let epoch_fee = Fee { denominator, numerator };
1255            let mut stake_pool = StakePool {
1256                total_lamports,
1257                pool_token_supply: total_lamports,
1258                epoch_fee,
1259                ..StakePool::default()
1260            };
1261            let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1262
1263            stake_pool.total_lamports += reward_lamports;
1264            stake_pool.pool_token_supply += pool_token_fee;
1265
1266            let fee_lamports = stake_pool.calc_lamports_withdraw_amount(pool_token_fee).unwrap();
1267            let max_fee_lamports = u64::try_from((reward_lamports as u128) * (epoch_fee.numerator as u128) / (epoch_fee.denominator as u128)).unwrap();
1268            assert!(max_fee_lamports >= fee_lamports,
1269                "Max possible fee must always be greater than or equal to what is actually withdrawn, max {} actual {}",
1270                max_fee_lamports,
1271                fee_lamports);
1272
1273            // since we do two "flooring" conversions, the max epsilon should be
1274            // correct up to 2 lamports (one for each floor division), plus a
1275            // correction for huge discrepancies between rewards and total stake
1276            let epsilon = 2 + reward_lamports / total_lamports;
1277            assert!(max_fee_lamports - fee_lamports <= epsilon,
1278                "Max expected fee in lamports {}, actually receive {}, epsilon {}",
1279                max_fee_lamports, fee_lamports, epsilon);
1280        }
1281    }
1282
1283    prop_compose! {
1284        fn total_tokens_and_deposit()(total_lamports in 1..u64::MAX)(
1285            total_lamports in Just(total_lamports),
1286            pool_token_supply in 1..=total_lamports,
1287            deposit_lamports in 1..total_lamports,
1288        ) -> (u64, u64, u64) {
1289            (total_lamports - deposit_lamports, pool_token_supply.saturating_sub(deposit_lamports).max(1), deposit_lamports)
1290        }
1291    }
1292
1293    proptest! {
1294        #[test]
1295        fn deposit_and_withdraw(
1296            (total_lamports, pool_token_supply, deposit_stake) in total_tokens_and_deposit()
1297        ) {
1298            let mut stake_pool = StakePool {
1299                total_lamports,
1300                pool_token_supply,
1301                ..StakePool::default()
1302            };
1303            let deposit_result = stake_pool.calc_pool_tokens_for_deposit(deposit_stake).unwrap();
1304            prop_assume!(deposit_result > 0);
1305            stake_pool.total_lamports += deposit_stake;
1306            stake_pool.pool_token_supply += deposit_result;
1307            let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
1308            assert!(withdraw_result <= deposit_stake);
1309
1310            // also test splitting the withdrawal in two operations
1311            if deposit_result >= 2 {
1312                let first_half_deposit = deposit_result / 2;
1313                let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
1314                stake_pool.total_lamports -= first_withdraw_result;
1315                stake_pool.pool_token_supply -= first_half_deposit;
1316                let second_half_deposit = deposit_result - first_half_deposit; // do the whole thing
1317                let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
1318                assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
1319            }
1320        }
1321    }
1322
1323    #[test]
1324    fn specific_split_withdrawal() {
1325        let total_lamports = 1_100_000_000_000;
1326        let pool_token_supply = 1_000_000_000_000;
1327        let deposit_stake = 3;
1328        let mut stake_pool = StakePool {
1329            total_lamports,
1330            pool_token_supply,
1331            ..StakePool::default()
1332        };
1333        let deposit_result = stake_pool
1334            .calc_pool_tokens_for_deposit(deposit_stake)
1335            .unwrap();
1336        assert!(deposit_result > 0);
1337        stake_pool.total_lamports += deposit_stake;
1338        stake_pool.pool_token_supply += deposit_result;
1339        let withdraw_result = stake_pool
1340            .calc_lamports_withdraw_amount(deposit_result / 2)
1341            .unwrap();
1342        assert!(withdraw_result * 2 <= deposit_stake);
1343    }
1344
1345    #[test]
1346    fn withdraw_all() {
1347        let total_lamports = 1_100_000_000_000;
1348        let pool_token_supply = 1_000_000_000_000;
1349        let mut stake_pool = StakePool {
1350            total_lamports,
1351            pool_token_supply,
1352            ..StakePool::default()
1353        };
1354        // take everything out at once
1355        let withdraw_result = stake_pool
1356            .calc_lamports_withdraw_amount(pool_token_supply)
1357            .unwrap();
1358        assert_eq!(stake_pool.total_lamports, withdraw_result);
1359
1360        // take out 1, then the rest
1361        let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1362        stake_pool.total_lamports -= withdraw_result;
1363        stake_pool.pool_token_supply -= 1;
1364        let withdraw_result = stake_pool
1365            .calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
1366            .unwrap();
1367        assert_eq!(stake_pool.total_lamports, withdraw_result);
1368
1369        // take out all except 1, then the rest
1370        let mut stake_pool = StakePool {
1371            total_lamports,
1372            pool_token_supply,
1373            ..StakePool::default()
1374        };
1375        let withdraw_result = stake_pool
1376            .calc_lamports_withdraw_amount(pool_token_supply - 1)
1377            .unwrap();
1378        stake_pool.total_lamports -= withdraw_result;
1379        stake_pool.pool_token_supply = 1;
1380        assert_ne!(stake_pool.total_lamports, 0);
1381
1382        let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1383        assert_eq!(stake_pool.total_lamports, withdraw_result);
1384    }
1385}