use std::collections::HashSet;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
visit::{self, Visit},
Expr, Ident, Member,
};
use super::instructions::{InstructionDataSpec, SeedElement};
use crate::light_pdas::{account::utils::is_pubkey_type, shared_utils::is_constant_identifier};
pub struct FieldExtractor<'ast, 'cfg> {
extract_ctx: bool,
extract_data: bool,
excluded: &'cfg [&'cfg str],
ctx_name: &'cfg str,
fields: Vec<&'ast Ident>,
seen: HashSet<String>,
}
impl<'ast, 'cfg> FieldExtractor<'ast, 'cfg> {
pub fn ctx_fields(excluded: &'cfg [&'cfg str]) -> Self {
Self::ctx_fields_with_name(excluded, "ctx")
}
pub fn ctx_fields_with_name(excluded: &'cfg [&'cfg str], ctx_name: &'cfg str) -> Self {
Self {
extract_ctx: true,
extract_data: false,
excluded,
ctx_name,
fields: Vec::new(),
seen: HashSet::new(),
}
}
pub fn data_fields() -> Self {
Self {
extract_ctx: false,
extract_data: true,
excluded: &[],
ctx_name: "ctx", fields: Vec::new(),
seen: HashSet::new(),
}
}
pub fn extract(mut self, expr: &'ast Expr) -> Vec<Ident> {
self.visit_expr(expr);
self.fields.into_iter().cloned().collect()
}
fn try_add(&mut self, field: &'ast Ident) {
let name = field.to_string();
if !self.excluded.contains(&name.as_str()) && self.seen.insert(name) {
self.fields.push(field);
}
}
pub fn is_ctx_accounts(base: &Expr) -> bool {
Self::is_ctx_accounts_with_name(base, "ctx")
}
pub fn is_ctx_accounts_with_name(base: &Expr, ctx_name: &str) -> bool {
if let Expr::Field(nested) = base {
if let Member::Named(member) = &nested.member {
return member == "accounts" && Self::is_path_ident(&nested.base, ctx_name);
}
}
false
}
pub fn is_any_ctx_accounts(base: &Expr) -> Option<String> {
if let Expr::Field(nested) = base {
if let Member::Named(member) = &nested.member {
if member == "accounts" {
if let Expr::Path(path) = &*nested.base {
if let Some(ident) = path.path.get_ident() {
return Some(ident.to_string());
}
}
}
}
}
None
}
pub fn is_path_ident(expr: &Expr, ident: &str) -> bool {
matches!(expr, Expr::Path(p) if p.path.is_ident(ident))
}
}
impl<'ast, 'cfg> Visit<'ast> for FieldExtractor<'ast, 'cfg> {
fn visit_expr_field(&mut self, node: &'ast syn::ExprField) {
if let Member::Named(field_name) = &node.member {
if self.extract_ctx && Self::is_ctx_accounts_with_name(&node.base, self.ctx_name) {
self.try_add(field_name);
return;
}
if self.extract_ctx && Self::is_path_ident(&node.base, self.ctx_name) {
self.try_add(field_name);
return;
}
if self.extract_data && Self::is_path_ident(&node.base, "data") {
self.try_add(field_name);
return;
}
}
visit::visit_expr_field(self, node);
}
}
#[derive(Debug, Clone)]
pub enum ClientSeedInfo {
Literal(String),
ByteLiteral(Vec<u8>),
Constant {
path: syn::Path,
is_cpi_signer: bool,
},
CtxField { field: Ident, method: Option<Ident> },
DataField { field: Ident, method: Option<Ident> },
FunctionCall(Box<syn::ExprCall>),
Identifier(Ident),
RawExpr(Box<syn::Expr>),
}
pub fn classify_seed(seed: &SeedElement) -> syn::Result<ClientSeedInfo> {
match seed {
SeedElement::Literal(lit) => Ok(ClientSeedInfo::Literal(lit.value())),
SeedElement::Expression(expr) => classify_seed_expr(expr),
}
}
fn classify_seed_expr(expr: &syn::Expr) -> syn::Result<ClientSeedInfo> {
match expr {
syn::Expr::Field(field_expr) => classify_field_expr(field_expr),
syn::Expr::MethodCall(method_call) => classify_method_call(method_call),
syn::Expr::Lit(lit_expr) => classify_lit_expr(lit_expr),
syn::Expr::Path(path_expr) => classify_path_expr(path_expr),
syn::Expr::Call(call_expr) => classify_call_expr(call_expr),
syn::Expr::Reference(ref_expr) => classify_seed_expr(&ref_expr.expr),
_ => Ok(ClientSeedInfo::RawExpr(Box::new(expr.clone()))),
}
}
fn classify_field_expr(field_expr: &syn::ExprField) -> syn::Result<ClientSeedInfo> {
if let Member::Named(field_name) = &field_expr.member {
if FieldExtractor::is_any_ctx_accounts(&field_expr.base).is_some() {
return Ok(ClientSeedInfo::CtxField {
field: field_name.clone(),
method: None,
});
}
if let syn::Expr::Path(path) = &*field_expr.base {
if let Some(segment) = path.path.segments.first() {
if segment.ident == "data" {
return Ok(ClientSeedInfo::DataField {
field: field_name.clone(),
method: None,
});
}
return Ok(ClientSeedInfo::CtxField {
field: field_name.clone(),
method: None,
});
}
}
return Ok(ClientSeedInfo::RawExpr(Box::new(syn::Expr::Field(
field_expr.clone(),
))));
}
Ok(ClientSeedInfo::RawExpr(Box::new(syn::Expr::Field(
field_expr.clone(),
))))
}
fn classify_method_call(method_call: &syn::ExprMethodCall) -> syn::Result<ClientSeedInfo> {
if let syn::Expr::Field(field_expr) = &*method_call.receiver {
if let Member::Named(field_name) = &field_expr.member {
if FieldExtractor::is_any_ctx_accounts(&field_expr.base).is_some() {
return Ok(ClientSeedInfo::CtxField {
field: field_name.clone(),
method: Some(method_call.method.clone()),
});
}
if let syn::Expr::Path(path) = &*field_expr.base {
if let Some(segment) = path.path.segments.first() {
if segment.ident == "data" {
return Ok(ClientSeedInfo::DataField {
field: field_name.clone(),
method: Some(method_call.method.clone()),
});
}
return Ok(ClientSeedInfo::CtxField {
field: field_name.clone(),
method: Some(method_call.method.clone()),
});
}
}
}
}
if let syn::Expr::Call(call_expr) = &*method_call.receiver {
return classify_call_expr(call_expr);
}
if let syn::Expr::Path(path_expr) = &*method_call.receiver {
if let Some(ident) = path_expr.path.get_ident() {
return Ok(ClientSeedInfo::Identifier(ident.clone()));
}
}
Ok(ClientSeedInfo::RawExpr(Box::new(syn::Expr::MethodCall(
method_call.clone(),
))))
}
fn classify_lit_expr(lit_expr: &syn::ExprLit) -> syn::Result<ClientSeedInfo> {
if let syn::Lit::ByteStr(byte_str) = &lit_expr.lit {
Ok(ClientSeedInfo::ByteLiteral(byte_str.value()))
} else {
Ok(ClientSeedInfo::RawExpr(Box::new(syn::Expr::Lit(
lit_expr.clone(),
))))
}
}
fn classify_path_expr(path_expr: &syn::ExprPath) -> syn::Result<ClientSeedInfo> {
if path_expr.qself.is_some() {
return Ok(ClientSeedInfo::RawExpr(Box::new(syn::Expr::Path(
path_expr.clone(),
))));
}
if let Some(last_seg) = path_expr.path.segments.last() {
let last_str = last_seg.ident.to_string();
if is_constant_identifier(&last_str) {
return Ok(ClientSeedInfo::Constant {
path: path_expr.path.clone(),
is_cpi_signer: last_str == "LIGHT_CPI_SIGNER",
});
}
}
if let Some(ident) = path_expr.path.get_ident() {
return Ok(ClientSeedInfo::Identifier(ident.clone()));
}
Ok(ClientSeedInfo::RawExpr(Box::new(syn::Expr::Path(
path_expr.clone(),
))))
}
fn classify_call_expr(call_expr: &syn::ExprCall) -> syn::Result<ClientSeedInfo> {
Ok(ClientSeedInfo::FunctionCall(Box::new(call_expr.clone())))
}
fn map_call_arg(
arg: &syn::Expr,
instruction_data: &[InstructionDataSpec],
seen_params: &mut HashSet<String>,
parameters: &mut Vec<TokenStream>,
is_pinocchio: bool,
) -> syn::Result<TokenStream> {
let pubkey_param = if is_pinocchio {
quote! { light_account_pinocchio::solana_pubkey::Pubkey }
} else {
quote! { solana_pubkey::Pubkey }
};
match arg {
syn::Expr::Reference(ref_expr) => {
let inner = map_call_arg(
&ref_expr.expr,
instruction_data,
seen_params,
parameters,
is_pinocchio,
)?;
Ok(quote! { &#inner })
}
syn::Expr::Field(field_expr) => {
if let syn::Member::Named(field_name) = &field_expr.member {
if FieldExtractor::is_ctx_accounts(&field_expr.base) {
if seen_params.insert(field_name.to_string()) {
parameters.push(quote! { #field_name: &#pubkey_param });
}
return Ok(quote! { #field_name });
}
if let syn::Expr::Path(path) = &*field_expr.base {
if let Some(segment) = path.path.segments.first() {
if segment.ident == "data" {
if let Some(data_spec) = instruction_data
.iter()
.find(|d| d.field_name == *field_name)
{
if seen_params.insert(field_name.to_string()) {
let param_type = &data_spec.field_type;
let param_with_ref = if is_pubkey_type(param_type) {
quote! { #field_name: &#param_type }
} else {
quote! { #field_name: #param_type }
};
parameters.push(param_with_ref);
}
return Ok(quote! { #field_name });
}
if seen_params.insert(field_name.to_string()) {
parameters.push(quote! { #field_name: &#pubkey_param });
}
return Ok(quote! { #field_name });
} else if segment.ident == "ctx" {
if seen_params.insert(field_name.to_string()) {
parameters.push(quote! { #field_name: &#pubkey_param });
}
return Ok(quote! { #field_name });
}
}
}
}
Ok(quote! { #field_expr })
}
syn::Expr::MethodCall(method_call) => {
let receiver = map_call_arg(
&method_call.receiver,
instruction_data,
seen_params,
parameters,
is_pinocchio,
)?;
let method = &method_call.method;
let args: Vec<TokenStream> = method_call
.args
.iter()
.map(|a| map_call_arg(a, instruction_data, seen_params, parameters, is_pinocchio))
.collect::<syn::Result<_>>()?;
Ok(quote! { (#receiver).#method(#(#args),*) })
}
syn::Expr::Call(nested_call) => {
let func = &nested_call.func;
let args: Vec<TokenStream> = nested_call
.args
.iter()
.map(|a| map_call_arg(a, instruction_data, seen_params, parameters, is_pinocchio))
.collect::<syn::Result<_>>()?;
Ok(quote! { (#func)(#(#args),*) })
}
syn::Expr::Path(path_expr) => {
if let Some(ident) = path_expr.path.get_ident() {
let name = ident.to_string();
if name != "ctx"
&& name != "data"
&& !is_constant_identifier(&name)
&& seen_params.insert(name)
{
parameters.push(quote! { #ident: &#pubkey_param });
}
}
Ok(quote! { #path_expr })
}
_ => Ok(quote! { #arg }),
}
}
pub fn generate_client_seed_code(
info: &ClientSeedInfo,
instruction_data: &[InstructionDataSpec],
seen_params: &mut HashSet<String>,
parameters: &mut Vec<TokenStream>,
expressions: &mut Vec<TokenStream>,
is_pinocchio: bool,
) -> syn::Result<()> {
let pubkey_param = if is_pinocchio {
quote! { light_account_pinocchio::solana_pubkey::Pubkey }
} else {
quote! { solana_pubkey::Pubkey }
};
match info {
ClientSeedInfo::Literal(s) => {
expressions.push(quote! { #s.as_bytes() });
}
ClientSeedInfo::ByteLiteral(bytes) => {
expressions.push(quote! { &[#(#bytes),*] });
}
ClientSeedInfo::Constant {
path,
is_cpi_signer,
} => {
let expr = if *is_cpi_signer {
quote! { #path.cpi_signer.as_ref() }
} else {
quote! { { let __seed: &[u8] = #path.as_ref(); __seed } }
};
expressions.push(expr);
}
ClientSeedInfo::CtxField { field, method } => {
if seen_params.insert(field.to_string()) {
parameters.push(quote! { #field: &#pubkey_param });
}
let expr = match method {
Some(m) => quote! { #field.#m().as_ref() },
None => quote! { #field.as_ref() },
};
expressions.push(expr);
}
ClientSeedInfo::DataField { field, method } => {
let data_spec = instruction_data
.iter()
.find(|d| d.field_name == *field)
.ok_or_else(|| {
syn::Error::new(
field.span(),
format!("data.{} used in seeds but no type specified", field),
)
})?;
if seen_params.insert(field.to_string()) {
let param_type = &data_spec.field_type;
let param_with_ref = if is_pubkey_type(param_type) {
quote! { #field: &#param_type }
} else {
quote! { #field: #param_type }
};
parameters.push(param_with_ref);
}
let expr = match method {
Some(m) => quote! { #field.#m().as_ref() },
None => quote! { #field.as_ref() },
};
expressions.push(expr);
}
ClientSeedInfo::Identifier(ident) => {
if seen_params.insert(ident.to_string()) {
parameters.push(quote! { #ident: &#pubkey_param });
}
expressions.push(quote! { #ident.as_ref() });
}
ClientSeedInfo::FunctionCall(call_expr) => {
let mut mapped_args: Vec<TokenStream> = Vec::new();
for arg in &call_expr.args {
let mapped =
map_call_arg(arg, instruction_data, seen_params, parameters, is_pinocchio)?;
mapped_args.push(mapped);
}
let func = &call_expr.func;
expressions.push(quote! { (#func)(#(#mapped_args),*).as_ref() });
}
ClientSeedInfo::RawExpr(expr) => {
expressions.push(quote! { { let __seed: &[u8] = (#expr).as_ref(); __seed } });
}
}
Ok(())
}