use {
super::{ConstraintGenerator, GeneratorResult},
crate::{
accounts::Account,
constraints::{ConstraintBump, ConstraintSeeded, ConstraintSeeds},
context::Context,
extractor::InnerTyExtractor,
visitor::ContextVisitor,
},
quote::{format_ident, quote},
syn::{parse_quote, punctuated::Punctuated, visit::Visit, Expr, Ident, PathSegment, Token},
};
#[derive(Default)]
pub struct BumpsGenerator {
context_name: Option<String>,
account: Option<(Ident, PathSegment)>,
bump: Option<Expr>,
is_seeded: bool,
seeds: Option<Punctuated<Expr, Token![,]>>,
result: GeneratorResult,
struct_fields: Vec<Ident>,
}
impl BumpsGenerator {
pub fn new() -> Self {
BumpsGenerator::default()
}
pub fn is_pda(&self) -> bool {
self.is_seeded || self.seeds.is_some()
}
fn extend_checks(&mut self) -> Result<(), syn::Error> {
let (name, ty) = self.account.as_ref().unwrap();
let pda_key = format_ident!("{}_key", name);
let pda_bump = format_ident!("{}_bump", name);
if let Some(bump) = &self.bump {
let (seeds_token, bump_token) = if self.is_seeded {
(
quote!(#name.data()?.seeds_with_bump(&[#pda_bump])),
quote!(let #pda_bump = { #bump };),
)
} else {
let seeds = self.seeds.as_ref().ok_or(syn::Error::new(
name.span(),
"Seeds constraint is not specified.",
))?;
(
quote!([#seeds, &[#pda_bump]]),
quote!(let #pda_bump = { #bump };),
)
};
self.result.after_init.extend(quote! {
#bump_token
let #pda_key = create_program_address(&#seeds_token, &crate::ID)?;
if #name.key() != &#pda_key {
return Err(ProgramError::InvalidSeeds);
}
});
} else {
let keys = self.seeds.as_ref().ok_or(syn::Error::new(
name.span(),
"Seeds constraint is not specified.",
))?;
let seeds = if self.is_seeded {
let mut inner_ty_extractor = InnerTyExtractor::new();
inner_ty_extractor.visit_path_segment(ty);
let inner_ty_str = inner_ty_extractor
.ty
.ok_or(syn::Error::new(name.span(), "Cannot find the inner type."))?;
let inner_ty = format_ident!("{inner_ty_str}");
quote!(#inner_ty::derive(#keys))
} else {
quote!([#keys])
};
self.result.at_init.extend(quote! {
let (#pda_key, #pda_bump) = find_program_address(&#seeds, &crate::ID);
if #name.key() != &#pda_key {
return Err(ProgramError::InvalidSeeds);
}
});
}
Ok(())
}
}
impl ConstraintGenerator for BumpsGenerator {
fn generate(&self) -> Result<GeneratorResult, syn::Error> {
let mut result = self.result.clone();
if !self.struct_fields.is_empty() {
let struct_name = format_ident!("{}Bumps", self.context_name.as_ref().unwrap());
let struct_fields = &self.struct_fields;
let bumps_struct = quote! {
#[derive(Debug, PartialEq)]
pub struct #struct_name {
#(pub #struct_fields: u8,)*
}
};
result.global_outside = bumps_struct;
let assign_fields = self.struct_fields.iter().map(|n| {
let bump_ident = format_ident!("{}_bump", n);
quote!(#n: #bump_ident)
});
result.at_init.extend(quote! {
let bumps = #struct_name {
#(#assign_fields),*
};
});
result.new_fields.push(parse_quote! {
pub bumps: #struct_name
});
}
Ok(result)
}
}
impl ContextVisitor for BumpsGenerator {
fn visit_context(&mut self, context: &Context) -> Result<(), syn::Error> {
self.context_name = Some(context.item_struct.ident.to_string());
self.visit_accounts(&context.accounts)?;
if let Some(args) = &context.args {
self.visit_arguments(args)?;
}
Ok(())
}
fn visit_account(&mut self, account: &Account) -> Result<(), syn::Error> {
self.account = Some((account.name.clone(), account.ty.clone()));
self.bump = None;
self.is_seeded = false;
self.seeds = None;
self.visit_constraints(&account.constraints)?;
if self.is_pda() {
self.extend_checks()?;
}
Ok(())
}
fn visit_bump(&mut self, constraint: &ConstraintBump) -> Result<(), syn::Error> {
self.bump = constraint.0.clone();
if self.bump.is_none() {
self.struct_fields
.push(self.account.as_ref().unwrap().0.clone());
}
Ok(())
}
fn visit_seeded(&mut self, constraint: &ConstraintSeeded) -> Result<(), syn::Error> {
if !self.is_seeded && self.seeds.is_some() {
return Err(syn::Error::new(
self.account.as_ref().unwrap().0.span(),
"Cannot specified keys and seeds at the same time.",
));
}
self.is_seeded = true;
self.seeds = constraint.0.clone();
Ok(())
}
fn visit_seeds(&mut self, constraint: &ConstraintSeeds) -> Result<(), syn::Error> {
if self.is_seeded {
return Err(syn::Error::new(
self.account.as_ref().unwrap().0.span(),
"Cannot specified keys and seeds at the same time.",
));
}
self.seeds = Some(constraint.seeds.clone());
Ok(())
}
}