use crate::{
account_set::{
modifiers::{CanInitAccount, CanInitSeeds, HasSeeds, SignedAccount},
AccountSetValidate,
},
prelude::*,
ErrorCode,
};
use bytemuck::bytes_of;
use derive_more::{Deref, DerefMut};
use std::marker::PhantomData;
pub use star_frame_proc::GetSeeds;
pub trait GetSeeds: Debug {
fn seeds(&self) -> Vec<&[u8]>;
}
impl<T> GetSeeds for T
where
T: Seed + Debug,
{
fn seeds(&self) -> Vec<&[u8]> {
vec![self.seed(), &[]]
}
}
pub trait Seed {
fn seed(&self) -> &[u8];
}
impl<T> Seed for T
where
T: NoUninit,
{
fn seed(&self) -> &[u8] {
bytes_of(self)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash, PartialOrd, Ord)]
pub struct SeedsWithBump<T: GetSeeds> {
pub seeds: T,
pub bump: u8,
}
impl<T> SeedsWithBump<T>
where
T: GetSeeds,
{
pub fn seeds_with_bump(&self) -> Vec<&[u8]> {
let mut seeds = self.seeds.seeds();
if let Some(last) = seeds.last_mut() {
if last.is_empty() {
*last = bytes_of(&self.bump);
return seeds;
}
}
seeds.push(bytes_of(&self.bump));
seeds
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash, PartialOrd, Ord)]
#[repr(transparent)]
pub struct Seeds<T>(pub T);
#[derive(Debug, Clone, Copy)]
pub struct CurrentProgram;
pub trait SeedProgram {
fn id(ctx: &Context) -> Result<Pubkey>;
#[cfg(all(feature = "idl", not(target_os = "solana")))]
fn idl_program() -> Option<Pubkey>;
}
impl SeedProgram for CurrentProgram {
fn id(ctx: &Context) -> Result<Pubkey> {
Ok(*ctx.current_program_id())
}
#[cfg(all(feature = "idl", not(target_os = "solana")))]
fn idl_program() -> Option<Pubkey> {
None
}
}
impl<P> SeedProgram for P
where
P: StarFrameProgram,
{
fn id(_ctx: &Context) -> Result<Pubkey> {
Ok(P::ID)
}
#[cfg(all(feature = "idl", not(target_os = "solana")))]
fn idl_program() -> Option<Pubkey> {
Some(P::ID)
}
}
#[derive(AccountSet, Deref, DerefMut, derive_where::DeriveWhere)]
#[derive_where(Debug, Clone; T, SeedsWithBump<S>)]
#[account_set(skip_default_idl, skip_default_validate)]
#[validate(
id = "seeds",
generics = [where T: AccountSetValidate<()> + SingleAccountSet],
arg = Seeds<S>,
before_validation = self.validate_and_set_seeds(&arg, ctx)
)]
#[validate(
id = "seeds_generic",
arg = (Seeds<S>, A),
before_validation = self.validate_and_set_seeds(&arg.0, ctx)
)]
#[validate(
id = "seeds_with_bump",
generics = [where T: AccountSetValidate<()> + SingleAccountSet],
arg = SeedsWithBump<S>,
before_validation = self.validate_and_set_seeds_with_bump(&arg, ctx)
)]
#[validate(
id = "seeds_with_bump_generic",
arg = (SeedsWithBump<S>, A),
before_validation = self.validate_and_set_seeds_with_bump(&arg.0, ctx)
)]
pub struct Seeded<T, S = <T as HasSeeds>::Seeds, P = CurrentProgram>
where
S: GetSeeds + Clone,
P: SeedProgram,
{
#[single_account_set(
skip_signed_account,
skip_has_seeds,
skip_can_init_seeds,
skip_can_init_account
)]
#[validate(id = "seeds_generic", arg = arg.1)]
#[validate(id = "seeds_with_bump_generic", arg = arg.1)]
#[deref]
#[deref_mut]
pub(crate) account: T,
#[account_set(skip = None)]
pub(crate) seeds: Option<SeedsWithBump<S>>,
#[account_set(skip = PhantomData)]
phantom_p: PhantomData<P>,
}
impl<T, S, P, A> CanInitSeeds<(Seeds<S>, A)> for Seeded<T, S, P>
where
T: SingleAccountSet + AccountSetValidate<A>,
S: GetSeeds + Clone,
P: SeedProgram,
{
fn init_seeds(&mut self, arg: &(Seeds<S>, A), ctx: &Context) -> Result<()> {
self.validate_and_set_seeds(&arg.0, ctx)
}
}
impl<T, S, P> CanInitSeeds<Seeds<S>> for Seeded<T, S, P>
where
T: SingleAccountSet + AccountSetValidate<()>,
S: GetSeeds + Clone,
P: SeedProgram,
{
fn init_seeds(&mut self, arg: &Seeds<S>, ctx: &Context) -> Result<()> {
self.validate_and_set_seeds(arg, ctx)
}
}
impl<T, S, P, A> CanInitSeeds<(SeedsWithBump<S>, A)> for Seeded<T, S, P>
where
T: SingleAccountSet + AccountSetValidate<A>,
S: GetSeeds + Clone,
P: SeedProgram,
{
fn init_seeds(&mut self, arg: &(SeedsWithBump<S>, A), ctx: &Context) -> Result<()> {
self.validate_and_set_seeds_with_bump(&arg.0, ctx)
}
}
impl<T, S, P> CanInitSeeds<SeedsWithBump<S>> for Seeded<T, S, P>
where
T: SingleAccountSet + AccountSetValidate<()>,
S: GetSeeds + Clone,
P: SeedProgram,
{
fn init_seeds(&mut self, arg: &SeedsWithBump<S>, ctx: &Context) -> Result<()> {
self.validate_and_set_seeds_with_bump(arg, ctx)
}
}
impl<T, S, P> Seeded<T, S, P>
where
T: SingleAccountSet,
S: GetSeeds + Clone,
P: SeedProgram,
{
fn validate_and_set_seeds(&mut self, seeds: &Seeds<S>, ctx: &Context) -> Result<()> {
if self.seeds.is_some() {
return Ok(());
}
let seeds = seeds.clone().0;
let (address, bump) = Pubkey::find_program_address(&seeds.seeds(), &P::id(ctx)?);
let expected = self.account.account_info().pubkey();
ensure!(
address.fast_eq(expected),
ErrorCode::AddressMismatch,
"Seeds: {seeds:?} result in address `{address}` and bump `{bump}`, expected `{expected}`"
);
self.seeds = Some(SeedsWithBump { seeds, bump });
Ok(())
}
fn validate_and_set_seeds_with_bump(
&mut self,
seeds: &SeedsWithBump<S>,
ctx: &Context,
) -> Result<()> {
if self.seeds.is_some() {
return Ok(());
}
let arg_seeds = seeds.seeds_with_bump();
let address = Pubkey::create_program_address(&arg_seeds, &P::id(ctx)?)?;
let expected = self.account.account_info().pubkey();
ensure!(
address.fast_eq(expected),
ErrorCode::AddressMismatch,
"Seeds `{seeds:?}` result in address `{address}`, expected `{expected}`"
);
self.seeds = Some(seeds.clone());
Ok(())
}
}
impl<T, S, P> Seeded<T, S, P>
where
S: GetSeeds + Clone,
P: SeedProgram,
{
pub fn access_seeds(&self) -> &SeedsWithBump<S> {
self.seeds.as_ref().expect("Seeds not set!")
}
}
impl<T, S> SignedAccount for Seeded<T, S, CurrentProgram>
where
T: SingleAccountSet,
S: GetSeeds + Clone,
{
fn signer_seeds(&self) -> Option<Vec<&[u8]>> {
Some(self.access_seeds().seeds_with_bump())
}
}
impl<T, S, P> HasSeeds for Seeded<T, S, P>
where
T: SingleAccountSet,
S: GetSeeds + Clone,
P: SeedProgram,
{
type Seeds = S;
}
impl<T, S, A> CanInitAccount<A> for Seeded<T, S, CurrentProgram>
where
T: CanInitAccount<A>,
S: GetSeeds + Clone,
{
fn init_account<const IF_NEEDED: bool>(
&mut self,
arg: A,
account_seeds: Option<&[&[u8]]>,
ctx: &Context,
) -> Result<bool> {
if account_seeds.is_some() {
bail!(
ErrorCode::ConflictingAccountSeeds,
"Conflicting account seeds during init."
);
}
let seeds = self
.seeds
.as_ref()
.map(|s| s.seeds_with_bump())
.ok_or_else(|| {
error!(
ErrorCode::SeedsNotSet,
"Seeds not set for `Seeded` during init."
)
})?;
self.account
.init_account::<IF_NEEDED>(arg, Some(&seeds), ctx)
}
}
#[cfg(all(feature = "idl", not(target_os = "solana")))]
mod idl_impl {
use crate::idl::FindIdlSeeds;
use super::*;
use star_frame_idl::{account_set::IdlAccountSetDef, seeds::IdlFindSeeds, IdlDefinition};
impl<T, A, S, P, F> AccountSetToIdl<(Seeds<F>, A)> for Seeded<T, S, P>
where
T: AccountSetToIdl<A> + SingleAccountSet,
S: GetSeeds + Clone,
P: SeedProgram,
F: FindIdlSeeds,
{
fn account_set_to_idl(
idl_definition: &mut IdlDefinition,
arg: (Seeds<F>, A),
) -> crate::IdlResult<IdlAccountSetDef> {
let mut set = T::account_set_to_idl(idl_definition, arg.1)?;
let single = set.single()?;
if single.seeds.is_some() {
return Err(star_frame_idl::Error::Custom(format!(
"Seeds already set for `Seeded`. Got: {single:?}"
)));
}
if single.is_init {
return Err(star_frame_idl::Error::Custom(format!(
"`Seeded` should not wrap an init account. Wrap `Seeded` with `Init` instead. Got: {single:?}"
)));
}
let seeds = IdlFindSeeds {
seeds: F::find_seeds(&arg.0 .0)?,
program: P::idl_program(),
};
single.seeds = Some(seeds);
Ok(set)
}
}
impl<T, S, P, F> AccountSetToIdl<Seeds<F>> for Seeded<T, S, P>
where
T: AccountSetToIdl<()> + SingleAccountSet,
S: GetSeeds + Clone,
P: SeedProgram,
F: FindIdlSeeds,
{
fn account_set_to_idl(
idl_definition: &mut IdlDefinition,
arg: Seeds<F>,
) -> crate::IdlResult<IdlAccountSetDef> {
Self::account_set_to_idl(idl_definition, (arg, ()))
}
}
impl<T, S, P> AccountSetToIdl<()> for Seeded<T, S, P>
where
T: AccountSetToIdl<()> + SingleAccountSet,
S: GetSeeds + Clone,
P: SeedProgram,
{
fn account_set_to_idl(
idl_definition: &mut IdlDefinition,
arg: (),
) -> crate::IdlResult<IdlAccountSetDef> {
T::account_set_to_idl(idl_definition, arg)?.assert_single()
}
}
}
fn _unnamed_seed_structs_fail() {}
#[cfg(test)]
mod tests {
use crate::prelude::*;
use solana_pubkey::Pubkey;
#[derive(Debug, GetSeeds, Clone)]
pub struct UnitSeeds {}
#[test]
fn test_unit_struct() {
let unit_seeds = UnitSeeds {};
let seeds = <UnitSeeds as crate::prelude::GetSeeds>::seeds(&unit_seeds);
assert_eq!(seeds, &[&[] as &[u8]]);
}
#[derive(Debug, GetSeeds, Clone)]
pub struct SingleKey {
key: Pubkey,
}
#[test]
fn test_single_key() {
let single_key = SingleKey {
key: Pubkey::new_unique(),
};
let intended_seeds = vec![single_key.key.seed(), &[]];
let seeds = single_key.seeds();
assert_eq!(seeds, intended_seeds);
}
#[derive(Debug, GetSeeds, Clone)]
pub struct TwoKeys {
key1: Pubkey,
key2: Pubkey,
}
#[test]
fn test_two_keys() {
let two_keys = TwoKeys {
key1: Pubkey::new_unique(),
key2: Pubkey::new_unique(),
};
let intended_seeds = vec![two_keys.key1.seed(), two_keys.key2.seed(), &[]];
let seeds = two_keys.seeds();
assert_eq!(seeds, intended_seeds);
}
#[derive(Debug, GetSeeds, Clone)]
pub struct KeyAndNumber {
key: Pubkey,
number: u64,
}
#[test]
fn test_key_and_number() {
let key_and_number = KeyAndNumber {
key: Pubkey::new_unique(),
number: 42,
};
let intended_seeds = vec![key_and_number.key.seed(), key_and_number.number.seed(), &[]];
let seeds = key_and_number.seeds();
assert_eq!(seeds, intended_seeds);
}
#[derive(Debug, GetSeeds, Clone)]
#[get_seeds(seed_const = b"TEST_CONST")]
pub struct OnlyConstSeed {}
#[test]
fn test_unit_with_const_seed() {
let only_const_seed = OnlyConstSeed {};
let seeds = only_const_seed.seeds();
let intended_seeds = vec![b"TEST_CONST".as_ref(), &[]];
assert_eq!(seeds, intended_seeds);
}
#[derive(Debug, GetSeeds, Clone)]
#[get_seeds(seed_const = b"TEST_CONST")]
pub struct OneKeyConstSeed {
key: Pubkey,
}
#[test]
fn test_one_key_with_const_seed() {
let account = OneKeyConstSeed {
key: Pubkey::new_unique(),
};
let intended_seeds = vec![b"TEST_CONST".as_ref(), account.key.seed(), &[]];
let seeds = account.seeds();
assert_eq!(seeds, intended_seeds);
}
pub struct Cool {}
impl Cool {
const DISC: &'static [u8] = b"TEST_CONST";
}
#[derive(Debug, GetSeeds, Clone)]
#[get_seeds(seed_const = Cool::DISC)]
pub struct SeedPath {}
#[test]
fn test_path_seed() {
let account = SeedPath {};
let seeds = account.seeds();
let intended_seeds = vec![b"TEST_CONST".as_ref(), &[]];
assert_eq!(seeds, intended_seeds);
}
}