1use std::iter;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2, TokenTree};
5use quote::{format_ident, quote};
6use syn::parse::{Parse, ParseStream, Parser};
7use syn::punctuated::Punctuated;
8use syn::{
9 Attribute, Error, FnArg, GenericArgument, GenericParam, Ident, ItemTrait, LitStr, Pat,
10 ReturnType, Token, TraitItem, TraitItemFn, Type, TypeParamBound, parse_macro_input,
11 parse_quote,
12};
13
14#[proc_macro_attribute]
15pub fn export_schema(attr: TokenStream, item: TokenStream) -> TokenStream {
16 let attr = proc_macro2::TokenStream::from(attr);
17 if !attr.is_empty() {
18 return Error::new(
19 Span::call_site(),
20 "#[export_schema] does not take arguments",
21 )
22 .to_compile_error()
23 .into();
24 }
25
26 let item = parse_macro_input!(item as ItemTrait);
27
28 match expand(item) {
29 Ok(tokens) => tokens.into(),
30 Err(err) => err.to_compile_error().into(),
31 }
32}
33
34fn expand(item: ItemTrait) -> syn::Result<TokenStream2> {
35 if item.generics.lifetimes().next().is_some() {
36 return Err(Error::new_spanned(
37 &item.generics,
38 "#[export_schema] does not support lifetime generics",
39 ));
40 }
41
42 let rpc_attr = find_attr(&item.attrs, "rpc").ok_or_else(|| {
43 Error::new_spanned(
44 &item.ident,
45 "#[export_schema] must be placed on a trait that also has #[rpc(...)]",
46 )
47 })?;
48 let rpc_config = RpcConfig::from_attr(rpc_attr)?;
49
50 let schema_ident = format_ident!("{}Schema", item.ident);
51 let builder_fn = format_ident!(
52 "__jsonrpsee_ts_build_{}_schema",
53 to_snake_case(&item.ident.to_string())
54 );
55
56 let used_entries = collect_entries(&item, &rpc_config)?;
57
58 let item_generics = item.generics.clone();
59 let bounded_generics = add_ts_bounds(item_generics.clone());
60 let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
61 let builder_generics = render_fn_generics(&bounded_generics);
62 let builder_where = bounded_generics.where_clause.clone();
63 let builder_turbofish = ty_generics.as_turbofish();
64 let builder_body = render_schema_builder(&used_entries);
65 let schema_generics = render_struct_generics(&item_generics);
66 let schema_marker = render_struct_marker(&item_generics);
67 let used_types = render_used_types(&used_entries);
68
69 Ok(quote! {
70 #item
71
72 #[doc(hidden)]
73 fn #builder_fn #builder_generics (cfg: &::ts_rs::Config) -> ::jsonrpsee_ts::Schema
74 #builder_where
75 {
76 #builder_body
77 }
78
79 ::jsonrpsee_ts::__jsonrpsee_ts_schema_impl! {
80 schema = #schema_ident,
81 builder = #builder_fn,
82 builder_generics = [#builder_turbofish],
83 struct_generics = [#schema_generics],
84 marker = [#schema_marker],
85 impl_generics = [#impl_generics],
86 type_generics = [#ty_generics],
87 where_clause = [#where_clause],
88 used_types = [#used_types]
89 }
90 })
91}
92
93fn render_struct_generics(generics: &syn::Generics) -> TokenStream2 {
94 let params = &generics.params;
95 if params.is_empty() {
96 TokenStream2::new()
97 } else {
98 quote!(<#params>)
99 }
100}
101
102fn render_struct_marker(generics: &syn::Generics) -> TokenStream2 {
103 let type_params = generics
104 .type_params()
105 .map(|param| param.ident.clone())
106 .collect::<Vec<_>>();
107
108 match type_params.as_slice() {
109 [] => TokenStream2::new(),
110 [single] => quote!((::std::marker::PhantomData<#single>)),
111 many => quote!((::std::marker::PhantomData<(#(#many),*)>)),
112 }
113}
114
115fn render_schema_builder(entries: &[RpcSchemaEntry]) -> TokenStream2 {
116 let requests = entries
117 .iter()
118 .filter(|entry| !entry.subscription)
119 .map(RpcSchemaEntry::builder_tokens)
120 .collect::<Vec<_>>();
121 let subscriptions = entries
122 .iter()
123 .filter(|entry| entry.subscription)
124 .map(RpcSchemaEntry::builder_tokens)
125 .collect::<Vec<_>>();
126
127 quote! {
128 ::jsonrpsee_ts::Schema::new()
129 #(.request(#requests))*
130 #(.subscription(#subscriptions))*
131 }
132}
133
134fn render_used_types(entries: &[RpcSchemaEntry]) -> TokenStream2 {
135 let used_types = entries
136 .iter()
137 .flat_map(|entry| entry.used_types.iter())
138 .collect::<Vec<_>>();
139
140 quote!(#(#used_types),*)
141}
142
143fn collect_entries(item: &ItemTrait, rpc_config: &RpcConfig) -> syn::Result<Vec<RpcSchemaEntry>> {
144 let mut entries = Vec::new();
145
146 for trait_item in &item.items {
147 let TraitItem::Fn(method) = trait_item else {
148 return Err(Error::new_spanned(
149 trait_item,
150 "#[export_schema] only supports RPC traits that contain methods",
151 ));
152 };
153
154 if let Some(attr) = find_attr(&method.attrs, "method") {
155 entries.push(RpcSchemaEntry::from_method(method, attr, rpc_config)?);
156 continue;
157 }
158
159 if let Some(attr) = find_attr(&method.attrs, "subscription") {
160 entries.push(RpcSchemaEntry::from_subscription(method, attr, rpc_config)?);
161 continue;
162 }
163
164 return Err(Error::new_spanned(
165 method,
166 "RPC trait methods must have either #[method(...)] or #[subscription(...)]",
167 ));
168 }
169
170 if entries.is_empty() {
171 return Err(Error::new_spanned(
172 &item.ident,
173 "RPC trait must contain at least one method or subscription",
174 ));
175 }
176
177 Ok(entries)
178}
179
180fn add_ts_bounds(mut generics: syn::Generics) -> syn::Generics {
181 for param in &mut generics.params {
182 if let GenericParam::Type(type_param) = param {
183 let has_ts_bound = type_param.bounds.iter().any(|bound| match bound {
184 TypeParamBound::Trait(bound) => bound.path.is_ident("TS"),
185 _ => false,
186 });
187
188 if !has_ts_bound {
189 type_param.bounds.push(parse_quote!(::ts_rs::TS));
190 }
191 }
192 }
193
194 generics
195}
196
197fn render_fn_generics(generics: &syn::Generics) -> TokenStream2 {
198 if generics.params.is_empty() {
199 TokenStream2::new()
200 } else {
201 let params = &generics.params;
202 quote!(<#params>)
203 }
204}
205
206fn find_attr<'a>(attrs: &'a [Attribute], ident: &str) -> Option<&'a Attribute> {
207 attrs.iter().find(|attr| attr.path().is_ident(ident))
208}
209
210#[derive(Clone)]
211struct RpcConfig {
212 namespace: Option<String>,
213 namespace_separator: String,
214}
215
216impl RpcConfig {
217 fn from_attr(attr: &Attribute) -> syn::Result<Self> {
218 let args = parse_arguments(attr)?;
219 let namespace = find_argument(&args, "namespace")?
220 .map(Argument::string)
221 .transpose()?;
222 let namespace_separator = find_argument(&args, "namespace_separator")?
223 .map(Argument::string)
224 .transpose()?
225 .unwrap_or_else(|| "_".to_string());
226
227 Ok(Self {
228 namespace,
229 namespace_separator,
230 })
231 }
232
233 fn rpc_method_name(&self, method: &str) -> String {
234 if let Some(namespace) = &self.namespace {
235 format!("{namespace}{}{method}", self.namespace_separator)
236 } else {
237 method.to_string()
238 }
239 }
240}
241
242struct RpcSchemaEntry {
243 subscription: bool,
244 name: String,
245 param_kind: RpcParamKind,
246 params: Vec<RpcParam>,
247 return_kind: SchemaReturn,
248 used_types: Vec<Type>,
249}
250
251impl RpcSchemaEntry {
252 fn from_method(
253 method: &TraitItemFn,
254 attr: &Attribute,
255 rpc_config: &RpcConfig,
256 ) -> syn::Result<Self> {
257 let args = parse_arguments(attr)?;
258 let name = find_argument(&args, "name")?
259 .ok_or_else(|| Error::new_spanned(attr, "#[method(...)] requires name = \"...\""))?
260 .string()?;
261 let param_kind = find_argument(&args, "param_kind")?
262 .map(Argument::param_kind)
263 .transpose()?
264 .unwrap_or(RpcParamKind::Array);
265
266 let params = collect_params(method)?;
267 let return_ty = match &method.sig.output {
268 ReturnType::Default => SchemaReturn::Void,
269 ReturnType::Type(_, ty) => SchemaReturn::Type(extract_success_type(ty.as_ref())),
270 };
271
272 let mut used_types = params
273 .iter()
274 .map(RpcParam::effective_ty)
275 .collect::<Vec<_>>();
276 if let SchemaReturn::Type(ty) = &return_ty {
277 used_types.push(ty.clone());
278 }
279
280 Ok(Self {
281 subscription: false,
282 name: rpc_config.rpc_method_name(&name),
283 param_kind,
284 params,
285 return_kind: return_ty,
286 used_types,
287 })
288 }
289
290 fn from_subscription(
291 method: &TraitItemFn,
292 attr: &Attribute,
293 rpc_config: &RpcConfig,
294 ) -> syn::Result<Self> {
295 let args = parse_arguments(attr)?;
296 let name = find_argument(&args, "name")?
297 .ok_or_else(|| {
298 Error::new_spanned(attr, "#[subscription(...)] requires name = \"...\"")
299 })?
300 .name_mapping()?;
301 let item = find_argument(&args, "item")?
302 .ok_or_else(|| Error::new_spanned(attr, "#[subscription(...)] requires item = Type"))?
303 .type_value()?;
304 let param_kind = find_argument(&args, "param_kind")?
305 .map(Argument::param_kind)
306 .transpose()?
307 .unwrap_or(RpcParamKind::Array);
308
309 let params = collect_params(method)?;
310 let mut used_types = params
311 .iter()
312 .map(RpcParam::effective_ty)
313 .collect::<Vec<_>>();
314 used_types.push(item.clone());
315
316 Ok(Self {
317 subscription: true,
318 name: rpc_config.rpc_method_name(&name.name),
319 param_kind,
320 params,
321 return_kind: SchemaReturn::Type(item),
322 used_types,
323 })
324 }
325
326 fn builder_tokens(&self) -> TokenStream2 {
327 let name = LitStr::new(&self.name, Span::call_site());
328 let param_kind = match self.param_kind {
329 RpcParamKind::Array => quote!(Array),
330 RpcParamKind::Map => quote!(Map),
331 };
332 let return_expr = match &self.return_kind {
333 SchemaReturn::Type(ty) => quote!(ty(#ty)),
334 SchemaReturn::Void => quote!(void),
335 };
336 let params = self
337 .params
338 .iter()
339 .map(RpcParam::builder_tokens)
340 .collect::<Vec<_>>();
341
342 quote! {
343 ::jsonrpsee_ts::__jsonrpsee_ts_method! {
344 cfg = cfg,
345 name = #name,
346 param_kind = #param_kind,
347 params = [#(#params),*],
348 return = #return_expr
349 }
350 }
351 }
352}
353
354enum SchemaReturn {
355 Type(Type),
356 Void,
357}
358
359#[derive(Clone, Copy)]
360enum RpcParamKind {
361 Array,
362 Map,
363}
364
365#[derive(Clone)]
366struct RpcParam {
367 name: String,
368 ty: Type,
369 optional: bool,
370}
371
372impl RpcParam {
373 fn effective_ty(&self) -> Type {
374 self.ty.clone()
375 }
376
377 fn builder_tokens(&self) -> TokenStream2 {
378 let name = LitStr::new(&self.name, Span::call_site());
379 let ty = &self.ty;
380
381 if self.optional {
382 quote!((#name, #ty, optional))
383 } else {
384 quote!((#name, #ty, required))
385 }
386 }
387}
388
389fn collect_params(method: &TraitItemFn) -> syn::Result<Vec<RpcParam>> {
390 method
391 .sig
392 .inputs
393 .iter()
394 .filter_map(|arg| match arg {
395 FnArg::Receiver(_) => None,
396 FnArg::Typed(arg) => Some(parse_param(arg)),
397 })
398 .collect()
399}
400
401fn parse_param(arg: &syn::PatType) -> syn::Result<RpcParam> {
402 let Pat::Ident(ident) = &*arg.pat else {
403 return Err(Error::new_spanned(
404 &arg.pat,
405 "RPC method parameters must be named identifiers",
406 ));
407 };
408
409 let name = parse_argument_rename(&arg.attrs)?.unwrap_or_else(|| ident.ident.to_string());
410 let (ty, optional) = unwrap_option_type(arg.ty.as_ref())
411 .map(|inner| (inner, true))
412 .unwrap_or_else(|| ((*arg.ty).clone(), false));
413
414 Ok(RpcParam { name, ty, optional })
415}
416
417fn parse_argument_rename(attrs: &[Attribute]) -> syn::Result<Option<String>> {
418 let Some(attr) = find_attr(attrs, "argument") else {
419 return Ok(None);
420 };
421
422 let args = parse_arguments(attr)?;
423 find_argument(&args, "rename")?
424 .map(Argument::string)
425 .transpose()
426}
427
428fn extract_success_type(ty: &Type) -> Type {
429 let Type::Path(type_path) = ty else {
430 return ty.clone();
431 };
432
433 let Some(segment) = type_path.path.segments.last() else {
434 return ty.clone();
435 };
436
437 if !matches!(segment.ident.to_string().as_str(), "Result" | "RpcResult") {
438 return ty.clone();
439 }
440
441 let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
442 return ty.clone();
443 };
444
445 args.args
446 .iter()
447 .find_map(|arg| match arg {
448 GenericArgument::Type(ty) => Some(ty.clone()),
449 _ => None,
450 })
451 .unwrap_or_else(|| ty.clone())
452}
453
454fn unwrap_option_type(ty: &Type) -> Option<Type> {
455 let Type::Path(type_path) = ty else {
456 return None;
457 };
458 let segment = type_path.path.segments.last()?;
459 if segment.ident != "Option" {
460 return None;
461 }
462
463 let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
464 return None;
465 };
466
467 args.args.iter().find_map(|arg| match arg {
468 GenericArgument::Type(ty) => Some(ty.clone()),
469 _ => None,
470 })
471}
472
473#[derive(Clone)]
474struct Argument {
475 label: Ident,
476 tokens: TokenStream2,
477}
478
479impl Argument {
480 fn string(&self) -> syn::Result<String> {
481 self.parse_value::<LitStr>().map(|lit| lit.value())
482 }
483
484 fn type_value(&self) -> syn::Result<Type> {
485 self.parse_value::<Type>()
486 }
487
488 fn name_mapping(&self) -> syn::Result<NameMapping> {
489 self.parse_value::<NameMapping>()
490 }
491
492 fn param_kind(&self) -> syn::Result<RpcParamKind> {
493 let ident = self.parse_value::<Ident>()?;
494 match ident.to_string().as_str() {
495 "array" => Ok(RpcParamKind::Array),
496 "map" => Ok(RpcParamKind::Map),
497 _ => Err(Error::new_spanned(
498 ident,
499 "param_kind must be either `array` or `map`",
500 )),
501 }
502 }
503
504 fn parse_value<T: Parse>(&self) -> syn::Result<T> {
505 fn parser<T: Parse>(stream: ParseStream) -> syn::Result<T> {
506 stream.parse::<Token![=]>()?;
507 stream.parse::<T>()
508 }
509
510 parser.parse2(self.tokens.clone())
511 }
512}
513
514fn find_argument<'a>(args: &'a [Argument], label: &str) -> syn::Result<Option<&'a Argument>> {
515 let mut matches = args.iter().filter(|arg| arg.label == label);
516 let first = matches.next();
517 if matches.next().is_some() {
518 return Err(Error::new(
519 Span::call_site(),
520 format!("duplicate `{label}` argument"),
521 ));
522 }
523 Ok(first)
524}
525
526fn parse_arguments(attr: &Attribute) -> syn::Result<Vec<Argument>> {
527 attr.parse_args_with(|input: ParseStream| {
528 let punctuated = Punctuated::<Argument, Token![,]>::parse_terminated(input)?;
529 Ok(punctuated.into_iter().collect::<Vec<_>>())
530 })
531}
532
533impl Parse for Argument {
534 fn parse(input: ParseStream) -> syn::Result<Self> {
535 let label = input.parse()?;
536 let mut scope = 0usize;
537 let tokens = iter::from_fn(|| {
538 if scope == 0 && input.peek(Token![,]) {
539 return None;
540 }
541
542 if input.peek(Token![<]) {
543 scope += 1;
544 } else if input.peek(Token![>]) {
545 scope = scope.saturating_sub(1);
546 }
547
548 input.parse::<TokenTree>().ok()
549 })
550 .collect();
551
552 Ok(Self { label, tokens })
553 }
554}
555
556struct NameMapping {
557 name: String,
558}
559
560impl Parse for NameMapping {
561 fn parse(input: ParseStream) -> syn::Result<Self> {
562 let name = input.parse::<LitStr>()?.value();
563 if input.peek(Token![=>]) {
564 input.parse::<Token![=>]>()?;
565 let _: LitStr = input.parse()?;
566 }
567
568 Ok(Self { name })
569 }
570}
571
572fn to_snake_case(input: &str) -> String {
573 let mut output = String::with_capacity(input.len());
574
575 for (idx, ch) in input.chars().enumerate() {
576 if ch.is_ascii_uppercase() {
577 if idx != 0 {
578 output.push('_');
579 }
580 output.push(ch.to_ascii_lowercase());
581 } else {
582 output.push(ch);
583 }
584 }
585
586 output
587}