star_frame 0.30.0

A high performance Solana framework for building fast, scalable, and secure smart contracts.
Documentation
//! A [`ProgramAccount`] that contains an [`UnsizedType`].

use crate::{
    account_set::{
        modifiers::{
            CanInitAccount, HasInnerType, HasOwnerProgram, HasSeeds, OwnerProgramDiscriminant,
        },
        CanAddLamports, CanCloseAccount as _, CanFundRent, CanModifyRent as _,
        CanSystemCreateAccount as _,
    },
    errors::ErrorCode,
    prelude::*,
    unsize::{init::UnsizedInit, wrapper::SharedWrapper},
};
use advancer::Advance;
use bytemuck::bytes_of;
use std::marker::PhantomData;

/// Increases or decreases the rent of self to be the minimum required using [`CanModifyRent::normalize_rent`](crate::account_set::CanModifyRent::normalize_rent).
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub struct NormalizeRent<T>(pub T);

/// Decreases the rent of self to be the minimum required using [`CanModifyRent::refund_rent`](crate::account_set::CanModifyRent::refund_rent).
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub struct RefundRent<T>(pub T);

/// Increases the rent of self to be at least the minimum rent using [`CanModifyRent::receive_rent`](crate::account_set::CanModifyRent::receive_rent).
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub struct ReceiveRent<T>(pub T);

/// Closes the account using [`CanCloseAccount::close_account`](crate::account_set::CanCloseAccount::close_account).
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub struct CloseAccount<T>(pub T);

/// A [`ProgramAccount`] that contains an [`UnsizedType`].
///
/// Calls [`ProgramAccount::validate_account_info`] during validation to ensure the owner and discriminant match.
#[derive(AccountSet, derive_where::DeriveWhere)]
#[derive_where(Clone, Debug, Copy)]
#[account_set(skip_default_idl, skip_default_cleanup)]
#[cfg_attr(feature = "aggressive_inline",
    validate(inline_always, extra_validation = T::validate_account_info(self.info))
)]
#[cfg_attr(not(feature = "aggressive_inline"),
    validate(extra_validation = T::validate_account_info(self.info))
)]
#[cleanup(
    generics = [],
    extra_cleanup = self.check_cleanup(ctx),
)]
#[cleanup(
    id = "normalize_rent",
    generics = [<'a, Funder> where Funder: CanFundRent],
    arg = NormalizeRent<&'a Funder>,
    extra_cleanup = self.normalize_rent(arg.0, ctx)
)]
#[cleanup(
    id = "normalize_rent_cached",
    arg = NormalizeRent<()>,
    generics = [],
    extra_cleanup = {
        let funder = ctx.get_funder().ok_or_else(|| error!(ErrorCode::EmptyFunderCache, "Missing `funder` in cache for `NormalizeRent`"))?;
        self.normalize_rent(funder, ctx)
    },
)]
#[cleanup(
    id = "receive_rent",
    generics = [<'a, Funder> where Funder: CanFundRent],
    arg = ReceiveRent<&'a Funder>,
    extra_cleanup = self.receive_rent(arg.0, ctx)
)]
#[cleanup(
    id = "receive_rent_cached",
    arg = ReceiveRent<()>,
    generics = [],
    extra_cleanup = {
        let funder = ctx.get_funder().ok_or_else(|| error!(ErrorCode::EmptyFunderCache, "Missing `funder` in cache for `ReceiveRent`"))?;
        self.receive_rent(funder, ctx)
    }
)]
#[cleanup(
    id = "refund_rent",
    generics = [<'a, Recipient> where Recipient: CanAddLamports],
    arg = RefundRent<&'a Recipient>,
    extra_cleanup = self.refund_rent(arg.0, ctx)
)]
#[cleanup(
    id = "refund_rent_cached",
    arg = RefundRent<()>,
    generics = [],
    extra_cleanup = {
        let recipient = ctx.get_recipient().ok_or_else(|| error!(ErrorCode::EmptyRecipientCache, "Missing `recipient` in cache for `RefundRent`"))?;
        self.refund_rent(recipient, ctx)
    }
)]
#[cleanup(
    id = "close_account",
    generics = [<'a, Recipient> where Recipient: CanAddLamports],
    arg = CloseAccount<&'a Recipient>,
    extra_cleanup = self.close_account(arg.0)
)]
#[cleanup(
    id = "close_account_cached",
    arg = CloseAccount<()>,
    generics = [],
    extra_cleanup = {
        let recipient = ctx.get_recipient().ok_or_else(|| error!(ErrorCode::EmptyRecipientCache, "Missing `recipient` in cache for `CloseAccount`"))?;
        self.close_account(recipient)
    }
)]
pub struct Account<T: ProgramAccount + UnsizedType + ?Sized> {
    #[single_account_set(
        skip_has_inner_type,
        skip_can_init_account,
        skip_has_seeds,
        skip_has_owner_program
    )]
    info: AccountInfo,
    #[account_set(skip = PhantomData)]
    phantom_t: PhantomData<T>,
}

impl<T> Account<T>
where
    T: ProgramAccount + UnsizedType + ?Sized,
{
    #[inline]
    pub fn data(&self) -> Result<SharedWrapper<'_, T::Ptr>> {
        // If the account is writable, changes could have been made after AccountSetValidate has been run
        if self.is_writable() {
            T::validate_account_info(self.info)?;
        }
        SharedWrapper::new::<AccountDiscriminant<T>>(&self.info)
    }

    #[inline]
    pub fn data_mut(&self) -> Result<ExclusiveWrapperTop<'_, AccountDiscriminant<T>, AccountInfo>> {
        // If the account is writable, changes could have been made after AccountSetValidate has been run
        if self.is_writable() {
            T::validate_account_info(self.info)?;
        } else {
            // TODO: Perhaps put this behind a debug flag?
            bail!(
                ProgramError::AccountBorrowFailed,
                "Tried to borrow mutably from Account `{}` which is not writable",
                self.pubkey()
            );
        }
        ExclusiveWrapper::new(&self.info)
    }
}

pub mod discriminant {
    use crate::{
        account_set::modifiers::OwnerProgramDiscriminant,
        unsize::{init::UnsizedInit, FromOwned, RawSliceAdvance},
    };

    use super::*;
    #[derive(Debug)]
    pub struct AccountDiscriminant<T: UnsizedType + ProgramAccount + ?Sized>(T);

    unsafe impl<T> UnsizedType for AccountDiscriminant<T>
    where
        T: ProgramAccount + UnsizedType + ?Sized,
    {
        type Ptr = T::Ptr;
        type Owned = T::Owned;
        const ZST_STATUS: bool = T::ZST_STATUS;

        #[inline]
        unsafe fn get_ptr(data: &mut *mut [u8]) -> Result<Self::Ptr> {
            data.try_advance(size_of::<OwnerProgramDiscriminant<T>>())
                .with_ctx(|| {
                    format!(
                        "Failed to advance past discriminant of size {}",
                        size_of::<OwnerProgramDiscriminant<T>>()
                    )
                })?;
            unsafe { T::get_ptr(data) }
        }

        #[inline]
        fn data_len(m: &Self::Ptr) -> usize {
            T::data_len(m)
        }

        #[inline]
        fn start_ptr(m: &Self::Ptr) -> *mut () {
            T::start_ptr(m)
        }

        fn owned(mut data: &[u8]) -> Result<Self::Owned> {
            data.try_advance(size_of::<OwnerProgramDiscriminant<T>>())
                .with_ctx(|| {
                    format!(
                        "Failed to advance past discriminant of size {}",
                        size_of::<OwnerProgramDiscriminant<T>>()
                    )
                })?;
            T::owned(data)
        }

        fn owned_from_ptr(r: &Self::Ptr) -> Result<Self::Owned> {
            T::owned_from_ptr(r)
        }

        #[inline]
        unsafe fn resize_notification(
            self_mut: &mut Self::Ptr,
            source_ptr: *const (),
            change: isize,
        ) -> Result<()> {
            unsafe { T::resize_notification(self_mut, source_ptr, change) }
        }
    }

    impl<T> FromOwned for AccountDiscriminant<T>
    where
        T: ProgramAccount + UnsizedType + FromOwned + ?Sized,
    {
        fn byte_size(owned: &T::Owned) -> usize {
            T::byte_size(owned) + size_of::<OwnerProgramDiscriminant<T>>()
        }

        fn from_owned(owned: T::Owned, bytes: &mut &mut [u8]) -> Result<usize> {
            bytes
                .try_advance(size_of::<OwnerProgramDiscriminant<T>>())
                .with_ctx(|| {
                    format!(
                        "Failed to advance past discriminant during initialization of {}",
                        std::any::type_name::<T>()
                    )
                })?
                .copy_from_slice(bytes_of(&T::DISCRIMINANT));
            T::from_owned(owned, bytes).map(|size| size + size_of::<OwnerProgramDiscriminant<T>>())
        }
    }
    impl<T, I> UnsizedInit<I> for AccountDiscriminant<T>
    where
        T: UnsizedType + ?Sized + ProgramAccount + UnsizedInit<I>,
    {
        const INIT_BYTES: usize = T::INIT_BYTES + size_of::<OwnerProgramDiscriminant<T>>();

        #[inline]
        fn init(bytes: &mut &mut [u8], arg: I) -> Result<()> {
            bytes
                .try_advance(size_of::<OwnerProgramDiscriminant<T>>())
                .with_ctx(|| {
                    format!(
                        "Failed to advance past discriminant during initialization of {}",
                        std::any::type_name::<T>()
                    )
                })?
                .copy_from_slice(bytes_of(&T::DISCRIMINANT));
            T::init(bytes, arg)
        }
    }
}
use discriminant::AccountDiscriminant;

impl<T: ProgramAccount + UnsizedType + ?Sized> HasInnerType for Account<T> {
    type Inner = T;
}

impl<T: ProgramAccount + UnsizedType + ?Sized> HasOwnerProgram for Account<T> {
    type OwnerProgram = T::OwnerProgram;
}

impl<T: ProgramAccount + UnsizedType + ?Sized> HasSeeds for Account<T>
where
    T: HasSeeds,
{
    type Seeds = T::Seeds;
}

impl<T: ProgramAccount + UnsizedType + ?Sized> CanInitAccount<()> for Account<T>
where
    T: UnsizedInit<DefaultInit>,
{
    #[inline]
    fn init_account<const IF_NEEDED: bool>(
        &mut self,
        _arg: (),
        account_seeds: Option<&[&[u8]]>,
        ctx: &Context,
    ) -> Result<bool> {
        self.init_account::<IF_NEEDED>(|| DefaultInit, account_seeds, ctx)
    }
}

impl<T: ProgramAccount + UnsizedType + ?Sized, Funder> CanInitAccount<(&Funder,)> for Account<T>
where
    T: UnsizedInit<DefaultInit>,
    Funder: CanFundRent + ?Sized,
{
    #[inline]
    fn init_account<const IF_NEEDED: bool>(
        &mut self,
        arg: (&Funder,),
        account_seeds: Option<&[&[u8]]>,
        ctx: &Context,
    ) -> Result<bool> {
        self.init_account::<IF_NEEDED>((|| DefaultInit, arg.0), account_seeds, ctx)
    }
}

impl<T: ProgramAccount + UnsizedType + ?Sized, InitArg, InitFn> CanInitAccount<InitFn>
    for Account<T>
where
    InitFn: FnOnce() -> InitArg,
    T: UnsizedInit<InitArg>,
{
    #[inline]
    fn init_account<const IF_NEEDED: bool>(
        &mut self,
        arg: InitFn,
        account_seeds: Option<&[&[u8]]>,
        ctx: &Context,
    ) -> Result<bool> {
        let funder = ctx.get_funder().ok_or_else(|| {
            error!(
                ErrorCode::EmptyFunderCache,
                "Missing tagged `funder` for Account `init_account`"
            )
        })?;
        self.init_account::<IF_NEEDED>((arg, funder), account_seeds, ctx)
    }
}

impl<T: ProgramAccount + UnsizedType + ?Sized, InitArg, Funder, InitFn>
    CanInitAccount<(InitFn, &Funder)> for Account<T>
where
    T: UnsizedInit<InitArg>,
    InitFn: FnOnce() -> InitArg,
    Funder: CanFundRent + ?Sized,
{
    #[inline]
    fn init_account<const IF_NEEDED: bool>(
        &mut self,
        arg: (InitFn, &Funder),
        account_seeds: Option<&[&[u8]]>,
        ctx: &Context,
    ) -> Result<bool> {
        if IF_NEEDED {
            let needs_init = self.info.owner().fast_eq(&System::ID)
                || self.account_data()?[..size_of::<OwnerProgramDiscriminant<T>>()]
                    .iter()
                    .all(|x| *x == 0);
            if !needs_init {
                return Ok(false);
            }
        }
        self.check_writable()?;
        let (arg, funder) = arg;
        self.system_create_account(
            funder,
            T::OwnerProgram::ID,
            <AccountDiscriminant<T>>::INIT_BYTES,
            account_seeds,
            ctx,
        )
        .ctx("system_create_account failed")?;
        let mut data_bytes = self.account_data_mut()?;
        let mut data_bytes = &mut *data_bytes;
        <AccountDiscriminant<T>>::init(&mut data_bytes, arg())?;
        Ok(true)
    }
}

#[cfg(all(feature = "idl", not(target_os = "solana")))]
mod idl_impl {

    use super::*;
    use star_frame::idl::AccountSetToIdl;
    use star_frame_idl::{account_set::IdlAccountSetDef, IdlDefinition};

    impl<T: ProgramAccount + UnsizedType + ?Sized, A> AccountSetToIdl<A> for Account<T>
    where
        AccountInfo: AccountSetToIdl<A>,
        T: AccountToIdl,
    {
        fn account_set_to_idl(
            idl_definition: &mut IdlDefinition,
            arg: A,
        ) -> crate::IdlResult<IdlAccountSetDef> {
            let mut set = <AccountInfo>::account_set_to_idl(idl_definition, arg)?;
            set.single()?
                .program_accounts
                .push(T::account_to_idl(idl_definition)?);
            Ok(set)
        }
    }
}