use quote::quote;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
Expr, Ident, ItemFn, LitStr, Result, Token,
};
use super::visitors::FieldExtractor;
macro_rules! macro_error {
($span:expr, $msg:expr) => {
syn::Error::new_spanned(
$span,
format!(
"{}\n --> macro location: {}:{}",
$msg,
file!(),
line!()
)
)
};
($span:expr, $fmt:expr, $($arg:tt)*) => {
syn::Error::new_spanned(
$span,
format!(
concat!($fmt, "\n --> macro location: {}:{}"),
$($arg)*,
file!(),
line!()
)
)
};
}
pub(crate) use macro_error;
#[derive(Debug, Clone, Copy)]
pub enum InstructionVariant {
PdaOnly,
TokenOnly,
Mixed,
MintOnly,
AtaOnly,
}
#[derive(Clone)]
pub struct TokenSeedSpec {
pub variant: Ident,
pub _eq: Token![=],
pub is_token: Option<bool>,
pub seeds: Punctuated<SeedElement, Token![,]>,
pub owner_seeds: Option<Vec<SeedElement>>,
pub inner_type: Option<syn::Type>,
pub is_zero_copy: bool,
}
impl Parse for TokenSeedSpec {
fn parse(input: ParseStream) -> Result<Self> {
let variant: Ident = input.parse()?;
let _eq: Token![=] = input.parse()?;
let content;
syn::parenthesized!(content in input);
let mut is_token = None;
let mut seeds = Punctuated::new();
let mut owner_seeds = None;
while !content.is_empty() {
if content.peek(Ident) {
let ident: Ident = content.parse()?;
let ident_str = ident.to_string();
match ident_str.as_str() {
"is_token" | "true" => {
is_token = Some(true);
}
"is_pda" | "false" => {
is_token = Some(false);
}
"seeds" => {
let _eq: Token![=] = content.parse()?;
let seeds_content;
syn::parenthesized!(seeds_content in content);
seeds = parse_seed_elements(&seeds_content)?;
}
"owner_seeds" => {
let _eq: Token![=] = content.parse()?;
owner_seeds = Some(parse_owner_seeds(&content)?);
}
_ => {
return Err(syn::Error::new_spanned(
&ident,
format!(
"Unknown keyword '{}'. Expected: is_token, seeds, or owner_seeds.\n\
Use explicit syntax: TypeName = (seeds = (\"seed\", ctx.account, ...))\n\
For tokens: TypeName = (is_token, seeds = (...), owner_seeds = (...))",
ident_str
),
));
}
}
} else {
return Err(syn::Error::new(
content.span(),
"Expected keyword (is_token, seeds, or owner_seeds). Use explicit syntax:\n\
- PDA: TypeName = (seeds = (\"seed\", ctx.account, ...))\n\
- Token: TypeName = (is_token, seeds = (...), owner_seeds = (...))",
));
}
if content.peek(Token![,]) {
let _comma: Token![,] = content.parse()?;
} else {
break;
}
}
if seeds.is_empty() {
return Err(syn::Error::new_spanned(
&variant,
format!(
"Missing seeds for '{}'. Use: {} = (seeds = (\"seed\", ctx.account, ...))",
variant, variant
),
));
}
Ok(TokenSeedSpec {
variant,
_eq,
is_token,
seeds,
owner_seeds,
inner_type: None, is_zero_copy: false, })
}
}
fn parse_seed_elements(content: ParseStream) -> Result<Punctuated<SeedElement, Token![,]>> {
let mut seeds = Punctuated::new();
while !content.is_empty() {
seeds.push(content.parse::<SeedElement>()?);
if content.peek(Token![,]) {
let _: Token![,] = content.parse()?;
if content.is_empty() {
break;
}
} else {
break;
}
}
Ok(seeds)
}
fn parse_owner_seeds(content: ParseStream) -> Result<Vec<SeedElement>> {
if content.peek(syn::token::Paren) {
let auth_content;
syn::parenthesized!(auth_content in content);
let mut auth_seeds = Vec::new();
while !auth_content.is_empty() {
auth_seeds.push(auth_content.parse::<SeedElement>()?);
if auth_content.peek(Token![,]) {
let _: Token![,] = auth_content.parse()?;
} else {
break;
}
}
Ok(auth_seeds)
} else {
Ok(vec![content.parse::<SeedElement>()?])
}
}
#[derive(Clone, Debug)]
pub enum SeedElement {
Literal(LitStr),
Expression(Box<Expr>),
}
impl Parse for SeedElement {
fn parse(input: ParseStream) -> Result<Self> {
if input.peek(LitStr) {
Ok(SeedElement::Literal(input.parse()?))
} else {
Ok(SeedElement::Expression(input.parse()?))
}
}
}
pub struct InstructionDataSpec {
pub field_name: Ident,
pub field_type: syn::Type,
}
impl Parse for InstructionDataSpec {
fn parse(input: ParseStream) -> Result<Self> {
let field_name: Ident = input.parse()?;
let _eq: Token![=] = input.parse()?;
let field_type: syn::Type = input.parse()?;
Ok(InstructionDataSpec {
field_name,
field_type,
})
}
}
pub fn extract_ctx_seed_fields(
seeds: &syn::punctuated::Punctuated<SeedElement, Token![,]>,
) -> Vec<Ident> {
let mut all_fields = Vec::new();
let mut seen = std::collections::HashSet::new();
for seed in seeds {
if let SeedElement::Expression(expr) = seed {
let fields = FieldExtractor::ctx_fields(&[]).extract(expr);
for field in fields {
let name = field.to_string();
if seen.insert(name) {
all_fields.push(field);
}
}
}
}
all_fields
}
pub fn extract_data_seed_fields(
seeds: &syn::punctuated::Punctuated<SeedElement, Token![,]>,
) -> Vec<Ident> {
let mut all_fields = Vec::new();
let mut seen = std::collections::HashSet::new();
for seed in seeds {
if let SeedElement::Expression(expr) = seed {
let fields = FieldExtractor::data_fields().extract(expr);
for field in fields {
let name = field.to_string();
if seen.insert(name) {
all_fields.push(field);
}
}
}
}
all_fields
}
pub fn convert_classified_to_seed_elements(
seeds: &[crate::light_pdas::seeds::ClassifiedSeed],
module_path: &str,
crate_ctx: &crate::light_pdas::parsing::CrateContext,
) -> Punctuated<SeedElement, Token![,]> {
use crate::light_pdas::seeds::{extract_data_field_info, ClassifiedSeed};
let mut result = Punctuated::new();
for seed in seeds {
let elem = match seed {
ClassifiedSeed::Literal(bytes) => {
if let Ok(s) = std::str::from_utf8(bytes) {
SeedElement::Literal(syn::LitStr::new(s, proc_macro2::Span::call_site()))
} else {
let byte_values: Vec<_> = bytes.iter().map(|b| quote!(#b)).collect();
let expr: Expr = syn::parse_quote!(&[#(#byte_values),*]);
SeedElement::Expression(Box::new(expr))
}
}
ClassifiedSeed::Constant { path, expr } => {
let is_single_segment = path.segments.len() == 1;
let qualified_expr: Expr = if is_single_segment {
let const_name = path.segments[0].ident.to_string();
let resolved = crate_ctx
.find_const_module_path(&const_name)
.filter(|p| crate_ctx.is_module_path_public(p))
.unwrap_or("crate");
let mod_path: syn::Path =
syn::parse_str(resolved).unwrap_or_else(|_| syn::parse_quote!(crate));
qualify_constant_in_expr(expr, &mod_path, path)
} else {
(**expr).clone()
};
SeedElement::Expression(Box::new(qualified_expr))
}
ClassifiedSeed::CtxRooted { account, .. } => {
let expr: Expr = syn::parse_quote!(ctx.#account);
SeedElement::Expression(Box::new(expr))
}
ClassifiedSeed::DataRooted { expr, .. } => {
if let Some((field_name, conversion)) = extract_data_field_info(expr) {
let expr: Expr = if let Some(method) = conversion {
syn::parse_quote!(data.#field_name.#method())
} else {
syn::parse_quote!(data.#field_name)
};
SeedElement::Expression(Box::new(expr))
} else {
SeedElement::Expression(expr.clone())
}
}
ClassifiedSeed::FunctionCall {
func_expr,
args: fn_args,
has_as_ref,
} => {
let rewritten_call =
rewrite_fn_call_for_scope(func_expr, fn_args, module_path, crate_ctx);
let expr: Expr = if *has_as_ref {
syn::parse_quote!(#rewritten_call.as_ref())
} else {
rewritten_call
};
SeedElement::Expression(Box::new(expr))
}
ClassifiedSeed::Passthrough(expr) => SeedElement::Expression(expr.clone()),
};
result.push(elem);
}
result
}
fn qualify_constant_in_expr(expr: &Expr, mod_path: &syn::Path, const_path: &syn::Path) -> Expr {
match expr {
Expr::MethodCall(method_call) => {
let qualified_receiver =
qualify_constant_in_expr(&method_call.receiver, mod_path, const_path);
Expr::MethodCall(syn::ExprMethodCall {
attrs: method_call.attrs.clone(),
receiver: Box::new(qualified_receiver),
dot_token: method_call.dot_token,
method: method_call.method.clone(),
turbofish: method_call.turbofish.clone(),
paren_token: method_call.paren_token,
args: method_call.args.clone(),
})
}
Expr::Path(_) => {
syn::parse_quote!(#mod_path::#const_path)
}
_ => {
syn::parse_quote!(#mod_path::#const_path)
}
}
}
fn rewrite_fn_call_for_scope(
func_expr: &Expr,
fn_args: &[crate::light_pdas::seeds::ClassifiedFnArg],
module_path: &str,
crate_ctx: &crate::light_pdas::parsing::CrateContext,
) -> Expr {
use quote::quote;
use crate::light_pdas::seeds::FnArgKind;
if let Expr::Call(call) = func_expr {
let func_path: Expr = if let Expr::Path(path_expr) = &*call.func {
if path_expr.path.segments.len() == 1 {
let fn_name = path_expr.path.segments[0].ident.to_string();
let resolved = crate_ctx
.find_fn_module_path(&fn_name)
.filter(|p| crate_ctx.is_module_path_public(p))
.unwrap_or(module_path);
let mod_path: syn::Path =
syn::parse_str(resolved).unwrap_or_else(|_| syn::parse_quote!(crate));
let ident = &path_expr.path.segments[0].ident;
syn::parse_quote!(#mod_path::#ident)
} else {
Expr::Path(path_expr.clone())
}
} else {
(*call.func).clone()
};
let rewritten_args: Vec<Expr> = call
.args
.iter()
.map(|arg| {
let arg_str = quote!(#arg).to_string();
for classified in fn_args {
let field = &classified.field_name;
let field_str = field.to_string();
if arg_str.contains(&field_str) {
return match classified.kind {
FnArgKind::CtxAccount => syn::parse_quote!(&ctx.#field),
FnArgKind::DataField => syn::parse_quote!(&data.#field),
};
}
}
arg.clone()
})
.collect();
syn::parse_quote!(#func_path(#(#rewritten_args),*))
} else {
func_expr.clone()
}
}
pub fn convert_classified_to_seed_elements_vec(
seeds: &[crate::light_pdas::seeds::ClassifiedSeed],
module_path: &str,
crate_ctx: &crate::light_pdas::parsing::CrateContext,
) -> Vec<SeedElement> {
convert_classified_to_seed_elements(seeds, module_path, crate_ctx)
.into_iter()
.collect()
}
pub enum ExtractResult {
Success {
context_type: String,
params_ident: Ident,
ctx_ident: Ident,
},
MultipleParams {
context_type: String,
param_names: Vec<String>,
},
None,
}
pub fn extract_context_and_params(fn_item: &ItemFn) -> ExtractResult {
let mut context_type = None;
let mut ctx_ident = None;
let mut params_candidates: Vec<Ident> = Vec::new();
for input in &fn_item.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
if let syn::Type::Path(type_path) = &*pat_type.ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Context" {
ctx_ident = Some(pat_ident.ident.clone());
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in args.args.iter().rev() {
if let syn::GenericArgument::Type(syn::Type::Path(inner_path)) =
arg
{
if let Some(inner_seg) = inner_path.path.segments.last() {
context_type = Some(inner_seg.ident.to_string());
break;
}
}
}
}
continue; }
}
}
let name = pat_ident.ident.to_string();
if !name.contains("signer") && !name.contains("bump") {
params_candidates.push(pat_ident.ident.clone());
}
}
}
}
match (context_type, ctx_ident) {
(Some(ctx_type), Some(ctx_name)) => {
if params_candidates.len() > 1 {
ExtractResult::MultipleParams {
context_type: ctx_type,
param_names: params_candidates.iter().map(|id| id.to_string()).collect(),
}
} else if let Some(params) = params_candidates.into_iter().next() {
ExtractResult::Success {
context_type: ctx_type,
params_ident: params,
ctx_ident: ctx_name,
}
} else {
ExtractResult::None
}
}
_ => ExtractResult::None,
}
}
fn is_delegation_body(block: &syn::Block, ctx_name: &str) -> bool {
if block.stmts.len() != 1 {
return false;
}
match &block.stmts[0] {
syn::Stmt::Expr(expr, _) => {
match expr {
syn::Expr::Call(call) => call_has_ctx_arg(&call.args, ctx_name),
syn::Expr::MethodCall(call) => call_has_ctx_arg(&call.args, ctx_name),
_ => false,
}
}
_ => false,
}
}
pub(crate) fn call_has_ctx_arg(
args: &syn::punctuated::Punctuated<syn::Expr, syn::token::Comma>,
ctx_name: &str,
) -> bool {
for arg in args {
match arg {
syn::Expr::Path(path) if path.path.is_ident(ctx_name) => return true,
syn::Expr::Reference(ref_expr) => {
if let syn::Expr::Path(p) = &*ref_expr.expr {
if p.path.is_ident(ctx_name) {
return true;
}
}
}
syn::Expr::MethodCall(method_call) => {
if let syn::Expr::Path(p) = &*method_call.receiver {
if p.path.is_ident(ctx_name) {
return true;
}
}
}
_ => {}
}
}
false
}
pub fn wrap_function_with_light(
fn_item: &ItemFn,
params_ident: &Ident,
ctx_name: &Ident,
) -> ItemFn {
let fn_vis = &fn_item.vis;
let fn_sig = &fn_item.sig;
let fn_block = &fn_item.block;
let fn_attrs = &fn_item.attrs;
let ctx_name_str = ctx_name.to_string();
let is_delegation = is_delegation_body(fn_block, &ctx_name_str);
if is_delegation {
syn::parse_quote! {
#(#fn_attrs)*
#fn_vis #fn_sig {
use light_account::{LightPreInit, LightFinalize};
let _ = #ctx_name.accounts.light_pre_init(#ctx_name.remaining_accounts, &#params_ident)
.map_err(|e| anchor_lang::error::Error::from(solana_program_error::ProgramError::from(e)))?;
#fn_block
}
}
} else {
syn::parse_quote! {
#(#fn_attrs)*
#fn_vis #fn_sig {
use light_account::{LightPreInit, LightFinalize};
let __has_pre_init = #ctx_name.accounts.light_pre_init(#ctx_name.remaining_accounts, &#params_ident)
.map_err(|e| anchor_lang::error::Error::from(solana_program_error::ProgramError::from(e)))?;
let __user_result: anchor_lang::Result<()> = #fn_block;
__user_result?;
#ctx_name.accounts.light_finalize(#ctx_name.remaining_accounts, &#params_ident, __has_pre_init)
.map_err(|e| anchor_lang::error::Error::from(solana_program_error::ProgramError::from(e)))?;
Ok(())
}
}
}
}