use syn::{Expr, Ident};
use super::{
instruction_args::InstructionArgSet,
types::{ClassifiedFnArg, ClassifiedSeed, FnArgKind},
};
use crate::light_pdas::shared_utils::is_constant_identifier;
pub fn classify_seed_expr(
expr: &Expr,
instruction_args: &InstructionArgSet,
) -> syn::Result<ClassifiedSeed> {
if let Some(bytes) = extract_byte_literal(expr) {
return Ok(ClassifiedSeed::Literal(bytes));
}
if let Some(path) = extract_constant_path(expr) {
return Ok(ClassifiedSeed::Constant {
path,
expr: Box::new(expr.clone()),
});
}
if let Some(root) = get_instruction_arg_root(expr, instruction_args) {
if let Some(terminal_field) = get_terminal_field_name(expr) {
let terminal_str = terminal_field.to_string();
if terminal_field == root && is_bare_identifier(expr) {
return Err(syn::Error::new_spanned(
expr,
format!(
"Ambiguous seed: '{}' matches an instruction argument but could also be \
a context account. Use explicit field access (e.g., `params.{}`) for \
instruction data, or use `{}.key().as_ref()` for a context account.",
root, root, root
),
));
}
if terminal_field != root && instruction_args.contains(&terminal_str) {
return Err(syn::Error::new_spanned(
expr,
format!(
"Ambiguous seed: '{}' is both a field on '{}' and a separate instruction \
argument. Use the bare instruction argument '{}' directly, or rename to \
avoid collision.",
terminal_field, root, terminal_field
),
));
}
}
return Ok(ClassifiedSeed::DataRooted {
root,
expr: Box::new(expr.clone()),
});
}
if let Some(account) = get_ctx_account_root(expr) {
return Ok(ClassifiedSeed::CtxRooted { account });
}
if let Some(fc) = classify_function_call(expr, instruction_args) {
return Ok(fc);
}
Ok(ClassifiedSeed::Passthrough(Box::new(expr.clone())))
}
fn classify_function_call(
expr: &Expr,
instruction_args: &InstructionArgSet,
) -> Option<ClassifiedSeed> {
let (call_expr, has_as_ref) = strip_trailing_as_ref(expr);
let call = match call_expr {
Expr::Call(c) => c,
_ => return None,
};
let mut classified_args = Vec::new();
let mut has_dynamic = false;
for arg in &call.args {
let inner = unwrap_references(arg);
if let Some(root) = get_instruction_arg_root(inner, instruction_args) {
let field_name = extract_terminal_field_name(inner).unwrap_or(root);
classified_args.push(ClassifiedFnArg {
field_name,
kind: FnArgKind::DataField,
});
has_dynamic = true;
continue;
}
if let Some(account) = get_ctx_account_root(inner) {
classified_args.push(ClassifiedFnArg {
field_name: account,
kind: FnArgKind::CtxAccount,
});
has_dynamic = true;
continue;
}
}
if !has_dynamic {
return None;
}
Some(ClassifiedSeed::FunctionCall {
func_expr: Box::new(Expr::Call(call.clone())),
args: classified_args,
has_as_ref,
})
}
fn strip_trailing_as_ref(expr: &Expr) -> (&Expr, bool) {
if let Expr::MethodCall(mc) = expr {
let method = mc.method.to_string();
if (method == "as_ref" || method == "as_bytes") && mc.args.is_empty() {
return (&mc.receiver, true);
}
}
(expr, false)
}
fn unwrap_references(expr: &Expr) -> &Expr {
match expr {
Expr::Reference(r) => unwrap_references(&r.expr),
_ => expr,
}
}
fn extract_terminal_field_name(expr: &Expr) -> Option<Ident> {
match expr {
Expr::Field(field) => {
if let syn::Member::Named(name) = &field.member {
Some(name.clone())
} else {
None
}
}
Expr::MethodCall(mc) => extract_terminal_field_name(&mc.receiver),
Expr::Reference(r) => extract_terminal_field_name(&r.expr),
Expr::Path(path) => path.path.get_ident().cloned(),
_ => None,
}
}
fn extract_byte_literal(expr: &Expr) -> Option<Vec<u8>> {
match expr {
Expr::Lit(lit) => {
if let syn::Lit::ByteStr(bs) = &lit.lit {
return Some(bs.value());
}
if let syn::Lit::Str(s) = &lit.lit {
return Some(s.value().into_bytes());
}
None
}
Expr::Index(idx) => {
if let Expr::Range(range) = &*idx.index {
if range.start.is_none() && range.end.is_none() {
if let Expr::Lit(lit) = &*idx.expr {
if let syn::Lit::ByteStr(bs) = &lit.lit {
return Some(bs.value());
}
}
}
}
None
}
Expr::Reference(r) => extract_byte_literal(&r.expr),
_ => None,
}
}
fn extract_constant_path(expr: &Expr) -> Option<syn::Path> {
match expr {
Expr::Path(path) => {
if path.qself.is_some() {
return None;
}
if let Some(ident) = path.path.get_ident() {
if is_constant_identifier(&ident.to_string()) {
return Some(path.path.clone());
}
} else if let Some(last_seg) = path.path.segments.last() {
if is_constant_identifier(&last_seg.ident.to_string()) {
return Some(path.path.clone());
}
}
None
}
Expr::Reference(r) => extract_constant_path(&r.expr),
Expr::MethodCall(mc) => extract_constant_path(&mc.receiver),
_ => None,
}
}
fn is_bare_identifier(expr: &Expr) -> bool {
match expr {
Expr::Path(path) => path.path.get_ident().is_some(),
Expr::MethodCall(mc) => is_bare_identifier(&mc.receiver),
Expr::Reference(r) => is_bare_identifier(&r.expr),
Expr::Paren(p) => is_bare_identifier(&p.expr),
_ => false,
}
}
fn get_terminal_field_name(expr: &Expr) -> Option<Ident> {
match expr {
Expr::Path(path) => path.path.get_ident().cloned(),
Expr::Field(field) => {
if let syn::Member::Named(name) = &field.member {
Some(name.clone())
} else {
None
}
}
Expr::MethodCall(mc) => get_terminal_field_name(&mc.receiver),
Expr::Reference(r) => get_terminal_field_name(&r.expr),
Expr::Paren(p) => get_terminal_field_name(&p.expr),
Expr::Index(idx) => get_terminal_field_name(&idx.expr),
_ => None,
}
}
fn get_instruction_arg_root(expr: &Expr, instruction_args: &InstructionArgSet) -> Option<Ident> {
match expr {
Expr::Path(path) => {
if let Some(ident) = path.path.get_ident() {
let name = ident.to_string();
if !is_constant_identifier(&name) && instruction_args.contains(&name) {
return Some(ident.clone());
}
}
None
}
Expr::Field(field) => get_instruction_arg_root(&field.base, instruction_args),
Expr::MethodCall(mc) => get_instruction_arg_root(&mc.receiver, instruction_args),
Expr::Index(idx) => get_instruction_arg_root(&idx.expr, instruction_args),
Expr::Reference(r) => get_instruction_arg_root(&r.expr, instruction_args),
_ => None,
}
}
fn get_ctx_account_root(expr: &Expr) -> Option<Ident> {
match expr {
Expr::Path(path) => {
if let Some(ident) = path.path.get_ident() {
let name = ident.to_string();
if !is_constant_identifier(&name) {
return Some(ident.clone());
}
}
None
}
Expr::Field(field) => {
if let syn::Member::Named(field_name) = &field.member {
match &*field.base {
Expr::Path(_) => Some(field_name.clone()),
Expr::Field(_) => {
Some(field_name.clone())
}
_ => get_ctx_account_root(&field.base),
}
} else {
None
}
}
Expr::MethodCall(mc) => get_ctx_account_root(&mc.receiver),
Expr::Reference(r) => get_ctx_account_root(&r.expr),
_ => None,
}
}
#[cfg(test)]
mod tests {
use syn::parse_quote;
use super::*;
fn make_instruction_args(names: &[&str]) -> InstructionArgSet {
InstructionArgSet::from_names(names.iter().map(|s| s.to_string()))
}
#[test]
fn test_bare_pubkey_instruction_arg() {
let args = make_instruction_args(&["owner", "amount"]);
let expr: syn::Expr = parse_quote!(owner);
let result = classify_seed_expr(&expr, &args);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Ambiguous seed"));
}
#[test]
fn test_bare_primitive_with_to_le_bytes() {
let args = make_instruction_args(&["amount"]);
let expr: syn::Expr = parse_quote!(amount.to_le_bytes().as_ref());
let result = classify_seed_expr(&expr, &args);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Ambiguous seed"));
}
#[test]
fn test_custom_struct_param_name() {
let args = make_instruction_args(&["input"]);
let expr: syn::Expr = parse_quote!(input.owner.as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::DataRooted { root, .. } if root == "input"));
}
#[test]
fn test_nested_field_access() {
let args = make_instruction_args(&["data"]);
let expr: syn::Expr = parse_quote!(data.inner.key.as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::DataRooted { root, .. } if root == "data"));
}
#[test]
fn test_context_account_not_confused_with_arg() {
let args = make_instruction_args(&["owner"]); let expr: syn::Expr = parse_quote!(authority.key().as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(
result,
ClassifiedSeed::CtxRooted { account, .. } if account == "authority"
));
}
#[test]
fn test_empty_instruction_args() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(owner);
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::CtxRooted { account, .. } if account == "owner"));
}
#[test]
fn test_literal_seed() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(b"seed");
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::Literal(bytes) if bytes == b"seed"));
}
#[test]
fn test_constant_seed() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(SEED_PREFIX);
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::Constant { .. }));
}
#[test]
fn test_standard_params_field_access() {
let args = make_instruction_args(&["params"]);
let expr: syn::Expr = parse_quote!(params.owner.as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::DataRooted { root, .. } if root == "params"));
}
#[test]
fn test_args_naming_format() {
let args = make_instruction_args(&["args"]);
let expr: syn::Expr = parse_quote!(args.key.as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::DataRooted { root, .. } if root == "args"));
}
#[test]
fn test_data_naming_format() {
let args = make_instruction_args(&["data"]);
let expr: syn::Expr = parse_quote!(data.value.to_le_bytes().as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(
result,
ClassifiedSeed::DataRooted { root, .. } if root == "data"
));
}
#[test]
fn test_format2_multiple_params() {
let args = make_instruction_args(&["owner", "amount"]);
let expr1: syn::Expr = parse_quote!(owner.as_ref());
let result1 = classify_seed_expr(&expr1, &args);
assert!(result1.is_err());
assert!(result1.unwrap_err().to_string().contains("Ambiguous seed"));
let expr2: syn::Expr = parse_quote!(amount.to_le_bytes().as_ref());
let result2 = classify_seed_expr(&expr2, &args);
assert!(result2.is_err());
assert!(result2.unwrap_err().to_string().contains("Ambiguous seed"));
}
#[test]
fn test_passthrough_for_complex_expressions() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(<Type as Trait>::CONST);
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::Passthrough(_)));
}
#[test]
fn test_passthrough_for_generic_function_call() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(identity_seed::<12>(b"seed"));
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::Passthrough(_)));
}
#[test]
fn test_function_call_with_data_args() {
let args = make_instruction_args(&["params"]);
let expr: syn::Expr = parse_quote!(crate::max_key(¶ms.key_a, ¶ms.key_b).as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
match result {
ClassifiedSeed::FunctionCall {
args: fn_args,
has_as_ref,
..
} => {
assert!(has_as_ref, "Should detect trailing .as_ref()");
assert_eq!(fn_args.len(), 2, "Should have 2 classified args");
assert_eq!(fn_args[0].field_name.to_string(), "key_a");
assert_eq!(fn_args[0].kind, FnArgKind::DataField);
assert_eq!(fn_args[1].field_name.to_string(), "key_b");
assert_eq!(fn_args[1].kind, FnArgKind::DataField);
}
other => panic!("Expected FunctionCall, got {:?}", other),
}
}
#[test]
fn test_function_call_with_ctx_args() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(some_func(&fee_payer, &authority).as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
match result {
ClassifiedSeed::FunctionCall {
args: fn_args,
has_as_ref,
..
} => {
assert!(has_as_ref);
assert_eq!(fn_args.len(), 2);
assert_eq!(fn_args[0].kind, FnArgKind::CtxAccount);
assert_eq!(fn_args[1].kind, FnArgKind::CtxAccount);
}
other => panic!("Expected FunctionCall, got {:?}", other),
}
}
#[test]
fn test_function_call_no_dynamic_args_becomes_passthrough() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(crate::id().as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(
matches!(result, ClassifiedSeed::Passthrough(_)),
"No-arg function call should be Passthrough, got {:?}",
result
);
}
#[test]
fn test_constant_method_call_not_function_call() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(SeedHolder::NAMESPACE.as_bytes());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(
matches!(result, ClassifiedSeed::Constant { .. }),
"Method call on constant should be Constant, got {:?}",
result
);
}
#[test]
fn test_function_call_mixed_args() {
let args = make_instruction_args(&["params"]);
let expr: syn::Expr = parse_quote!(func(¶ms.key_a, &authority).as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
match result {
ClassifiedSeed::FunctionCall {
args: fn_args,
has_as_ref,
..
} => {
assert!(has_as_ref);
assert_eq!(fn_args.len(), 2);
assert_eq!(fn_args[0].field_name.to_string(), "key_a");
assert_eq!(fn_args[0].kind, FnArgKind::DataField);
assert_eq!(fn_args[1].field_name.to_string(), "authority");
assert_eq!(fn_args[1].kind, FnArgKind::CtxAccount);
}
other => panic!("Expected FunctionCall, got {:?}", other),
}
}
#[test]
fn test_literal_sliced() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(b"literal"[..]);
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(matches!(result, ClassifiedSeed::Literal(bytes) if bytes == b"literal"));
}
#[test]
fn test_constant_qualified() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(crate::state::SEED_CONSTANT);
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(
matches!(result, ClassifiedSeed::Constant { path, .. } if path.segments.last().unwrap().ident == "SEED_CONSTANT")
);
}
#[test]
fn test_ctx_account_nested() {
let args = InstructionArgSet::empty();
let expr: syn::Expr = parse_quote!(ctx.accounts.authority.key().as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(
matches!(result, ClassifiedSeed::CtxRooted { account, .. } if account == "authority")
);
}
#[test]
fn test_ctx_account_root_terminal_extraction() {
let args = InstructionArgSet::empty();
let expr1: syn::Expr = parse_quote!(ctx.accounts.authority.key().as_ref());
let result1 = get_ctx_account_root(&expr1);
assert_eq!(
result1.as_ref().map(|i| i.to_string()).as_deref(),
Some("authority")
);
let expr2: syn::Expr = parse_quote!(authority.key().as_ref());
let result2 = get_ctx_account_root(&expr2);
assert_eq!(
result2.as_ref().map(|i| i.to_string()).as_deref(),
Some("authority")
);
let expr3: syn::Expr = parse_quote!(ctx.accounts.authority);
let result3 = get_ctx_account_root(&expr3);
assert_eq!(
result3.as_ref().map(|i| i.to_string()).as_deref(),
Some("authority")
);
let expr4: syn::Expr = parse_quote!(authority.key().as_ref());
let classified = classify_seed_expr(&expr4, &args).unwrap();
assert!(
matches!(classified, ClassifiedSeed::CtxRooted { account, .. } if account == "authority")
);
}
#[test]
fn test_bare_identifier_collision_error() {
let args = make_instruction_args(&["authority"]);
let expr: syn::Expr = parse_quote!(authority.as_ref());
let result = classify_seed_expr(&expr, &args);
assert!(
result.is_err(),
"Expected error for ambiguous bare identifier"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("Ambiguous seed"),
"Error should mention ambiguity: {}",
err
);
}
#[test]
fn test_field_access_no_collision() {
let args = make_instruction_args(&["params"]);
let expr: syn::Expr = parse_quote!(params.authority.as_ref());
let result = classify_seed_expr(&expr, &args).unwrap();
assert!(
matches!(result, ClassifiedSeed::DataRooted { root, .. } if root == "params"),
"Field access should be DataRooted without error"
);
}
#[test]
fn test_is_bare_identifier() {
let expr1: syn::Expr = parse_quote!(authority);
assert!(is_bare_identifier(&expr1));
let expr2: syn::Expr = parse_quote!(authority.as_ref());
assert!(is_bare_identifier(&expr2));
let expr3: syn::Expr = parse_quote!(params.authority);
assert!(!is_bare_identifier(&expr3));
let expr4: syn::Expr = parse_quote!(params.inner.authority.as_ref());
assert!(!is_bare_identifier(&expr4));
}
#[test]
fn test_terminal_field_collision_with_instruction_arg() {
let args = make_instruction_args(&["params", "authority"]);
let expr: syn::Expr = parse_quote!(params.authority.as_ref());
let result = classify_seed_expr(&expr, &args);
assert!(
result.is_err(),
"Expected error for terminal field matching instruction arg"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("Ambiguous seed"),
"Error should mention ambiguity: {}",
err
);
}
}