1use heck::ToUpperCamelCase;
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::parse::{Parse, Parser};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8 parse_macro_input,
9 parse_quote,
10 spanned::Spanned,
11 Expr,
12 FnArg,
13 Ident,
14 ItemTrait,
15 LitStr,
16 Pat,
17 Path,
18 PathArguments,
19 TraitItem,
20 Type,
21 TypeParamBound,
22};
23
24#[proc_macro_attribute]
25pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
26 let parser = Punctuated::<KeyValue, Comma>::parse_terminated;
27 let args_tokens = proc_macro2::TokenStream::from(attr);
28 let args = match parser.parse2(args_tokens) {
29 Ok(value) => value,
30 Err(err) => return err.into_compile_error().into(),
31 };
32
33 let mut input = parse_macro_input!(item as ItemTrait);
34
35 match expand_service(args, &mut input) {
36 Ok(tokens) => tokens.into(),
37 Err(err) => err.to_compile_error().into(),
38 }
39}
40
41struct KeyValue {
42 key: Ident,
43 value: Expr,
44}
45
46impl Parse for KeyValue {
47 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
48 let key: Ident = input.parse()?;
49 input.parse::<syn::Token![=]>()?;
50 let value: Expr = input.parse()?;
51 Ok(Self { key, value })
52 }
53}
54
55struct ServiceOptions {
56 request_ident: Ident,
57 response_ident: Ident,
58 client_ident: Ident,
59 error_path: Path,
60}
61
62struct MethodArg {
63 ident: Ident,
64 ty: Type,
65}
66
67struct MethodInfo {
68 method_ident: Ident,
69 request_struct_ident: Ident,
70 request_fields: Vec<MethodArg>,
71 method_inputs: Vec<syn::PatType>,
72 success_ty: Type,
73 name_literal: LitStr,
74}
75
76fn expand_service(
77 args: Punctuated<KeyValue, Comma>,
78 input: &mut ItemTrait,
79) -> syn::Result<proc_macro2::TokenStream> {
80 ensure_async_trait(input)?;
81
82 let options = parse_service_options(args, &input.ident)?;
83
84 let methods = collect_methods(input)?;
85
86 if methods.is_empty() {
87 return Err(syn::Error::new(
88 input.ident.span(),
89 "RPC traits must declare at least one method",
90 ));
91 }
92
93 let trait_ident = &input.ident;
94 let vis = &input.vis;
95 let request_ident = &options.request_ident;
96 let response_ident = &options.response_ident;
97 let client_ident = &options.client_ident;
98 let error_path = &options.error_path;
99
100 let mut request_structs = Vec::new();
101 let mut request_variants = Vec::new();
102 let mut response_variants = Vec::new();
103 let mut request_variant_names = Vec::new();
104 let mut response_variant_names = Vec::new();
105 let mut dispatch_arms = Vec::new();
106 let mut client_methods = Vec::new();
107
108 const MAX_METHODS: usize = 256;
110
111 if methods.len() > MAX_METHODS {
112 return Err(syn::Error::new(
113 input.ident.span(),
114 format!("RPC traits cannot have more than {} methods", MAX_METHODS),
115 ));
116 }
117
118 for (method_idx, method_info) in methods.iter().enumerate() {
120 let MethodInfo {
121 method_ident,
122 request_struct_ident,
123 request_fields,
124 method_inputs,
125 success_ty,
126 name_literal,
127 } = method_info;
128
129 let placeholder_ident = format_ident!("Method{}", method_idx);
131
132 let mut struct_fields = Vec::new();
133 let mut destructure_fields = Vec::new();
134 let mut argument_idents = Vec::new();
135 let mut request_init = Vec::new();
136
137 for field in request_fields {
138 let ident = &field.ident;
139 let ty = &field.ty;
140 struct_fields.push(quote! { pub #ident: #ty });
141 destructure_fields.push(quote! { #ident });
142 argument_idents.push(quote! { #ident });
143 request_init.push(quote! { #ident });
144 }
145
146 request_structs.push(quote! {
147 #[derive(::bitrpc::bitcode::Encode, ::bitrpc::bitcode::Decode, ::core::fmt::Debug)]
148 #vis struct #request_struct_ident {
149 #( #struct_fields, )*
150 }
151 });
152
153 request_variants.push(quote! { #placeholder_ident(#request_struct_ident) });
154 response_variants.push(quote! { #placeholder_ident(#success_ty) });
155 request_variant_names.push(quote! { #request_ident::#placeholder_ident(_) => #name_literal });
156 response_variant_names.push(quote! { #response_ident::#placeholder_ident(_) => #name_literal });
157
158 dispatch_arms.push(quote! {
159 #request_ident::#placeholder_ident(payload) => {
160 let #request_struct_ident { #( #destructure_fields, )* } = payload;
161 match handler.#method_ident(#( #argument_idents ),*).await {
162 ::core::result::Result::Ok(value) => #response_ident::#placeholder_ident(value),
163 ::core::result::Result::Err(err) => #response_ident::Error(err),
164 }
165 }
166 });
167
168 let client_args_def = method_inputs.iter().map(|pat_type| quote! { #pat_type });
169 let request_struct_init = quote! {
170 #request_struct_ident { #( #request_init, )* }
171 };
172
173 client_methods.push(quote! {
174 pub async fn #method_ident(&mut self #( , #client_args_def )* ) -> ::bitrpc::Result<#success_ty> {
175 let request = #request_ident::#placeholder_ident(#request_struct_init);
176 let bytes = ::bitrpc::bitcode::encode(&request);
177 let response_bytes = self.transport.call(bytes).await?;
178 let response = #response_ident::decode(&response_bytes)?;
179 match response {
180 #response_ident::#placeholder_ident(value) => ::core::result::Result::Ok(value),
181 #response_ident::Error(err) => ::core::result::Result::Err(err),
182 other => ::core::result::Result::Err(::bitrpc::RpcError::unexpected(#name_literal, other.variant_name())),
183 }
184 }
185 });
186 }
187
188 for i in methods.len()..(MAX_METHODS - 1) { let placeholder_ident = format_ident!("Placeholder{}", i);
191 request_variants.push(quote! { #placeholder_ident });
192 response_variants.push(quote! { #placeholder_ident });
193 request_variant_names.push(quote! {
194 #request_ident::#placeholder_ident => concat!("Placeholder", stringify!(#i))
195 });
196 response_variant_names.push(quote! {
197 #response_ident::#placeholder_ident => concat!("Placeholder", stringify!(#i))
198 });
199 }
200
201 response_variants.push(quote! { Error(#error_path) });
202 response_variant_names.push(quote! { #response_ident::Error(_) => "Error" });
203
204 let expanded = quote! {
205 #[::bitrpc::async_trait]
206 #input
207
208 #( #request_structs )*
209
210 #[derive(::bitrpc::bitcode::Encode, ::bitrpc::bitcode::Decode, ::core::fmt::Debug)]
211 #vis enum #request_ident {
212 #( #request_variants, )*
213 }
214
215 impl #request_ident {
216 pub fn encode(&self) -> ::std::vec::Vec<u8> {
217 ::bitrpc::bitcode::encode(self)
218 }
219
220 pub fn decode(bytes: &[u8]) -> ::core::result::Result<Self, ::bitrpc::DecodeError> {
221 ::bitrpc::bitcode::decode(bytes)
222 }
223
224 pub fn variant_name(&self) -> &'static str {
225 match self {
226 #( #request_variant_names, )*
227 }
228 }
229 }
230
231 #[derive(::bitrpc::bitcode::Encode, ::bitrpc::bitcode::Decode, ::core::fmt::Debug)]
232 #vis enum #response_ident {
233 #( #response_variants, )*
234 }
235
236 impl #response_ident {
237 pub fn encode(&self) -> ::std::vec::Vec<u8> {
238 ::bitrpc::bitcode::encode(self)
239 }
240
241 pub fn decode(bytes: &[u8]) -> ::core::result::Result<Self, ::bitrpc::DecodeError> {
242 ::bitrpc::bitcode::decode(bytes)
243 }
244
245 pub fn variant_name(&self) -> &'static str {
246 match self {
247 #( #response_variant_names, )*
248 }
249 }
250 }
251
252 pub async fn dispatch<T>(handler: &T, request: #request_ident) -> #response_ident
253 where
254 T: #trait_ident + ?Sized,
255 {
256 match request {
257 #( #dispatch_arms, )*
258 _ => #response_ident::Error(#error_path::unknown_method()),
259 }
260 }
261
262 #vis struct #client_ident<T> {
263 transport: T,
264 }
265
266 impl<T> #client_ident<T> {
267 pub fn new(transport: T) -> Self {
268 Self { transport }
269 }
270
271 pub fn into_inner(self) -> T {
272 self.transport
273 }
274
275 pub fn transport(&self) -> &T {
276 &self.transport
277 }
278
279 pub fn transport_mut(&mut self) -> &mut T {
280 &mut self.transport
281 }
282 }
283
284 impl<T> #client_ident<T> where T: ::bitrpc::RpcTransport {
285 #( #client_methods )*
286 }
287
288 #[derive(Clone)]
289 #vis struct RpcRequestServiceWrapper<T>(pub T);
290
291 impl<T> ::bitrpc::RpcRequestService for RpcRequestServiceWrapper<T>
292 where
293 T: #trait_ident + Clone,
294 {
295 type Request = #request_ident;
296 type Response = #response_ident;
297
298 async fn dispatch(&self, request: #request_ident) -> #response_ident {
299 dispatch(&self.0, request).await
300 }
301 }
302 };
303
304 Ok(expanded)
305}
306
307fn ensure_async_trait(trait_item: &mut ItemTrait) -> syn::Result<()> {
308 let mut has_send = false;
309 let mut has_sync = false;
310
311 for bound in &trait_item.supertraits {
312 if let TypeParamBound::Trait(bound_trait) = bound {
313 if bound_trait
314 .path
315 .segments
316 .last()
317 .map(|seg| seg.ident == "Send")
318 .unwrap_or(false)
319 {
320 has_send = true;
321 }
322
323 if bound_trait
324 .path
325 .segments
326 .last()
327 .map(|seg| seg.ident == "Sync")
328 .unwrap_or(false)
329 {
330 has_sync = true;
331 }
332 }
333 }
334
335 if !has_send {
336 if !trait_item.supertraits.is_empty() {
337 trait_item.supertraits.push_punct(syn::token::Plus::default());
338 }
339 trait_item
340 .supertraits
341 .push_value(parse_quote!(::core::marker::Send));
342 }
343
344 if !has_sync {
345 if !trait_item.supertraits.is_empty() {
346 trait_item.supertraits.push_punct(syn::token::Plus::default());
347 }
348 trait_item
349 .supertraits
350 .push_value(parse_quote!(::core::marker::Sync));
351 }
352
353 Ok(())
354}
355
356fn collect_methods(trait_item: &ItemTrait) -> syn::Result<Vec<MethodInfo>> {
357 let mut methods = Vec::new();
358
359 for item in &trait_item.items {
360 match item {
361 TraitItem::Fn(method) => {
362 if method.default.is_some() {
363 return Err(syn::Error::new(
364 method.sig.span(),
365 "RPC trait methods cannot have default implementations",
366 ));
367 }
368
369 if method.sig.asyncness.is_none() {
370 return Err(syn::Error::new(
371 method.sig.span(),
372 "RPC trait methods must be async",
373 ));
374 }
375
376 let mut inputs_iter = method.sig.inputs.iter();
377 match inputs_iter.next() {
378 Some(FnArg::Receiver(recv)) => {
379 if recv.reference.is_none() || recv.mutability.is_some() {
380 return Err(syn::Error::new(
381 recv.span(),
382 "RPC trait methods must take &self",
383 ));
384 }
385 }
386 _ => {
387 return Err(syn::Error::new(
388 method.sig.span(),
389 "RPC trait methods must take &self",
390 ));
391 }
392 }
393
394 let mut request_fields = Vec::new();
395 let mut method_inputs = Vec::new();
396
397 for arg in method.sig.inputs.iter().skip(1) {
398 if let FnArg::Typed(pat_type) = arg {
399 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
400 let ident = pat_ident.ident.clone();
401 let ty = (*pat_type.ty).clone();
402 request_fields.push(MethodArg { ident, ty });
403 method_inputs.push(pat_type.clone());
404 } else {
405 return Err(syn::Error::new(
406 pat_type.pat.span(),
407 "RPC trait method arguments must be simple identifiers",
408 ));
409 }
410 } else {
411 return Err(syn::Error::new(
412 arg.span(),
413 "unsupported argument type",
414 ));
415 }
416 }
417
418 let success_ty = extract_success_type(&method.sig)?;
419
420 let method_name = method.sig.ident.to_string();
421 let variant_base = method_name.to_upper_camel_case();
422 let request_struct_ident = format_ident!("{}Request", variant_base);
423 let name_literal = LitStr::new(method_name.as_str(), method.sig.ident.span());
424
425 methods.push(MethodInfo {
426 method_ident: method.sig.ident.clone(),
427 request_struct_ident,
428 request_fields,
429 method_inputs,
430 success_ty,
431 name_literal,
432 });
433 }
434 TraitItem::Type(item) => {
435 return Err(syn::Error::new(
436 item.span(),
437 "RPC traits cannot declare associated types",
438 ));
439 }
440 TraitItem::Const(item) => {
441 return Err(syn::Error::new(
442 item.span(),
443 "RPC traits cannot declare associated constants",
444 ));
445 }
446 _ => {}
447 }
448 }
449
450 Ok(methods)
451}
452
453fn extract_success_type(sig: &syn::Signature) -> syn::Result<Type> {
454 let return_type = match &sig.output {
455 syn::ReturnType::Default => {
456 return Err(syn::Error::new(
457 sig.span(),
458 "RPC trait methods must return ::bitrpc::Result<T>",
459 ))
460 }
461 syn::ReturnType::Type(_, ty) => ty,
462 };
463
464 match return_type.as_ref() {
465 Type::Path(type_path) => extract_success_type_from_path(type_path),
466 _ => Err(syn::Error::new(
467 return_type.span(),
468 "RPC trait methods must return ::bitrpc::Result<T>",
469 )),
470 }
471}
472
473fn extract_success_type_from_path(type_path: &syn::TypePath) -> syn::Result<Type> {
474 let last_segment = type_path
475 .path
476 .segments
477 .last()
478 .ok_or_else(|| syn::Error::new(type_path.span(), "invalid return type"))?;
479
480 if last_segment.ident != "Result" {
481 return Err(syn::Error::new(
482 last_segment.ident.span(),
483 "RPC trait methods must return ::bitrpc::Result<T>",
484 ));
485 }
486
487 match &last_segment.arguments {
488 PathArguments::AngleBracketed(args) => {
489 let mut iter = args.args.iter();
490 if let Some(syn::GenericArgument::Type(success_ty)) = iter.next() {
491 Ok(success_ty.clone())
492 } else {
493 Err(syn::Error::new(
494 args.span(),
495 "Result must specify a success type",
496 ))
497 }
498 }
499 _ => Err(syn::Error::new(
500 last_segment.arguments.span(),
501 "Result must use angle bracket generic arguments",
502 )),
503 }
504}
505
506fn parse_service_options(
507 args: Punctuated<KeyValue, Comma>,
508 trait_ident: &Ident,
509) -> syn::Result<ServiceOptions> {
510 let mut request_ident: Option<Ident> = None;
511 let mut response_ident: Option<Ident> = None;
512 let mut client_ident: Option<Ident> = None;
513 let mut error_path: Option<Path> = None;
514
515 for arg in args {
516 let key = arg.key.to_string();
517 match key.as_str() {
518 "request" => match arg.value {
519 Expr::Path(expr_path) if expr_path.path.segments.len() == 1 => {
520 request_ident = Some(expr_path.path.segments[0].ident.clone());
521 }
522 _ => {
523 return Err(syn::Error::new(
524 arg.value.span(),
525 "request must be a simple identifier",
526 ))
527 }
528 },
529 "response" => match arg.value {
530 Expr::Path(expr_path) if expr_path.path.segments.len() == 1 => {
531 response_ident = Some(expr_path.path.segments[0].ident.clone());
532 }
533 _ => {
534 return Err(syn::Error::new(
535 arg.value.span(),
536 "response must be a simple identifier",
537 ))
538 }
539 },
540 "client" => match arg.value {
541 Expr::Path(expr_path) if expr_path.path.segments.len() == 1 => {
542 client_ident = Some(expr_path.path.segments[0].ident.clone());
543 }
544 _ => {
545 return Err(syn::Error::new(
546 arg.value.span(),
547 "client must be a simple identifier",
548 ))
549 }
550 },
551 "error" => match arg.value {
552 Expr::Path(expr_path) => {
553 error_path = Some(expr_path.path.clone());
554 }
555 _ => {
556 return Err(syn::Error::new(
557 arg.value.span(),
558 "error must be a path",
559 ))
560 }
561 },
562 _ => {
563 return Err(syn::Error::new(
564 arg.key.span(),
565 "unsupported service option",
566 ))
567 }
568 }
569 }
570
571 let base_name = trait_ident.to_string();
572 let request_ident = request_ident.unwrap_or_else(|| format_ident!("{}Request", base_name));
573 let response_ident = response_ident.unwrap_or_else(|| format_ident!("{}Response", base_name));
574 let client_ident = client_ident.unwrap_or_else(|| format_ident!("{}Client", base_name));
575 let error_path = error_path.unwrap_or_else(|| syn::parse_quote!(::bitrpc::RpcError));
576
577 Ok(ServiceOptions {
578 request_ident,
579 response_ident,
580 client_ident,
581 error_path,
582 })
583}