Skip to main content

jsonrpsee_ts_macros/
lib.rs

1use std::iter;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2, TokenTree};
5use quote::{format_ident, quote};
6use syn::parse::{Parse, ParseStream, Parser};
7use syn::punctuated::Punctuated;
8use syn::{
9    Attribute, Error, FnArg, GenericArgument, GenericParam, Ident, ItemTrait, LitStr, Pat,
10    ReturnType, Token, TraitItem, TraitItemFn, Type, TypeParamBound, parse_macro_input,
11    parse_quote,
12};
13
14#[proc_macro_attribute]
15pub fn export_schema(attr: TokenStream, item: TokenStream) -> TokenStream {
16    let attr = proc_macro2::TokenStream::from(attr);
17    if !attr.is_empty() {
18        return Error::new(
19            Span::call_site(),
20            "#[export_schema] does not take arguments",
21        )
22        .to_compile_error()
23        .into();
24    }
25
26    let item = parse_macro_input!(item as ItemTrait);
27
28    match expand(item) {
29        Ok(tokens) => tokens.into(),
30        Err(err) => err.to_compile_error().into(),
31    }
32}
33
34fn expand(item: ItemTrait) -> syn::Result<TokenStream2> {
35    if item.generics.lifetimes().next().is_some() {
36        return Err(Error::new_spanned(
37            &item.generics,
38            "#[export_schema] does not support lifetime generics",
39        ));
40    }
41
42    let rpc_attr = find_attr(&item.attrs, "rpc").ok_or_else(|| {
43        Error::new_spanned(
44            &item.ident,
45            "#[export_schema] must be placed on a trait that also has #[rpc(...)]",
46        )
47    })?;
48    let rpc_config = RpcConfig::from_attr(rpc_attr)?;
49
50    let schema_ident = format_ident!("{}Schema", item.ident);
51    let builder_fn = format_ident!(
52        "__jsonrpsee_ts_build_{}_schema",
53        to_snake_case(&item.ident.to_string())
54    );
55
56    let used_entries = collect_entries(&item, &rpc_config)?;
57
58    let item_generics = item.generics.clone();
59    let bounded_generics = add_ts_bounds(item_generics.clone());
60    let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
61    let builder_generics = render_fn_generics(&bounded_generics);
62    let builder_where = bounded_generics.where_clause.clone();
63    let builder_turbofish = ty_generics.as_turbofish();
64    let builder_body = render_schema_builder(&used_entries);
65    let schema_generics = render_struct_generics(&item_generics);
66    let schema_marker = render_struct_marker(&item_generics);
67    let used_types = render_used_types(&used_entries);
68
69    Ok(quote! {
70        #item
71
72        #[doc(hidden)]
73        fn #builder_fn #builder_generics (cfg: &::ts_rs::Config) -> ::jsonrpsee_ts::Schema
74        #builder_where
75        {
76            #builder_body
77        }
78
79        ::jsonrpsee_ts::__jsonrpsee_ts_schema_impl! {
80            schema = #schema_ident,
81            builder = #builder_fn,
82            builder_generics = [#builder_turbofish],
83            struct_generics = [#schema_generics],
84            marker = [#schema_marker],
85            impl_generics = [#impl_generics],
86            type_generics = [#ty_generics],
87            where_clause = [#where_clause],
88            used_types = [#used_types]
89        }
90    })
91}
92
93fn render_struct_generics(generics: &syn::Generics) -> TokenStream2 {
94    let params = &generics.params;
95    if params.is_empty() {
96        TokenStream2::new()
97    } else {
98        quote!(<#params>)
99    }
100}
101
102fn render_struct_marker(generics: &syn::Generics) -> TokenStream2 {
103    let type_params = generics
104        .type_params()
105        .map(|param| param.ident.clone())
106        .collect::<Vec<_>>();
107
108    match type_params.as_slice() {
109        [] => TokenStream2::new(),
110        [single] => quote!((::std::marker::PhantomData<#single>)),
111        many => quote!((::std::marker::PhantomData<(#(#many),*)>)),
112    }
113}
114
115fn render_schema_builder(entries: &[RpcSchemaEntry]) -> TokenStream2 {
116    let requests = entries
117        .iter()
118        .filter(|entry| !entry.subscription)
119        .map(RpcSchemaEntry::builder_tokens)
120        .collect::<Vec<_>>();
121    let subscriptions = entries
122        .iter()
123        .filter(|entry| entry.subscription)
124        .map(RpcSchemaEntry::builder_tokens)
125        .collect::<Vec<_>>();
126
127    quote! {
128        ::jsonrpsee_ts::Schema::new()
129            #(.request(#requests))*
130            #(.subscription(#subscriptions))*
131    }
132}
133
134fn render_used_types(entries: &[RpcSchemaEntry]) -> TokenStream2 {
135    let used_types = entries
136        .iter()
137        .flat_map(|entry| entry.used_types.iter())
138        .collect::<Vec<_>>();
139
140    quote!(#(#used_types),*)
141}
142
143fn collect_entries(item: &ItemTrait, rpc_config: &RpcConfig) -> syn::Result<Vec<RpcSchemaEntry>> {
144    let mut entries = Vec::new();
145
146    for trait_item in &item.items {
147        let TraitItem::Fn(method) = trait_item else {
148            return Err(Error::new_spanned(
149                trait_item,
150                "#[export_schema] only supports RPC traits that contain methods",
151            ));
152        };
153
154        if let Some(attr) = find_attr(&method.attrs, "method") {
155            entries.push(RpcSchemaEntry::from_method(method, attr, rpc_config)?);
156            continue;
157        }
158
159        if let Some(attr) = find_attr(&method.attrs, "subscription") {
160            entries.push(RpcSchemaEntry::from_subscription(method, attr, rpc_config)?);
161            continue;
162        }
163
164        return Err(Error::new_spanned(
165            method,
166            "RPC trait methods must have either #[method(...)] or #[subscription(...)]",
167        ));
168    }
169
170    if entries.is_empty() {
171        return Err(Error::new_spanned(
172            &item.ident,
173            "RPC trait must contain at least one method or subscription",
174        ));
175    }
176
177    Ok(entries)
178}
179
180fn add_ts_bounds(mut generics: syn::Generics) -> syn::Generics {
181    for param in &mut generics.params {
182        if let GenericParam::Type(type_param) = param {
183            let has_ts_bound = type_param.bounds.iter().any(|bound| match bound {
184                TypeParamBound::Trait(bound) => bound.path.is_ident("TS"),
185                _ => false,
186            });
187
188            if !has_ts_bound {
189                type_param.bounds.push(parse_quote!(::ts_rs::TS));
190            }
191        }
192    }
193
194    generics
195}
196
197fn render_fn_generics(generics: &syn::Generics) -> TokenStream2 {
198    if generics.params.is_empty() {
199        TokenStream2::new()
200    } else {
201        let params = &generics.params;
202        quote!(<#params>)
203    }
204}
205
206fn find_attr<'a>(attrs: &'a [Attribute], ident: &str) -> Option<&'a Attribute> {
207    attrs.iter().find(|attr| attr.path().is_ident(ident))
208}
209
210#[derive(Clone)]
211struct RpcConfig {
212    namespace: Option<String>,
213    namespace_separator: String,
214}
215
216impl RpcConfig {
217    fn from_attr(attr: &Attribute) -> syn::Result<Self> {
218        let args = parse_arguments(attr)?;
219        let namespace = find_argument(&args, "namespace")?
220            .map(Argument::string)
221            .transpose()?;
222        let namespace_separator = find_argument(&args, "namespace_separator")?
223            .map(Argument::string)
224            .transpose()?
225            .unwrap_or_else(|| "_".to_string());
226
227        Ok(Self {
228            namespace,
229            namespace_separator,
230        })
231    }
232
233    fn rpc_method_name(&self, method: &str) -> String {
234        if let Some(namespace) = &self.namespace {
235            format!("{namespace}{}{method}", self.namespace_separator)
236        } else {
237            method.to_string()
238        }
239    }
240}
241
242struct RpcSchemaEntry {
243    subscription: bool,
244    name: String,
245    param_kind: RpcParamKind,
246    params: Vec<RpcParam>,
247    return_kind: SchemaReturn,
248    used_types: Vec<Type>,
249}
250
251impl RpcSchemaEntry {
252    fn from_method(
253        method: &TraitItemFn,
254        attr: &Attribute,
255        rpc_config: &RpcConfig,
256    ) -> syn::Result<Self> {
257        let args = parse_arguments(attr)?;
258        let name = find_argument(&args, "name")?
259            .ok_or_else(|| Error::new_spanned(attr, "#[method(...)] requires name = \"...\""))?
260            .string()?;
261        let param_kind = find_argument(&args, "param_kind")?
262            .map(Argument::param_kind)
263            .transpose()?
264            .unwrap_or(RpcParamKind::Array);
265
266        let params = collect_params(method)?;
267        let return_ty = match &method.sig.output {
268            ReturnType::Default => SchemaReturn::Void,
269            ReturnType::Type(_, ty) => SchemaReturn::Type(extract_success_type(ty.as_ref())),
270        };
271
272        let mut used_types = params
273            .iter()
274            .map(RpcParam::effective_ty)
275            .collect::<Vec<_>>();
276        if let SchemaReturn::Type(ty) = &return_ty {
277            used_types.push(ty.clone());
278        }
279
280        Ok(Self {
281            subscription: false,
282            name: rpc_config.rpc_method_name(&name),
283            param_kind,
284            params,
285            return_kind: return_ty,
286            used_types,
287        })
288    }
289
290    fn from_subscription(
291        method: &TraitItemFn,
292        attr: &Attribute,
293        rpc_config: &RpcConfig,
294    ) -> syn::Result<Self> {
295        let args = parse_arguments(attr)?;
296        let name = find_argument(&args, "name")?
297            .ok_or_else(|| {
298                Error::new_spanned(attr, "#[subscription(...)] requires name = \"...\"")
299            })?
300            .name_mapping()?;
301        let item = find_argument(&args, "item")?
302            .ok_or_else(|| Error::new_spanned(attr, "#[subscription(...)] requires item = Type"))?
303            .type_value()?;
304        let param_kind = find_argument(&args, "param_kind")?
305            .map(Argument::param_kind)
306            .transpose()?
307            .unwrap_or(RpcParamKind::Array);
308
309        let params = collect_params(method)?;
310        let mut used_types = params
311            .iter()
312            .map(RpcParam::effective_ty)
313            .collect::<Vec<_>>();
314        used_types.push(item.clone());
315
316        Ok(Self {
317            subscription: true,
318            name: rpc_config.rpc_method_name(&name.name),
319            param_kind,
320            params,
321            return_kind: SchemaReturn::Type(item),
322            used_types,
323        })
324    }
325
326    fn builder_tokens(&self) -> TokenStream2 {
327        let name = LitStr::new(&self.name, Span::call_site());
328        let param_kind = match self.param_kind {
329            RpcParamKind::Array => quote!(Array),
330            RpcParamKind::Map => quote!(Map),
331        };
332        let return_expr = match &self.return_kind {
333            SchemaReturn::Type(ty) => quote!(ty(#ty)),
334            SchemaReturn::Void => quote!(void),
335        };
336        let params = self
337            .params
338            .iter()
339            .map(RpcParam::builder_tokens)
340            .collect::<Vec<_>>();
341
342        quote! {
343            ::jsonrpsee_ts::__jsonrpsee_ts_method! {
344                cfg = cfg,
345                name = #name,
346                param_kind = #param_kind,
347                params = [#(#params),*],
348                return = #return_expr
349            }
350        }
351    }
352}
353
354enum SchemaReturn {
355    Type(Type),
356    Void,
357}
358
359#[derive(Clone, Copy)]
360enum RpcParamKind {
361    Array,
362    Map,
363}
364
365#[derive(Clone)]
366struct RpcParam {
367    name: String,
368    ty: Type,
369    optional: bool,
370}
371
372impl RpcParam {
373    fn effective_ty(&self) -> Type {
374        self.ty.clone()
375    }
376
377    fn builder_tokens(&self) -> TokenStream2 {
378        let name = LitStr::new(&self.name, Span::call_site());
379        let ty = &self.ty;
380
381        if self.optional {
382            quote!((#name, #ty, optional))
383        } else {
384            quote!((#name, #ty, required))
385        }
386    }
387}
388
389fn collect_params(method: &TraitItemFn) -> syn::Result<Vec<RpcParam>> {
390    method
391        .sig
392        .inputs
393        .iter()
394        .filter_map(|arg| match arg {
395            FnArg::Receiver(_) => None,
396            FnArg::Typed(arg) => Some(parse_param(arg)),
397        })
398        .collect()
399}
400
401fn parse_param(arg: &syn::PatType) -> syn::Result<RpcParam> {
402    let Pat::Ident(ident) = &*arg.pat else {
403        return Err(Error::new_spanned(
404            &arg.pat,
405            "RPC method parameters must be named identifiers",
406        ));
407    };
408
409    let name = parse_argument_rename(&arg.attrs)?.unwrap_or_else(|| ident.ident.to_string());
410    let (ty, optional) = unwrap_option_type(arg.ty.as_ref())
411        .map(|inner| (inner, true))
412        .unwrap_or_else(|| ((*arg.ty).clone(), false));
413
414    Ok(RpcParam { name, ty, optional })
415}
416
417fn parse_argument_rename(attrs: &[Attribute]) -> syn::Result<Option<String>> {
418    let Some(attr) = find_attr(attrs, "argument") else {
419        return Ok(None);
420    };
421
422    let args = parse_arguments(attr)?;
423    find_argument(&args, "rename")?
424        .map(Argument::string)
425        .transpose()
426}
427
428fn extract_success_type(ty: &Type) -> Type {
429    let Type::Path(type_path) = ty else {
430        return ty.clone();
431    };
432
433    let Some(segment) = type_path.path.segments.last() else {
434        return ty.clone();
435    };
436
437    if !matches!(segment.ident.to_string().as_str(), "Result" | "RpcResult") {
438        return ty.clone();
439    }
440
441    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
442        return ty.clone();
443    };
444
445    args.args
446        .iter()
447        .find_map(|arg| match arg {
448            GenericArgument::Type(ty) => Some(ty.clone()),
449            _ => None,
450        })
451        .unwrap_or_else(|| ty.clone())
452}
453
454fn unwrap_option_type(ty: &Type) -> Option<Type> {
455    let Type::Path(type_path) = ty else {
456        return None;
457    };
458    let segment = type_path.path.segments.last()?;
459    if segment.ident != "Option" {
460        return None;
461    }
462
463    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
464        return None;
465    };
466
467    args.args.iter().find_map(|arg| match arg {
468        GenericArgument::Type(ty) => Some(ty.clone()),
469        _ => None,
470    })
471}
472
473#[derive(Clone)]
474struct Argument {
475    label: Ident,
476    tokens: TokenStream2,
477}
478
479impl Argument {
480    fn string(&self) -> syn::Result<String> {
481        self.parse_value::<LitStr>().map(|lit| lit.value())
482    }
483
484    fn type_value(&self) -> syn::Result<Type> {
485        self.parse_value::<Type>()
486    }
487
488    fn name_mapping(&self) -> syn::Result<NameMapping> {
489        self.parse_value::<NameMapping>()
490    }
491
492    fn param_kind(&self) -> syn::Result<RpcParamKind> {
493        let ident = self.parse_value::<Ident>()?;
494        match ident.to_string().as_str() {
495            "array" => Ok(RpcParamKind::Array),
496            "map" => Ok(RpcParamKind::Map),
497            _ => Err(Error::new_spanned(
498                ident,
499                "param_kind must be either `array` or `map`",
500            )),
501        }
502    }
503
504    fn parse_value<T: Parse>(&self) -> syn::Result<T> {
505        fn parser<T: Parse>(stream: ParseStream) -> syn::Result<T> {
506            stream.parse::<Token![=]>()?;
507            stream.parse::<T>()
508        }
509
510        parser.parse2(self.tokens.clone())
511    }
512}
513
514fn find_argument<'a>(args: &'a [Argument], label: &str) -> syn::Result<Option<&'a Argument>> {
515    let mut matches = args.iter().filter(|arg| arg.label == label);
516    let first = matches.next();
517    if matches.next().is_some() {
518        return Err(Error::new(
519            Span::call_site(),
520            format!("duplicate `{label}` argument"),
521        ));
522    }
523    Ok(first)
524}
525
526fn parse_arguments(attr: &Attribute) -> syn::Result<Vec<Argument>> {
527    attr.parse_args_with(|input: ParseStream| {
528        let punctuated = Punctuated::<Argument, Token![,]>::parse_terminated(input)?;
529        Ok(punctuated.into_iter().collect::<Vec<_>>())
530    })
531}
532
533impl Parse for Argument {
534    fn parse(input: ParseStream) -> syn::Result<Self> {
535        let label = input.parse()?;
536        let mut scope = 0usize;
537        let tokens = iter::from_fn(|| {
538            if scope == 0 && input.peek(Token![,]) {
539                return None;
540            }
541
542            if input.peek(Token![<]) {
543                scope += 1;
544            } else if input.peek(Token![>]) {
545                scope = scope.saturating_sub(1);
546            }
547
548            input.parse::<TokenTree>().ok()
549        })
550        .collect();
551
552        Ok(Self { label, tokens })
553    }
554}
555
556struct NameMapping {
557    name: String,
558}
559
560impl Parse for NameMapping {
561    fn parse(input: ParseStream) -> syn::Result<Self> {
562        let name = input.parse::<LitStr>()?.value();
563        if input.peek(Token![=>]) {
564            input.parse::<Token![=>]>()?;
565            let _: LitStr = input.parse()?;
566        }
567
568        Ok(Self { name })
569    }
570}
571
572fn to_snake_case(input: &str) -> String {
573    let mut output = String::with_capacity(input.len());
574
575    for (idx, ch) in input.chars().enumerate() {
576        if ch.is_ascii_uppercase() {
577            if idx != 0 {
578                output.push('_');
579            }
580            output.push(ch.to_ascii_lowercase());
581        } else {
582            output.push(ch);
583        }
584    }
585
586    output
587}