jsonrpc_utils_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote, ToTokens};
4use syn::{
5    braced, meta, parenthesized,
6    parse::{Parse, ParseStream, Parser},
7    parse2, parse_macro_input, parse_quote, Attribute, FnArg, GenericArgument, Ident, ImplItemFn,
8    ItemTrait, LitStr, Pat, PathArguments, Result, ReturnType, Token, TraitItem, TraitItemFn, Type,
9    Visibility,
10};
11
12#[proc_macro_attribute]
13pub fn rpc(args: TokenStream, input: TokenStream) -> TokenStream {
14    let args = parse_macro_input!(args as RpcArgs);
15
16    let mut item_trait = parse_macro_input!(input as ItemTrait);
17    let vis = &item_trait.vis;
18    let trait_name = &item_trait.ident;
19    let trait_name_snake = to_snake_case(trait_name.to_string());
20    let add_method_name = format_ident!("add_{}_methods", trait_name_snake);
21
22    let doc_func = if args.openrpc {
23        let doc_func_name = format_ident!("{}_doc", trait_name_snake);
24        let method_defs = item_trait
25            .items
26            .iter_mut()
27            .filter_map(|m| match m {
28                TraitItem::Fn(m) => Some(m),
29                _ => None,
30            })
31            .map(|m| method_def(m))
32            .collect::<Result<Vec<_>>>();
33        let method_defs = match method_defs {
34            Ok(x) => x,
35            Err(e) => return e.to_compile_error().into(),
36        };
37        let description = gen_desc_from_attrs(&item_trait.attrs).unwrap();
38        let title = &*trait_name_snake;
39        quote!(
40            /// Generate OpenRPC document for the RPC methods.
41            #vis fn #doc_func_name() -> jsonrpc_utils::serde_json::Value {
42                #[allow(unused)]
43                use schemars::JsonSchema;
44
45                let mut gen = schemars::gen::SchemaSettings::draft07().with(|s|
46                    s.definitions_path = "#/components/schemas/".into()
47                ).into_generator();
48
49                let methods = jsonrpc_utils::serde_json::json!([#(#method_defs)*]);
50
51                jsonrpc_utils::serde_json::json!({
52                    "openrpc": "1.2.6",
53                    "info": {
54                        "title": #title,
55                        "version": "1.0.0",
56                        #description
57                    },
58                    "methods": methods,
59                    "components": {
60                        "schemas": gen.take_definitions(),
61                    }
62                })
63            }
64        )
65    } else {
66        quote!()
67    };
68
69    let add_methods = item_trait
70        .items
71        .iter_mut()
72        .filter_map(|m| match m {
73            TraitItem::Fn(m) => Some(m),
74            _ => None,
75        })
76        .map(add_method)
77        .collect::<Result<Vec<_>>>();
78    let add_methods = match add_methods {
79        Ok(x) => x,
80        Err(e) => return e.to_compile_error().into(),
81    };
82
83    let result = quote! {
84        #item_trait
85
86        /// Add RPC methods to the given `jsonrpc_utils::jsonrpc_core::MetaIoHandler`.
87        #vis fn #add_method_name(rpc: &mut jsonrpc_utils::jsonrpc_core::MetaIoHandler<Option<jsonrpc_utils::pub_sub::Session>>, rpc_impl: impl #trait_name + Clone + Send + Sync + 'static) {
88            #(#add_methods)*
89        }
90
91        #doc_func
92    };
93
94    result.into()
95}
96
97#[proc_macro_attribute]
98pub fn rpc_client(_attr: TokenStream, input: TokenStream) -> TokenStream {
99    rpc_client_impl(input.into())
100        .unwrap_or_else(|e| e.into_compile_error())
101        .into()
102}
103
104struct ImplMethods {
105    attributes: Vec<Attribute>,
106    ident: Ident,
107    items: Vec<ImplItemFn>,
108}
109
110impl Parse for ImplMethods {
111    fn parse(input: ParseStream) -> Result<Self> {
112        let attributes = input.call(Attribute::parse_outer)?;
113        input.parse::<Token![impl]>()?;
114        let ident: Ident = input.parse()?;
115        let mut items: Vec<ImplItemFn> = Vec::new();
116        let content;
117        braced!(content in input);
118        while !content.is_empty() {
119            let vis: Visibility = content.parse()?;
120            items.push({
121                let m: TraitItemFn = content.parse()?;
122                ImplItemFn {
123                    attrs: m.attrs,
124                    vis,
125                    defaultness: None,
126                    sig: m.sig,
127                    block: parse_quote!({}),
128                }
129            });
130        }
131        Ok(Self {
132            attributes,
133            ident,
134            items,
135        })
136    }
137}
138
139fn rpc_client_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream> {
140    let mut impl_block: ImplMethods = parse2(input)?;
141
142    for item in &mut impl_block.items {
143        rewrite_method(item)?;
144    }
145
146    let attributes = impl_block.attributes;
147    let ident = impl_block.ident;
148    let items = impl_block.items;
149
150    Ok(quote! {
151        #(#attributes)*
152        impl #ident {
153            #(#items)*
154        }
155    })
156}
157
158fn rewrite_method(m: &mut ImplItemFn) -> Result<()> {
159    let args = ClientMethodArgs::parse_attrs(&m.attrs)?;
160
161    m.attrs.retain(|a| !a.path().is_ident("rpc"));
162
163    let method_name = args.name.unwrap_or_else(|| m.sig.ident.to_string());
164
165    let mut ident_counter = 0u32;
166    let mut next_ident_counter = || {
167        ident_counter += 1;
168        ident_counter
169    };
170
171    let mut params_names = Vec::new();
172    for arg in &mut m.sig.inputs {
173        match arg {
174            FnArg::Receiver(_) => {}
175            FnArg::Typed(pat_type) => match &*pat_type.pat {
176                Pat::Ident(ident) => {
177                    params_names.push(ident.ident.clone());
178                }
179                _ => {
180                    let ident = format_ident!("param_{}", next_ident_counter());
181                    params_names.push(ident.clone());
182                    pat_type.pat = parse_quote!(#ident);
183                }
184            },
185        }
186    }
187    let params_names = quote!(#(#params_names ,)*);
188
189    if m.sig.asyncness.is_some() {
190        m.block = parse_quote!({
191            let result = self
192                .inner
193                .rpc(#method_name, &jsonrpc_utils::serde_json::value::to_raw_value(&(#params_names))?)
194                .await?;
195            Ok(jsonrpc_utils::serde_json::from_value(result)?)
196        });
197    } else {
198        m.block = parse_quote!({
199            let result = self
200                .inner
201                .rpc(#method_name, &jsonrpc_utils::serde_json::value::to_raw_value(&(#params_names))?)?;
202            Ok(jsonrpc_utils::serde_json::from_value(result)?)
203        });
204    }
205    Ok(())
206}
207
208fn method_def(m: &TraitItemFn) -> Result<proc_macro2::TokenStream> {
209    let description = gen_desc_from_attrs(&m.attrs)?;
210    let attrs = MethodArgs::parse_attrs(&m.attrs)?;
211    if attrs.pub_sub.is_some() {
212        return Ok(quote!());
213    }
214    let name = attrs.name.unwrap_or_else(|| m.sig.ident.to_string());
215    let params: Vec<_> = m
216        .sig
217        .inputs
218        .iter()
219        .filter_map(|input| match input {
220            FnArg::Receiver(_) => None,
221            FnArg::Typed(pat_type) => {
222                let name = match &*pat_type.pat {
223                    Pat::Ident(i) => LitStr::new(&i.ident.to_string(), i.ident.span()),
224                    _ => LitStr::new("parameter", Span::call_site()),
225                };
226                let ty = &pat_type.ty;
227                Some(quote! {
228                    {
229                        "name": #name,
230                        "schema": gen.subschema_for::<#ty>(),
231                    }
232                })
233            }
234        })
235        .collect();
236    let result_type = match &m.sig.output {
237        ReturnType::Default => todo!(),
238        ReturnType::Type(_, t) => match &**t {
239            Type::Path(tp) => {
240                if tp.qself.is_none() && tp.path.segments.len() == 1 {
241                    let seg = tp.path.segments.first().unwrap();
242                    if seg.ident == "Result" {
243                        match &seg.arguments {
244                            PathArguments::AngleBracketed(ang) => match ang.args.first() {
245                                Some(GenericArgument::Type(t)) => t,
246                                _ => t,
247                            },
248                            _ => t,
249                        }
250                    } else {
251                        t
252                    }
253                } else {
254                    t
255                }
256            }
257            _ => t,
258        },
259    };
260    // TODO: more meaningful result name.
261    let result = quote! {
262        {
263            "name": #name,
264            "schema": gen.subschema_for::<#result_type>(),
265        }
266    };
267    Ok(quote! {
268        jsonrpc_utils::serde_json::json!({
269            "name": #name,
270            #description
271            "params": [#(#params),*],
272            "result": #result,
273        }),
274    })
275}
276
277fn add_method(m: &mut TraitItemFn) -> Result<proc_macro2::TokenStream> {
278    let method_name = &m.sig.ident;
279
280    let (rpc_attributes, other_attributes) = m
281        .attrs
282        .drain(..)
283        .partition::<Vec<_>, _>(|a| a.path().is_ident("rpc"));
284    let args = MethodArgs::parse_attrs(&rpc_attributes)?;
285    m.attrs = other_attributes;
286
287    let rpc_method_name = args.name.unwrap_or_else(|| method_name.to_string());
288
289    let mut params_names = Vec::new();
290    let mut params_tys = Vec::new();
291    let mut ident_counter = 0u32;
292    let mut next_ident_counter = || {
293        ident_counter += 1;
294        ident_counter
295    };
296    for arg in &m.sig.inputs {
297        match arg {
298            FnArg::Receiver(_) => {}
299            FnArg::Typed(pat_type) => match &*pat_type.pat {
300                Pat::Ident(ident) => {
301                    params_names.push(ident.ident.clone());
302                    params_tys.push(&*pat_type.ty);
303                }
304                _ => {
305                    params_names.push(format_ident!("param_{}", next_ident_counter()));
306                    params_tys.push(&*pat_type.ty);
307                }
308            },
309        }
310    }
311    let no_params = params_names.is_empty();
312    // Number of tailing optional parameters.
313    let optional_params = params_tys
314        .iter()
315        .rev()
316        .take_while(|t| match t {
317            Type::Path(t) => t.path.segments.first().is_some_and(|s| s.ident == "Option"),
318            _ => false,
319        })
320        .count();
321    let params_names1 = quote!(#(#params_names ,)*);
322    let params_tys1 = quote!(#(#params_tys ,)*);
323    let parse_params = if optional_params > 0 {
324        let required_params = params_names.len() - optional_params;
325        let mut parse_params = quote! {
326            let mut arr = match params {
327                jsonrpc_utils::jsonrpc_core::Params::Array(arr) => arr.into_iter(),
328                jsonrpc_utils::jsonrpc_core::Params::None => Vec::new().into_iter(),
329                _ => return Err(jsonrpc_utils::jsonrpc_core::Error::invalid_params("Invalid params: invalid type map, expect an array or null")),
330            };
331        };
332        for i in 0..required_params {
333            let p = &params_names[i];
334            let p_str = LitStr::new(&p.to_string(), p.span());
335            let ty = params_tys[i];
336            parse_params.extend(quote! {
337                let #p: #ty = jsonrpc_utils::serde_json::from_value(arr.next().ok_or_else(|| jsonrpc_utils::jsonrpc_core::Error::invalid_params(format!("Missing required parameter `{}`", #p_str)))?).map_err(|e|
338                    jsonrpc_utils::jsonrpc_core::Error::invalid_params(format!("Invalid parameter for `{}`: {}", #p_str, e))
339                )?;
340            });
341        }
342        for i in required_params..params_names.len() {
343            let p = &params_names[i];
344            let p_str = LitStr::new(&p.to_string(), p.span());
345            let ty = params_tys[i];
346            parse_params.extend(quote! {
347                let #p: #ty = match arr.next() {
348                    Some(v) => jsonrpc_utils::serde_json::from_value(v).map_err(|e| jsonrpc_utils::jsonrpc_core::Error::invalid_params(format!("Invalid parameter for `{}`: {}", #p_str, e)))?,
349                    None => None,
350                };
351            });
352        }
353        parse_params
354    } else if no_params {
355        quote!(params.expect_no_params()?;)
356    } else {
357        quote!(let (#params_names1): (#params_tys1) = params.parse()?;)
358    };
359    let result = if m.sig.asyncness.is_some() {
360        quote!(rpc_impl.#method_name(#params_names1).await)
361    } else {
362        quote!(rpc_impl.#method_name(#params_names1))
363    };
364
365    Ok(if let Some(pub_sub) = args.pub_sub {
366        let notify_method_lit = &*pub_sub.notify;
367        let unsubscribe_method_lit = &*pub_sub.unsubscribe;
368        quote! {
369            jsonrpc_utils::pub_sub::add_pub_sub(rpc, #rpc_method_name, #notify_method_lit.into(), #unsubscribe_method_lit, {
370                let rpc_impl = rpc_impl.clone();
371                move |params: jsonrpc_utils::jsonrpc_core::Params| {
372                    #parse_params
373                    rpc_impl.#method_name(#params_names1).map_err(Into::into)
374                }
375            });
376        }
377    } else {
378        quote! {
379            rpc.add_method(#rpc_method_name, {
380                let rpc_impl = rpc_impl.clone();
381                move |params: jsonrpc_utils::jsonrpc_core::Params| {
382                    let rpc_impl = rpc_impl.clone();
383                    async move {
384                        #parse_params
385                        jsonrpc_utils::serde_json::to_value(#result?).map_err(|_| jsonrpc_utils::jsonrpc_core::Error::internal_error())
386                    }
387                }
388            });
389        }
390    })
391}
392
393fn gen_desc_from_attrs(attrs: &[Attribute]) -> Result<proc_macro2::TokenStream> {
394    let doc_attrs = attrs
395        .iter()
396        .filter(|a| a.path().is_ident("doc"))
397        .collect::<Vec<_>>();
398    let doc = if !doc_attrs.is_empty() {
399        let mut values = vec![];
400        for a in doc_attrs {
401            let v = &a.meta.require_name_value()?.value;
402            values.push(parse2::<LitStr>(v.to_token_stream())?);
403        }
404        Some(values)
405    } else {
406        None
407    };
408
409    let description = if let Some(lines) = doc {
410        let doc = lines
411            .iter()
412            .map(|l| l.value())
413            .collect::<Vec<_>>()
414            .join("\n");
415        let doc = LitStr::new(&doc, Span::call_site());
416        quote!( "description": #doc, )
417    } else {
418        quote!()
419    };
420    Ok(description)
421}
422
423struct RpcArgs {
424    openrpc: bool,
425}
426
427impl Parse for RpcArgs {
428    fn parse(input: ParseStream) -> Result<Self> {
429        let mut openrpc = false;
430        if !input.is_empty() {
431            let parser = meta::parser(|m| {
432                if m.path.is_ident("openrpc") {
433                    openrpc = true;
434                    Ok(())
435                } else {
436                    Err(m.error("unknown arg"))
437                }
438            });
439            parser.parse2(input.parse()?)?;
440        }
441        Ok(Self { openrpc })
442    }
443}
444
445struct ClientMethodArgs {
446    name: Option<String>,
447}
448
449impl ClientMethodArgs {
450    fn parse_attrs(attrs: &[Attribute]) -> Result<Self> {
451        let mut name: Option<LitStr> = None;
452
453        for a in attrs {
454            if a.path().is_ident("rpc") {
455                a.parse_nested_meta(|m| {
456                    if m.path.is_ident("name") {
457                        name = Some(m.value()?.parse()?);
458                    } else {
459                        return Err(m.error("unknown arg"));
460                    }
461                    Ok(())
462                })?;
463            }
464        }
465
466        let name = name.map(|n| n.value());
467        Ok(Self { name })
468    }
469}
470
471struct MethodArgs {
472    pub_sub: Option<PubSubArgs>,
473    name: Option<String>,
474}
475
476impl MethodArgs {
477    fn parse_attrs(attrs: &[Attribute]) -> Result<Self> {
478        let mut pub_sub: Option<PubSubArgs> = None;
479        let mut name: Option<LitStr> = None;
480        for a in attrs {
481            if a.path().is_ident("rpc") {
482                a.parse_nested_meta(|m| {
483                    if m.path.is_ident("pub_sub") {
484                        let content;
485                        parenthesized!(content in m.input);
486                        pub_sub = Some(content.parse()?);
487                        Ok(())
488                    } else if m.path.is_ident("name") {
489                        name = Some(m.value()?.parse()?);
490                        Ok(())
491                    } else {
492                        Err(m.error("unknown arg"))
493                    }
494                })?;
495            }
496        }
497        let name = name.map(|n| n.value());
498        Ok(Self { name, pub_sub })
499    }
500}
501
502struct PubSubArgs {
503    notify: String,
504    unsubscribe: String,
505}
506
507impl Parse for PubSubArgs {
508    fn parse(input: ParseStream) -> Result<Self> {
509        let mut notify: Option<LitStr> = None;
510        let mut unsubscribe: Option<LitStr> = None;
511
512        let parser = meta::parser(|m| {
513            if m.path.is_ident("notify") {
514                notify = Some(m.value()?.parse()?);
515            } else if m.path.is_ident("unsubscribe") {
516                unsubscribe = Some(m.value()?.parse()?);
517            } else {
518                return Err(m.error("unknown arg"));
519            }
520            Ok(())
521        });
522        parser.parse2(input.parse()?)?;
523
524        let notify = notify
525            .ok_or_else(|| input.error("missing arg notify"))?
526            .value();
527        let unsubscribe = unsubscribe
528            .ok_or_else(|| input.error("missing arg unsubscribe"))?
529            .value();
530
531        Ok(Self {
532            notify,
533            unsubscribe,
534        })
535    }
536}
537
538fn to_snake_case(ident: String) -> String {
539    let mut result = String::with_capacity(ident.len());
540    for c in ident.chars() {
541        if c.is_ascii_uppercase() {
542            if !result.is_empty() {
543                result.push('_');
544            }
545            result.push(c.to_ascii_lowercase());
546        } else {
547            result.push(c)
548        }
549    }
550    result
551}
552
553#[cfg(test)]
554mod tests {
555    use syn::{parse2, Stmt};
556
557    use super::*;
558
559    fn test_method(m: proc_macro2::TokenStream) -> Stmt {
560        let output = add_method(&mut parse2(m).unwrap()).unwrap();
561        println!("output: {}", output);
562        parse2(output).unwrap()
563    }
564
565    #[test]
566    fn test_methods() {
567        test_method(quote!(
568            async fn no_param(&self) -> Result<u64>;
569        ));
570        test_method(quote!(
571            async fn sleep(&self, x: u64) -> Result<u64>;
572        ));
573        test_method(quote!(
574            #[rpc(name = "@sleep")]
575            fn sleep(&self, a: i32, b: i32) -> Result<i32>;
576        ));
577        test_method(quote!(
578            fn sleep2(&self, a: Option<i32>, b: Option<i32>) -> Result<i32>;
579        ));
580        test_method(quote!(
581            #[rpc(pub_sub(notify = "subscription", unsubscribe = "unsubscribe"))]
582            fn subscribe(&self, a: i32, b: i32) -> Result<S>;
583        ));
584    }
585}