Skip to main content

imbibe_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::{Parse, ParseStream};
4use syn::{DeriveInput, Ident, Token, parse_macro_input};
5
6struct SignerFields {
7	fields: Vec<Ident>,
8}
9
10#[proc_macro_derive(GetSigners, attributes(signer_fields))]
11pub fn derive_get_signers(input: TokenStream) -> TokenStream {
12	let input = parse_macro_input!(input as DeriveInput);
13	let struct_name = &input.ident;
14
15	let signer_fields = input
16		.attrs
17		.iter()
18		.find_map(|attr| {
19			attr.path()
20				.is_ident("signer_fields")
21				.then(|| attr.parse_args::<SignerFields>().ok().map(|sf| sf.fields))
22				.flatten()
23		})
24		.unwrap_or_default();
25
26	let mut borrowed_field_exprs = vec![];
27	let mut owned_field_exprs = vec![];
28	if let syn::Data::Struct(data_struct) = &input.data {
29		for field_name in signer_fields {
30			let field = data_struct
31				.fields
32				.iter()
33				.find(|f| f.ident.as_ref().map(|id| id == &field_name).unwrap_or(false));
34
35			if let Some(field) = field {
36				let field_ident = field.ident.as_ref().unwrap();
37				let ty = &field.ty;
38				let ty_str = quote! { #ty }.to_string().replace(' ', "");
39
40				let (borrowed_field_expr, owned_field_expr) = match ty_str.as_str() {
41					"::prost::alloc::string::String" => (
42						quote! { .chain(std::iter::once(self.#field_ident.as_str())) },
43						quote! { .chain(std::iter::once(self.#field_ident)) },
44					),
45					"::prost::alloc::vec::Vec<::prost::alloc::string::String>" => (
46						quote! { .chain(self.#field_ident.iter().map(|s| s.as_str())) },
47						quote! { .chain(self.#field_ident) },
48					),
49					s if s.starts_with("::prost::alloc::vec::Vec<") => (
50						quote! { .chain(self.#field_ident.iter().flat_map(GetSigners::signers)) },
51						quote! { .chain(self.#field_ident.into_iter().flat_map(GetSigners::signers)) },
52					),
53					_ => (
54						quote! { .chain(GetSigners::singers(&self.#field_ident)) },
55						quote! { .chain(GetSigners::singers(self.#field_ident)) },
56					),
57				};
58
59				borrowed_field_exprs.push(borrowed_field_expr);
60				owned_field_exprs.push(owned_field_expr);
61			}
62		}
63	}
64
65	let expanded = quote! {
66		impl<'a> GetSigners for &'a #struct_name {
67			type Signer = &'a str;
68
69			fn signers(self) -> impl Iterator<Item = Self::Signer> {
70				std::iter::empty()
71					#(#borrowed_field_exprs)*
72			}
73		}
74
75
76		impl GetSigners for #struct_name {
77			type Signer = String;
78
79			fn signers(self) -> impl Iterator<Item = Self::Signer> {
80				std::iter::empty()
81					#(#owned_field_exprs)*
82			}
83		}
84	};
85
86	TokenStream::from(expanded)
87}
88
89impl Parse for SignerFields {
90	fn parse(input: ParseStream) -> syn::Result<Self> {
91		let mut fields = vec![];
92		while !input.is_empty() {
93			fields.push(input.parse()?);
94			if input.peek(Token![,]) {
95				input.parse::<Token![,]>()?;
96			}
97		}
98
99		Ok(SignerFields { fields })
100	}
101}