bitrpc_macros/
lib.rs

1use heck::ToUpperCamelCase;
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::parse::{Parse, Parser};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8    parse_macro_input,
9    parse_quote,
10    spanned::Spanned,
11    Expr,
12    FnArg,
13    Ident,
14    ItemTrait,
15    LitStr,
16    Pat,
17    Path,
18    PathArguments,
19    TraitItem,
20    Type,
21    TypeParamBound,
22};
23
24#[proc_macro_attribute]
25pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
26    let parser = Punctuated::<KeyValue, Comma>::parse_terminated;
27    let args_tokens = proc_macro2::TokenStream::from(attr);
28    let args = match parser.parse2(args_tokens) {
29        Ok(value) => value,
30        Err(err) => return err.into_compile_error().into(),
31    };
32
33    let mut input = parse_macro_input!(item as ItemTrait);
34
35    match expand_service(args, &mut input) {
36        Ok(tokens) => tokens.into(),
37        Err(err) => err.to_compile_error().into(),
38    }
39}
40
41struct KeyValue {
42    key: Ident,
43    value: Expr,
44}
45
46impl Parse for KeyValue {
47    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
48        let key: Ident = input.parse()?;
49        input.parse::<syn::Token![=]>()?;
50        let value: Expr = input.parse()?;
51        Ok(Self { key, value })
52    }
53}
54
55struct ServiceOptions {
56    request_ident: Ident,
57    response_ident: Ident,
58    client_ident: Ident,
59    error_path: Path,
60}
61
62struct MethodArg {
63    ident: Ident,
64    ty: Type,
65}
66
67struct MethodInfo {
68    method_ident: Ident,
69    request_struct_ident: Ident,
70    request_fields: Vec<MethodArg>,
71    method_inputs: Vec<syn::PatType>,
72    success_ty: Type,
73    name_literal: LitStr,
74}
75
76fn expand_service(
77    args: Punctuated<KeyValue, Comma>,
78    input: &mut ItemTrait,
79) -> syn::Result<proc_macro2::TokenStream> {
80    ensure_async_trait(input)?;
81
82    let options = parse_service_options(args, &input.ident)?;
83
84    let methods = collect_methods(input)?;
85
86    if methods.is_empty() {
87        return Err(syn::Error::new(
88            input.ident.span(),
89            "RPC traits must declare at least one method",
90        ));
91    }
92
93    let trait_ident = &input.ident;
94    let vis = &input.vis;
95    let request_ident = &options.request_ident;
96    let response_ident = &options.response_ident;
97    let client_ident = &options.client_ident;
98    let error_path = &options.error_path;
99
100    let mut request_structs = Vec::new();
101    let mut request_variants = Vec::new();
102    let mut response_variants = Vec::new();
103    let mut request_variant_names = Vec::new();
104    let mut response_variant_names = Vec::new();
105    let mut dispatch_arms = Vec::new();
106    let mut client_methods = Vec::new();
107
108    // Generate 256 placeholder variants for stable encoding
109    const MAX_METHODS: usize = 256;
110    
111    if methods.len() > MAX_METHODS {
112        return Err(syn::Error::new(
113            input.ident.span(),
114            format!("RPC traits cannot have more than {} methods", MAX_METHODS),
115        ));
116    }
117    
118    // Map methods to placeholder indices based on trait definition order
119    for (method_idx, method_info) in methods.iter().enumerate() {
120        let MethodInfo {
121            method_ident,
122            request_struct_ident,
123            request_fields,
124            method_inputs,
125            success_ty,
126            name_literal,
127        } = method_info;
128        
129        // Use placeholder variant name for stable encoding
130        let placeholder_ident = format_ident!("Method{}", method_idx);
131
132        let mut struct_fields = Vec::new();
133        let mut destructure_fields = Vec::new();
134        let mut argument_idents = Vec::new();
135        let mut request_init = Vec::new();
136
137        for field in request_fields {
138            let ident = &field.ident;
139            let ty = &field.ty;
140            struct_fields.push(quote! { pub #ident: #ty });
141            destructure_fields.push(quote! { #ident });
142            argument_idents.push(quote! { #ident });
143            request_init.push(quote! { #ident });
144        }
145
146        request_structs.push(quote! {
147            #[derive(::bitrpc::bitcode::Encode, ::bitrpc::bitcode::Decode, ::core::fmt::Debug)]
148            #vis struct #request_struct_ident {
149                #( #struct_fields, )*
150            }
151        });
152
153        request_variants.push(quote! { #placeholder_ident(#request_struct_ident) });
154        response_variants.push(quote! { #placeholder_ident(#success_ty) });
155        request_variant_names.push(quote! { #request_ident::#placeholder_ident(_) => #name_literal });
156        response_variant_names.push(quote! { #response_ident::#placeholder_ident(_) => #name_literal });
157
158        dispatch_arms.push(quote! {
159            #request_ident::#placeholder_ident(payload) => {
160                let #request_struct_ident { #( #destructure_fields, )* } = payload;
161                match handler.#method_ident(#( #argument_idents ),*).await {
162                    ::core::result::Result::Ok(value) => #response_ident::#placeholder_ident(value),
163                    ::core::result::Result::Err(err) => #response_ident::Error(err),
164                }
165            }
166        });
167
168        let client_args_def = method_inputs.iter().map(|pat_type| quote! { #pat_type });
169        let request_struct_init = quote! {
170            #request_struct_ident { #( #request_init, )* }
171        };
172
173        client_methods.push(quote! {
174            pub async fn #method_ident(&mut self #( , #client_args_def )* ) -> ::bitrpc::Result<#success_ty> {
175                let request = #request_ident::#placeholder_ident(#request_struct_init);
176                let bytes = ::bitrpc::bitcode::encode(&request);
177                let response_bytes = self.transport.call(bytes).await?;
178                let response = #response_ident::decode(&response_bytes)?;
179                match response {
180                    #response_ident::#placeholder_ident(value) => ::core::result::Result::Ok(value),
181                    #response_ident::Error(err) => ::core::result::Result::Err(err),
182                    other => ::core::result::Result::Err(::bitrpc::RpcError::unexpected(#name_literal, other.variant_name())),
183                }
184            }
185        });
186    }
187    
188    // Add remaining placeholders for future expansion
189    for i in methods.len()..(MAX_METHODS - 1) { // -1 to leave room for Error variant
190        let placeholder_ident = format_ident!("Placeholder{}", i);
191        request_variants.push(quote! { #placeholder_ident });
192        response_variants.push(quote! { #placeholder_ident });
193        request_variant_names.push(quote! { 
194            #request_ident::#placeholder_ident => concat!("Placeholder", stringify!(#i))
195        });
196        response_variant_names.push(quote! { 
197            #response_ident::#placeholder_ident => concat!("Placeholder", stringify!(#i))
198        });
199    }
200
201    response_variants.push(quote! { Error(#error_path) });
202    response_variant_names.push(quote! { #response_ident::Error(_) => "Error" });
203
204    let expanded = quote! {
205        #[::bitrpc::async_trait]
206        #input
207
208        #( #request_structs )*
209
210        #[derive(::bitrpc::bitcode::Encode, ::bitrpc::bitcode::Decode, ::core::fmt::Debug)]
211        #vis enum #request_ident {
212            #( #request_variants, )*
213        }
214
215        impl #request_ident {
216            pub fn encode(&self) -> ::std::vec::Vec<u8> {
217                ::bitrpc::bitcode::encode(self)
218            }
219
220            pub fn decode(bytes: &[u8]) -> ::core::result::Result<Self, ::bitrpc::DecodeError> {
221                ::bitrpc::bitcode::decode(bytes)
222            }
223
224            pub fn variant_name(&self) -> &'static str {
225                match self {
226                    #( #request_variant_names, )*
227                }
228            }
229        }
230
231        #[derive(::bitrpc::bitcode::Encode, ::bitrpc::bitcode::Decode, ::core::fmt::Debug)]
232        #vis enum #response_ident {
233            #( #response_variants, )*
234        }
235
236        impl #response_ident {
237            pub fn encode(&self) -> ::std::vec::Vec<u8> {
238                ::bitrpc::bitcode::encode(self)
239            }
240
241            pub fn decode(bytes: &[u8]) -> ::core::result::Result<Self, ::bitrpc::DecodeError> {
242                ::bitrpc::bitcode::decode(bytes)
243            }
244
245            pub fn variant_name(&self) -> &'static str {
246                match self {
247                    #( #response_variant_names, )*
248                }
249            }
250        }
251
252        pub async fn dispatch<T>(handler: &T, request: #request_ident) -> #response_ident
253        where
254            T: #trait_ident + ?Sized,
255        {
256            match request {
257                #( #dispatch_arms, )*
258                _ => #response_ident::Error(#error_path::unknown_method()),
259            }
260        }
261
262        #vis struct #client_ident<T> {
263            transport: T,
264        }
265
266        impl<T> #client_ident<T> {
267            pub fn new(transport: T) -> Self {
268                Self { transport }
269            }
270
271            pub fn into_inner(self) -> T {
272                self.transport
273            }
274
275            pub fn transport(&self) -> &T {
276                &self.transport
277            }
278
279            pub fn transport_mut(&mut self) -> &mut T {
280                &mut self.transport
281            }
282        }
283
284        impl<T> #client_ident<T> where T: ::bitrpc::RpcTransport {
285            #( #client_methods )*
286        }
287
288        #[derive(Clone)]
289        #vis struct RpcRequestServiceWrapper<T>(pub T);
290
291        impl<T> ::bitrpc::RpcRequestService for RpcRequestServiceWrapper<T>
292        where
293            T: #trait_ident + Clone,
294        {
295            type Request = #request_ident;
296            type Response = #response_ident;
297
298            async fn dispatch(&self, request: #request_ident) -> #response_ident {
299                dispatch(&self.0, request).await
300            }
301        }
302    };
303
304    Ok(expanded)
305}
306
307fn ensure_async_trait(trait_item: &mut ItemTrait) -> syn::Result<()> {
308    let mut has_send = false;
309    let mut has_sync = false;
310
311    for bound in &trait_item.supertraits {
312        if let TypeParamBound::Trait(bound_trait) = bound {
313            if bound_trait
314                .path
315                .segments
316                .last()
317                .map(|seg| seg.ident == "Send")
318                .unwrap_or(false)
319            {
320                has_send = true;
321            }
322
323            if bound_trait
324                .path
325                .segments
326                .last()
327                .map(|seg| seg.ident == "Sync")
328                .unwrap_or(false)
329            {
330                has_sync = true;
331            }
332        }
333    }
334
335    if !has_send {
336        if !trait_item.supertraits.is_empty() {
337            trait_item.supertraits.push_punct(syn::token::Plus::default());
338        }
339        trait_item
340            .supertraits
341            .push_value(parse_quote!(::core::marker::Send));
342    }
343
344    if !has_sync {
345        if !trait_item.supertraits.is_empty() {
346            trait_item.supertraits.push_punct(syn::token::Plus::default());
347        }
348        trait_item
349            .supertraits
350            .push_value(parse_quote!(::core::marker::Sync));
351    }
352
353    Ok(())
354}
355
356fn collect_methods(trait_item: &ItemTrait) -> syn::Result<Vec<MethodInfo>> {
357    let mut methods = Vec::new();
358
359    for item in &trait_item.items {
360        match item {
361            TraitItem::Fn(method) => {
362                if method.default.is_some() {
363                    return Err(syn::Error::new(
364                        method.sig.span(),
365                        "RPC trait methods cannot have default implementations",
366                    ));
367                }
368
369                if method.sig.asyncness.is_none() {
370                    return Err(syn::Error::new(
371                        method.sig.span(),
372                        "RPC trait methods must be async",
373                    ));
374                }
375
376                let mut inputs_iter = method.sig.inputs.iter();
377                match inputs_iter.next() {
378                    Some(FnArg::Receiver(recv)) => {
379                        if recv.reference.is_none() || recv.mutability.is_some() {
380                            return Err(syn::Error::new(
381                                recv.span(),
382                                "RPC trait methods must take &self",
383                            ));
384                        }
385                    }
386                    _ => {
387                        return Err(syn::Error::new(
388                            method.sig.span(),
389                            "RPC trait methods must take &self",
390                        ));
391                    }
392                }
393
394                let mut request_fields = Vec::new();
395                let mut method_inputs = Vec::new();
396
397                for arg in method.sig.inputs.iter().skip(1) {
398                    if let FnArg::Typed(pat_type) = arg {
399                        if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
400                            let ident = pat_ident.ident.clone();
401                            let ty = (*pat_type.ty).clone();
402                            request_fields.push(MethodArg { ident, ty });
403                            method_inputs.push(pat_type.clone());
404                        } else {
405                            return Err(syn::Error::new(
406                                pat_type.pat.span(),
407                                "RPC trait method arguments must be simple identifiers",
408                            ));
409                        }
410                    } else {
411                        return Err(syn::Error::new(
412                            arg.span(),
413                            "unsupported argument type",
414                        ));
415                    }
416                }
417
418                let success_ty = extract_success_type(&method.sig)?;
419
420                let method_name = method.sig.ident.to_string();
421                let variant_base = method_name.to_upper_camel_case();
422                let request_struct_ident = format_ident!("{}Request", variant_base);
423                let name_literal = LitStr::new(method_name.as_str(), method.sig.ident.span());
424
425                methods.push(MethodInfo {
426                    method_ident: method.sig.ident.clone(),
427                    request_struct_ident,
428                    request_fields,
429                    method_inputs,
430                    success_ty,
431                    name_literal,
432                });
433            }
434            TraitItem::Type(item) => {
435                return Err(syn::Error::new(
436                    item.span(),
437                    "RPC traits cannot declare associated types",
438                ));
439            }
440            TraitItem::Const(item) => {
441                return Err(syn::Error::new(
442                    item.span(),
443                    "RPC traits cannot declare associated constants",
444                ));
445            }
446            _ => {}
447        }
448    }
449
450    Ok(methods)
451}
452
453fn extract_success_type(sig: &syn::Signature) -> syn::Result<Type> {
454    let return_type = match &sig.output {
455        syn::ReturnType::Default => {
456            return Err(syn::Error::new(
457                sig.span(),
458                "RPC trait methods must return ::bitrpc::Result<T>",
459            ))
460        }
461        syn::ReturnType::Type(_, ty) => ty,
462    };
463
464    match return_type.as_ref() {
465        Type::Path(type_path) => extract_success_type_from_path(type_path),
466        _ => Err(syn::Error::new(
467            return_type.span(),
468            "RPC trait methods must return ::bitrpc::Result<T>",
469        )),
470    }
471}
472
473fn extract_success_type_from_path(type_path: &syn::TypePath) -> syn::Result<Type> {
474    let last_segment = type_path
475        .path
476        .segments
477        .last()
478        .ok_or_else(|| syn::Error::new(type_path.span(), "invalid return type"))?;
479
480    if last_segment.ident != "Result" {
481        return Err(syn::Error::new(
482            last_segment.ident.span(),
483            "RPC trait methods must return ::bitrpc::Result<T>",
484        ));
485    }
486
487    match &last_segment.arguments {
488        PathArguments::AngleBracketed(args) => {
489            let mut iter = args.args.iter();
490            if let Some(syn::GenericArgument::Type(success_ty)) = iter.next() {
491                Ok(success_ty.clone())
492            } else {
493                Err(syn::Error::new(
494                    args.span(),
495                    "Result must specify a success type",
496                ))
497            }
498        }
499        _ => Err(syn::Error::new(
500            last_segment.arguments.span(),
501            "Result must use angle bracket generic arguments",
502        )),
503    }
504}
505
506fn parse_service_options(
507    args: Punctuated<KeyValue, Comma>,
508    trait_ident: &Ident,
509) -> syn::Result<ServiceOptions> {
510    let mut request_ident: Option<Ident> = None;
511    let mut response_ident: Option<Ident> = None;
512    let mut client_ident: Option<Ident> = None;
513    let mut error_path: Option<Path> = None;
514
515    for arg in args {
516        let key = arg.key.to_string();
517        match key.as_str() {
518            "request" => match arg.value {
519                Expr::Path(expr_path) if expr_path.path.segments.len() == 1 => {
520                    request_ident = Some(expr_path.path.segments[0].ident.clone());
521                }
522                _ => {
523                    return Err(syn::Error::new(
524                        arg.value.span(),
525                        "request must be a simple identifier",
526                    ))
527                }
528            },
529            "response" => match arg.value {
530                Expr::Path(expr_path) if expr_path.path.segments.len() == 1 => {
531                    response_ident = Some(expr_path.path.segments[0].ident.clone());
532                }
533                _ => {
534                    return Err(syn::Error::new(
535                        arg.value.span(),
536                        "response must be a simple identifier",
537                    ))
538                }
539            },
540            "client" => match arg.value {
541                Expr::Path(expr_path) if expr_path.path.segments.len() == 1 => {
542                    client_ident = Some(expr_path.path.segments[0].ident.clone());
543                }
544                _ => {
545                    return Err(syn::Error::new(
546                        arg.value.span(),
547                        "client must be a simple identifier",
548                    ))
549                }
550            },
551            "error" => match arg.value {
552                Expr::Path(expr_path) => {
553                    error_path = Some(expr_path.path.clone());
554                }
555                _ => {
556                    return Err(syn::Error::new(
557                        arg.value.span(),
558                        "error must be a path",
559                    ))
560                }
561            },
562            _ => {
563                return Err(syn::Error::new(
564                    arg.key.span(),
565                    "unsupported service option",
566                ))
567            }
568        }
569    }
570
571    let base_name = trait_ident.to_string();
572    let request_ident = request_ident.unwrap_or_else(|| format_ident!("{}Request", base_name));
573    let response_ident = response_ident.unwrap_or_else(|| format_ident!("{}Response", base_name));
574    let client_ident = client_ident.unwrap_or_else(|| format_ident!("{}Client", base_name));
575    let error_path = error_path.unwrap_or_else(|| syn::parse_quote!(::bitrpc::RpcError));
576
577    Ok(ServiceOptions {
578        request_ident,
579        response_ident,
580        client_ident,
581        error_path,
582    })
583}