use syn::{Ident, ItemStruct, Type};
use super::{
anchor_extraction::extract_anchor_seeds,
classification::classify_seed_expr,
instruction_args::InstructionArgSet,
types::{ClassifiedSeed, ExtractedAccountsInfo, ExtractedSeedSpec, ExtractedTokenSpec},
};
use crate::{
light_pdas::{
account::validation::{type_name, AccountTypeError},
light_account_keywords::{
is_standalone_keyword, unknown_key_error, valid_keys_for_namespace,
},
},
utils::snake_to_camel_case,
};
pub fn extract_account_inner_type(ty: &Type) -> Result<(bool, Type), AccountTypeError> {
match ty {
Type::Path(type_path) => {
let segment = type_path
.path
.segments
.last()
.ok_or_else(|| AccountTypeError::WrongType { got: type_name(ty) })?;
let ident_str = segment.ident.to_string();
match ident_str.as_str() {
"Account" | "AccountLoader" | "InterfaceAccount" => {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(inner_ty) = arg {
if let Type::Path(inner_path) = inner_ty {
if let Some(inner_seg) = inner_path.path.segments.last() {
if inner_seg.ident != "info" {
return Ok((false, inner_ty.clone()));
}
}
}
}
}
}
Err(AccountTypeError::ExtractionFailed)
}
"Box" => {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
if let Type::Path(inner_path) = inner_ty {
if let Some(inner_seg) = inner_path.path.segments.last() {
if inner_seg.ident == "Box" {
return Err(AccountTypeError::NestedBox);
}
}
}
return match extract_account_inner_type(inner_ty) {
Ok((_, inner_type)) => Ok((true, inner_type)),
Err(e) => Err(e),
};
}
}
Err(AccountTypeError::ExtractionFailed)
}
_ => Err(AccountTypeError::WrongType { got: type_name(ty) }),
}
}
_ => Err(AccountTypeError::WrongType { got: type_name(ty) }),
}
}
fn check_light_account_type(attrs: &[syn::Attribute]) -> (bool, bool, bool, bool) {
for attr in attrs {
if attr.path().is_ident("light_account") {
let tokens = match &attr.meta {
syn::Meta::List(list) => list.tokens.clone(),
_ => continue,
};
let token_vec: Vec<_> = tokens.clone().into_iter().collect();
let has_namespace_prefix = |namespace: &str| {
token_vec.windows(2).any(|window| {
matches!(
(&window[0], &window[1]),
(
proc_macro2::TokenTree::Ident(ident),
proc_macro2::TokenTree::Punct(punct)
) if ident == namespace && punct.as_char() == ':'
)
})
};
let has_mint_namespace = has_namespace_prefix("mint");
let has_token_namespace = has_namespace_prefix("token");
let has_ata_namespace = has_namespace_prefix("associated_token");
let has_init = token_vec
.iter()
.any(|t| matches!(t, proc_macro2::TokenTree::Ident(ident) if ident == "init"));
let has_zero_copy = token_vec
.iter()
.any(|t| matches!(t, proc_macro2::TokenTree::Ident(ident) if ident == "zero_copy"));
if has_init {
if has_mint_namespace {
return (false, true, false, false);
}
if has_ata_namespace {
return (false, false, true, false);
}
if has_token_namespace {
return (false, false, false, false);
}
return (true, false, false, has_zero_copy);
}
}
}
(false, false, false, false)
}
struct LightTokenAttr {
variant_name: Option<Ident>,
owner_seeds: Option<Vec<ClassifiedSeed>>,
}
fn extract_light_token_attr(
attrs: &[syn::Attribute],
instruction_args: &InstructionArgSet,
) -> syn::Result<Option<LightTokenAttr>> {
for attr in attrs {
if attr.path().is_ident("light_account") {
let tokens = match &attr.meta {
syn::Meta::List(list) => list.tokens.clone(),
_ => continue,
};
let token_vec: Vec<_> = tokens.clone().into_iter().collect();
let has_token_namespace = token_vec.windows(2).any(|window| {
matches!(
(&window[0], &window[1]),
(
proc_macro2::TokenTree::Ident(ident),
proc_macro2::TokenTree::Punct(punct)
) if ident == "token" && punct.as_char() == ':'
)
});
if has_token_namespace {
let parsed = parse_light_token_list(&tokens, instruction_args, "token")?;
return Ok(Some(parsed));
}
}
}
Ok(None)
}
fn parse_light_token_list(
tokens: &proc_macro2::TokenStream,
instruction_args: &InstructionArgSet,
account_type: &str,
) -> syn::Result<LightTokenAttr> {
use syn::parse::Parser;
let instruction_args = instruction_args.clone();
let account_type_owned = account_type.to_string();
let valid_keys = valid_keys_for_namespace(account_type);
let parser = move |input: syn::parse::ParseStream| -> syn::Result<LightTokenAttr> {
let mut owner_seeds = None;
while !input.is_empty() {
if input.peek(Ident) {
let ident: Ident = input.parse()?;
let ident_str = ident.to_string();
if input.peek(syn::Token![:]) {
input.parse::<syn::Token![:]>()?;
if input.peek(syn::Token![:]) {
input.parse::<syn::Token![:]>()?;
}
let key: Ident = input.parse()?;
let key_str = key.to_string();
if ident_str != account_type_owned {
if input.peek(syn::Token![=]) {
input.parse::<syn::Token![=]>()?;
let _expr: syn::Expr = input.parse()?;
}
} else {
if !valid_keys.contains(&key_str.as_str()) {
return Err(syn::Error::new_spanned(
&key,
unknown_key_error(&account_type_owned, &key_str),
));
}
if input.peek(syn::Token![=]) {
input.parse::<syn::Token![=]>()?;
if key_str == "owner_seeds" {
let array_content = input.step(|cursor| {
if let Some((group, _span, rest)) =
cursor.group(proc_macro2::Delimiter::Bracket)
{
Ok((group.token_stream(), rest))
} else {
Err(cursor.error("expected bracketed array"))
}
})?;
let elems: syn::punctuated::Punctuated<syn::Expr, syn::Token![,]> =
syn::parse::Parser::parse2(
syn::punctuated::Punctuated::parse_terminated,
array_content,
)?;
let mut seeds = Vec::new();
for elem in &elems {
let seed = classify_seed_expr(elem, &instruction_args)
.map_err(|e| {
syn::Error::new_spanned(
elem,
format!("invalid owner seed: {}", e),
)
})?;
seeds.push(seed);
}
owner_seeds = Some(seeds);
} else {
let _expr: syn::Expr = input.parse()?;
}
}
}
} else if is_standalone_keyword(&ident_str) {
} else {
return Err(syn::Error::new_spanned(
&ident,
format!(
"Unknown keyword `{}` in #[light_account(...)]. \
Use namespaced syntax: `{}::owner_seeds = [...]`, `{}::mint`, etc.",
ident_str, account_type_owned, account_type_owned
),
));
}
} else {
let valid_kw_str = valid_keys.join(", ");
return Err(syn::Error::new(
input.span(),
format!(
"Expected keyword in #[light_account(...)]. \
Valid namespaced keys: {}::{{{}}}, or standalone: init",
account_type_owned, valid_kw_str
),
));
}
if input.peek(syn::Token![,]) {
input.parse::<syn::Token![,]>()?;
}
}
if let Some(ref seeds) = owner_seeds {
validate_owner_seeds_are_constants(seeds)?;
}
Ok(LightTokenAttr {
variant_name: None, owner_seeds,
})
};
parser.parse2(tokens.clone())
}
fn validate_owner_seeds_are_constants(seeds: &[ClassifiedSeed]) -> syn::Result<()> {
for seed in seeds {
match seed {
ClassifiedSeed::Literal(_) | ClassifiedSeed::Constant { .. } => {
continue;
}
ClassifiedSeed::CtxRooted { account } => {
return Err(syn::Error::new(
account.span(),
"owner_seeds must be constants only. \
Dynamic ctx account references like `authority.key()` are not allowed. \
Use only byte literals (b\"seed\") or const references (SEED.as_bytes()).",
));
}
ClassifiedSeed::DataRooted { root, .. } => {
return Err(syn::Error::new(
root.span(),
"owner_seeds must be constants only. \
Instruction data references like `params.owner` are not allowed. \
Use only byte literals (b\"seed\") or const references (SEED.as_bytes()).",
));
}
ClassifiedSeed::FunctionCall { func_expr, .. } => {
return Err(syn::Error::new_spanned(
func_expr,
"owner_seeds must be constants only. \
Dynamic function calls are not allowed. \
Use only byte literals (b\"seed\") or const references (SEED.as_bytes()).",
));
}
ClassifiedSeed::Passthrough(expr) => {
return Err(syn::Error::new_spanned(
expr,
"owner_seeds must be constants only. \
This expression type is not recognized as a constant. \
Use only byte literals (b\"seed\") or const references (SEED.as_bytes()).",
));
}
}
}
Ok(())
}
pub fn extract_from_accounts_struct(
item: &ItemStruct,
instruction_args: &InstructionArgSet,
module_path: &str,
) -> syn::Result<Option<ExtractedAccountsInfo>> {
let fields = match &item.fields {
syn::Fields::Named(named) => &named.named,
_ => return Ok(None),
};
let mut pda_fields = Vec::new();
let mut token_fields = Vec::new();
let mut has_light_mint_fields = false;
let mut has_light_ata_fields = false;
for field in fields {
let field_ident = match &field.ident {
Some(id) => id.clone(),
None => continue,
};
let (has_light_account_pda, has_light_account_mint, has_light_account_ata, has_zero_copy) =
check_light_account_type(&field.attrs);
if has_light_account_mint {
has_light_mint_fields = true;
}
if has_light_account_ata {
has_light_ata_fields = true;
}
let token_attr = extract_light_token_attr(&field.attrs, instruction_args)?;
if has_light_account_pda {
let (_, inner_type) =
extract_account_inner_type(&field.ty).map_err(|e| e.into_syn_error(&field.ty))?;
let seeds = extract_anchor_seeds(&field.attrs, instruction_args)?;
let variant_name = {
let camel = snake_to_camel_case(&field_ident.to_string());
Ident::new(&camel, field_ident.span())
};
pda_fields.push(ExtractedSeedSpec {
variant_name,
inner_type,
seeds,
is_zero_copy: has_zero_copy,
struct_name: item.ident.to_string(),
module_path: module_path.to_string(),
});
} else if let Some(token_attr) = token_attr {
let seeds = extract_anchor_seeds(&field.attrs, instruction_args)?;
let variant_name = token_attr.variant_name.unwrap_or_else(|| {
let camel = snake_to_camel_case(&field_ident.to_string());
Ident::new(&camel, field_ident.span())
});
token_fields.push(ExtractedTokenSpec {
field_name: field_ident,
variant_name,
seeds,
owner_seeds: token_attr.owner_seeds,
module_path: module_path.to_string(),
});
}
}
if pda_fields.is_empty()
&& token_fields.is_empty()
&& !has_light_mint_fields
&& !has_light_ata_fields
{
return Ok(None);
}
for token in &token_fields {
if token.owner_seeds.is_none() {
return Err(syn::Error::new(
token.field_name.span(),
format!(
"Token account field '{}' requires owner_seeds. \
The owner must be a PDA derived from constant seeds for decompression.\n\
Add `token::owner_seeds = [b\"seed\", CONSTANT.as_bytes()]` to the #[light_account(...)] attribute.",
token.field_name,
),
));
}
}
Ok(Some(ExtractedAccountsInfo {
struct_name: item.ident.clone(),
pda_fields,
token_fields,
has_light_mint_fields,
has_light_ata_fields,
}))
}
#[cfg(test)]
mod tests {
use syn::parse_quote;
use super::{
super::instruction_args::InstructionArgSet, check_light_account_type,
extract_account_inner_type, extract_from_accounts_struct, AccountTypeError,
};
#[test]
fn test_extract_account_inner_type() {
let ty: syn::Type = parse_quote!(Account<'info, UserRecord>);
let result = extract_account_inner_type(&ty);
assert!(result.is_ok(), "Should extract Account inner type");
let (is_boxed, inner) = result.unwrap();
assert!(!is_boxed);
if let syn::Type::Path(path) = inner {
assert_eq!(
path.path.segments.last().unwrap().ident.to_string(),
"UserRecord"
);
} else {
panic!("Expected path type");
}
}
#[test]
fn test_extract_account_inner_type_boxed() {
let ty: syn::Type = parse_quote!(Box<Account<'info, UserRecord>>);
let result = extract_account_inner_type(&ty);
assert!(result.is_ok(), "Should extract Box<Account> inner type");
let (is_boxed, inner) = result.unwrap();
assert!(is_boxed);
if let syn::Type::Path(path) = inner {
assert_eq!(
path.path.segments.last().unwrap().ident.to_string(),
"UserRecord"
);
} else {
panic!("Expected path type");
}
}
#[test]
fn test_extract_account_inner_type_nested_box_fails() {
let ty: syn::Type = parse_quote!(Box<Box<Account<'info, UserRecord>>>);
let result = extract_account_inner_type(&ty);
assert!(
matches!(result, Err(AccountTypeError::NestedBox)),
"Nested Box should return NestedBox error"
);
}
#[test]
fn test_extract_account_inner_type_wrong_type_fails() {
let ty: syn::Type = parse_quote!(String);
let result = extract_account_inner_type(&ty);
assert!(
matches!(result, Err(AccountTypeError::WrongType { .. })),
"Wrong type should return WrongType error"
);
}
#[test]
fn test_check_light_account_type_mint_namespace() {
let attrs: Vec<syn::Attribute> = vec![parse_quote!(
#[light_account(init,
mint::signer = mint_signer,
mint::authority = fee_payer,
mint::decimals = 6
)]
)];
let (has_pda, has_mint, has_ata, has_zero_copy) = check_light_account_type(&attrs);
assert!(!has_pda, "Should NOT be detected as PDA");
assert!(has_mint, "Should be detected as mint");
assert!(!has_ata, "Should NOT be detected as ATA");
assert!(!has_zero_copy, "Should NOT be detected as zero_copy");
}
#[test]
fn test_check_light_account_type_pda_only() {
let attrs: Vec<syn::Attribute> = vec![parse_quote!(
#[light_account(init)]
)];
let (has_pda, has_mint, has_ata, has_zero_copy) = check_light_account_type(&attrs);
assert!(has_pda, "Should be detected as PDA");
assert!(!has_mint, "Should NOT be detected as mint");
assert!(!has_ata, "Should NOT be detected as ATA");
assert!(!has_zero_copy, "Should NOT be detected as zero_copy");
}
#[test]
fn test_extract_from_accounts_struct() {
let item: syn::ItemStruct = parse_quote!(
#[derive(Accounts)]
#[instruction(params: CreateParams)]
pub struct Create<'info> {
#[account(mut)]
pub fee_payer: Signer<'info>,
#[account(
init,
payer = fee_payer,
space = 100,
seeds = [b"user", authority.key().as_ref()],
bump
)]
#[light_account(init)]
pub user_record: Account<'info, UserRecord>,
pub authority: Signer<'info>,
}
);
let instruction_args = InstructionArgSet::from_names(["params".to_string()]);
let result = extract_from_accounts_struct(&item, &instruction_args, "crate::instructions")
.expect("should extract");
assert!(result.is_some());
let info = result.unwrap();
assert_eq!(info.struct_name.to_string(), "Create");
assert_eq!(info.pda_fields.len(), 1);
assert_eq!(info.pda_fields[0].variant_name.to_string(), "UserRecord");
assert_eq!(info.pda_fields[0].seeds.len(), 2);
assert!(!info.has_light_mint_fields);
assert!(!info.has_light_ata_fields);
}
}