arbiter_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    parse_macro_input, DeriveInput, Fields, FieldsNamed, GenericArgument, Ident, PathArguments,
7    Type, TypePath,
8};
9
10#[proc_macro_derive(Deploy)]
11pub fn derive(input: TokenStream) -> TokenStream {
12    let input = parse_macro_input!(input as DeriveInput);
13
14    let struct_name = input.ident;
15    let new_struct_name = Ident::new(&format!("{}Deployed", struct_name), struct_name.span());
16
17    let fields = match input.data {
18        syn::Data::Struct(syn::DataStruct {
19            fields: Fields::Named(FieldsNamed { named, .. }),
20            ..
21        }) => named,
22        _ => panic!("Only named fields are supported"),
23    };
24
25    let field_names: Vec<_> = fields.iter().map(|f| &f.ident).collect();
26    let new_field_types: Vec<_> = fields
27        .iter()
28        .map(|f| match &f.ty {
29            Type::Path(TypePath { path, .. }) => {
30                if let Some(segment) = path.segments.last() {
31                    if let PathArguments::AngleBracketed(angle_bracketed_args) = &segment.arguments
32                    {
33                        if let Some(GenericArgument::Type(Type::Path(TypePath { path, .. }))) =
34                            angle_bracketed_args.args.last()
35                        {
36                            return quote! { #path };
37                        }
38                    }
39                }
40                quote! { #f }
41            }
42            _ => quote! { #f },
43        })
44        .collect();
45    let middleware_type = if let Type::Path(tp) = &fields.iter().next().unwrap().ty {
46        if let Some(segment) = tp.path.segments.first() {
47            if let PathArguments::AngleBracketed(angle_bracketed_args) = &segment.arguments {
48                if let Some(GenericArgument::Type(Type::Path(type_path))) =
49                    angle_bracketed_args.args.iter().nth(1)
50                {
51                    Some(&type_path.path)
52                } else {
53                    None
54                }
55            } else {
56                None
57            }
58        } else {
59            None
60        }
61    } else {
62        None
63    };
64
65    // Generate the code for new struct and impl
66    let expanded = quote! {
67        #[derive(Clone, Debug)]
68        pub struct #new_struct_name {
69            #( pub #field_names: #new_field_types ),*
70        }
71
72        impl #struct_name {
73            pub async fn deploy(self) -> Result<#new_struct_name, ethers::contract::ContractError<#middleware_type>> {
74                Ok(#new_struct_name {
75                    #(
76                        #field_names: self.#field_names.send().await.unwrap(),
77                    )*
78                })
79            }
80        }
81    };
82
83    TokenStream::from(expanded)
84}