Skip to main content

spl_stake_pool/
state.rs

1//! State transition types
2
3use {
4    crate::{
5        big_vec::BigVec, error::StakePoolError, MAX_WITHDRAWAL_FEE_INCREASE,
6        WITHDRAWAL_BASELINE_FEE,
7    },
8    borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
9    bytemuck::{Pod, Zeroable},
10    num_derive::{FromPrimitive, ToPrimitive},
11    num_traits::{FromPrimitive, ToPrimitive},
12    solana_program::{
13        account_info::AccountInfo,
14        borsh1::get_instance_packed_len,
15        msg,
16        program_error::ProgramError,
17        program_memory::sol_memcmp,
18        program_pack::{Pack, Sealed},
19        pubkey::{Pubkey, PUBKEY_BYTES},
20        stake::state::Lockup,
21    },
22    spl_pod::primitives::{PodU32, PodU64},
23    spl_token_2022::{
24        extension::{BaseStateWithExtensions, ExtensionType, StateWithExtensions},
25        state::{Account, AccountState, Mint},
26    },
27    std::{borrow::Borrow, convert::TryFrom, fmt, matches},
28};
29
30/// Enum representing the account type managed by the program
31#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
32pub enum AccountType {
33    /// If the account has not been initialized, the enum will be 0
34    #[default]
35    Uninitialized,
36    /// Stake pool
37    StakePool,
38    /// Validator stake list
39    ValidatorList,
40}
41
42/// Initialized program details.
43#[repr(C)]
44#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
45pub struct StakePool {
46    /// Account type, must be `StakePool` currently
47    pub account_type: AccountType,
48
49    /// Manager authority, allows for updating the staker, manager, and fee
50    /// account
51    pub manager: Pubkey,
52
53    /// Staker authority, allows for adding and removing validators, and
54    /// managing stake distribution
55    pub staker: Pubkey,
56
57    /// Stake deposit authority
58    ///
59    /// If a depositor pubkey is specified on initialization, then deposits must
60    /// be signed by this authority. If no deposit authority is specified,
61    /// then the stake pool will default to the result of:
62    /// `Pubkey::find_program_address(
63    ///     &[&stake_pool_address.as_ref(), b"deposit"],
64    ///     program_id,
65    /// )`
66    pub stake_deposit_authority: Pubkey,
67
68    /// Stake withdrawal authority bump seed
69    /// for `create_program_address(&[state::StakePool account, "withdrawal"])`
70    pub stake_withdraw_bump_seed: u8,
71
72    /// Validator stake list storage account
73    pub validator_list: Pubkey,
74
75    /// Reserve stake account, holds deactivated stake
76    pub reserve_stake: Pubkey,
77
78    /// Pool Mint
79    pub pool_mint: Pubkey,
80
81    /// Manager fee account
82    pub manager_fee_account: Pubkey,
83
84    /// Pool token program id
85    pub token_program_id: Pubkey,
86
87    /// Total stake under management.
88    /// Note that if `last_update_epoch` does not match the current epoch then
89    /// this field may not be accurate
90    pub total_lamports: u64,
91
92    /// Total supply of pool tokens (should always match the supply in the Pool
93    /// Mint)
94    pub pool_token_supply: u64,
95
96    /// Last epoch the `total_lamports` field was updated
97    pub last_update_epoch: u64,
98
99    /// Lockup that all stakes in the pool must have
100    pub lockup: Lockup,
101
102    /// Fee taken as a proportion of rewards each epoch
103    pub epoch_fee: Fee,
104
105    /// Fee for next epoch
106    pub next_epoch_fee: FutureEpoch<Fee>,
107
108    /// Preferred deposit validator vote account pubkey
109    pub preferred_deposit_validator_vote_address: Option<Pubkey>,
110
111    /// Preferred withdraw validator vote account pubkey
112    pub preferred_withdraw_validator_vote_address: Option<Pubkey>,
113
114    /// Fee assessed on stake deposits
115    pub stake_deposit_fee: Fee,
116
117    /// Fee assessed on withdrawals
118    pub stake_withdrawal_fee: Fee,
119
120    /// Future stake withdrawal fee, to be set for the following epoch
121    pub next_stake_withdrawal_fee: FutureEpoch<Fee>,
122
123    /// Fees paid out to referrers on referred stake deposits.
124    /// Expressed as a percentage (0 - 100) of deposit fees.
125    /// i.e. `stake_deposit_fee`% of stake deposited is collected as deposit
126    /// fees for every deposit and `stake_referral_fee`% of the collected
127    /// stake deposit fees is paid out to the referrer
128    pub stake_referral_fee: u8,
129
130    /// Toggles whether the `DepositSol` instruction requires a signature from
131    /// this `sol_deposit_authority`
132    pub sol_deposit_authority: Option<Pubkey>,
133
134    /// Fee assessed on SOL deposits
135    pub sol_deposit_fee: Fee,
136
137    /// Fees paid out to referrers on referred SOL deposits.
138    /// Expressed as a percentage (0 - 100) of SOL deposit fees.
139    /// i.e. `sol_deposit_fee`% of SOL deposited is collected as deposit fees
140    /// for every deposit and `sol_referral_fee`% of the collected SOL
141    /// deposit fees is paid out to the referrer
142    pub sol_referral_fee: u8,
143
144    /// Toggles whether the `WithdrawSol` instruction requires a signature from
145    /// the `deposit_authority`
146    pub sol_withdraw_authority: Option<Pubkey>,
147
148    /// Fee assessed on SOL withdrawals
149    pub sol_withdrawal_fee: Fee,
150
151    /// Future SOL withdrawal fee, to be set for the following epoch
152    pub next_sol_withdrawal_fee: FutureEpoch<Fee>,
153
154    /// Last epoch's total pool tokens, used only for APR estimation
155    pub last_epoch_pool_token_supply: u64,
156
157    /// Last epoch's total lamports, used only for APR estimation
158    pub last_epoch_total_lamports: u64,
159}
160impl StakePool {
161    /// calculate the pool tokens that should be minted for a deposit of
162    /// `stake_lamports`
163    #[inline]
164    pub fn calc_pool_tokens_for_deposit(&self, stake_lamports: u64) -> Option<u64> {
165        if self.total_lamports == 0 || self.pool_token_supply == 0 {
166            return Some(stake_lamports);
167        }
168        u64::try_from(
169            (stake_lamports as u128)
170                .checked_mul(self.pool_token_supply as u128)?
171                .checked_div(self.total_lamports as u128)?,
172        )
173        .ok()
174    }
175
176    /// calculate lamports amount on withdrawal
177    #[inline]
178    pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
179        // `checked_div` returns `None` for a 0 quotient result, but in this
180        // case, a return of 0 is valid for small amounts of pool tokens. So
181        // we check for that separately
182        let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
183        let denominator = self.pool_token_supply as u128;
184        if numerator < denominator || denominator == 0 {
185            Some(0)
186        } else {
187            u64::try_from(numerator.checked_div(denominator)?).ok()
188        }
189    }
190
191    /// calculate pool tokens to be deducted as withdrawal fees
192    #[inline]
193    pub fn calc_pool_tokens_stake_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
194        u64::try_from(self.stake_withdrawal_fee.apply(pool_tokens)?).ok()
195    }
196
197    /// calculate pool tokens to be deducted as withdrawal fees
198    #[inline]
199    pub fn calc_pool_tokens_sol_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
200        u64::try_from(self.sol_withdrawal_fee.apply(pool_tokens)?).ok()
201    }
202
203    /// calculate pool tokens to be deducted as stake deposit fees
204    #[inline]
205    pub fn calc_pool_tokens_stake_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
206        u64::try_from(self.stake_deposit_fee.apply(pool_tokens_minted)?).ok()
207    }
208
209    /// calculate pool tokens to be deducted from deposit fees as referral fees
210    #[inline]
211    pub fn calc_pool_tokens_stake_referral_fee(&self, stake_deposit_fee: u64) -> Option<u64> {
212        u64::try_from(
213            (stake_deposit_fee as u128)
214                .checked_mul(self.stake_referral_fee as u128)?
215                .checked_div(100u128)?,
216        )
217        .ok()
218    }
219
220    /// calculate pool tokens to be deducted as SOL deposit fees
221    #[inline]
222    pub fn calc_pool_tokens_sol_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
223        u64::try_from(self.sol_deposit_fee.apply(pool_tokens_minted)?).ok()
224    }
225
226    /// calculate pool tokens to be deducted from SOL deposit fees as referral
227    /// fees
228    #[inline]
229    pub fn calc_pool_tokens_sol_referral_fee(&self, sol_deposit_fee: u64) -> Option<u64> {
230        u64::try_from(
231            (sol_deposit_fee as u128)
232                .checked_mul(self.sol_referral_fee as u128)?
233                .checked_div(100u128)?,
234        )
235        .ok()
236    }
237
238    /// Calculate the fee in pool tokens that goes to the manager
239    ///
240    /// This function assumes that `reward_lamports` has not already been added
241    /// to the stake pool's `total_lamports`
242    #[inline]
243    pub fn calc_epoch_fee_amount(&self, reward_lamports: u64) -> Option<u64> {
244        if reward_lamports == 0 {
245            return Some(0);
246        }
247        let total_lamports = (self.total_lamports as u128).checked_add(reward_lamports as u128)?;
248        let fee_lamports = self.epoch_fee.apply(reward_lamports)?;
249        if total_lamports == fee_lamports || self.pool_token_supply == 0 {
250            Some(reward_lamports)
251        } else {
252            u64::try_from(
253                (self.pool_token_supply as u128)
254                    .checked_mul(fee_lamports)?
255                    .checked_div(total_lamports.checked_sub(fee_lamports)?)?,
256            )
257            .ok()
258        }
259    }
260
261    /// Get the current value of pool tokens, rounded up
262    #[inline]
263    pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
264        self.total_lamports
265            .checked_add(self.pool_token_supply)?
266            .checked_sub(1)?
267            .checked_div(self.pool_token_supply)
268    }
269
270    /// Checks that the withdraw or deposit authority is valid
271    fn check_program_derived_authority(
272        authority_address: &Pubkey,
273        program_id: &Pubkey,
274        stake_pool_address: &Pubkey,
275        authority_seed: &[u8],
276        bump_seed: u8,
277    ) -> Result<(), ProgramError> {
278        let expected_address = Pubkey::create_program_address(
279            &[stake_pool_address.as_ref(), authority_seed, &[bump_seed]],
280            program_id,
281        )?;
282
283        if *authority_address == expected_address {
284            Ok(())
285        } else {
286            msg!(
287                "Incorrect authority provided, expected {}, received {}",
288                expected_address,
289                authority_address
290            );
291            Err(StakePoolError::InvalidProgramAddress.into())
292        }
293    }
294
295    /// Check if the manager fee info is a valid token program account
296    /// capable of receiving tokens from the mint.
297    pub(crate) fn check_manager_fee_info(
298        &self,
299        manager_fee_info: &AccountInfo,
300    ) -> Result<(), ProgramError> {
301        let account_data = manager_fee_info.try_borrow_data()?;
302        let token_account = StateWithExtensions::<Account>::unpack(&account_data)?;
303        if manager_fee_info.owner != &self.token_program_id
304            || token_account.base.state != AccountState::Initialized
305            || token_account.base.mint != self.pool_mint
306        {
307            msg!("Manager fee account is not owned by token program, is not initialized, or does not match stake pool's mint");
308            return Err(StakePoolError::InvalidFeeAccount.into());
309        }
310        let extensions = token_account.get_extension_types()?;
311        if extensions
312            .iter()
313            .any(|x| !is_extension_supported_for_fee_account(x))
314        {
315            return Err(StakePoolError::UnsupportedFeeAccountExtension.into());
316        }
317        Ok(())
318    }
319
320    /// Checks that the withdraw authority is valid
321    #[inline]
322    pub(crate) fn check_authority_withdraw(
323        &self,
324        withdraw_authority: &Pubkey,
325        program_id: &Pubkey,
326        stake_pool_address: &Pubkey,
327    ) -> Result<(), ProgramError> {
328        Self::check_program_derived_authority(
329            withdraw_authority,
330            program_id,
331            stake_pool_address,
332            crate::AUTHORITY_WITHDRAW,
333            self.stake_withdraw_bump_seed,
334        )
335    }
336    /// Checks that the deposit authority is valid
337    #[inline]
338    pub(crate) fn check_stake_deposit_authority(
339        &self,
340        stake_deposit_authority: &Pubkey,
341    ) -> Result<(), ProgramError> {
342        if self.stake_deposit_authority == *stake_deposit_authority {
343            Ok(())
344        } else {
345            Err(StakePoolError::InvalidStakeDepositAuthority.into())
346        }
347    }
348
349    /// Checks that the deposit authority is valid
350    /// Does nothing if `sol_deposit_authority` is currently not set
351    #[inline]
352    pub(crate) fn check_sol_deposit_authority(
353        &self,
354        maybe_sol_deposit_authority: Result<&AccountInfo, ProgramError>,
355    ) -> Result<(), ProgramError> {
356        if let Some(auth) = self.sol_deposit_authority {
357            let sol_deposit_authority = maybe_sol_deposit_authority?;
358            if auth != *sol_deposit_authority.key {
359                msg!("Expected {}, received {}", auth, sol_deposit_authority.key);
360                return Err(StakePoolError::InvalidSolDepositAuthority.into());
361            }
362            if !sol_deposit_authority.is_signer {
363                msg!("SOL Deposit authority signature missing");
364                return Err(StakePoolError::SignatureMissing.into());
365            }
366        }
367        Ok(())
368    }
369
370    /// Checks that the sol withdraw authority is valid
371    /// Does nothing if `sol_withdraw_authority` is currently not set
372    #[inline]
373    pub(crate) fn check_sol_withdraw_authority(
374        &self,
375        maybe_sol_withdraw_authority: Result<&AccountInfo, ProgramError>,
376    ) -> Result<(), ProgramError> {
377        if let Some(auth) = self.sol_withdraw_authority {
378            let sol_withdraw_authority = maybe_sol_withdraw_authority?;
379            if auth != *sol_withdraw_authority.key {
380                return Err(StakePoolError::InvalidSolWithdrawAuthority.into());
381            }
382            if !sol_withdraw_authority.is_signer {
383                msg!("SOL withdraw authority signature missing");
384                return Err(StakePoolError::SignatureMissing.into());
385            }
386        }
387        Ok(())
388    }
389
390    /// Check mint is correct
391    #[inline]
392    pub(crate) fn check_mint(&self, mint_info: &AccountInfo) -> Result<u8, ProgramError> {
393        if *mint_info.key != self.pool_mint {
394            Err(StakePoolError::WrongPoolMint.into())
395        } else {
396            let mint_data = mint_info.try_borrow_data()?;
397            let mint = StateWithExtensions::<Mint>::unpack(&mint_data)?;
398            Ok(mint.base.decimals)
399        }
400    }
401
402    /// Check manager validity and signature
403    pub(crate) fn check_manager(&self, manager_info: &AccountInfo) -> Result<(), ProgramError> {
404        if *manager_info.key != self.manager {
405            msg!(
406                "Incorrect manager provided, expected {}, received {}",
407                self.manager,
408                manager_info.key
409            );
410            return Err(StakePoolError::WrongManager.into());
411        }
412        if !manager_info.is_signer {
413            msg!("Manager signature missing");
414            return Err(StakePoolError::SignatureMissing.into());
415        }
416        Ok(())
417    }
418
419    /// Check staker validity and signature
420    pub(crate) fn check_staker(&self, staker_info: &AccountInfo) -> Result<(), ProgramError> {
421        if *staker_info.key != self.staker {
422            msg!(
423                "Incorrect staker provided, expected {}, received {}",
424                self.staker,
425                staker_info.key
426            );
427            return Err(StakePoolError::WrongStaker.into());
428        }
429        if !staker_info.is_signer {
430            msg!("Staker signature missing");
431            return Err(StakePoolError::SignatureMissing.into());
432        }
433        Ok(())
434    }
435
436    /// Check the validator list is valid
437    pub fn check_validator_list(
438        &self,
439        validator_list_info: &AccountInfo,
440    ) -> Result<(), ProgramError> {
441        if *validator_list_info.key != self.validator_list {
442            msg!(
443                "Invalid validator list provided, expected {}, received {}",
444                self.validator_list,
445                validator_list_info.key
446            );
447            Err(StakePoolError::InvalidValidatorStakeList.into())
448        } else {
449            Ok(())
450        }
451    }
452
453    /// Check the reserve stake is valid
454    pub fn check_reserve_stake(
455        &self,
456        reserve_stake_info: &AccountInfo,
457    ) -> Result<(), ProgramError> {
458        if *reserve_stake_info.key != self.reserve_stake {
459            msg!(
460                "Invalid reserve stake provided, expected {}, received {}",
461                self.reserve_stake,
462                reserve_stake_info.key
463            );
464            Err(StakePoolError::InvalidProgramAddress.into())
465        } else {
466            Ok(())
467        }
468    }
469
470    /// Check if `StakePool` is actually initialized as a stake pool
471    pub fn is_valid(&self) -> bool {
472        self.account_type == AccountType::StakePool
473    }
474
475    /// Check if `StakePool` is currently uninitialized
476    pub fn is_uninitialized(&self) -> bool {
477        self.account_type == AccountType::Uninitialized
478    }
479
480    /// Updates one of the `StakePool`'s fees.
481    pub fn update_fee(&mut self, fee: &FeeType) -> Result<(), StakePoolError> {
482        match fee {
483            FeeType::SolReferral(new_fee) => self.sol_referral_fee = *new_fee,
484            FeeType::StakeReferral(new_fee) => self.stake_referral_fee = *new_fee,
485            FeeType::Epoch(new_fee) => self.next_epoch_fee = FutureEpoch::new(*new_fee),
486            FeeType::StakeWithdrawal(new_fee) => {
487                new_fee.check_withdrawal(&self.stake_withdrawal_fee)?;
488                self.next_stake_withdrawal_fee = FutureEpoch::new(*new_fee)
489            }
490            FeeType::SolWithdrawal(new_fee) => {
491                new_fee.check_withdrawal(&self.sol_withdrawal_fee)?;
492                self.next_sol_withdrawal_fee = FutureEpoch::new(*new_fee)
493            }
494            FeeType::SolDeposit(new_fee) => self.sol_deposit_fee = *new_fee,
495            FeeType::StakeDeposit(new_fee) => self.stake_deposit_fee = *new_fee,
496        };
497        Ok(())
498    }
499}
500
501/// Checks if the given extension is supported for the stake pool mint
502pub fn is_extension_supported_for_mint(extension_type: &ExtensionType) -> bool {
503    const SUPPORTED_EXTENSIONS: [ExtensionType; 8] = [
504        ExtensionType::Uninitialized,
505        ExtensionType::TransferFeeConfig,
506        ExtensionType::ConfidentialTransferMint,
507        ExtensionType::ConfidentialTransferFeeConfig,
508        ExtensionType::DefaultAccountState, // ok, but a freeze authority is not
509        ExtensionType::InterestBearingConfig,
510        ExtensionType::MetadataPointer,
511        ExtensionType::TokenMetadata,
512    ];
513    if !SUPPORTED_EXTENSIONS.contains(extension_type) {
514        msg!(
515            "Stake pool mint account cannot have the {:?} extension",
516            extension_type
517        );
518        false
519    } else {
520        true
521    }
522}
523
524/// Checks if the given extension is supported for the stake pool's fee account
525pub fn is_extension_supported_for_fee_account(extension_type: &ExtensionType) -> bool {
526    // Note: this does not include the `ConfidentialTransferAccount` extension
527    // because it is possible to block non-confidential transfers with the
528    // extension enabled.
529    const SUPPORTED_EXTENSIONS: [ExtensionType; 4] = [
530        ExtensionType::Uninitialized,
531        ExtensionType::TransferFeeAmount,
532        ExtensionType::ImmutableOwner,
533        ExtensionType::CpiGuard,
534    ];
535    if !SUPPORTED_EXTENSIONS.contains(extension_type) {
536        msg!("Fee account cannot have the {:?} extension", extension_type);
537        false
538    } else {
539        true
540    }
541}
542
543/// Storage list for all validator stake accounts in the pool.
544#[repr(C)]
545#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
546pub struct ValidatorList {
547    /// Data outside of the validator list, separated out for cheaper
548    /// deserialization
549    pub header: ValidatorListHeader,
550
551    /// List of stake info for each validator in the pool
552    pub validators: Vec<ValidatorStakeInfo>,
553}
554
555/// Helper type to deserialize just the start of a `ValidatorList`
556#[repr(C)]
557#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
558pub struct ValidatorListHeader {
559    /// Account type, must be `ValidatorList` currently
560    pub account_type: AccountType,
561
562    /// Maximum allowable number of validators
563    pub max_validators: u32,
564}
565
566/// Status of the stake account in the validator list, for accounting
567#[derive(
568    ToPrimitive,
569    FromPrimitive,
570    Copy,
571    Clone,
572    Debug,
573    Default,
574    PartialEq,
575    BorshDeserialize,
576    BorshSerialize,
577    BorshSchema,
578)]
579pub enum StakeStatus {
580    /// Stake account is active, there may be a transient stake as well
581    #[default]
582    Active,
583    /// Only transient stake account exists, when a transient stake is
584    /// deactivating during validator removal
585    DeactivatingTransient,
586    /// No more validator stake accounts exist, entry ready for removal during
587    /// `UpdateStakePoolBalance`
588    ReadyForRemoval,
589    /// Only the validator stake account is deactivating, no transient stake
590    /// account exists
591    DeactivatingValidator,
592    /// Both the transient and validator stake account are deactivating, when
593    /// a validator is removed with a transient stake active
594    DeactivatingAll,
595}
596
597/// Wrapper struct that can be `Pod`, containing a byte that *should* be a valid
598/// `StakeStatus` underneath.
599#[repr(transparent)]
600#[derive(
601    Clone,
602    Copy,
603    Debug,
604    Default,
605    PartialEq,
606    Pod,
607    Zeroable,
608    BorshDeserialize,
609    BorshSerialize,
610    BorshSchema,
611)]
612pub struct PodStakeStatus(u8);
613impl PodStakeStatus {
614    /// Downgrade the status towards ready for removal by removing the validator
615    /// stake
616    pub fn remove_validator_stake(&mut self) -> Result<(), ProgramError> {
617        let status = StakeStatus::try_from(*self)?;
618        let new_self = match status {
619            StakeStatus::Active
620            | StakeStatus::DeactivatingTransient
621            | StakeStatus::ReadyForRemoval => status,
622            StakeStatus::DeactivatingAll => StakeStatus::DeactivatingTransient,
623            StakeStatus::DeactivatingValidator => StakeStatus::ReadyForRemoval,
624        };
625        *self = new_self.into();
626        Ok(())
627    }
628    /// Downgrade the status towards ready for removal by removing the transient
629    /// stake
630    pub fn remove_transient_stake(&mut self) -> Result<(), ProgramError> {
631        let status = StakeStatus::try_from(*self)?;
632        let new_self = match status {
633            StakeStatus::Active
634            | StakeStatus::DeactivatingValidator
635            | StakeStatus::ReadyForRemoval => status,
636            StakeStatus::DeactivatingAll => StakeStatus::DeactivatingValidator,
637            StakeStatus::DeactivatingTransient => StakeStatus::ReadyForRemoval,
638        };
639        *self = new_self.into();
640        Ok(())
641    }
642}
643impl TryFrom<PodStakeStatus> for StakeStatus {
644    type Error = ProgramError;
645    fn try_from(pod: PodStakeStatus) -> Result<Self, Self::Error> {
646        FromPrimitive::from_u8(pod.0).ok_or(ProgramError::InvalidAccountData)
647    }
648}
649impl From<StakeStatus> for PodStakeStatus {
650    fn from(status: StakeStatus) -> Self {
651        // unwrap is safe here because the variants of `StakeStatus` fit very
652        // comfortably within a `u8`
653        PodStakeStatus(status.to_u8().unwrap())
654    }
655}
656
657/// Withdrawal type, figured out during `process_withdraw_stake`
658#[derive(Debug, PartialEq)]
659pub(crate) enum StakeWithdrawSource {
660    /// Some of an active stake account, but not all
661    Active,
662    /// Some of a transient stake account
663    Transient,
664    /// Take a whole validator stake account
665    ValidatorRemoval,
666}
667
668/// Information about a validator in the pool
669///
670/// NOTE: ORDER IS VERY IMPORTANT HERE, PLEASE DO NOT RE-ORDER THE FIELDS UNLESS
671/// THERE'S AN EXTREMELY GOOD REASON.
672///
673/// To save on BPF instructions, the serialized bytes are reinterpreted with a
674/// `bytemuck` transmute, which means that this structure cannot have any
675/// undeclared alignment-padding in its representation.
676#[repr(C)]
677#[derive(
678    Clone,
679    Copy,
680    Debug,
681    Default,
682    PartialEq,
683    Pod,
684    Zeroable,
685    BorshDeserialize,
686    BorshSerialize,
687    BorshSchema,
688)]
689pub struct ValidatorStakeInfo {
690    /// Amount of lamports on the validator stake account, including rent
691    ///
692    /// Note that if `last_update_epoch` does not match the current epoch then
693    /// this field may not be accurate
694    pub active_stake_lamports: PodU64,
695
696    /// Amount of transient stake delegated to this validator
697    ///
698    /// Note that if `last_update_epoch` does not match the current epoch then
699    /// this field may not be accurate
700    pub transient_stake_lamports: PodU64,
701
702    /// Last epoch the active and transient stake lamports fields were updated
703    pub last_update_epoch: PodU64,
704
705    /// Transient account seed suffix, used to derive the transient stake
706    /// account address
707    pub transient_seed_suffix: PodU64,
708
709    /// Unused space, initially meant to specify the end of seed suffixes
710    pub unused: PodU32,
711
712    /// Validator account seed suffix
713    pub validator_seed_suffix: PodU32, // really `Option<NonZeroU32>` so 0 is `None`
714
715    /// Status of the validator stake account
716    pub status: PodStakeStatus,
717
718    /// Validator vote account address
719    pub vote_account_address: Pubkey,
720}
721
722impl ValidatorStakeInfo {
723    /// Get the total lamports on this validator (active and transient)
724    pub fn stake_lamports(&self) -> Result<u64, StakePoolError> {
725        u64::from(self.active_stake_lamports)
726            .checked_add(self.transient_stake_lamports.into())
727            .ok_or(StakePoolError::CalculationFailure)
728    }
729
730    /// Performs a very cheap comparison, for checking if this validator stake
731    /// info matches the vote account address
732    pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
733        sol_memcmp(
734            &data[41..41_usize.saturating_add(PUBKEY_BYTES)],
735            vote_address.as_ref(),
736            PUBKEY_BYTES,
737        ) == 0
738    }
739
740    /// Performs a comparison, used to check if this validator stake
741    /// info has more active lamports than some limit
742    pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
743        // without this unwrap, compute usage goes up significantly
744        u64::try_from_slice(&data[0..8]).unwrap() > *lamports
745    }
746
747    /// Performs a comparison, used to check if this validator stake
748    /// info has more transient lamports than some limit
749    pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
750        // without this unwrap, compute usage goes up significantly
751        u64::try_from_slice(&data[8..16]).unwrap() > *lamports
752    }
753
754    /// Check that the validator stake info is totally removed
755    pub fn is_removed(data: &[u8]) -> bool {
756        FromPrimitive::from_u8(data[40]) == Some(StakeStatus::ReadyForRemoval)
757            && data[0..16] == [0; 16] // active and transient stake lamports are 0
758    }
759
760    /// Check that the validator stake info is active
761    pub fn is_active(data: &[u8]) -> bool {
762        FromPrimitive::from_u8(data[40]) == Some(StakeStatus::Active)
763    }
764}
765
766impl Sealed for ValidatorStakeInfo {}
767
768impl Pack for ValidatorStakeInfo {
769    const LEN: usize = 73;
770    fn pack_into_slice(&self, data: &mut [u8]) {
771        // Removing this unwrap would require changing from `Pack` to some other
772        // trait or `bytemuck`, so it stays in for now
773        borsh::to_writer(data, self).unwrap();
774    }
775    fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
776        let unpacked = Self::try_from_slice(src)?;
777        Ok(unpacked)
778    }
779}
780
781impl ValidatorList {
782    /// Create an empty instance containing space for `max_validators` and
783    /// preferred validator keys
784    pub fn new(max_validators: u32) -> Self {
785        Self {
786            header: ValidatorListHeader {
787                account_type: AccountType::ValidatorList,
788                max_validators,
789            },
790            validators: vec![ValidatorStakeInfo::default(); max_validators as usize],
791        }
792    }
793
794    /// Calculate the number of validator entries that fit in the provided
795    /// length
796    pub fn calculate_max_validators(buffer_length: usize) -> usize {
797        let header_size = ValidatorListHeader::LEN.saturating_add(4);
798        buffer_length
799            .saturating_sub(header_size)
800            .saturating_div(ValidatorStakeInfo::LEN)
801    }
802
803    /// Check if contains validator with particular pubkey
804    pub fn contains(&self, vote_account_address: &Pubkey) -> bool {
805        self.validators
806            .iter()
807            .any(|x| x.vote_account_address == *vote_account_address)
808    }
809
810    /// Check if contains validator with particular pubkey
811    pub fn find_mut(&mut self, vote_account_address: &Pubkey) -> Option<&mut ValidatorStakeInfo> {
812        self.validators
813            .iter_mut()
814            .find(|x| x.vote_account_address == *vote_account_address)
815    }
816    /// Check if contains validator with particular pubkey
817    pub fn find(&self, vote_account_address: &Pubkey) -> Option<&ValidatorStakeInfo> {
818        self.validators
819            .iter()
820            .find(|x| x.vote_account_address == *vote_account_address)
821    }
822
823    /// Check if the list has any active stake
824    pub fn has_active_stake(&self) -> bool {
825        self.validators
826            .iter()
827            .any(|x| u64::from(x.active_stake_lamports) > 0)
828    }
829}
830
831impl ValidatorListHeader {
832    const LEN: usize = 1 + 4;
833
834    /// Check if validator stake list is actually initialized as a validator
835    /// stake list
836    pub fn is_valid(&self) -> bool {
837        self.account_type == AccountType::ValidatorList
838    }
839
840    /// Check if the validator stake list is uninitialized
841    pub fn is_uninitialized(&self) -> bool {
842        self.account_type == AccountType::Uninitialized
843    }
844
845    /// Extracts a slice of `ValidatorStakeInfo` types from the vec part
846    /// of the `ValidatorList`
847    pub fn deserialize_mut_slice<'a>(
848        big_vec: &'a mut BigVec,
849        skip: usize,
850        len: usize,
851    ) -> Result<&'a mut [ValidatorStakeInfo], ProgramError> {
852        big_vec.deserialize_mut_slice::<ValidatorStakeInfo>(skip, len)
853    }
854
855    /// Extracts the validator list into its header and internal `BigVec`
856    pub fn deserialize_vec(data: &mut [u8]) -> Result<(Self, BigVec<'_>), ProgramError> {
857        let mut data_mut = data.borrow();
858        let header = ValidatorListHeader::deserialize(&mut data_mut)?;
859        let length = get_instance_packed_len(&header)?;
860
861        let big_vec = BigVec {
862            data: &mut data[length..],
863        };
864        Ok((header, big_vec))
865    }
866}
867
868/// Wrapper type that "counts down" epochs, which is Borsh-compatible with the
869/// native `Option`
870#[repr(C)]
871#[derive(Clone, Copy, Debug, Default, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
872pub enum FutureEpoch<T> {
873    /// Nothing is set
874    #[default]
875    None,
876    /// Value is ready after the next epoch boundary
877    One(T),
878    /// Value is ready after two epoch boundaries
879    Two(T),
880}
881impl<T> FutureEpoch<T> {
882    /// Create a new value to be unlocked in a two epochs
883    pub fn new(value: T) -> Self {
884        Self::Two(value)
885    }
886}
887impl<T: Clone> FutureEpoch<T> {
888    /// Update the epoch, to be done after `get`ting the underlying value
889    pub fn update_epoch(&mut self) {
890        match self {
891            Self::None => {}
892            Self::One(_) => {
893                // The value has waited its last epoch
894                *self = Self::None;
895            }
896            // The value still has to wait one more epoch after this
897            Self::Two(v) => {
898                *self = Self::One(v.clone());
899            }
900        }
901    }
902
903    /// Get the value if it's ready, which is only at `One` epoch remaining
904    pub fn get(&self) -> Option<&T> {
905        match self {
906            Self::None | Self::Two(_) => None,
907            Self::One(v) => Some(v),
908        }
909    }
910}
911impl<T> From<FutureEpoch<T>> for Option<T> {
912    fn from(v: FutureEpoch<T>) -> Option<T> {
913        match v {
914            FutureEpoch::None => None,
915            FutureEpoch::One(inner) | FutureEpoch::Two(inner) => Some(inner),
916        }
917    }
918}
919
920/// Fee rate as a ratio, minted on `UpdateStakePoolBalance` as a proportion of
921/// the rewards
922/// If either the numerator or the denominator is 0, the fee is considered to be
923/// 0
924#[repr(C)]
925#[derive(Clone, Copy, Debug, Default, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
926pub struct Fee {
927    /// denominator of the fee ratio
928    pub denominator: u64,
929    /// numerator of the fee ratio
930    pub numerator: u64,
931}
932
933impl Fee {
934    /// Applies the Fee's rates to a given amount, `amt`
935    /// returning the amount to be subtracted from it as fees
936    /// (0 if denominator is 0 or amt is 0),
937    /// or None if overflow occurs
938    #[inline]
939    pub fn apply(&self, amt: u64) -> Option<u128> {
940        if self.denominator == 0 {
941            return Some(0);
942        }
943        let numerator = (amt as u128).checked_mul(self.numerator as u128)?;
944        // ceiling the calculation by adding (denominator - 1) to the numerator
945        let denominator = self.denominator as u128;
946        numerator
947            .checked_add(denominator)?
948            .checked_sub(1)?
949            .checked_div(denominator)
950    }
951
952    /// Withdrawal fees have some additional restrictions, this function checks
953    /// if those are met, returning an error if not.
954    pub fn check_withdrawal(&self, old_withdrawal_fee: &Fee) -> Result<(), StakePoolError> {
955        // If the previous withdrawal fee was 0, we allow the fee to be set to a
956        // maximum of (WITHDRAWAL_BASELINE_FEE * MAX_WITHDRAWAL_FEE_INCREASE)
957        let (old_num, old_denom) =
958            if old_withdrawal_fee.denominator == 0 || old_withdrawal_fee.numerator == 0 {
959                (
960                    WITHDRAWAL_BASELINE_FEE.numerator,
961                    WITHDRAWAL_BASELINE_FEE.denominator,
962                )
963            } else {
964                (old_withdrawal_fee.numerator, old_withdrawal_fee.denominator)
965            };
966
967        // Check that new_fee / old_fee <= MAX_WITHDRAWAL_FEE_INCREASE
968        // Program fails if provided numerator or denominator is too large, resulting in
969        // overflow
970        if (old_num as u128)
971            .checked_mul(self.denominator as u128)
972            .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.numerator as u128))
973            .ok_or(StakePoolError::CalculationFailure)?
974            < (self.numerator as u128)
975                .checked_mul(old_denom as u128)
976                .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.denominator as u128))
977                .ok_or(StakePoolError::CalculationFailure)?
978        {
979            msg!(
980                "Fee increase exceeds maximum allowed, proposed increase factor ({} / {})",
981                self.numerator.saturating_mul(old_denom),
982                old_num.saturating_mul(self.denominator),
983            );
984            return Err(StakePoolError::FeeIncreaseTooHigh);
985        }
986        Ok(())
987    }
988}
989
990impl fmt::Display for Fee {
991    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
992        if self.numerator > 0 && self.denominator > 0 {
993            write!(f, "{}/{}", self.numerator, self.denominator)
994        } else {
995            write!(f, "none")
996        }
997    }
998}
999
1000/// The type of fees that can be set on the stake pool
1001#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
1002pub enum FeeType {
1003    /// Referral fees for SOL deposits
1004    SolReferral(u8),
1005    /// Referral fees for stake deposits
1006    StakeReferral(u8),
1007    /// Management fee paid per epoch
1008    Epoch(Fee),
1009    /// Stake withdrawal fee
1010    StakeWithdrawal(Fee),
1011    /// Deposit fee for SOL deposits
1012    SolDeposit(Fee),
1013    /// Deposit fee for stake deposits
1014    StakeDeposit(Fee),
1015    /// SOL withdrawal fee
1016    SolWithdrawal(Fee),
1017}
1018
1019impl FeeType {
1020    /// Checks if the provided fee is too high, returning an error if so
1021    pub fn check_too_high(&self) -> Result<(), StakePoolError> {
1022        let too_high = match self {
1023            Self::SolReferral(pct) => *pct > 100u8,
1024            Self::StakeReferral(pct) => *pct > 100u8,
1025            Self::Epoch(fee) => fee.numerator > fee.denominator,
1026            Self::StakeWithdrawal(fee) => fee.numerator > fee.denominator,
1027            Self::SolWithdrawal(fee) => fee.numerator > fee.denominator,
1028            Self::SolDeposit(fee) => fee.numerator > fee.denominator,
1029            Self::StakeDeposit(fee) => fee.numerator > fee.denominator,
1030        };
1031        if too_high {
1032            msg!("Fee greater than 100%: {:?}", self);
1033            return Err(StakePoolError::FeeTooHigh);
1034        }
1035        Ok(())
1036    }
1037
1038    /// Returns if the contained fee can only be updated earliest on the next
1039    /// epoch
1040    #[inline]
1041    pub fn can_only_change_next_epoch(&self) -> bool {
1042        matches!(
1043            self,
1044            Self::StakeWithdrawal(_) | Self::SolWithdrawal(_) | Self::Epoch(_)
1045        )
1046    }
1047}
1048
1049#[cfg(test)]
1050mod test {
1051    #![allow(clippy::arithmetic_side_effects)]
1052    use {
1053        super::*,
1054        proptest::prelude::*,
1055        solana_program::{
1056            borsh1::{get_packed_len, try_from_slice_unchecked},
1057            clock::{DEFAULT_SLOTS_PER_EPOCH, DEFAULT_S_PER_SLOT, SECONDS_PER_DAY},
1058            native_token::LAMPORTS_PER_SOL,
1059        },
1060    };
1061
1062    fn uninitialized_validator_list() -> ValidatorList {
1063        ValidatorList {
1064            header: ValidatorListHeader {
1065                account_type: AccountType::Uninitialized,
1066                max_validators: 0,
1067            },
1068            validators: vec![],
1069        }
1070    }
1071
1072    fn test_validator_list(max_validators: u32) -> ValidatorList {
1073        ValidatorList {
1074            header: ValidatorListHeader {
1075                account_type: AccountType::ValidatorList,
1076                max_validators,
1077            },
1078            validators: vec![
1079                ValidatorStakeInfo {
1080                    status: StakeStatus::Active.into(),
1081                    vote_account_address: Pubkey::new_from_array([1; 32]),
1082                    active_stake_lamports: u64::from_le_bytes([255; 8]).into(),
1083                    transient_stake_lamports: u64::from_le_bytes([128; 8]).into(),
1084                    last_update_epoch: u64::from_le_bytes([64; 8]).into(),
1085                    transient_seed_suffix: 0.into(),
1086                    unused: 0.into(),
1087                    validator_seed_suffix: 0.into(),
1088                },
1089                ValidatorStakeInfo {
1090                    status: StakeStatus::DeactivatingTransient.into(),
1091                    vote_account_address: Pubkey::new_from_array([2; 32]),
1092                    active_stake_lamports: 998877665544.into(),
1093                    transient_stake_lamports: 222222222.into(),
1094                    last_update_epoch: 11223445566.into(),
1095                    transient_seed_suffix: 0.into(),
1096                    unused: 0.into(),
1097                    validator_seed_suffix: 0.into(),
1098                },
1099                ValidatorStakeInfo {
1100                    status: StakeStatus::ReadyForRemoval.into(),
1101                    vote_account_address: Pubkey::new_from_array([3; 32]),
1102                    active_stake_lamports: 0.into(),
1103                    transient_stake_lamports: 0.into(),
1104                    last_update_epoch: 999999999999999.into(),
1105                    transient_seed_suffix: 0.into(),
1106                    unused: 0.into(),
1107                    validator_seed_suffix: 0.into(),
1108                },
1109            ],
1110        }
1111    }
1112
1113    #[test]
1114    fn state_packing() {
1115        let max_validators = 10_000;
1116        let size = get_instance_packed_len(&ValidatorList::new(max_validators)).unwrap();
1117        let stake_list = uninitialized_validator_list();
1118        let mut byte_vec = vec![0u8; size];
1119        let bytes = byte_vec.as_mut_slice();
1120        borsh::to_writer(bytes, &stake_list).unwrap();
1121        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1122        assert_eq!(stake_list_unpacked, stake_list);
1123
1124        // Empty, one preferred key
1125        let stake_list = ValidatorList {
1126            header: ValidatorListHeader {
1127                account_type: AccountType::ValidatorList,
1128                max_validators: 0,
1129            },
1130            validators: vec![],
1131        };
1132        let mut byte_vec = vec![0u8; size];
1133        let bytes = byte_vec.as_mut_slice();
1134        borsh::to_writer(bytes, &stake_list).unwrap();
1135        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1136        assert_eq!(stake_list_unpacked, stake_list);
1137
1138        // With several accounts
1139        let stake_list = test_validator_list(max_validators);
1140        let mut byte_vec = vec![0u8; size];
1141        let bytes = byte_vec.as_mut_slice();
1142        borsh::to_writer(bytes, &stake_list).unwrap();
1143        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1144        assert_eq!(stake_list_unpacked, stake_list);
1145    }
1146
1147    #[test]
1148    fn validator_list_active_stake() {
1149        let max_validators = 10_000;
1150        let mut validator_list = test_validator_list(max_validators);
1151        assert!(validator_list.has_active_stake());
1152        for validator in validator_list.validators.iter_mut() {
1153            validator.active_stake_lamports = 0.into();
1154        }
1155        assert!(!validator_list.has_active_stake());
1156    }
1157
1158    #[test]
1159    fn validator_list_deserialize_mut_slice() {
1160        let max_validators = 10;
1161        let stake_list = test_validator_list(max_validators);
1162        let mut serialized = borsh::to_vec(&stake_list).unwrap();
1163        let (header, mut big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1164        let list = ValidatorListHeader::deserialize_mut_slice(
1165            &mut big_vec,
1166            0,
1167            stake_list.validators.len(),
1168        )
1169        .unwrap();
1170        assert_eq!(header.account_type, AccountType::ValidatorList);
1171        assert_eq!(header.max_validators, max_validators);
1172        assert!(list
1173            .iter()
1174            .zip(stake_list.validators.iter())
1175            .all(|(a, b)| a == b));
1176
1177        let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 1, 2).unwrap();
1178        assert!(list
1179            .iter()
1180            .zip(stake_list.validators[1..].iter())
1181            .all(|(a, b)| a == b));
1182        let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 2, 1).unwrap();
1183        assert!(list
1184            .iter()
1185            .zip(stake_list.validators[2..].iter())
1186            .all(|(a, b)| a == b));
1187        let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 0, 2).unwrap();
1188        assert!(list
1189            .iter()
1190            .zip(stake_list.validators[..2].iter())
1191            .all(|(a, b)| a == b));
1192
1193        assert_eq!(
1194            ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 0, 4).unwrap_err(),
1195            ProgramError::AccountDataTooSmall
1196        );
1197        assert_eq!(
1198            ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 1, 3).unwrap_err(),
1199            ProgramError::AccountDataTooSmall
1200        );
1201    }
1202
1203    #[test]
1204    fn validator_list_iter() {
1205        let max_validators = 10;
1206        let stake_list = test_validator_list(max_validators);
1207        let mut serialized = borsh::to_vec(&stake_list).unwrap();
1208        let (_, big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1209        for (a, b) in big_vec
1210            .deserialize_slice::<ValidatorStakeInfo>(0, big_vec.len() as usize)
1211            .unwrap()
1212            .iter()
1213            .zip(stake_list.validators.iter())
1214        {
1215            assert_eq!(a, b);
1216        }
1217    }
1218
1219    proptest! {
1220        #[test]
1221        fn stake_list_size_calculation(test_amount in 0..=100_000_u32) {
1222            let validators = ValidatorList::new(test_amount);
1223            let size = get_instance_packed_len(&validators).unwrap();
1224            assert_eq!(ValidatorList::calculate_max_validators(size), test_amount as usize);
1225            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(1)), test_amount as usize);
1226            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(get_packed_len::<ValidatorStakeInfo>())), (test_amount + 1)as usize);
1227            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_sub(1)), (test_amount.saturating_sub(1)) as usize);
1228        }
1229    }
1230
1231    prop_compose! {
1232        fn fee()(denominator in 1..=u16::MAX)(
1233            denominator in Just(denominator),
1234            numerator in 0..=denominator,
1235        ) -> (u64, u64) {
1236            (numerator as u64, denominator as u64)
1237        }
1238    }
1239
1240    prop_compose! {
1241        fn total_stake_and_rewards()(total_lamports in 1..u64::MAX)(
1242            total_lamports in Just(total_lamports),
1243            rewards in 0..=total_lamports,
1244        ) -> (u64, u64) {
1245            (total_lamports - rewards, rewards)
1246        }
1247    }
1248
1249    #[test]
1250    fn specific_fee_calculation() {
1251        // 10% of 10 SOL in rewards should be 1 SOL in fees
1252        let epoch_fee = Fee {
1253            numerator: 1,
1254            denominator: 10,
1255        };
1256        let mut stake_pool = StakePool {
1257            total_lamports: 100 * LAMPORTS_PER_SOL,
1258            pool_token_supply: 100 * LAMPORTS_PER_SOL,
1259            epoch_fee,
1260            ..StakePool::default()
1261        };
1262        let reward_lamports = 10 * LAMPORTS_PER_SOL;
1263        let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1264
1265        stake_pool.total_lamports += reward_lamports;
1266        stake_pool.pool_token_supply += pool_token_fee;
1267
1268        let fee_lamports = stake_pool
1269            .calc_lamports_withdraw_amount(pool_token_fee)
1270            .unwrap();
1271        assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); // off-by-one due to
1272                                                        // truncation
1273    }
1274
1275    #[test]
1276    fn zero_withdraw_calculation() {
1277        let epoch_fee = Fee {
1278            numerator: 0,
1279            denominator: 1,
1280        };
1281        let stake_pool = StakePool {
1282            epoch_fee,
1283            ..StakePool::default()
1284        };
1285        let fee_lamports = stake_pool.calc_lamports_withdraw_amount(0).unwrap();
1286        assert_eq!(fee_lamports, 0);
1287    }
1288
1289    #[test]
1290    fn divide_by_zero_fee() {
1291        let stake_pool = StakePool {
1292            total_lamports: 0,
1293            epoch_fee: Fee {
1294                numerator: 1,
1295                denominator: 10,
1296            },
1297            ..StakePool::default()
1298        };
1299        let rewards = 10;
1300        let fee = stake_pool.calc_epoch_fee_amount(rewards).unwrap();
1301        assert_eq!(fee, rewards);
1302    }
1303
1304    #[test]
1305    fn approximate_apr_calculation() {
1306        // 8% / year means roughly .044% / epoch
1307        let stake_pool = StakePool {
1308            last_epoch_total_lamports: 100_000,
1309            last_epoch_pool_token_supply: 100_000,
1310            total_lamports: 100_044,
1311            pool_token_supply: 100_000,
1312            ..StakePool::default()
1313        };
1314        let pool_token_value =
1315            stake_pool.total_lamports as f64 / stake_pool.pool_token_supply as f64;
1316        let last_epoch_pool_token_value = stake_pool.last_epoch_total_lamports as f64
1317            / stake_pool.last_epoch_pool_token_supply as f64;
1318        let epoch_rate = pool_token_value / last_epoch_pool_token_value - 1.0;
1319        const SECONDS_PER_EPOCH: f64 = DEFAULT_SLOTS_PER_EPOCH as f64 * DEFAULT_S_PER_SLOT;
1320        const EPOCHS_PER_YEAR: f64 = SECONDS_PER_DAY as f64 * 365.25 / SECONDS_PER_EPOCH;
1321        const EPSILON: f64 = 0.00001;
1322        let yearly_rate = epoch_rate * EPOCHS_PER_YEAR;
1323        assert!((yearly_rate - 0.080355).abs() < EPSILON);
1324    }
1325
1326    proptest! {
1327        #[test]
1328        fn fee_calculation(
1329            (numerator, denominator) in fee(),
1330            (total_lamports, reward_lamports) in total_stake_and_rewards(),
1331        ) {
1332            let epoch_fee = Fee { denominator, numerator };
1333            let mut stake_pool = StakePool {
1334                total_lamports,
1335                pool_token_supply: total_lamports,
1336                epoch_fee,
1337                ..StakePool::default()
1338            };
1339            let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1340
1341            stake_pool.total_lamports += reward_lamports;
1342            stake_pool.pool_token_supply += pool_token_fee;
1343
1344            let fee_lamports = stake_pool.calc_lamports_withdraw_amount(pool_token_fee).unwrap();
1345            let max_fee_lamports = u64::try_from((reward_lamports as u128) * (epoch_fee.numerator as u128) / (epoch_fee.denominator as u128)).unwrap();
1346            assert!(max_fee_lamports >= fee_lamports,
1347                "Max possible fee must always be greater than or equal to what is actually withdrawn, max {} actual {}",
1348                max_fee_lamports,
1349                fee_lamports);
1350
1351            // since we do two "flooring" conversions, the max epsilon should be
1352            // correct up to 2 lamports (one for each floor division), plus a
1353            // correction for huge discrepancies between rewards and total stake
1354            let epsilon = 2 + reward_lamports / total_lamports;
1355            assert!(max_fee_lamports - fee_lamports <= epsilon,
1356                "Max expected fee in lamports {}, actually receive {}, epsilon {}",
1357                max_fee_lamports, fee_lamports, epsilon);
1358        }
1359    }
1360
1361    prop_compose! {
1362        fn total_tokens_and_deposit()(total_lamports in 1..u64::MAX)(
1363            total_lamports in Just(total_lamports),
1364            pool_token_supply in 1..=total_lamports,
1365            deposit_lamports in 1..total_lamports,
1366        ) -> (u64, u64, u64) {
1367            (total_lamports - deposit_lamports, pool_token_supply.saturating_sub(deposit_lamports).max(1), deposit_lamports)
1368        }
1369    }
1370
1371    proptest! {
1372        #[test]
1373        fn deposit_and_withdraw(
1374            (total_lamports, pool_token_supply, deposit_stake) in total_tokens_and_deposit()
1375        ) {
1376            let mut stake_pool = StakePool {
1377                total_lamports,
1378                pool_token_supply,
1379                ..StakePool::default()
1380            };
1381            let deposit_result = stake_pool.calc_pool_tokens_for_deposit(deposit_stake).unwrap();
1382            prop_assume!(deposit_result > 0);
1383            stake_pool.total_lamports += deposit_stake;
1384            stake_pool.pool_token_supply += deposit_result;
1385            let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
1386            assert!(withdraw_result <= deposit_stake);
1387
1388            // also test splitting the withdrawal in two operations
1389            if deposit_result >= 2 {
1390                let first_half_deposit = deposit_result / 2;
1391                let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
1392                stake_pool.total_lamports -= first_withdraw_result;
1393                stake_pool.pool_token_supply -= first_half_deposit;
1394                let second_half_deposit = deposit_result - first_half_deposit; // do the whole thing
1395                let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
1396                assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
1397            }
1398        }
1399    }
1400
1401    #[test]
1402    fn specific_split_withdrawal() {
1403        let total_lamports = 1_100_000_000_000;
1404        let pool_token_supply = 1_000_000_000_000;
1405        let deposit_stake = 3;
1406        let mut stake_pool = StakePool {
1407            total_lamports,
1408            pool_token_supply,
1409            ..StakePool::default()
1410        };
1411        let deposit_result = stake_pool
1412            .calc_pool_tokens_for_deposit(deposit_stake)
1413            .unwrap();
1414        assert!(deposit_result > 0);
1415        stake_pool.total_lamports += deposit_stake;
1416        stake_pool.pool_token_supply += deposit_result;
1417        let withdraw_result = stake_pool
1418            .calc_lamports_withdraw_amount(deposit_result / 2)
1419            .unwrap();
1420        assert!(withdraw_result * 2 <= deposit_stake);
1421    }
1422
1423    #[test]
1424    fn withdraw_all() {
1425        let total_lamports = 1_100_000_000_000;
1426        let pool_token_supply = 1_000_000_000_000;
1427        let mut stake_pool = StakePool {
1428            total_lamports,
1429            pool_token_supply,
1430            ..StakePool::default()
1431        };
1432        // take everything out at once
1433        let withdraw_result = stake_pool
1434            .calc_lamports_withdraw_amount(pool_token_supply)
1435            .unwrap();
1436        assert_eq!(stake_pool.total_lamports, withdraw_result);
1437
1438        // take out 1, then the rest
1439        let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1440        stake_pool.total_lamports -= withdraw_result;
1441        stake_pool.pool_token_supply -= 1;
1442        let withdraw_result = stake_pool
1443            .calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
1444            .unwrap();
1445        assert_eq!(stake_pool.total_lamports, withdraw_result);
1446
1447        // take out all except 1, then the rest
1448        let mut stake_pool = StakePool {
1449            total_lamports,
1450            pool_token_supply,
1451            ..StakePool::default()
1452        };
1453        let withdraw_result = stake_pool
1454            .calc_lamports_withdraw_amount(pool_token_supply - 1)
1455            .unwrap();
1456        stake_pool.total_lamports -= withdraw_result;
1457        stake_pool.pool_token_supply = 1;
1458        assert_ne!(stake_pool.total_lamports, 0);
1459
1460        let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1461        assert_eq!(stake_pool.total_lamports, withdraw_result);
1462    }
1463}