use syn::{Expr, Ident};
use super::types::{ClassifiedSeed, FnArgKind};
use crate::light_pdas::shared_utils::is_base_path;
pub fn get_data_fields(seeds: &[ClassifiedSeed]) -> Vec<(Ident, Option<Ident>)> {
let mut fields = Vec::new();
for seed in seeds {
match seed {
ClassifiedSeed::DataRooted { expr, .. } => {
if let Some((field_name, conversion)) = extract_data_field_info(expr) {
if !fields.iter().any(|(f, _): &(Ident, _)| f == &field_name) {
fields.push((field_name, conversion));
}
}
}
ClassifiedSeed::FunctionCall { args, .. } => {
for arg in args {
if matches!(arg.kind, FnArgKind::DataField) {
let field_name = arg.field_name.clone();
if !fields.iter().any(|(f, _): &(Ident, _)| *f == field_name) {
fields.push((field_name, None));
}
}
}
}
_ => {}
}
}
fields
}
pub fn extract_data_field_info(expr: &Expr) -> Option<(Ident, Option<Ident>)> {
match expr {
Expr::Path(path) => {
if let Some(ident) = path.path.get_ident() {
return Some((ident.clone(), None));
}
None
}
Expr::Field(field) => {
if let syn::Member::Named(field_name) = &field.member {
return Some((field_name.clone(), None));
}
None
}
Expr::MethodCall(mc) => {
let method_name = mc.method.to_string();
if method_name == "to_le_bytes" || method_name == "to_be_bytes" {
if let Some((field_name, _)) = extract_data_field_info(&mc.receiver) {
return Some((field_name, Some(mc.method.clone())));
}
}
if method_name == "as_ref" || method_name == "as_bytes" || method_name == "as_slice" {
return extract_data_field_info(&mc.receiver);
}
None
}
Expr::Index(idx) => extract_data_field_info(&idx.expr),
Expr::Reference(r) => extract_data_field_info(&r.expr),
_ => None,
}
}
pub fn get_params_only_seed_fields_from_spec(
spec: &crate::light_pdas::program::instructions::TokenSeedSpec,
state_field_names: &std::collections::HashSet<String>,
) -> Vec<(Ident, syn::Type, bool)> {
use crate::light_pdas::program::instructions::SeedElement;
let mut fields = Vec::new();
for seed in &spec.seeds {
if let SeedElement::Expression(expr) = seed {
if let Some((field_name, has_conversion)) = extract_data_field_from_expr(expr) {
add_params_only_field(&field_name, has_conversion, state_field_names, &mut fields);
}
extract_data_fields_from_nested_calls(expr, state_field_names, &mut fields);
}
}
fields
}
fn add_params_only_field(
field_name: &Ident,
has_conversion: bool,
state_field_names: &std::collections::HashSet<String>,
fields: &mut Vec<(Ident, syn::Type, bool)>,
) {
let field_str = field_name.to_string();
if !state_field_names.contains(&field_str)
&& !fields
.iter()
.any(|(f, _, _): &(Ident, _, _)| f == field_name)
{
let field_type: syn::Type = if has_conversion {
syn::parse_quote!(u64)
} else {
syn::parse_quote!(Pubkey)
};
fields.push((field_name.clone(), field_type, has_conversion));
}
}
fn extract_data_fields_from_nested_calls(
expr: &syn::Expr,
state_field_names: &std::collections::HashSet<String>,
fields: &mut Vec<(Ident, syn::Type, bool)>,
) {
match expr {
syn::Expr::Call(call) => {
for arg in &call.args {
if let Some((field_name, has_conversion)) = extract_data_field_from_expr(arg) {
add_params_only_field(&field_name, has_conversion, state_field_names, fields);
}
extract_data_fields_from_nested_calls(arg, state_field_names, fields);
}
}
syn::Expr::MethodCall(mc) => {
extract_data_fields_from_nested_calls(&mc.receiver, state_field_names, fields);
for arg in &mc.args {
extract_data_fields_from_nested_calls(arg, state_field_names, fields);
}
}
syn::Expr::Reference(r) => {
extract_data_fields_from_nested_calls(&r.expr, state_field_names, fields);
}
_ => {}
}
}
pub fn extract_data_field_name_from_expr(expr: &syn::Expr) -> Option<Ident> {
if let Some((field, _)) = extract_data_field_info(expr) {
return Some(field);
}
extract_data_field_from_expr(expr).map(|(name, _)| name)
}
fn extract_data_field_from_expr(expr: &syn::Expr) -> Option<(Ident, bool)> {
match expr {
syn::Expr::Field(field_expr) => {
if let syn::Member::Named(field_name) = &field_expr.member {
if is_base_path(&field_expr.base, "data") {
return Some((field_name.clone(), false));
}
}
None
}
syn::Expr::MethodCall(method_call) => {
let has_bytes_conversion =
method_call.method == "to_le_bytes" || method_call.method == "to_be_bytes";
if has_bytes_conversion {
return extract_data_field_from_expr(&method_call.receiver)
.map(|(name, _)| (name, true));
}
if method_call.method == "as_ref" || method_call.method == "as_bytes" {
return extract_data_field_from_expr(&method_call.receiver);
}
None
}
syn::Expr::Reference(ref_expr) => extract_data_field_from_expr(&ref_expr.expr),
_ => None,
}
}
#[cfg(test)]
mod tests {
use syn::parse_quote;
use super::*;
use crate::light_pdas::seeds::types::ClassifiedFnArg;
fn make_ident(s: &str) -> Ident {
Ident::new(s, proc_macro2::Span::call_site())
}
#[test]
fn test_get_data_fields_simple() {
let seeds = vec![
ClassifiedSeed::Literal(b"seed".to_vec()),
ClassifiedSeed::DataRooted {
root: make_ident("params"),
expr: Box::new(parse_quote!(params.owner.as_ref())),
},
];
let fields = get_data_fields(&seeds);
assert_eq!(fields.len(), 1);
assert_eq!(fields[0].0.to_string(), "owner");
assert!(fields[0].1.is_none()); }
#[test]
fn test_get_data_fields_with_conversion() {
let seeds = vec![ClassifiedSeed::DataRooted {
root: make_ident("params"),
expr: Box::new(parse_quote!(params.amount.to_le_bytes().as_ref())),
}];
let fields = get_data_fields(&seeds);
assert_eq!(fields.len(), 1);
assert_eq!(fields[0].0.to_string(), "amount");
assert!(fields[0].1.is_some()); assert_eq!(fields[0].1.as_ref().unwrap().to_string(), "to_le_bytes");
}
#[test]
fn test_get_data_fields_from_function_call() {
let seeds = vec![ClassifiedSeed::FunctionCall {
func_expr: Box::new(parse_quote!(crate::max_key(¶ms.key_a, ¶ms.key_b))),
args: vec![
ClassifiedFnArg {
field_name: make_ident("key_a"),
kind: FnArgKind::DataField,
},
ClassifiedFnArg {
field_name: make_ident("key_b"),
kind: FnArgKind::DataField,
},
],
has_as_ref: true,
}];
let fields = get_data_fields(&seeds);
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].0.to_string(), "key_a");
assert_eq!(fields[1].0.to_string(), "key_b");
}
#[test]
fn test_get_data_fields_deduplicates() {
let seeds = vec![
ClassifiedSeed::DataRooted {
root: make_ident("params"),
expr: Box::new(parse_quote!(params.owner.as_ref())),
},
ClassifiedSeed::DataRooted {
root: make_ident("params"),
expr: Box::new(parse_quote!(params.owner.as_ref())),
},
];
let fields = get_data_fields(&seeds);
assert_eq!(fields.len(), 1);
}
#[test]
fn test_extract_data_field_info_bare_ident() {
let expr: syn::Expr = parse_quote!(owner);
let result = extract_data_field_info(&expr);
assert!(result.is_some());
let (field, conversion) = result.unwrap();
assert_eq!(field.to_string(), "owner");
assert!(conversion.is_none());
}
#[test]
fn test_extract_data_field_info_field_access() {
let expr: syn::Expr = parse_quote!(params.owner);
let result = extract_data_field_info(&expr);
assert!(result.is_some());
let (field, conversion) = result.unwrap();
assert_eq!(field.to_string(), "owner");
assert!(conversion.is_none());
}
#[test]
fn test_extract_data_field_info_with_as_ref() {
let expr: syn::Expr = parse_quote!(params.owner.as_ref());
let result = extract_data_field_info(&expr);
assert!(result.is_some());
let (field, conversion) = result.unwrap();
assert_eq!(field.to_string(), "owner");
assert!(conversion.is_none());
}
#[test]
fn test_extract_data_field_info_with_to_le_bytes() {
let expr: syn::Expr = parse_quote!(params.amount.to_le_bytes());
let result = extract_data_field_info(&expr);
assert!(result.is_some());
let (field, conversion) = result.unwrap();
assert_eq!(field.to_string(), "amount");
assert!(conversion.is_some());
assert_eq!(conversion.unwrap().to_string(), "to_le_bytes");
}
#[test]
fn test_extract_data_field_name_from_expr() {
let expr: syn::Expr = parse_quote!(params.owner.as_ref());
let result = extract_data_field_name_from_expr(&expr);
assert!(result.is_some());
assert_eq!(result.unwrap().to_string(), "owner");
}
#[test]
fn test_extract_data_field_info_nested() {
let expr: syn::Expr = parse_quote!(params.inner.authority.as_ref());
let result = extract_data_field_info(&expr);
assert!(result.is_some());
let (field, conversion) = result.unwrap();
assert_eq!(
field.to_string(),
"authority",
"Should extract terminal field 'authority', not 'inner'"
);
assert!(conversion.is_none());
}
#[test]
fn test_get_data_fields_nested() {
let seeds = vec![ClassifiedSeed::DataRooted {
root: make_ident("params"),
expr: Box::new(parse_quote!(params.inner.authority.as_ref())),
}];
let fields = get_data_fields(&seeds);
assert_eq!(fields.len(), 1);
assert_eq!(
fields[0].0.to_string(),
"authority",
"Seed struct should use terminal field 'authority'"
);
}
#[test]
fn test_get_data_fields_deeply_nested() {
let seeds = vec![ClassifiedSeed::DataRooted {
root: make_ident("data"),
expr: Box::new(parse_quote!(data.level1.level2.level3.key.as_ref())),
}];
let fields = get_data_fields(&seeds);
assert_eq!(fields.len(), 1);
assert_eq!(
fields[0].0.to_string(),
"key",
"Should extract deepest terminal field 'key'"
);
}
}