1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::{Parse, ParseStream};
4use syn::{DeriveInput, Ident, Token, parse_macro_input};
5
6struct SignerFields {
7 fields: Vec<Ident>,
8}
9
10#[proc_macro_derive(GetSigners, attributes(signer_fields))]
11pub fn derive_get_signers(input: TokenStream) -> TokenStream {
12 let input = parse_macro_input!(input as DeriveInput);
13 let struct_name = &input.ident;
14
15 let signer_fields = input
16 .attrs
17 .iter()
18 .find_map(|attr| {
19 attr.path()
20 .is_ident("signer_fields")
21 .then(|| attr.parse_args::<SignerFields>().ok().map(|sf| sf.fields))
22 .flatten()
23 })
24 .unwrap_or_default();
25
26 let mut borrowed_field_exprs = vec![];
27 let mut owned_field_exprs = vec![];
28 if let syn::Data::Struct(data_struct) = &input.data {
29 for field_name in signer_fields {
30 let field = data_struct
31 .fields
32 .iter()
33 .find(|f| f.ident.as_ref().map(|id| id == &field_name).unwrap_or(false));
34
35 if let Some(field) = field {
36 let field_ident = field.ident.as_ref().unwrap();
37 let ty = &field.ty;
38 let ty_str = quote! { #ty }.to_string().replace(' ', "");
39
40 let (borrowed_field_expr, owned_field_expr) = match ty_str.as_str() {
41 "::prost::alloc::string::String" => (
42 quote! { .chain(std::iter::once(self.#field_ident.as_str())) },
43 quote! { .chain(std::iter::once(self.#field_ident)) },
44 ),
45 "::prost::alloc::vec::Vec<::prost::alloc::string::String>" => (
46 quote! { .chain(self.#field_ident.iter().map(|s| s.as_str())) },
47 quote! { .chain(self.#field_ident) },
48 ),
49 s if s.starts_with("::prost::alloc::vec::Vec<") => (
50 quote! { .chain(self.#field_ident.iter().flat_map(GetSigners::signers)) },
51 quote! { .chain(self.#field_ident.into_iter().flat_map(GetSigners::signers)) },
52 ),
53 _ => (
54 quote! { .chain(GetSigners::singers(&self.#field_ident)) },
55 quote! { .chain(GetSigners::singers(self.#field_ident)) },
56 ),
57 };
58
59 borrowed_field_exprs.push(borrowed_field_expr);
60 owned_field_exprs.push(owned_field_expr);
61 }
62 }
63 }
64
65 let expanded = quote! {
66 impl<'a> GetSigners for &'a #struct_name {
67 type Signer = &'a str;
68
69 fn signers(self) -> impl Iterator<Item = Self::Signer> {
70 std::iter::empty()
71 #(#borrowed_field_exprs)*
72 }
73 }
74
75
76 impl GetSigners for #struct_name {
77 type Signer = String;
78
79 fn signers(self) -> impl Iterator<Item = Self::Signer> {
80 std::iter::empty()
81 #(#owned_field_exprs)*
82 }
83 }
84 };
85
86 TokenStream::from(expanded)
87}
88
89impl Parse for SignerFields {
90 fn parse(input: ParseStream) -> syn::Result<Self> {
91 let mut fields = vec![];
92 while !input.is_empty() {
93 fields.push(input.parse()?);
94 if input.peek(Token![,]) {
95 input.parse::<Token![,]>()?;
96 }
97 }
98
99 Ok(SignerFields { fields })
100 }
101}