1use {
4 crate::{
5 big_vec::BigVec, error::StakePoolError, MAX_WITHDRAWAL_FEE_INCREASE,
6 WITHDRAWAL_BASELINE_FEE,
7 },
8 borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
9 bytemuck::{Pod, Zeroable},
10 num_derive::{FromPrimitive, ToPrimitive},
11 num_traits::{FromPrimitive, ToPrimitive},
12 solana_program::{
13 account_info::AccountInfo,
14 borsh1::get_instance_packed_len,
15 msg,
16 program_error::ProgramError,
17 program_memory::sol_memcmp,
18 program_pack::{Pack, Sealed},
19 pubkey::{Pubkey, PUBKEY_BYTES},
20 stake::state::Lockup,
21 },
22 spl_pod::primitives::{PodU32, PodU64},
23 spl_token_2022::{
24 extension::{BaseStateWithExtensions, ExtensionType, StateWithExtensions},
25 state::{Account, AccountState, Mint},
26 },
27 std::{borrow::Borrow, convert::TryFrom, fmt, matches},
28};
29
30#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
32pub enum AccountType {
33 #[default]
35 Uninitialized,
36 StakePool,
38 ValidatorList,
40}
41
42#[repr(C)]
44#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
45pub struct StakePool {
46 pub account_type: AccountType,
48
49 pub manager: Pubkey,
52
53 pub staker: Pubkey,
56
57 pub stake_deposit_authority: Pubkey,
67
68 pub stake_withdraw_bump_seed: u8,
71
72 pub validator_list: Pubkey,
74
75 pub reserve_stake: Pubkey,
77
78 pub pool_mint: Pubkey,
80
81 pub manager_fee_account: Pubkey,
83
84 pub token_program_id: Pubkey,
86
87 pub total_lamports: u64,
91
92 pub pool_token_supply: u64,
95
96 pub last_update_epoch: u64,
98
99 pub lockup: Lockup,
101
102 pub epoch_fee: Fee,
104
105 pub next_epoch_fee: FutureEpoch<Fee>,
107
108 pub preferred_deposit_validator_vote_address: Option<Pubkey>,
110
111 pub preferred_withdraw_validator_vote_address: Option<Pubkey>,
113
114 pub stake_deposit_fee: Fee,
116
117 pub stake_withdrawal_fee: Fee,
119
120 pub next_stake_withdrawal_fee: FutureEpoch<Fee>,
122
123 pub stake_referral_fee: u8,
129
130 pub sol_deposit_authority: Option<Pubkey>,
133
134 pub sol_deposit_fee: Fee,
136
137 pub sol_referral_fee: u8,
143
144 pub sol_withdraw_authority: Option<Pubkey>,
147
148 pub sol_withdrawal_fee: Fee,
150
151 pub next_sol_withdrawal_fee: FutureEpoch<Fee>,
153
154 pub last_epoch_pool_token_supply: u64,
156
157 pub last_epoch_total_lamports: u64,
159}
160impl StakePool {
161 #[inline]
164 pub fn calc_pool_tokens_for_deposit(&self, stake_lamports: u64) -> Option<u64> {
165 if self.total_lamports == 0 || self.pool_token_supply == 0 {
166 return Some(stake_lamports);
167 }
168 u64::try_from(
169 (stake_lamports as u128)
170 .checked_mul(self.pool_token_supply as u128)?
171 .checked_div(self.total_lamports as u128)?,
172 )
173 .ok()
174 }
175
176 #[inline]
178 pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
179 let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
183 let denominator = self.pool_token_supply as u128;
184 if numerator < denominator || denominator == 0 {
185 Some(0)
186 } else {
187 u64::try_from(numerator.checked_div(denominator)?).ok()
188 }
189 }
190
191 #[inline]
193 pub fn calc_pool_tokens_stake_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
194 u64::try_from(self.stake_withdrawal_fee.apply(pool_tokens)?).ok()
195 }
196
197 #[inline]
199 pub fn calc_pool_tokens_sol_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
200 u64::try_from(self.sol_withdrawal_fee.apply(pool_tokens)?).ok()
201 }
202
203 #[inline]
205 pub fn calc_pool_tokens_stake_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
206 u64::try_from(self.stake_deposit_fee.apply(pool_tokens_minted)?).ok()
207 }
208
209 #[inline]
211 pub fn calc_pool_tokens_stake_referral_fee(&self, stake_deposit_fee: u64) -> Option<u64> {
212 u64::try_from(
213 (stake_deposit_fee as u128)
214 .checked_mul(self.stake_referral_fee as u128)?
215 .checked_div(100u128)?,
216 )
217 .ok()
218 }
219
220 #[inline]
222 pub fn calc_pool_tokens_sol_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
223 u64::try_from(self.sol_deposit_fee.apply(pool_tokens_minted)?).ok()
224 }
225
226 #[inline]
229 pub fn calc_pool_tokens_sol_referral_fee(&self, sol_deposit_fee: u64) -> Option<u64> {
230 u64::try_from(
231 (sol_deposit_fee as u128)
232 .checked_mul(self.sol_referral_fee as u128)?
233 .checked_div(100u128)?,
234 )
235 .ok()
236 }
237
238 #[inline]
243 pub fn calc_epoch_fee_amount(&self, reward_lamports: u64) -> Option<u64> {
244 if reward_lamports == 0 {
245 return Some(0);
246 }
247 let total_lamports = (self.total_lamports as u128).checked_add(reward_lamports as u128)?;
248 let fee_lamports = self.epoch_fee.apply(reward_lamports)?;
249 if total_lamports == fee_lamports || self.pool_token_supply == 0 {
250 Some(reward_lamports)
251 } else {
252 u64::try_from(
253 (self.pool_token_supply as u128)
254 .checked_mul(fee_lamports)?
255 .checked_div(total_lamports.checked_sub(fee_lamports)?)?,
256 )
257 .ok()
258 }
259 }
260
261 #[inline]
263 pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
264 self.total_lamports
265 .checked_add(self.pool_token_supply)?
266 .checked_sub(1)?
267 .checked_div(self.pool_token_supply)
268 }
269
270 fn check_program_derived_authority(
272 authority_address: &Pubkey,
273 program_id: &Pubkey,
274 stake_pool_address: &Pubkey,
275 authority_seed: &[u8],
276 bump_seed: u8,
277 ) -> Result<(), ProgramError> {
278 let expected_address = Pubkey::create_program_address(
279 &[stake_pool_address.as_ref(), authority_seed, &[bump_seed]],
280 program_id,
281 )?;
282
283 if *authority_address == expected_address {
284 Ok(())
285 } else {
286 msg!(
287 "Incorrect authority provided, expected {}, received {}",
288 expected_address,
289 authority_address
290 );
291 Err(StakePoolError::InvalidProgramAddress.into())
292 }
293 }
294
295 pub(crate) fn check_manager_fee_info(
298 &self,
299 manager_fee_info: &AccountInfo,
300 ) -> Result<(), ProgramError> {
301 let account_data = manager_fee_info.try_borrow_data()?;
302 let token_account = StateWithExtensions::<Account>::unpack(&account_data)?;
303 if manager_fee_info.owner != &self.token_program_id
304 || token_account.base.state != AccountState::Initialized
305 || token_account.base.mint != self.pool_mint
306 {
307 msg!("Manager fee account is not owned by token program, is not initialized, or does not match stake pool's mint");
308 return Err(StakePoolError::InvalidFeeAccount.into());
309 }
310 let extensions = token_account.get_extension_types()?;
311 if extensions
312 .iter()
313 .any(|x| !is_extension_supported_for_fee_account(x))
314 {
315 return Err(StakePoolError::UnsupportedFeeAccountExtension.into());
316 }
317 Ok(())
318 }
319
320 #[inline]
322 pub(crate) fn check_authority_withdraw(
323 &self,
324 withdraw_authority: &Pubkey,
325 program_id: &Pubkey,
326 stake_pool_address: &Pubkey,
327 ) -> Result<(), ProgramError> {
328 Self::check_program_derived_authority(
329 withdraw_authority,
330 program_id,
331 stake_pool_address,
332 crate::AUTHORITY_WITHDRAW,
333 self.stake_withdraw_bump_seed,
334 )
335 }
336 #[inline]
338 pub(crate) fn check_stake_deposit_authority(
339 &self,
340 stake_deposit_authority: &Pubkey,
341 ) -> Result<(), ProgramError> {
342 if self.stake_deposit_authority == *stake_deposit_authority {
343 Ok(())
344 } else {
345 Err(StakePoolError::InvalidStakeDepositAuthority.into())
346 }
347 }
348
349 #[inline]
352 pub(crate) fn check_sol_deposit_authority(
353 &self,
354 maybe_sol_deposit_authority: Result<&AccountInfo, ProgramError>,
355 ) -> Result<(), ProgramError> {
356 if let Some(auth) = self.sol_deposit_authority {
357 let sol_deposit_authority = maybe_sol_deposit_authority?;
358 if auth != *sol_deposit_authority.key {
359 msg!("Expected {}, received {}", auth, sol_deposit_authority.key);
360 return Err(StakePoolError::InvalidSolDepositAuthority.into());
361 }
362 if !sol_deposit_authority.is_signer {
363 msg!("SOL Deposit authority signature missing");
364 return Err(StakePoolError::SignatureMissing.into());
365 }
366 }
367 Ok(())
368 }
369
370 #[inline]
373 pub(crate) fn check_sol_withdraw_authority(
374 &self,
375 maybe_sol_withdraw_authority: Result<&AccountInfo, ProgramError>,
376 ) -> Result<(), ProgramError> {
377 if let Some(auth) = self.sol_withdraw_authority {
378 let sol_withdraw_authority = maybe_sol_withdraw_authority?;
379 if auth != *sol_withdraw_authority.key {
380 return Err(StakePoolError::InvalidSolWithdrawAuthority.into());
381 }
382 if !sol_withdraw_authority.is_signer {
383 msg!("SOL withdraw authority signature missing");
384 return Err(StakePoolError::SignatureMissing.into());
385 }
386 }
387 Ok(())
388 }
389
390 #[inline]
392 pub(crate) fn check_mint(&self, mint_info: &AccountInfo) -> Result<u8, ProgramError> {
393 if *mint_info.key != self.pool_mint {
394 Err(StakePoolError::WrongPoolMint.into())
395 } else {
396 let mint_data = mint_info.try_borrow_data()?;
397 let mint = StateWithExtensions::<Mint>::unpack(&mint_data)?;
398 Ok(mint.base.decimals)
399 }
400 }
401
402 pub(crate) fn check_manager(&self, manager_info: &AccountInfo) -> Result<(), ProgramError> {
404 if *manager_info.key != self.manager {
405 msg!(
406 "Incorrect manager provided, expected {}, received {}",
407 self.manager,
408 manager_info.key
409 );
410 return Err(StakePoolError::WrongManager.into());
411 }
412 if !manager_info.is_signer {
413 msg!("Manager signature missing");
414 return Err(StakePoolError::SignatureMissing.into());
415 }
416 Ok(())
417 }
418
419 pub(crate) fn check_staker(&self, staker_info: &AccountInfo) -> Result<(), ProgramError> {
421 if *staker_info.key != self.staker {
422 msg!(
423 "Incorrect staker provided, expected {}, received {}",
424 self.staker,
425 staker_info.key
426 );
427 return Err(StakePoolError::WrongStaker.into());
428 }
429 if !staker_info.is_signer {
430 msg!("Staker signature missing");
431 return Err(StakePoolError::SignatureMissing.into());
432 }
433 Ok(())
434 }
435
436 pub fn check_validator_list(
438 &self,
439 validator_list_info: &AccountInfo,
440 ) -> Result<(), ProgramError> {
441 if *validator_list_info.key != self.validator_list {
442 msg!(
443 "Invalid validator list provided, expected {}, received {}",
444 self.validator_list,
445 validator_list_info.key
446 );
447 Err(StakePoolError::InvalidValidatorStakeList.into())
448 } else {
449 Ok(())
450 }
451 }
452
453 pub fn check_reserve_stake(
455 &self,
456 reserve_stake_info: &AccountInfo,
457 ) -> Result<(), ProgramError> {
458 if *reserve_stake_info.key != self.reserve_stake {
459 msg!(
460 "Invalid reserve stake provided, expected {}, received {}",
461 self.reserve_stake,
462 reserve_stake_info.key
463 );
464 Err(StakePoolError::InvalidProgramAddress.into())
465 } else {
466 Ok(())
467 }
468 }
469
470 pub fn is_valid(&self) -> bool {
472 self.account_type == AccountType::StakePool
473 }
474
475 pub fn is_uninitialized(&self) -> bool {
477 self.account_type == AccountType::Uninitialized
478 }
479
480 pub fn update_fee(&mut self, fee: &FeeType) -> Result<(), StakePoolError> {
482 match fee {
483 FeeType::SolReferral(new_fee) => self.sol_referral_fee = *new_fee,
484 FeeType::StakeReferral(new_fee) => self.stake_referral_fee = *new_fee,
485 FeeType::Epoch(new_fee) => self.next_epoch_fee = FutureEpoch::new(*new_fee),
486 FeeType::StakeWithdrawal(new_fee) => {
487 new_fee.check_withdrawal(&self.stake_withdrawal_fee)?;
488 self.next_stake_withdrawal_fee = FutureEpoch::new(*new_fee)
489 }
490 FeeType::SolWithdrawal(new_fee) => {
491 new_fee.check_withdrawal(&self.sol_withdrawal_fee)?;
492 self.next_sol_withdrawal_fee = FutureEpoch::new(*new_fee)
493 }
494 FeeType::SolDeposit(new_fee) => self.sol_deposit_fee = *new_fee,
495 FeeType::StakeDeposit(new_fee) => self.stake_deposit_fee = *new_fee,
496 };
497 Ok(())
498 }
499}
500
501pub fn is_extension_supported_for_mint(extension_type: &ExtensionType) -> bool {
503 const SUPPORTED_EXTENSIONS: [ExtensionType; 8] = [
504 ExtensionType::Uninitialized,
505 ExtensionType::TransferFeeConfig,
506 ExtensionType::ConfidentialTransferMint,
507 ExtensionType::ConfidentialTransferFeeConfig,
508 ExtensionType::DefaultAccountState, ExtensionType::InterestBearingConfig,
510 ExtensionType::MetadataPointer,
511 ExtensionType::TokenMetadata,
512 ];
513 if !SUPPORTED_EXTENSIONS.contains(extension_type) {
514 msg!(
515 "Stake pool mint account cannot have the {:?} extension",
516 extension_type
517 );
518 false
519 } else {
520 true
521 }
522}
523
524pub fn is_extension_supported_for_fee_account(extension_type: &ExtensionType) -> bool {
526 const SUPPORTED_EXTENSIONS: [ExtensionType; 4] = [
530 ExtensionType::Uninitialized,
531 ExtensionType::TransferFeeAmount,
532 ExtensionType::ImmutableOwner,
533 ExtensionType::CpiGuard,
534 ];
535 if !SUPPORTED_EXTENSIONS.contains(extension_type) {
536 msg!("Fee account cannot have the {:?} extension", extension_type);
537 false
538 } else {
539 true
540 }
541}
542
543#[repr(C)]
545#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
546pub struct ValidatorList {
547 pub header: ValidatorListHeader,
550
551 pub validators: Vec<ValidatorStakeInfo>,
553}
554
555#[repr(C)]
557#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
558pub struct ValidatorListHeader {
559 pub account_type: AccountType,
561
562 pub max_validators: u32,
564}
565
566#[derive(
568 ToPrimitive,
569 FromPrimitive,
570 Copy,
571 Clone,
572 Debug,
573 Default,
574 PartialEq,
575 BorshDeserialize,
576 BorshSerialize,
577 BorshSchema,
578)]
579pub enum StakeStatus {
580 #[default]
582 Active,
583 DeactivatingTransient,
586 ReadyForRemoval,
589 DeactivatingValidator,
592 DeactivatingAll,
595}
596
597#[repr(transparent)]
600#[derive(
601 Clone,
602 Copy,
603 Debug,
604 Default,
605 PartialEq,
606 Pod,
607 Zeroable,
608 BorshDeserialize,
609 BorshSerialize,
610 BorshSchema,
611)]
612pub struct PodStakeStatus(u8);
613impl PodStakeStatus {
614 pub fn remove_validator_stake(&mut self) -> Result<(), ProgramError> {
617 let status = StakeStatus::try_from(*self)?;
618 let new_self = match status {
619 StakeStatus::Active
620 | StakeStatus::DeactivatingTransient
621 | StakeStatus::ReadyForRemoval => status,
622 StakeStatus::DeactivatingAll => StakeStatus::DeactivatingTransient,
623 StakeStatus::DeactivatingValidator => StakeStatus::ReadyForRemoval,
624 };
625 *self = new_self.into();
626 Ok(())
627 }
628 pub fn remove_transient_stake(&mut self) -> Result<(), ProgramError> {
631 let status = StakeStatus::try_from(*self)?;
632 let new_self = match status {
633 StakeStatus::Active
634 | StakeStatus::DeactivatingValidator
635 | StakeStatus::ReadyForRemoval => status,
636 StakeStatus::DeactivatingAll => StakeStatus::DeactivatingValidator,
637 StakeStatus::DeactivatingTransient => StakeStatus::ReadyForRemoval,
638 };
639 *self = new_self.into();
640 Ok(())
641 }
642}
643impl TryFrom<PodStakeStatus> for StakeStatus {
644 type Error = ProgramError;
645 fn try_from(pod: PodStakeStatus) -> Result<Self, Self::Error> {
646 FromPrimitive::from_u8(pod.0).ok_or(ProgramError::InvalidAccountData)
647 }
648}
649impl From<StakeStatus> for PodStakeStatus {
650 fn from(status: StakeStatus) -> Self {
651 PodStakeStatus(status.to_u8().unwrap())
654 }
655}
656
657#[derive(Debug, PartialEq)]
659pub(crate) enum StakeWithdrawSource {
660 Active,
662 Transient,
664 ValidatorRemoval,
666}
667
668#[repr(C)]
677#[derive(
678 Clone,
679 Copy,
680 Debug,
681 Default,
682 PartialEq,
683 Pod,
684 Zeroable,
685 BorshDeserialize,
686 BorshSerialize,
687 BorshSchema,
688)]
689pub struct ValidatorStakeInfo {
690 pub active_stake_lamports: PodU64,
695
696 pub transient_stake_lamports: PodU64,
701
702 pub last_update_epoch: PodU64,
704
705 pub transient_seed_suffix: PodU64,
708
709 pub unused: PodU32,
711
712 pub validator_seed_suffix: PodU32, pub status: PodStakeStatus,
717
718 pub vote_account_address: Pubkey,
720}
721
722impl ValidatorStakeInfo {
723 pub fn stake_lamports(&self) -> Result<u64, StakePoolError> {
725 u64::from(self.active_stake_lamports)
726 .checked_add(self.transient_stake_lamports.into())
727 .ok_or(StakePoolError::CalculationFailure)
728 }
729
730 pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
733 sol_memcmp(
734 &data[41..41_usize.saturating_add(PUBKEY_BYTES)],
735 vote_address.as_ref(),
736 PUBKEY_BYTES,
737 ) == 0
738 }
739
740 pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
743 u64::try_from_slice(&data[0..8]).unwrap() > *lamports
745 }
746
747 pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
750 u64::try_from_slice(&data[8..16]).unwrap() > *lamports
752 }
753
754 pub fn is_removed(data: &[u8]) -> bool {
756 FromPrimitive::from_u8(data[40]) == Some(StakeStatus::ReadyForRemoval)
757 && data[0..16] == [0; 16] }
759
760 pub fn is_active(data: &[u8]) -> bool {
762 FromPrimitive::from_u8(data[40]) == Some(StakeStatus::Active)
763 }
764}
765
766impl Sealed for ValidatorStakeInfo {}
767
768impl Pack for ValidatorStakeInfo {
769 const LEN: usize = 73;
770 fn pack_into_slice(&self, data: &mut [u8]) {
771 borsh::to_writer(data, self).unwrap();
774 }
775 fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
776 let unpacked = Self::try_from_slice(src)?;
777 Ok(unpacked)
778 }
779}
780
781impl ValidatorList {
782 pub fn new(max_validators: u32) -> Self {
785 Self {
786 header: ValidatorListHeader {
787 account_type: AccountType::ValidatorList,
788 max_validators,
789 },
790 validators: vec![ValidatorStakeInfo::default(); max_validators as usize],
791 }
792 }
793
794 pub fn calculate_max_validators(buffer_length: usize) -> usize {
797 let header_size = ValidatorListHeader::LEN.saturating_add(4);
798 buffer_length
799 .saturating_sub(header_size)
800 .saturating_div(ValidatorStakeInfo::LEN)
801 }
802
803 pub fn contains(&self, vote_account_address: &Pubkey) -> bool {
805 self.validators
806 .iter()
807 .any(|x| x.vote_account_address == *vote_account_address)
808 }
809
810 pub fn find_mut(&mut self, vote_account_address: &Pubkey) -> Option<&mut ValidatorStakeInfo> {
812 self.validators
813 .iter_mut()
814 .find(|x| x.vote_account_address == *vote_account_address)
815 }
816 pub fn find(&self, vote_account_address: &Pubkey) -> Option<&ValidatorStakeInfo> {
818 self.validators
819 .iter()
820 .find(|x| x.vote_account_address == *vote_account_address)
821 }
822
823 pub fn has_active_stake(&self) -> bool {
825 self.validators
826 .iter()
827 .any(|x| u64::from(x.active_stake_lamports) > 0)
828 }
829}
830
831impl ValidatorListHeader {
832 const LEN: usize = 1 + 4;
833
834 pub fn is_valid(&self) -> bool {
837 self.account_type == AccountType::ValidatorList
838 }
839
840 pub fn is_uninitialized(&self) -> bool {
842 self.account_type == AccountType::Uninitialized
843 }
844
845 pub fn deserialize_mut_slice<'a>(
848 big_vec: &'a mut BigVec,
849 skip: usize,
850 len: usize,
851 ) -> Result<&'a mut [ValidatorStakeInfo], ProgramError> {
852 big_vec.deserialize_mut_slice::<ValidatorStakeInfo>(skip, len)
853 }
854
855 pub fn deserialize_vec(data: &mut [u8]) -> Result<(Self, BigVec<'_>), ProgramError> {
857 let mut data_mut = data.borrow();
858 let header = ValidatorListHeader::deserialize(&mut data_mut)?;
859 let length = get_instance_packed_len(&header)?;
860
861 let big_vec = BigVec {
862 data: &mut data[length..],
863 };
864 Ok((header, big_vec))
865 }
866}
867
868#[repr(C)]
871#[derive(Clone, Copy, Debug, Default, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
872pub enum FutureEpoch<T> {
873 #[default]
875 None,
876 One(T),
878 Two(T),
880}
881impl<T> FutureEpoch<T> {
882 pub fn new(value: T) -> Self {
884 Self::Two(value)
885 }
886}
887impl<T: Clone> FutureEpoch<T> {
888 pub fn update_epoch(&mut self) {
890 match self {
891 Self::None => {}
892 Self::One(_) => {
893 *self = Self::None;
895 }
896 Self::Two(v) => {
898 *self = Self::One(v.clone());
899 }
900 }
901 }
902
903 pub fn get(&self) -> Option<&T> {
905 match self {
906 Self::None | Self::Two(_) => None,
907 Self::One(v) => Some(v),
908 }
909 }
910}
911impl<T> From<FutureEpoch<T>> for Option<T> {
912 fn from(v: FutureEpoch<T>) -> Option<T> {
913 match v {
914 FutureEpoch::None => None,
915 FutureEpoch::One(inner) | FutureEpoch::Two(inner) => Some(inner),
916 }
917 }
918}
919
920#[repr(C)]
925#[derive(Clone, Copy, Debug, Default, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
926pub struct Fee {
927 pub denominator: u64,
929 pub numerator: u64,
931}
932
933impl Fee {
934 #[inline]
939 pub fn apply(&self, amt: u64) -> Option<u128> {
940 if self.denominator == 0 {
941 return Some(0);
942 }
943 let numerator = (amt as u128).checked_mul(self.numerator as u128)?;
944 let denominator = self.denominator as u128;
946 numerator
947 .checked_add(denominator)?
948 .checked_sub(1)?
949 .checked_div(denominator)
950 }
951
952 pub fn check_withdrawal(&self, old_withdrawal_fee: &Fee) -> Result<(), StakePoolError> {
955 let (old_num, old_denom) =
958 if old_withdrawal_fee.denominator == 0 || old_withdrawal_fee.numerator == 0 {
959 (
960 WITHDRAWAL_BASELINE_FEE.numerator,
961 WITHDRAWAL_BASELINE_FEE.denominator,
962 )
963 } else {
964 (old_withdrawal_fee.numerator, old_withdrawal_fee.denominator)
965 };
966
967 if (old_num as u128)
971 .checked_mul(self.denominator as u128)
972 .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.numerator as u128))
973 .ok_or(StakePoolError::CalculationFailure)?
974 < (self.numerator as u128)
975 .checked_mul(old_denom as u128)
976 .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.denominator as u128))
977 .ok_or(StakePoolError::CalculationFailure)?
978 {
979 msg!(
980 "Fee increase exceeds maximum allowed, proposed increase factor ({} / {})",
981 self.numerator.saturating_mul(old_denom),
982 old_num.saturating_mul(self.denominator),
983 );
984 return Err(StakePoolError::FeeIncreaseTooHigh);
985 }
986 Ok(())
987 }
988}
989
990impl fmt::Display for Fee {
991 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
992 if self.numerator > 0 && self.denominator > 0 {
993 write!(f, "{}/{}", self.numerator, self.denominator)
994 } else {
995 write!(f, "none")
996 }
997 }
998}
999
1000#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
1002pub enum FeeType {
1003 SolReferral(u8),
1005 StakeReferral(u8),
1007 Epoch(Fee),
1009 StakeWithdrawal(Fee),
1011 SolDeposit(Fee),
1013 StakeDeposit(Fee),
1015 SolWithdrawal(Fee),
1017}
1018
1019impl FeeType {
1020 pub fn check_too_high(&self) -> Result<(), StakePoolError> {
1022 let too_high = match self {
1023 Self::SolReferral(pct) => *pct > 100u8,
1024 Self::StakeReferral(pct) => *pct > 100u8,
1025 Self::Epoch(fee) => fee.numerator > fee.denominator,
1026 Self::StakeWithdrawal(fee) => fee.numerator > fee.denominator,
1027 Self::SolWithdrawal(fee) => fee.numerator > fee.denominator,
1028 Self::SolDeposit(fee) => fee.numerator > fee.denominator,
1029 Self::StakeDeposit(fee) => fee.numerator > fee.denominator,
1030 };
1031 if too_high {
1032 msg!("Fee greater than 100%: {:?}", self);
1033 return Err(StakePoolError::FeeTooHigh);
1034 }
1035 Ok(())
1036 }
1037
1038 #[inline]
1041 pub fn can_only_change_next_epoch(&self) -> bool {
1042 matches!(
1043 self,
1044 Self::StakeWithdrawal(_) | Self::SolWithdrawal(_) | Self::Epoch(_)
1045 )
1046 }
1047}
1048
1049#[cfg(test)]
1050mod test {
1051 #![allow(clippy::arithmetic_side_effects)]
1052 use {
1053 super::*,
1054 proptest::prelude::*,
1055 solana_program::{
1056 borsh1::{get_packed_len, try_from_slice_unchecked},
1057 clock::{DEFAULT_SLOTS_PER_EPOCH, DEFAULT_S_PER_SLOT, SECONDS_PER_DAY},
1058 native_token::LAMPORTS_PER_SOL,
1059 },
1060 };
1061
1062 fn uninitialized_validator_list() -> ValidatorList {
1063 ValidatorList {
1064 header: ValidatorListHeader {
1065 account_type: AccountType::Uninitialized,
1066 max_validators: 0,
1067 },
1068 validators: vec![],
1069 }
1070 }
1071
1072 fn test_validator_list(max_validators: u32) -> ValidatorList {
1073 ValidatorList {
1074 header: ValidatorListHeader {
1075 account_type: AccountType::ValidatorList,
1076 max_validators,
1077 },
1078 validators: vec![
1079 ValidatorStakeInfo {
1080 status: StakeStatus::Active.into(),
1081 vote_account_address: Pubkey::new_from_array([1; 32]),
1082 active_stake_lamports: u64::from_le_bytes([255; 8]).into(),
1083 transient_stake_lamports: u64::from_le_bytes([128; 8]).into(),
1084 last_update_epoch: u64::from_le_bytes([64; 8]).into(),
1085 transient_seed_suffix: 0.into(),
1086 unused: 0.into(),
1087 validator_seed_suffix: 0.into(),
1088 },
1089 ValidatorStakeInfo {
1090 status: StakeStatus::DeactivatingTransient.into(),
1091 vote_account_address: Pubkey::new_from_array([2; 32]),
1092 active_stake_lamports: 998877665544.into(),
1093 transient_stake_lamports: 222222222.into(),
1094 last_update_epoch: 11223445566.into(),
1095 transient_seed_suffix: 0.into(),
1096 unused: 0.into(),
1097 validator_seed_suffix: 0.into(),
1098 },
1099 ValidatorStakeInfo {
1100 status: StakeStatus::ReadyForRemoval.into(),
1101 vote_account_address: Pubkey::new_from_array([3; 32]),
1102 active_stake_lamports: 0.into(),
1103 transient_stake_lamports: 0.into(),
1104 last_update_epoch: 999999999999999.into(),
1105 transient_seed_suffix: 0.into(),
1106 unused: 0.into(),
1107 validator_seed_suffix: 0.into(),
1108 },
1109 ],
1110 }
1111 }
1112
1113 #[test]
1114 fn state_packing() {
1115 let max_validators = 10_000;
1116 let size = get_instance_packed_len(&ValidatorList::new(max_validators)).unwrap();
1117 let stake_list = uninitialized_validator_list();
1118 let mut byte_vec = vec![0u8; size];
1119 let bytes = byte_vec.as_mut_slice();
1120 borsh::to_writer(bytes, &stake_list).unwrap();
1121 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1122 assert_eq!(stake_list_unpacked, stake_list);
1123
1124 let stake_list = ValidatorList {
1126 header: ValidatorListHeader {
1127 account_type: AccountType::ValidatorList,
1128 max_validators: 0,
1129 },
1130 validators: vec![],
1131 };
1132 let mut byte_vec = vec![0u8; size];
1133 let bytes = byte_vec.as_mut_slice();
1134 borsh::to_writer(bytes, &stake_list).unwrap();
1135 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1136 assert_eq!(stake_list_unpacked, stake_list);
1137
1138 let stake_list = test_validator_list(max_validators);
1140 let mut byte_vec = vec![0u8; size];
1141 let bytes = byte_vec.as_mut_slice();
1142 borsh::to_writer(bytes, &stake_list).unwrap();
1143 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1144 assert_eq!(stake_list_unpacked, stake_list);
1145 }
1146
1147 #[test]
1148 fn validator_list_active_stake() {
1149 let max_validators = 10_000;
1150 let mut validator_list = test_validator_list(max_validators);
1151 assert!(validator_list.has_active_stake());
1152 for validator in validator_list.validators.iter_mut() {
1153 validator.active_stake_lamports = 0.into();
1154 }
1155 assert!(!validator_list.has_active_stake());
1156 }
1157
1158 #[test]
1159 fn validator_list_deserialize_mut_slice() {
1160 let max_validators = 10;
1161 let stake_list = test_validator_list(max_validators);
1162 let mut serialized = borsh::to_vec(&stake_list).unwrap();
1163 let (header, mut big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1164 let list = ValidatorListHeader::deserialize_mut_slice(
1165 &mut big_vec,
1166 0,
1167 stake_list.validators.len(),
1168 )
1169 .unwrap();
1170 assert_eq!(header.account_type, AccountType::ValidatorList);
1171 assert_eq!(header.max_validators, max_validators);
1172 assert!(list
1173 .iter()
1174 .zip(stake_list.validators.iter())
1175 .all(|(a, b)| a == b));
1176
1177 let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 1, 2).unwrap();
1178 assert!(list
1179 .iter()
1180 .zip(stake_list.validators[1..].iter())
1181 .all(|(a, b)| a == b));
1182 let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 2, 1).unwrap();
1183 assert!(list
1184 .iter()
1185 .zip(stake_list.validators[2..].iter())
1186 .all(|(a, b)| a == b));
1187 let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 0, 2).unwrap();
1188 assert!(list
1189 .iter()
1190 .zip(stake_list.validators[..2].iter())
1191 .all(|(a, b)| a == b));
1192
1193 assert_eq!(
1194 ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 0, 4).unwrap_err(),
1195 ProgramError::AccountDataTooSmall
1196 );
1197 assert_eq!(
1198 ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 1, 3).unwrap_err(),
1199 ProgramError::AccountDataTooSmall
1200 );
1201 }
1202
1203 #[test]
1204 fn validator_list_iter() {
1205 let max_validators = 10;
1206 let stake_list = test_validator_list(max_validators);
1207 let mut serialized = borsh::to_vec(&stake_list).unwrap();
1208 let (_, big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1209 for (a, b) in big_vec
1210 .deserialize_slice::<ValidatorStakeInfo>(0, big_vec.len() as usize)
1211 .unwrap()
1212 .iter()
1213 .zip(stake_list.validators.iter())
1214 {
1215 assert_eq!(a, b);
1216 }
1217 }
1218
1219 proptest! {
1220 #[test]
1221 fn stake_list_size_calculation(test_amount in 0..=100_000_u32) {
1222 let validators = ValidatorList::new(test_amount);
1223 let size = get_instance_packed_len(&validators).unwrap();
1224 assert_eq!(ValidatorList::calculate_max_validators(size), test_amount as usize);
1225 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(1)), test_amount as usize);
1226 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(get_packed_len::<ValidatorStakeInfo>())), (test_amount + 1)as usize);
1227 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_sub(1)), (test_amount.saturating_sub(1)) as usize);
1228 }
1229 }
1230
1231 prop_compose! {
1232 fn fee()(denominator in 1..=u16::MAX)(
1233 denominator in Just(denominator),
1234 numerator in 0..=denominator,
1235 ) -> (u64, u64) {
1236 (numerator as u64, denominator as u64)
1237 }
1238 }
1239
1240 prop_compose! {
1241 fn total_stake_and_rewards()(total_lamports in 1..u64::MAX)(
1242 total_lamports in Just(total_lamports),
1243 rewards in 0..=total_lamports,
1244 ) -> (u64, u64) {
1245 (total_lamports - rewards, rewards)
1246 }
1247 }
1248
1249 #[test]
1250 fn specific_fee_calculation() {
1251 let epoch_fee = Fee {
1253 numerator: 1,
1254 denominator: 10,
1255 };
1256 let mut stake_pool = StakePool {
1257 total_lamports: 100 * LAMPORTS_PER_SOL,
1258 pool_token_supply: 100 * LAMPORTS_PER_SOL,
1259 epoch_fee,
1260 ..StakePool::default()
1261 };
1262 let reward_lamports = 10 * LAMPORTS_PER_SOL;
1263 let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1264
1265 stake_pool.total_lamports += reward_lamports;
1266 stake_pool.pool_token_supply += pool_token_fee;
1267
1268 let fee_lamports = stake_pool
1269 .calc_lamports_withdraw_amount(pool_token_fee)
1270 .unwrap();
1271 assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); }
1274
1275 #[test]
1276 fn zero_withdraw_calculation() {
1277 let epoch_fee = Fee {
1278 numerator: 0,
1279 denominator: 1,
1280 };
1281 let stake_pool = StakePool {
1282 epoch_fee,
1283 ..StakePool::default()
1284 };
1285 let fee_lamports = stake_pool.calc_lamports_withdraw_amount(0).unwrap();
1286 assert_eq!(fee_lamports, 0);
1287 }
1288
1289 #[test]
1290 fn divide_by_zero_fee() {
1291 let stake_pool = StakePool {
1292 total_lamports: 0,
1293 epoch_fee: Fee {
1294 numerator: 1,
1295 denominator: 10,
1296 },
1297 ..StakePool::default()
1298 };
1299 let rewards = 10;
1300 let fee = stake_pool.calc_epoch_fee_amount(rewards).unwrap();
1301 assert_eq!(fee, rewards);
1302 }
1303
1304 #[test]
1305 fn approximate_apr_calculation() {
1306 let stake_pool = StakePool {
1308 last_epoch_total_lamports: 100_000,
1309 last_epoch_pool_token_supply: 100_000,
1310 total_lamports: 100_044,
1311 pool_token_supply: 100_000,
1312 ..StakePool::default()
1313 };
1314 let pool_token_value =
1315 stake_pool.total_lamports as f64 / stake_pool.pool_token_supply as f64;
1316 let last_epoch_pool_token_value = stake_pool.last_epoch_total_lamports as f64
1317 / stake_pool.last_epoch_pool_token_supply as f64;
1318 let epoch_rate = pool_token_value / last_epoch_pool_token_value - 1.0;
1319 const SECONDS_PER_EPOCH: f64 = DEFAULT_SLOTS_PER_EPOCH as f64 * DEFAULT_S_PER_SLOT;
1320 const EPOCHS_PER_YEAR: f64 = SECONDS_PER_DAY as f64 * 365.25 / SECONDS_PER_EPOCH;
1321 const EPSILON: f64 = 0.00001;
1322 let yearly_rate = epoch_rate * EPOCHS_PER_YEAR;
1323 assert!((yearly_rate - 0.080355).abs() < EPSILON);
1324 }
1325
1326 proptest! {
1327 #[test]
1328 fn fee_calculation(
1329 (numerator, denominator) in fee(),
1330 (total_lamports, reward_lamports) in total_stake_and_rewards(),
1331 ) {
1332 let epoch_fee = Fee { denominator, numerator };
1333 let mut stake_pool = StakePool {
1334 total_lamports,
1335 pool_token_supply: total_lamports,
1336 epoch_fee,
1337 ..StakePool::default()
1338 };
1339 let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1340
1341 stake_pool.total_lamports += reward_lamports;
1342 stake_pool.pool_token_supply += pool_token_fee;
1343
1344 let fee_lamports = stake_pool.calc_lamports_withdraw_amount(pool_token_fee).unwrap();
1345 let max_fee_lamports = u64::try_from((reward_lamports as u128) * (epoch_fee.numerator as u128) / (epoch_fee.denominator as u128)).unwrap();
1346 assert!(max_fee_lamports >= fee_lamports,
1347 "Max possible fee must always be greater than or equal to what is actually withdrawn, max {} actual {}",
1348 max_fee_lamports,
1349 fee_lamports);
1350
1351 let epsilon = 2 + reward_lamports / total_lamports;
1355 assert!(max_fee_lamports - fee_lamports <= epsilon,
1356 "Max expected fee in lamports {}, actually receive {}, epsilon {}",
1357 max_fee_lamports, fee_lamports, epsilon);
1358 }
1359 }
1360
1361 prop_compose! {
1362 fn total_tokens_and_deposit()(total_lamports in 1..u64::MAX)(
1363 total_lamports in Just(total_lamports),
1364 pool_token_supply in 1..=total_lamports,
1365 deposit_lamports in 1..total_lamports,
1366 ) -> (u64, u64, u64) {
1367 (total_lamports - deposit_lamports, pool_token_supply.saturating_sub(deposit_lamports).max(1), deposit_lamports)
1368 }
1369 }
1370
1371 proptest! {
1372 #[test]
1373 fn deposit_and_withdraw(
1374 (total_lamports, pool_token_supply, deposit_stake) in total_tokens_and_deposit()
1375 ) {
1376 let mut stake_pool = StakePool {
1377 total_lamports,
1378 pool_token_supply,
1379 ..StakePool::default()
1380 };
1381 let deposit_result = stake_pool.calc_pool_tokens_for_deposit(deposit_stake).unwrap();
1382 prop_assume!(deposit_result > 0);
1383 stake_pool.total_lamports += deposit_stake;
1384 stake_pool.pool_token_supply += deposit_result;
1385 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
1386 assert!(withdraw_result <= deposit_stake);
1387
1388 if deposit_result >= 2 {
1390 let first_half_deposit = deposit_result / 2;
1391 let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
1392 stake_pool.total_lamports -= first_withdraw_result;
1393 stake_pool.pool_token_supply -= first_half_deposit;
1394 let second_half_deposit = deposit_result - first_half_deposit; let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
1396 assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
1397 }
1398 }
1399 }
1400
1401 #[test]
1402 fn specific_split_withdrawal() {
1403 let total_lamports = 1_100_000_000_000;
1404 let pool_token_supply = 1_000_000_000_000;
1405 let deposit_stake = 3;
1406 let mut stake_pool = StakePool {
1407 total_lamports,
1408 pool_token_supply,
1409 ..StakePool::default()
1410 };
1411 let deposit_result = stake_pool
1412 .calc_pool_tokens_for_deposit(deposit_stake)
1413 .unwrap();
1414 assert!(deposit_result > 0);
1415 stake_pool.total_lamports += deposit_stake;
1416 stake_pool.pool_token_supply += deposit_result;
1417 let withdraw_result = stake_pool
1418 .calc_lamports_withdraw_amount(deposit_result / 2)
1419 .unwrap();
1420 assert!(withdraw_result * 2 <= deposit_stake);
1421 }
1422
1423 #[test]
1424 fn withdraw_all() {
1425 let total_lamports = 1_100_000_000_000;
1426 let pool_token_supply = 1_000_000_000_000;
1427 let mut stake_pool = StakePool {
1428 total_lamports,
1429 pool_token_supply,
1430 ..StakePool::default()
1431 };
1432 let withdraw_result = stake_pool
1434 .calc_lamports_withdraw_amount(pool_token_supply)
1435 .unwrap();
1436 assert_eq!(stake_pool.total_lamports, withdraw_result);
1437
1438 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1440 stake_pool.total_lamports -= withdraw_result;
1441 stake_pool.pool_token_supply -= 1;
1442 let withdraw_result = stake_pool
1443 .calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
1444 .unwrap();
1445 assert_eq!(stake_pool.total_lamports, withdraw_result);
1446
1447 let mut stake_pool = StakePool {
1449 total_lamports,
1450 pool_token_supply,
1451 ..StakePool::default()
1452 };
1453 let withdraw_result = stake_pool
1454 .calc_lamports_withdraw_amount(pool_token_supply - 1)
1455 .unwrap();
1456 stake_pool.total_lamports -= withdraw_result;
1457 stake_pool.pool_token_supply = 1;
1458 assert_ne!(stake_pool.total_lamports, 0);
1459
1460 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1461 assert_eq!(stake_pool.total_lamports, withdraw_result);
1462 }
1463}