use syn::{DeriveInput, Error, Expr, Ident, ItemStruct, Type};
use super::{
infra::{InfraFieldClassifier, InfraFields},
instruction_arg::{args_to_set, parse_instruction_attr, InstructionArg, InstructionArgSet},
};
use crate::light_pdas::seeds::ClassifiedSeed;
type ParsedAtaField = crate::light_pdas::accounts::light_account::AtaField;
type ParsedTokenField = crate::light_pdas::accounts::light_account::TokenAccountField;
type ParsedMintField = crate::light_pdas::accounts::mint::LightMintField;
#[derive(Debug)]
pub struct ParsedAccountsStruct {
pub struct_name: Ident,
pub generics: syn::Generics,
pub pda_fields: Vec<ParsedPdaField>,
pub mint_fields: Vec<ParsedMintField>,
pub token_fields: Vec<ParsedTokenField>,
pub ata_fields: Vec<ParsedAtaField>,
pub instruction_args: Option<Vec<InstructionArg>>,
pub infra_fields: InfraFields,
pub direct_proof_arg: Option<Ident>,
}
#[derive(Debug, Clone)]
pub struct ParsedPdaField {
pub field_name: Ident,
pub is_boxed: bool,
pub is_zero_copy: bool,
pub address_tree_info: Option<Expr>,
pub output_tree: Option<Expr>,
}
fn parse_accounts_struct_impl(
input: &ItemStruct,
direct_proof_arg: Option<Ident>,
) -> Result<ParsedAccountsStruct, Error> {
let struct_name = input.ident.clone();
let generics = input.generics.clone();
let instruction_args = parse_instruction_attr(&input.attrs)?;
let instruction_arg_set = match &instruction_args {
Some(args) => args_to_set(args),
None => InstructionArgSet::empty(),
};
let fields = match &input.fields {
syn::Fields::Named(fields) => &fields.named,
_ => {
return Err(Error::new_spanned(
input,
"expected struct with named fields",
));
}
};
let mut pda_fields = Vec::new();
let mut mint_fields = Vec::new();
let mut token_fields = Vec::new();
let mut ata_fields = Vec::new();
let mut infra_fields = InfraFields::default();
for field in fields {
let field_ident = field
.ident
.clone()
.ok_or_else(|| Error::new_spanned(field, "expected named field with identifier"))?;
let field_name = field_ident.to_string();
if let Some(field_type) = InfraFieldClassifier::classify(&field_name) {
infra_fields.set(field_type, field_ident.clone())?;
}
if let Some(light_account_field) =
crate::light_pdas::accounts::light_account::parse_light_account_attr(
field,
&field_ident,
&direct_proof_arg,
)?
{
use crate::light_pdas::accounts::light_account::LightAccountField;
match light_account_field {
LightAccountField::Pda(pda) => {
let _seeds: Vec<ClassifiedSeed> =
crate::light_pdas::seeds::anchor_extraction::extract_anchor_seeds(
&field.attrs,
&instruction_arg_set,
)?;
pda_fields.push(ParsedPdaField {
field_name: field_ident,
is_boxed: pda.is_boxed,
is_zero_copy: pda.is_zero_copy,
address_tree_info: Some(pda.address_tree_info),
output_tree: Some(pda.output_tree),
});
}
LightAccountField::Mint(mint) => {
mint_fields.push(*mint);
}
LightAccountField::TokenAccount(token) => {
token_fields.push(*token);
}
LightAccountField::AssociatedToken(ata) => {
ata_fields.push(*ata);
}
}
}
}
let has_light_account_fields = !pda_fields.is_empty()
|| !mint_fields.is_empty()
|| !token_fields.is_empty()
|| !ata_fields.is_empty();
if has_light_account_fields && instruction_args.is_none() {
return Err(Error::new_spanned(
input,
"#[light_account] fields require #[instruction(params: YourParamsType)] \
attribute on the struct",
));
}
Ok(ParsedAccountsStruct {
struct_name,
generics,
pda_fields,
mint_fields,
token_fields,
ata_fields,
instruction_args,
infra_fields,
direct_proof_arg,
})
}
pub fn parse_derive_input(input: &DeriveInput) -> Result<ParsedAccountsStruct, Error> {
let instruction_args = parse_instruction_attr(&input.attrs)?;
let direct_proof_arg = find_direct_proof_arg(&instruction_args)?;
match &input.data {
syn::Data::Struct(data) => {
let item_struct = ItemStruct {
attrs: input.attrs.clone(),
vis: input.vis.clone(),
struct_token: data.struct_token,
ident: input.ident.clone(),
generics: input.generics.clone(),
fields: data.fields.clone(),
semi_token: data.semi_token,
};
parse_accounts_struct_impl(&item_struct, direct_proof_arg)
}
_ => Err(Error::new_spanned(input, "expected struct")),
}
}
fn is_create_accounts_proof_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "CreateAccountsProof";
}
}
false
}
fn find_direct_proof_arg(
instruction_args: &Option<Vec<InstructionArg>>,
) -> Result<Option<Ident>, Error> {
let Some(args) = instruction_args.as_ref() else {
return Ok(None);
};
let proof_args: Vec<_> = args
.iter()
.filter(|arg| is_create_accounts_proof_type(&arg.ty))
.collect();
match proof_args.len() {
0 => Ok(None),
1 => Ok(Some(proof_args[0].name.clone())),
_ => {
let names: Vec<_> = proof_args.iter().map(|a| a.name.to_string()).collect();
Err(Error::new_spanned(
&proof_args[1].name,
format!(
"Multiple CreateAccountsProof arguments found: [{}]. \
Only one CreateAccountsProof argument is allowed per instruction.",
names.join(", ")
),
))
}
}
}
#[cfg(test)]
mod tests {
use syn::parse_quote;
use super::*;
#[test]
fn test_parse_empty_struct() {
let input: DeriveInput = parse_quote! {
#[derive(Accounts)]
pub struct Empty<'info> {
pub fee_payer: Signer<'info>,
}
};
let result = parse_derive_input(&input);
assert!(result.is_ok());
let parsed = result.unwrap();
assert!(parsed.pda_fields.is_empty());
}
#[test]
fn test_parse_with_pda_field() {
let input: DeriveInput = parse_quote! {
#[derive(Accounts)]
#[instruction(params: CreateParams)]
pub struct Create<'info> {
#[account(mut)]
pub fee_payer: Signer<'info>,
#[account(init, payer = fee_payer, space = 100, seeds = [b"user"], bump)]
#[light_account(init)]
pub user_record: Account<'info, UserRecord>,
}
};
let result = parse_derive_input(&input);
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.pda_fields.len(), 1);
assert_eq!(parsed.pda_fields[0].field_name.to_string(), "user_record");
}
#[test]
fn test_parse_infra_fields() {
let input: DeriveInput = parse_quote! {
#[derive(Accounts)]
pub struct Test<'info> {
#[account(mut)]
pub fee_payer: Signer<'info>,
pub compression_config: AccountInfo<'info>,
}
};
let result = parse_derive_input(&input);
assert!(result.is_ok());
let parsed = result.unwrap();
assert!(parsed.infra_fields.fee_payer.is_some());
assert!(parsed.infra_fields.compression_config.is_some());
}
#[test]
fn test_light_account_without_instruction_fails() {
let input: DeriveInput = parse_quote! {
#[derive(Accounts)]
pub struct NoInstruction<'info> {
#[account(init, payer = fee_payer, space = 100, seeds = [b"user"], bump)]
#[light_account(init)]
pub user_record: Account<'info, UserRecord>,
}
};
let result = parse_derive_input(&input);
assert!(result.is_err());
let err = result.err().unwrap().to_string();
assert!(err.contains("#[instruction"));
}
}