use {
super::{ConstraintGenerator, GeneratorResult},
crate::{
accounts::Account,
constraints::{
ConstraintAssociatedToken, ConstraintInit, ConstraintInitIfNeeded, ConstraintMint,
ConstraintPayer, ConstraintSeeded, ConstraintSeeds, ConstraintSpace, ConstraintToken,
},
visitor::ContextVisitor,
},
proc_macro2::{Span, TokenStream},
quote::{format_ident, quote},
syn::{parse_quote, punctuated::Punctuated, Expr, Ident, PathSegment, Token},
};
enum InitAccountGeneratorTy {
TokenAccount {
is_ata: bool,
mint: Option<Ident>,
authority: Option<Expr>,
},
Mint {
decimals: Option<Expr>,
authority: Option<Expr>,
freeze_authority: Box<Option<Expr>>,
},
Other {
space: Option<Expr>,
},
}
struct InitAccountGenerator<'a> {
name: &'a Ident,
inner_account_ty: String,
payer: Option<Ident>,
is_seeded: bool,
keys: Option<Punctuated<Expr, Token![,]>>,
init_if_needed: bool,
ty: InitAccountGeneratorTy,
}
impl<'a> InitAccountGenerator<'a> {
pub fn new(account: &'a Account) -> Self {
let ty = match account.inner_ty.as_str() {
"Mint" => InitAccountGeneratorTy::Mint {
authority: None,
decimals: None,
freeze_authority: Box::new(None),
},
"TokenAccount" => InitAccountGeneratorTy::TokenAccount {
is_ata: false,
mint: None,
authority: None,
},
_ => InitAccountGeneratorTy::Other { space: None },
};
InitAccountGenerator {
name: &account.name,
inner_account_ty: account.inner_ty.clone(),
payer: None,
is_seeded: false,
keys: None,
init_if_needed: false,
ty,
}
}
fn get_seeds(&self) -> Option<TokenStream> {
let punctuated_keys = self.keys.as_ref()?;
let seeds = if self.is_seeded {
let account_ty = format_ident!("{}", self.inner_account_ty);
quote!(#account_ty::derive_with_bump(#punctuated_keys, &bump))
} else {
quote!(seeds!(#punctuated_keys, &bump))
};
Some(seeds)
}
fn generate_token(&self, account_ty: &PathSegment) -> Result<TokenStream, syn::Error> {
let name = self.name;
let Some(ref payer) = self.payer else {
return Err(syn::Error::new(
self.name.span(),
"A payer is needed for the `init` constraint",
));
};
let maybe_signer = {
self.get_seeds().map(|seeds| {
quote! {
let bump = [bumps.#name];
let seeds = #seeds;
let signer = instruction::CpiSigner::from(&seeds);
}
})
};
let signers = if maybe_signer.is_some() {
quote!(Some(&[signer]))
} else {
quote!(None)
};
let token = match &self.ty {
InitAccountGeneratorTy::Mint {
decimals,
authority,
freeze_authority,
} => {
let default_decimals = parse_quote!(9);
let decimals = decimals.as_ref().unwrap_or(&default_decimals);
let Some(ref authority) = authority else {
return Err(syn::Error::new(
self.name.span(),
"A `authority` need to be specified for the `init` constraint",
));
};
let f_auth_token = if let Some(auth) = freeze_authority.as_ref() {
quote!(Some(#auth))
} else {
quote!(None)
};
quote!(SPLCreate::create_mint(#name, &rent, &#payer, &#authority, #decimals, #f_auth_token, #signers)?)
}
InitAccountGeneratorTy::TokenAccount {
authority,
mint,
is_ata,
} => {
let Some(ref authority) = authority else {
return Err(syn::Error::new(
self.name.span(),
"A `authority` need to be specified for the `init` constraint",
));
};
let Some(ref mint) = mint else {
return Err(syn::Error::new(
self.name.span(),
"A `mint` need to be specified for the `init` constraint",
));
};
if *is_ata {
quote!(SPLCreate::create_associated_token_account(#name, &#payer, &#mint, &#authority, &system_program, &token_program)?)
} else {
quote!(SPLCreate::create_token_account(#name, &rent, &#payer, &#mint, &#authority, #signers)?)
}
}
InitAccountGeneratorTy::Other { space } => {
let Some(ref space) = space else {
return Err(syn::Error::new(
self.name.span(),
"A space need to be specified for the `init` constraint",
));
};
quote!(SystemCpi::create_account(#name, &rent, &#payer, &crate::ID, #space, #signers)?)
}
};
if self.init_if_needed {
Ok(quote! {
let #name = if <Mut<UncheckedAccount> as ChecksExt>::is_initialized(&#name) {
<#account_ty as FromAccountInfo>::try_from_info(#name.into())?
}else {
#maybe_signer
#token
};
})
} else {
Ok(quote! {
let #name: #account_ty = {
#maybe_signer
#token
};
})
}
}
}
impl ContextVisitor for InitAccountGenerator<'_> {
fn visit_payer(&mut self, contraint: &ConstraintPayer) -> Result<(), syn::Error> {
self.payer = Some(contraint.target.clone());
Ok(())
}
fn visit_space(&mut self, contraint: &ConstraintSpace) -> Result<(), syn::Error> {
match &mut self.ty {
InitAccountGeneratorTy::Other { space } => *space = Some(contraint.space.clone()),
_ => {
return Err(syn::Error::new(
self.name.span(),
"Cannot use `space` constraint with `Mint` or `TokenAccount`.",
))
}
}
Ok(())
}
fn visit_seeded(&mut self, constraint: &ConstraintSeeded) -> Result<(), syn::Error> {
self.keys = constraint.0.clone();
self.is_seeded = true;
Ok(())
}
fn visit_seeds(&mut self, contraint: &ConstraintSeeds) -> Result<(), syn::Error> {
self.keys = Some(contraint.seeds.clone());
Ok(())
}
fn visit_token(&mut self, constraint: &ConstraintToken) -> Result<(), syn::Error> {
match &mut self.ty {
InitAccountGeneratorTy::TokenAccount {
mint, authority, ..
} => match constraint {
ConstraintToken::Mint(ident) => *mint = Some(ident.clone()),
ConstraintToken::Authority(expr) => *authority = Some(expr.clone()),
},
_ => {
return Err(syn::Error::new(
self.name.span(),
"Cannot use `token` constraint on non `TokenAccount` account.",
))
}
}
Ok(())
}
fn visit_associated_token(
&mut self,
constraint: &ConstraintAssociatedToken,
) -> Result<(), syn::Error> {
match &mut self.ty {
InitAccountGeneratorTy::TokenAccount {
mint,
authority,
is_ata,
} => {
match constraint {
ConstraintAssociatedToken::Mint(ident) => *mint = Some(ident.clone()),
ConstraintAssociatedToken::Authority(ident) => {
*authority = Some(parse_quote!(#ident))
}
}
*is_ata = true
}
_ => {
return Err(syn::Error::new(
self.name.span(),
"Cannot use `associated_token` constraint on non `TokenAccount` account.",
))
}
}
Ok(())
}
fn visit_mint(&mut self, constraint: &ConstraintMint) -> Result<(), syn::Error> {
match &mut self.ty {
InitAccountGeneratorTy::Mint {
authority,
decimals,
freeze_authority,
} => match constraint {
ConstraintMint::Authority(expr) => *authority = Some(expr.clone()),
ConstraintMint::Decimals(expr) => *decimals = Some(expr.clone()),
ConstraintMint::FreezeAuthority(expr) => {
*freeze_authority.as_mut() = Some(expr.clone())
}
},
_ => {
return Err(syn::Error::new(
self.name.span(),
"Cannot use `mint` constraint on non `Mint` account.",
))
}
}
Ok(())
}
fn visit_init_if_needed(
&mut self,
_constraint: &ConstraintInitIfNeeded,
) -> Result<(), syn::Error> {
self.init_if_needed = true;
Ok(())
}
}
#[derive(Default)]
pub struct InitializationGenerator {
need_check_system: bool,
need_check_token: bool,
need_check_ata: bool,
has_init: bool,
result: GeneratorResult,
}
impl InitializationGenerator {
pub fn new() -> Self {
Self::default()
}
fn check_prerequisite(&self, context: &[Account], program: &str) -> Result<(), syn::Error> {
let has_system_program = context
.iter()
.any(|acc| acc.ty.ident == "Program" && acc.inner_ty == program);
if !has_system_program {
return Err(syn::Error::new(
Span::call_site(),
format!(
"Using `init` requires including the `Program<{}>` account",
program
),
));
}
Ok(())
}
}
impl ConstraintGenerator for InitializationGenerator {
fn generate(&self) -> Result<GeneratorResult, syn::Error> {
Ok(self.result.clone())
}
}
impl ContextVisitor for InitializationGenerator {
fn visit_init(&mut self, _constraint: &ConstraintInit) -> Result<(), syn::Error> {
self.has_init = true;
self.need_check_system = true;
Ok(())
}
fn visit_init_if_needed(
&mut self,
_constraint: &ConstraintInitIfNeeded,
) -> Result<(), syn::Error> {
self.has_init = true;
self.need_check_system = true;
Ok(())
}
fn visit_associated_token(
&mut self,
_constraint: &ConstraintAssociatedToken,
) -> Result<(), syn::Error> {
self.need_check_ata = true;
Ok(())
}
fn visit_accounts(&mut self, accounts: &Vec<Account>) -> Result<(), syn::Error> {
for account in accounts {
self.visit_account(account)?;
}
if self.need_check_system {
self.check_prerequisite(accounts, "System")?;
if self.need_check_token {
self.check_prerequisite(accounts, "TokenProgram")?;
}
if self.need_check_ata {
self.check_prerequisite(accounts, "AtaTokenProgram")?;
}
}
Ok(())
}
fn visit_account(&mut self, account: &Account) -> Result<(), syn::Error> {
self.visit_constraints(&account.constraints)?;
if self.has_init {
if account.inner_ty == "Mint" || account.inner_ty == "TokenAccount" {
self.need_check_token = true;
}
let mut account_generator = InitAccountGenerator::new(account);
account_generator.visit_account(account)?;
self.result
.at_init
.extend(account_generator.generate_token(&account.ty)?);
self.has_init = false;
}
Ok(())
}
}