rpc_it_macros/
lib.rs

1use std::collections::HashSet;
2
3use convert_case::{Case, Casing};
4use proc_macro2::TokenStream;
5use proc_macro_error::{abort, emit_error, proc_macro_error};
6use quote::{quote, quote_spanned};
7use syn::{
8    spanned::Spanned, FnArg, GenericArgument, Ident, Pat, ReturnType, Token, TraitItem,
9    TraitItemFn, Type, VisRestricted, Visibility,
10};
11
12/// Defines new RPC service
13///
14/// All parameter types must implement both of [`serde::Serialize`] and [`serde::Deserialize`].
15///
16/// The trait name will be converted to snake_case module, and all related definitions will be
17/// defined inside the module.
18///
19/// # Available Attributes for Traits
20///
21/// - `no_service`: Do not generate service-related code (TODO)
22/// - `no_client`: Do not generate client-related code (TODO)
23///
24/// # Available Attributes for Methods
25///
26/// - `sync`: Force generation of synchronous functions
27/// - `aliases = "..."`: Additional routes for the method. This is useful when you want to have
28///   multiple routes for the same method.
29/// - `with_reuse`: Generate `*_with_reuse` series of methods. This is useful when you want to
30///   optimize buffer allocation over multiple consecutive calls.
31/// - `skip`: Do not generate any code from this. This is useful when you need just a trait method,
32///   which can be used another default implementations.
33/// - `route`: Rename routing for caller
34///
35#[proc_macro_error]
36#[proc_macro_attribute]
37pub fn service(
38    _attr: proc_macro::TokenStream,
39    item: proc_macro::TokenStream,
40) -> proc_macro::TokenStream {
41    let tokens = proc_macro2::TokenStream::from(item);
42
43    /*
44        From comment example, this macro automatically implements:
45            struct MyServiceLoader;
46            impl MyServiceLoader {
47                fn load(this: Arc<MyService>, service: &mut ServiceBuilder) {
48                    service.register ...
49                }
50            }
51
52        For client, this struct is defined ...
53            struct MyServiceStub<'a>(Cow<'a, Transceiver>);
54            impl<'a> MyServiceStub<'a> {
55                pub fn new(trans: impl Into<Cow<'a, Transceiver>>) {
56
57                }
58            }
59    */
60
61    let Ok(ast) = syn::parse2::<syn::ItemTrait>(tokens) else {
62        return proc_macro::TokenStream::new();
63    };
64
65    let module_name = format!("{}", ast.ident.to_string().to_case(Case::Snake));
66    let module_name = Ident::new(&module_name, proc_macro2::Span::call_site());
67    let original_vis = &ast.vis;
68    let vis = match ast.vis.clone() {
69        vis @ (syn::Visibility::Public(_) | syn::Visibility::Restricted(_)) => vis,
70        syn::Visibility::Inherited => syn::Visibility::Restricted(VisRestricted {
71            pub_token: Token![pub](proc_macro2::Span::call_site()),
72            paren_token: Default::default(),
73            in_token: None,
74            path: Box::new(syn::parse2::<syn::Path>(quote!(super)).unwrap()),
75        }),
76    };
77
78    let mut statefuls = Vec::new();
79    let mut statelesses = Vec::new();
80    let mut call_binds = Vec::new();
81    let mut name_table = HashSet::new();
82
83    let (functions, attrs): (Vec<_>, Vec<_>) = ast
84        .items
85        .iter()
86        .filter_map(|x| if let TraitItem::Fn(x) = x { Some(x) } else { None })
87        .map(|x| {
88            let mut x = x.clone();
89            let attrs = method_attrs(&mut x);
90            (x, attrs)
91        })
92        .unzip();
93
94    let non_functions = ast
95        .items
96        .iter()
97        .filter_map(|x| if let TraitItem::Fn(_) = x { None } else { Some(x) })
98        .collect::<Vec<_>>();
99
100    for item in functions.iter().zip(&attrs) {
101        if let Some(loaded) = generate_loader_item(item.0, item.1, &mut name_table) {
102            match loaded {
103                LoaderOutput::Stateful(stateful) => statefuls.push(stateful),
104                LoaderOutput::Stateless(stateless) => statelesses.push(stateless),
105            }
106        }
107
108        if let Some(caller) = generate_call_stubs(item.0, item.1, &vis) {
109            call_binds.push(caller);
110        }
111    }
112
113    let trait_signatures = generate_trait_signatures(&functions, &attrs);
114
115    let output = quote!(
116        #original_vis mod #module_name {
117            #![allow(unused_parens)]
118            #![allow(unused)]
119
120            use super::*;
121
122            use rpc_it::service as __sv;
123            use rpc_it::service::macro_utils as __mc;
124            use rpc_it::serde;
125            use rpc_it::ExtractUserData;
126
127            #vis trait Service: Send + Sync + 'static + Clone {
128                #trait_signatures
129                #(#non_functions)*
130            }
131
132            #vis fn load_service_stateful_only<T: Service, R: __sv::Router>(
133                __this: T,
134                __service: &mut __sv::ServiceBuilder<R>
135            ) -> __mc::RegisterResult {
136                #(#statefuls;)*
137                Ok(())
138            }
139
140            #vis fn load_service_stateless_only<T: Service, R: __sv::Router>(
141                __service: &mut __sv::ServiceBuilder<R>
142            ) -> __mc::RegisterResult {
143                #(#statelesses;)*
144                Ok(())
145            }
146
147            #vis fn load_service<T:Service, R: __sv::Router>(
148                __this: T,
149                __service: &mut __sv::ServiceBuilder<R>
150            ) -> __mc::RegisterResult {
151                load_service_stateful_only(__this, __service)?;
152                load_service_stateless_only::<T, _>(__service)?;
153                Ok(())
154            }
155
156            #[derive(Debug, Clone)]
157            #vis struct Client(rpc_it::Sender);
158
159            impl Client {
160                #vis fn new(inner: rpc_it::Sender) -> Self {
161                    Self(inner)
162                }
163
164                #vis fn into_inner(self) -> rpc_it::Sender {
165                    self.0
166                }
167
168                #vis fn inner(&self) -> &rpc_it::Sender {
169                    &self.0
170                }
171
172                #(#call_binds)*
173            }
174        }
175    );
176
177    output.into()
178}
179
180enum LoaderOutput {
181    Stateful(TokenStream),
182    Stateless(TokenStream),
183}
184
185fn generate_loader_item(
186    method: &TraitItemFn,
187    attrs: &MethodAttrs,
188    used_route_table: &mut HashSet<String>,
189) -> Option<LoaderOutput> {
190    if attrs.skip {
191        return None;
192    }
193
194    let mut is_self_ref = false;
195    let mut is_stateless = false;
196
197    if let Some(receiver) = method.sig.receiver() {
198        if receiver.reference.is_some() && receiver.colon_token.is_none() {
199            if receiver.mutability.is_some() {
200                emit_error!(receiver, "Only `&self` is allowed");
201                return None;
202            }
203
204            is_self_ref = true;
205        } else if receiver.colon_token.is_some() && receiver.reference.is_none() {
206            is_self_ref = matches!(&*receiver.ty, syn::Type::Reference(_));
207        }
208    } else {
209        is_stateless = true;
210    };
211
212    // Additional routes
213    let is_sync_func = attrs.sync;
214    let mut routes = Vec::with_capacity(1 + attrs.aliases.len());
215    let ident = &method.sig.ident;
216    routes.push(
217        attrs
218            .route
219            .as_ref()
220            .map(syn::LitStr::value)
221            .unwrap_or_else(|| method.sig.ident.to_string()),
222    );
223
224    for route in &attrs.aliases {
225        routes.push(route.value());
226    }
227
228    // Pairs of (is_ref, req-type)
229    let (is_ref, inputs): (Vec<_>, Vec<_>) = method
230        .sig
231        .inputs
232        .iter()
233        .skip(if is_self_ref { 1 } else { 0 })
234        .map(|input| {
235            let syn::FnArg::Typed(pat) = input else {
236                abort!(input, "unexpected argument type");
237            };
238
239            if let Type::Reference(r) = &*pat.ty {
240                let inner = &r.elem;
241                (true, Type::Verbatim(quote!(std::borrow::Cow<#inner>)))
242            } else {
243                (false, (*pat.ty).clone())
244            }
245        })
246        .unzip();
247
248    let tup_inputs = quote!((#(#inputs),*));
249    let route_paths = quote!(&[#(#routes),*]);
250    let unpack = if inputs.len() == 1 {
251        let tok_ref = is_ref[0].then(|| quote!(&));
252        quote!(#tok_ref __req)
253    } else {
254        let vals = (0..inputs.len()).map(|x| syn::Index::from(x));
255        let tok_ref = is_ref.iter().map(|x| if *x { quote!(&) } else { quote!() });
256        quote!(#( #tok_ref __req.#vals ),*)
257    };
258
259    for r in routes {
260        if !used_route_table.insert(r.clone()) {
261            emit_error!(method, "duplicated route: {}", r);
262        }
263    }
264
265    let output = OutputType::new(&method.sig.output);
266
267    let tok_this_clone = (!is_stateless).then(|| quote!(let __this_2 = __this.clone();));
268    let tok_this_param = (!is_stateless).then(|| quote!(&__this_2,));
269
270    let strm = if output.is_notify() {
271        quote!(
272            #tok_this_clone
273            __service.register_notify_handler(#route_paths, move |__src, __req: #tup_inputs| {
274                T::#ident(#tok_this_param __src, #unpack);
275                Ok(())
276            })?
277        )
278    } else {
279        let type_out = output.typed_req();
280        if is_sync_func {
281            let rval = output.handle_sync_retval_to_response(
282                Ident::new("__src", method.sig.output.span()),
283                Ident::new("__result", method.sig.output.span()),
284            );
285
286            quote!(
287                #tok_this_clone
288                __service.register_request_handler(#route_paths, move |__src: #type_out, __req: #tup_inputs| {
289                    let __result = T::#ident(#tok_this_param __src.user_data_owned(), #unpack);
290                    #rval;
291                    Ok(())
292                })?
293            )
294        } else {
295            quote!(
296                #tok_this_clone
297                __service.register_request_handler(#route_paths, move |__src: #type_out, __req: #tup_inputs| {
298                    T::#ident(#tok_this_param __src, #unpack);
299                    Ok(())
300                })?
301            )
302        }
303    };
304
305    Some(if is_stateless { LoaderOutput::Stateless(strm) } else { LoaderOutput::Stateful(strm) })
306}
307
308fn generate_call_stubs(
309    method: &TraitItemFn,
310    attrs: &MethodAttrs,
311    vis: &Visibility,
312) -> Option<TokenStream> {
313    if attrs.skip {
314        return None;
315    }
316
317    let has_receiver = method.sig.receiver().is_some();
318
319    let inputs = method
320        .sig
321        .inputs
322        .iter()
323        .skip(if has_receiver { 1 } else { 0 })
324        .map(|arg| {
325            let FnArg::Typed(pat) = arg else { abort!(arg, "unexpected argument type") };
326            if !matches!(*pat.pat, Pat::Ident(_)) {
327                abort!(arg, "Function argument pattern must be named identifier.");
328            }
329            pat
330        })
331        .collect::<Vec<_>>();
332
333    let input_ref_args = inputs.iter().map(|x| *x).cloned().map(|mut x| {
334        x.ty = match *x.ty {
335            ty @ Type::Reference(_) => ty.into(),
336            other => Type::Reference(syn::TypeReference {
337                and_token: Token![&](other.span()),
338                lifetime: None,
339                mutability: None,
340                elem: other.into(),
341            })
342            .into(),
343        };
344        x
345    });
346    let input_ref_arg_tokens = quote!(#(#input_ref_args),*);
347
348    let input_idents = inputs
349        .iter()
350        .map(|x| *x)
351        .cloned()
352        .map(|x| {
353            let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &*x.pat else { unreachable!() };
354            ident.clone()
355        })
356        .collect::<Vec<_>>();
357
358    let method_ident = &method.sig.ident;
359    let output = OutputType::new(&method.sig.output);
360
361    let method_str =
362        attrs.route.as_ref().map(syn::LitStr::value).unwrap_or_else(|| method_ident.to_string());
363
364    let new_ident_suffixed =
365        |sfx: &str| syn::Ident::new(&format!("{0}_{1}", method_ident, sfx), method_ident.span());
366    let new_ident_prefixed =
367        |sfx: &str| syn::Ident::new(&format!("{1}_{0}", method_ident, sfx), method_ident.span());
368    let method_ident_deferred = new_ident_suffixed("deferred");
369
370    Some(if output.is_notify() {
371        let method_ident_with_reuse = new_ident_suffixed("with_reuse");
372        let method_ident_deferred_with_reuse = new_ident_suffixed("deferred_with_reuse");
373
374        let reuse_version = attrs.with_reuse.then(|| quote!(
375            #[doc(hidden)]
376            #vis async fn #method_ident_with_reuse(&self, buffer: &mut rpc_it::rpc::WriteBuffer,  #input_ref_arg_tokens) -> Result<(), rpc_it::SendError> {
377                self.0.notify_with_reuse(buffer, #method_str, &(#(#input_idents),*)).await
378            }
379
380            #[doc(hidden)]
381            #vis async fn #method_ident_deferred_with_reuse(&self, buffer: &mut rpc_it::rpc::WriteBuffer,  #input_ref_arg_tokens) -> Result<(), rpc_it::SendError> {
382                self.0.notify_deferred_with_reuse(buffer, #method_str, &(#(#input_idents),*))
383            }
384        ));
385
386        quote!(
387            #vis async fn #method_ident(&self, #input_ref_arg_tokens) -> Result<(), rpc_it::SendError> {
388                self.0.notify(#method_str, &(#(#input_idents),*)).await
389            }
390
391
392            #vis async fn #method_ident_deferred(&self, #input_ref_arg_tokens) -> Result<(), rpc_it::SendError> {
393                self.0.notify_deferred(#method_str, &(#(#input_idents),*))
394            }
395
396            #reuse_version
397        )
398    } else {
399        let (ok_tok, err_tok) = match &output {
400            OutputType::Response(ok, err) => (quote!(#ok), quote!(#err)),
401            OutputType::ResponseNoErr(ok) => (quote!(#ok), quote!(())),
402            OutputType::Notify => unreachable!(),
403        };
404
405        let method_ident_request = new_ident_prefixed("request");
406
407        quote!(
408            #vis async fn #method_ident(&self, #input_ref_arg_tokens)
409                -> Result<#ok_tok, rpc_it::TypedCallError<#err_tok>>
410            {
411                self.0.call_with_err(#method_str, &(#(#input_idents),*)).await
412            }
413
414            #vis async fn #method_ident_request(&self, #input_ref_arg_tokens)
415                -> Result<rpc_it::TypedResponse<#ok_tok, #err_tok>, rpc_it::SendError>
416            {
417                let resp = self.0.request(#method_str, &(#(#input_idents),*)).await?;
418                Ok(rpc_it::TypedResponse::new(resp.to_owned()))
419            }
420
421            #vis async fn #method_ident_deferred(&self, #input_ref_arg_tokens)
422                -> Result<rpc_it::TypedResponse<#ok_tok, #err_tok>, rpc_it::SendError>
423            {
424                let resp = self.0.request_deferred(#method_str, &(#(#input_idents),*))?;
425                Ok(rpc_it::TypedResponse::new(resp.to_owned()))
426            }
427        )
428    })
429}
430
431fn generate_trait_signatures(items: &[TraitItemFn], attrs: &[MethodAttrs]) -> TokenStream {
432    let tokens = items.iter().zip(attrs).map(|(method, attrs)| {
433        let mut method = method.clone();
434        let out = OutputType::new(&method.sig.output);
435
436        if attrs.skip {
437            // Use as-is
438            return TraitItem::Fn(method);
439        }
440
441        let req_param_ident = if let Some(body) =
442            method.default.as_ref().filter(|_| !attrs.sync && !out.is_notify())
443        {
444            let span = body.span();
445            let id_req = Ident::new("___rq", span);
446            let payload: syn::Expr = syn::parse_quote_spanned!(span => (move || #body)());
447            let response = match out {
448                OutputType::Notify => unreachable!(),
449                OutputType::ResponseNoErr(_) => {
450                    quote_spanned!(span => #id_req.ok(&#payload).ok();)
451                }
452                OutputType::Response(_, _) => {
453                    quote_spanned!(
454                        span => match #payload {
455                            Ok(x) => #id_req.ok(&x).ok(),
456                            Err(e) => #id_req.err(&e).ok(),
457                        }
458                    )
459                }
460            };
461
462            method.default = Some(syn::parse_quote_spanned!(
463                span =>
464                {
465                    #response;
466                }
467            ));
468
469            syn::Pat::Ident(syn::PatIdent {
470                attrs: Vec::new(),
471                by_ref: None,
472                mutability: None,
473                subpat: None,
474                ident: id_req,
475            })
476        } else {
477            syn::Pat::Wild(syn::PatWild {
478                attrs: Vec::new(),
479                underscore_token: Token![_](method.sig.output.span()),
480            })
481        };
482
483        {
484            let has_receiver = method.sig.receiver().is_some();
485            let insert_at = if has_receiver { 1 } else { 0 };
486
487            if out.is_notify() {
488                method.sig.inputs.insert(
489                    insert_at,
490                    syn::parse_quote_spanned!(method.sig.output.span() => _: rpc_it::Notify),
491                );
492            } else if !attrs.sync {
493                method.sig.inputs.insert(
494                    insert_at,
495                    syn::FnArg::Typed(syn::PatType {
496                        attrs: Vec::new(),
497                        colon_token: Default::default(),
498                        pat: req_param_ident.into(),
499                        ty: out.typed_req().into(),
500                    }),
501                );
502
503                method.sig.output = ReturnType::Default;
504            } else {
505                method.sig.inputs.insert(
506                    insert_at,
507                    syn::parse_quote_spanned!(method.sig.output.span() => _: rpc_it::OwnedUserData),
508                );
509            }
510        }
511
512        TraitItem::Fn(method)
513    });
514
515    quote!(#(#tokens)*)
516}
517
518#[derive(Default)]
519struct MethodAttrs {
520    sync: bool,
521    skip: bool,
522    aliases: Vec<syn::LitStr>,
523    with_reuse: bool,
524    route: Option<syn::LitStr>,
525}
526
527fn method_attrs(method: &mut TraitItemFn) -> MethodAttrs {
528    let mut attrs = MethodAttrs::default();
529
530    for attr in std::mem::take(&mut method.attrs) {
531        match &attr.meta {
532            syn::Meta::Path(path) => {
533                if path.is_ident("sync") {
534                    if matches!(method.sig.output, ReturnType::Default) {
535                        emit_error!(attr, "'sync' attribute is only allowed for requests");
536                    }
537
538                    attrs.sync = true;
539                } else if path.is_ident("skip") {
540                    attrs.skip = true;
541                } else if path.is_ident("with_reuse") {
542                    attrs.with_reuse = true;
543                } else {
544                    emit_error!(attr, "unexpected attribute")
545                }
546            }
547
548            syn::Meta::List(_) => {
549                emit_error!(attr, "unexpected attribute")
550            }
551
552            syn::Meta::NameValue(kv) => {
553                let Some(ident) = kv.path.get_ident() else {
554                    emit_error!(attr, "unexpected attribute");
555                    continue;
556                };
557                let syn::Expr::Lit(syn::ExprLit { lit, .. }) = &kv.value else {
558                    emit_error!(attr, "unexpected attribute");
559                    continue;
560                };
561
562                if ident == "aliases" {
563                    let syn::Lit::Str(route) = lit else {
564                        emit_error!(lit, "unexpected non-string literal attribute");
565                        continue;
566                    };
567                    attrs.aliases.push(route.clone());
568                } else if ident == "route" {
569                    let syn::Lit::Str(route) = lit else {
570                        emit_error!(lit, "unexpected non-string literal attribute");
571                        continue;
572                    };
573                    attrs.route = Some(route.clone());
574                } else {
575                    emit_error!(attr, "unexpected attribute")
576                }
577            }
578        }
579    }
580
581    attrs
582}
583
584enum OutputType {
585    Notify,
586    ResponseNoErr(Type),
587    Response(GenericArgument, GenericArgument),
588}
589
590impl OutputType {
591    fn new(val: &syn::ReturnType) -> Self {
592        let syn::ReturnType::Type(_, ty) = val else { return Self::Notify };
593
594        let fb = || Self::ResponseNoErr((**ty).clone());
595        let Type::Path(tp) = &**ty else { return fb() };
596        let Some(first_seg) = tp.path.segments.first() else { return fb() };
597
598        if first_seg.ident != "Result" {
599            return fb();
600        }
601
602        let syn::PathArguments::AngleBracketed(ang) = &first_seg.arguments else {
603            return fb();
604        };
605
606        let mut type_iter = ang.args.iter();
607        let [Some(ok), Some(err)] = std::array::from_fn(|_| type_iter.next()) else {
608            return fb();
609        };
610
611        Self::Response(ok.clone(), err.clone())
612    }
613
614    fn is_notify(&self) -> bool {
615        matches!(self, Self::Notify)
616    }
617
618    fn typed_req(&self) -> Type {
619        match self {
620            OutputType::Notify => unimplemented!(),
621
622            OutputType::ResponseNoErr(x) => {
623                syn::parse2(quote!(rpc_it::TypedRequest<#x, ()>)).unwrap()
624            }
625
626            OutputType::Response(r, e) => {
627                syn::parse2(quote!(rpc_it::TypedRequest<#r, #e>)).unwrap()
628            }
629        }
630    }
631
632    fn handle_sync_retval_to_response(&self, req_ident: Ident, val_ident: Ident) -> TokenStream {
633        match self {
634            OutputType::Notify => unimplemented!(),
635
636            OutputType::ResponseNoErr(_) => {
637                quote!(#req_ident.ok(&#val_ident)?;)
638            }
639
640            OutputType::Response(_, _) => {
641                quote!(
642                    match #val_ident {
643                        Ok(x) => #req_ident.ok(&x)?,
644                        Err(e) => #req_ident.err(&e)?,
645                    }
646                )
647            }
648        }
649    }
650}