bma_jrpc_derive/
lib.rs

1use darling::FromAttributes;
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::{quote, ToTokens};
5use syn::{FnArg, Ident};
6
7fn parse_method_arg(arg: &FnArg) -> (Ident, Ident, bool) {
8    if let syn::FnArg::Typed(a) = arg {
9        match a.ty.as_ref() {
10            syn::Type::Path(p) => {
11                if let syn::Pat::Ident(i) = a.pat.as_ref() {
12                    let attr_name = i.ident.clone();
13                    let attr_type = p.path.segments[0].ident.clone();
14                    return (attr_name, attr_type, false);
15                }
16            }
17            syn::Type::Reference(p) => {
18                if let syn::Pat::Ident(i) = a.pat.as_ref() {
19                    if let syn::Type::Path(p) = p.elem.as_ref() {
20                        let attr_name = i.ident.clone();
21                        let attr_type = p.path.segments[0].ident.clone();
22                        return (attr_name, attr_type, true);
23                    }
24                }
25            }
26            _ => {}
27        };
28    }
29    panic!("unsupported function argument");
30}
31
32#[derive(Debug, FromAttributes)]
33#[darling(attributes(rpc))]
34struct MethodAttrs {
35    #[darling()]
36    name: Option<String>,
37    #[darling()]
38    result_field: Option<String>,
39}
40
41#[allow(clippy::too_many_lines)]
42#[proc_macro_attribute]
43/// # Panics
44///
45/// Will panic on invalid or unsupported
46pub fn rpc_client(_args: TokenStream, input: TokenStream) -> TokenStream {
47    let item: syn::Item = syn::parse(input).expect("invalid input");
48    if let syn::Item::Trait(trait_item) = item {
49        let struct_name = Ident::new(&format!("{}Client", trait_item.ident), Span::call_site());
50        let name = trait_item.ident;
51        let mut methods = Vec::new();
52        for item in trait_item.items {
53            if let syn::TraitItem::Method(method) = item {
54                assert!(
55                    (method.sig.ident != "get_rpc_client"),
56                    "get_rpc_client is a reserved name"
57                );
58                let attrs = match MethodAttrs::from_attributes(&method.attrs) {
59                    Ok(v) => v,
60                    Err(e) => return TokenStream::from(e.write_errors()),
61                };
62                let method_name = method.sig.ident.clone();
63                let rpc_method_name = if let Some(name) = attrs.name {
64                    Ident::new(&name, Span::call_site())
65                } else {
66                    method.sig.ident
67                };
68                let ty: Option<syn::Type> = match method.sig.output {
69                    syn::ReturnType::Type(_, ty) => Some(*ty),
70                    syn::ReturnType::Default => None,
71                };
72                let mut refs_found = false;
73                let mut input_struct_names = Vec::new();
74                let mut input_struct_args = Vec::new();
75                let inputs = method.sig.inputs;
76                let ret = if let Some(t) = ty {
77                    if let syn::Type::Path(tpath) = t {
78                        let r = tpath.path.segments[0].ident.clone();
79                        quote! { #r }
80                    } else {
81                        panic!("unsupported return type");
82                    }
83                } else {
84                    quote! { () }
85                };
86                for arg in inputs.iter().skip(1) {
87                    let (name, tp, is_ref) = parse_method_arg(arg);
88                    input_struct_names.push(quote! {
89                        #name,
90                    });
91                    if is_ref {
92                        refs_found = true;
93                        input_struct_args.push(quote! {
94                            #name: &'a #tp,
95                        });
96                    } else {
97                        input_struct_args.push(quote! {
98                            #name: #tp,
99                        });
100                    }
101                }
102                let (input_struct, payload) = if input_struct_args.is_empty() {
103                    let p = quote! {
104                        ()
105                    };
106                    (None, p)
107                } else {
108                    let lifetime = if refs_found {
109                        Some(quote! { <'a> })
110                    } else {
111                        None
112                    };
113                    let s = Some(quote! {
114                        #[derive(serde::Serialize)]
115                        struct InputPayload #lifetime {
116                            #(#input_struct_args)*
117                        }
118                    });
119                    let p = quote! {
120                        InputPayload {
121                            #(#input_struct_names)*
122                        }
123                    };
124                    (s, p)
125                };
126                let (response_tp, out, output_struct) =
127                    if let Some(result_field) = attrs.result_field {
128                        let field = Ident::new(&result_field, Span::call_site());
129                        let output_type = Ident::new("OutputPayload", Span::call_site());
130                        (
131                            Some(quote! {
132                                #output_type
133                            }),
134                            quote! {
135                                Ok(response.#field)
136                            },
137                            Some(quote! {
138                                #[derive(serde::Deserialize)]
139                                struct OutputPayload {
140                                    #field: #ret
141                                }
142                            }),
143                        )
144                    } else {
145                        (
146                            Some(ret.clone()),
147                            quote! {
148                                Ok(response)
149                            },
150                            None,
151                        )
152                    };
153                let f = quote! {
154                    fn #method_name(#inputs) -> Result<#ret, ::bma_jrpc::Error> {
155                        #input_struct
156                        #output_struct
157                        let response: #response_tp = self.get_rpc_client().call(
158                            stringify!(#rpc_method_name), #payload)?;
159                        #out
160                    }
161                };
162                methods.push(f);
163            }
164        }
165        let f = quote! {
166            trait #name<X: ::bma_jrpc::Rpc> {
167                #(#methods)*
168                fn get_rpc_client(&self) -> &X;
169            }
170            struct #struct_name<X: ::bma_jrpc::Rpc> {
171                client: X
172            }
173            impl<X: ::bma_jrpc::Rpc> #struct_name<X> {
174                fn new(client: X) -> Self {
175                    Self { client }
176                }
177            }
178            impl<X: ::bma_jrpc::Rpc> #name<X> for #struct_name<X> {
179                fn get_rpc_client(&self) -> &X {
180                    &self.client
181                }
182            }
183        };
184        f.into_token_stream().into()
185    } else {
186        panic!("the attribute must be placed on a trait");
187    }
188}