Skip to main content

roam_macros_core/
lib.rs

1//! Code generation core for roam RPC service macros.
2//!
3//! This crate contains all the code generation logic for the `#[service]` proc macro,
4//! extracted into a regular library so it can be unit-tested with insta snapshots.
5
6#![deny(unsafe_code)]
7
8use ::quote::{format_ident, quote};
9use heck::ToSnakeCase;
10use proc_macro2::TokenStream as TokenStream2;
11
12pub mod crate_name;
13
14pub use roam_macros_parse::*;
15
16use crate_name::FoundCrate;
17
18/// Error type for validation/codegen errors.
19#[derive(Debug, Clone)]
20pub struct Error {
21    pub span: proc_macro2::Span,
22    pub message: String,
23}
24
25impl Error {
26    pub fn new(span: proc_macro2::Span, message: impl Into<String>) -> Self {
27        Self {
28            span,
29            message: message.into(),
30        }
31    }
32
33    pub fn to_compile_error(&self) -> TokenStream2 {
34        let msg = &self.message;
35        let span = self.span;
36        quote::quote_spanned! {span=> compile_error!(#msg); }
37    }
38}
39
40impl From<ParseError> for Error {
41    fn from(err: ParseError) -> Self {
42        Self::new(proc_macro2::Span::call_site(), err.to_string())
43    }
44}
45
46/// Parse a trait definition from a token stream, returning a codegen-friendly error.
47pub fn parse(tokens: &TokenStream2) -> Result<ServiceTrait, Error> {
48    parse_trait(tokens).map_err(Error::from)
49}
50
51/// Returns the token stream for accessing the `roam` crate.
52pub fn roam_crate() -> TokenStream2 {
53    match crate_name::crate_name("roam") {
54        Ok(FoundCrate::Itself) => quote! { crate },
55        Ok(FoundCrate::Name(name)) => {
56            let ident = format_ident!("{}", name);
57            quote! { ::#ident }
58        }
59        Err(_) => quote! { ::roam },
60    }
61}
62
63/// Convert a parsed type into a token stream where every borrowed lifetime is `'static`.
64///
65/// This is used for descriptor hashing and client borrowed-return decode paths, where
66/// we need a concrete `'static` shape type independent of method-local lifetimes.
67fn to_static_type_tokens(ty: &Type) -> TokenStream2 {
68    match ty {
69        Type::Reference(TypeRef { mutable, inner, .. }) => {
70            let inner = to_static_type_tokens(inner);
71            if mutable.is_some() {
72                quote! { &'static mut #inner }
73            } else {
74                quote! { &'static #inner }
75            }
76        }
77        Type::Tuple(TypeTuple(group)) => {
78            let elems: Vec<TokenStream2> = group
79                .content
80                .iter()
81                .map(|entry| to_static_type_tokens(&entry.value))
82                .collect();
83            match elems.len() {
84                0 => quote! { () },
85                1 => {
86                    let t = &elems[0];
87                    quote! { (#t,) }
88                }
89                _ => quote! { (#(#elems),*) },
90            }
91        }
92        Type::PathWithGenerics(PathWithGenerics { path, args, .. }) => {
93            let path = path.to_token_stream();
94            let args: Vec<TokenStream2> = args
95                .iter()
96                .map(|entry| match &entry.value {
97                    GenericArgument::Lifetime(_) => quote! { 'static },
98                    GenericArgument::Type(inner) => to_static_type_tokens(inner),
99                })
100                .collect();
101            quote! { #path < #(#args),* > }
102        }
103        Type::Path(path) => path.to_token_stream(),
104    }
105}
106
107// r[service-macro.is-source-of-truth]
108// r[impl rpc]
109// r[impl rpc.service]
110// r[impl rpc.service.methods]
111/// Generate all service code for a parsed trait.
112///
113/// Takes a `roam` token stream (the path to the roam crate) so that this function
114/// can be called from tests with a fixed path like `::roam`.
115pub fn generate_service(parsed: &ServiceTrait, roam: &TokenStream2) -> Result<TokenStream2, Error> {
116    // r[impl rpc.channel.placement]
117    // Validate: channels are only allowed in method args.
118    for method in parsed.methods() {
119        let return_type = method.return_type();
120        if return_type.contains_channel() {
121            return Err(Error::new(
122                proc_macro2::Span::call_site(),
123                format!(
124                    "method `{}` has Channel (Tx/Rx) in return type - channels are only allowed in method arguments",
125                    method.name()
126                ),
127            ));
128        }
129
130        let (ok_ty, err_ty) = method_ok_and_err_types(&return_type);
131        if ok_ty.has_elided_reference_lifetime() {
132            return Err(Error::new(
133                proc_macro2::Span::call_site(),
134                format!(
135                    "method `{}` return type uses an elided reference lifetime; use explicit `'roam` (for example `&'roam str`)",
136                    method.name()
137                ),
138            ));
139        }
140        if ok_ty.has_non_named_lifetime("roam") {
141            return Err(Error::new(
142                proc_macro2::Span::call_site(),
143                format!(
144                    "method `{}` return type may only use lifetime `'roam` for borrowed response data",
145                    method.name()
146                ),
147            ));
148        }
149        if let Some(err_ty) = err_ty
150            && (err_ty.has_lifetime() || err_ty.has_elided_reference_lifetime())
151        {
152            return Err(Error::new(
153                proc_macro2::Span::call_site(),
154                format!(
155                    "method `{}` error type must be owned (no lifetimes), because client errors are not wrapped in SelfRef",
156                    method.name()
157                ),
158            ));
159        }
160    }
161
162    let service_descriptor_fn = generate_service_descriptor_fn(parsed, roam);
163    let service_trait = generate_service_trait(parsed, roam);
164    let dispatcher = generate_dispatcher(parsed, roam);
165    let client = generate_client(parsed, roam);
166    Ok(quote! {
167        #service_descriptor_fn
168        #service_trait
169        #dispatcher
170        #client
171    })
172}
173
174// ============================================================================
175// Service Descriptor Generation
176// ============================================================================
177
178fn generate_service_descriptor_fn(parsed: &ServiceTrait, roam: &TokenStream2) -> TokenStream2 {
179    let service_name = parsed.name();
180    let descriptor_fn_name = format_ident!("{}_service_descriptor", service_name.to_snake_case());
181
182    // Build method descriptor expressions
183    let method_descriptors: Vec<TokenStream2> = parsed
184        .methods()
185        .map(|m| {
186            let method_name_str = m.name();
187
188            // Build args tuple type and return type
189            let arg_types: Vec<TokenStream2> =
190                m.args().map(|arg| to_static_type_tokens(&arg.ty)).collect();
191            let args_tuple_ty = quote! { (#(#arg_types,)*) };
192            let arg_name_strs: Vec<String> = m.args().map(|arg| arg.name().to_string()).collect();
193
194            let return_type = m.return_type();
195            let return_ty_tokens = to_static_type_tokens(&return_type);
196
197            let method_doc_expr = match m.doc() {
198                Some(d) => quote! { Some(#d) },
199                None => quote! { None },
200            };
201
202            quote! {
203                #roam::hash::method_descriptor::<#args_tuple_ty, #return_ty_tokens>(
204                    #service_name,
205                    #method_name_str,
206                    &[#(#arg_name_strs),*],
207                    #method_doc_expr,
208                )
209            }
210        })
211        .collect();
212
213    let service_doc_expr = match parsed.doc() {
214        Some(d) => quote! { Some(#d) },
215        None => quote! { None },
216    };
217
218    quote! {
219        #[allow(non_snake_case, clippy::all)]
220        pub fn #descriptor_fn_name() -> &'static #roam::session::ServiceDescriptor {
221            static DESCRIPTOR: std::sync::OnceLock<&'static #roam::session::ServiceDescriptor> = std::sync::OnceLock::new();
222            DESCRIPTOR.get_or_init(|| {
223                let methods: Vec<&'static #roam::session::MethodDescriptor> = vec![
224                    #(#method_descriptors),*
225                ];
226                Box::leak(Box::new(#roam::session::ServiceDescriptor {
227                    service_name: #service_name,
228                    methods: Box::leak(methods.into_boxed_slice()),
229                    doc: #service_doc_expr,
230                }))
231            })
232        }
233    }
234}
235
236// ============================================================================
237// Service Trait Generation
238// ============================================================================
239
240fn generate_service_trait(parsed: &ServiceTrait, roam: &TokenStream2) -> TokenStream2 {
241    let trait_name = parsed.name.clone();
242    let trait_doc = parsed.doc().map(|d| quote! { #[doc = #d] });
243
244    let methods: Vec<TokenStream2> = parsed
245        .methods()
246        .map(|m| generate_trait_method(m, roam))
247        .collect();
248
249    quote! {
250        #trait_doc
251        pub trait #trait_name
252        where
253            Self: Send + Sync,
254        {
255            #(#methods)*
256        }
257    }
258}
259
260fn generate_trait_method(method: &ServiceMethod, roam: &TokenStream2) -> TokenStream2 {
261    let method_name = format_ident!("{}", method.name().to_snake_case());
262    let method_doc = method.doc().map(|d| quote! { #[doc = #d] });
263
264    let return_type = method.return_type();
265    let (ok_ty_ref, err_ty_ref) = method_ok_and_err_types(&return_type);
266    let ok_has_roam_lifetime = ok_ty_ref.has_named_lifetime("roam");
267    let method_lifetime = if ok_has_roam_lifetime {
268        quote! { <'roam> }
269    } else {
270        quote! {}
271    };
272
273    let params: Vec<TokenStream2> = method
274        .args()
275        .map(|arg| {
276            let name = format_ident!("{}", arg.name().to_snake_case());
277            let ty = arg.ty.to_token_stream();
278            quote! { #name: #ty }
279        })
280        .collect();
281
282    if ok_has_roam_lifetime {
283        let ok_ty = ok_ty_ref.to_token_stream();
284        let err_ty = err_ty_ref
285            .map(Type::to_token_stream)
286            .unwrap_or_else(|| quote! { ::core::convert::Infallible });
287        quote! {
288            #method_doc
289            fn #method_name #method_lifetime (&self, call: impl #roam::Call<'roam, #ok_ty, #err_ty>, #(#params),*) -> impl std::future::Future<Output = ()> + Send;
290        }
291    } else {
292        let output_ty = return_type.to_token_stream();
293        quote! {
294            #method_doc
295            fn #method_name (&self, #(#params),*) -> impl std::future::Future<Output = #output_ty> + Send;
296        }
297    }
298}
299
300// ============================================================================
301// Dispatcher Generation
302// ============================================================================
303
304fn generate_dispatcher(parsed: &ServiceTrait, roam: &TokenStream2) -> TokenStream2 {
305    let trait_name = parsed.name.clone();
306    let dispatcher_name = format_ident!("{}Dispatcher", parsed.name());
307    let descriptor_fn_name = format_ident!("{}_service_descriptor", parsed.name().to_snake_case());
308
309    // Generate the if-else dispatch arms inside handle()
310    let dispatch_arms: Vec<TokenStream2> = parsed
311        .methods()
312        .enumerate()
313        .map(|(i, m)| generate_dispatch_arm(m, i, roam, &descriptor_fn_name))
314        .collect();
315
316    let no_methods = dispatch_arms.is_empty();
317
318    let dispatch_body = if no_methods {
319        quote! {
320            let _ = (call, reply);
321        }
322    } else {
323        // r[impl rpc.unknown-method]
324        quote! {
325            let method_id = call.method_id;
326            let args_bytes = match &call.args {
327                #roam::Payload::Incoming(bytes) => bytes,
328                _ => {
329                    reply.send_error(#roam::RoamError::<::core::convert::Infallible>::InvalidPayload).await;
330                    return;
331                }
332            };
333            #(#dispatch_arms)*
334            reply.send_error(#roam::RoamError::<::core::convert::Infallible>::UnknownMethod).await;
335        }
336    };
337
338    quote! {
339        /// Dispatcher for this service.
340        ///
341        /// Wraps a handler and implements [`#roam::Handler`] by routing incoming
342        /// calls to the appropriate trait method by method ID.
343        #[derive(Clone)]
344        pub struct #dispatcher_name<H> {
345            handler: H,
346        }
347
348        impl<H> #dispatcher_name<H>
349        where
350            H: #trait_name + Clone + Send + Sync + 'static,
351        {
352            /// Create a new dispatcher wrapping the given handler.
353            pub fn new(handler: H) -> Self {
354                Self { handler }
355            }
356        }
357
358        impl<H, R> #roam::Handler<R> for #dispatcher_name<H>
359        where
360            H: #trait_name + Clone + Send + Sync + 'static,
361            R: #roam::ReplySink,
362        {
363            async fn handle(&self, call: #roam::SelfRef<#roam::RequestCall<'static>>, reply: R) {
364                #dispatch_body
365            }
366        }
367    }
368}
369
370fn generate_dispatch_arm(
371    method: &ServiceMethod,
372    method_index: usize,
373    roam: &TokenStream2,
374    descriptor_fn_name: &proc_macro2::Ident,
375) -> TokenStream2 {
376    let method_fn = format_ident!("{}", method.name().to_snake_case());
377    let idx = method_index;
378
379    // Build args tuple type for deserialization
380    let arg_types: Vec<TokenStream2> = method
381        .args()
382        .map(|a| to_static_type_tokens(&a.ty))
383        .collect();
384    let args_tuple_type = match arg_types.len() {
385        0 => quote! { () },
386        1 => {
387            let t = &arg_types[0];
388            quote! { (#t,) }
389        }
390        _ => quote! { (#(#arg_types),*) },
391    };
392
393    // Destructure args tuple into named bindings
394    let arg_names: Vec<proc_macro2::Ident> = method
395        .args()
396        .map(|a| format_ident!("{}", a.name().to_snake_case()))
397        .collect();
398    let destructure = match arg_names.len() {
399        0 => quote! { let () = args; },
400        1 => {
401            let n = &arg_names[0];
402            quote! { let (#n,) = args; }
403        }
404        _ => quote! { let (#(#arg_names),*) = args; },
405    };
406
407    let _ = idx;
408
409    let has_channels = method.args().any(|a| a.ty.contains_channel());
410
411    let channel_binding = if has_channels {
412        quote! {
413            #[cfg(not(target_arch = "wasm32"))]
414            {
415                if let Some(binder) = reply.channel_binder() {
416                    let plan = #roam::RpcPlan::for_type::<#args_tuple_type>();
417                    if !plan.channel_locations.is_empty() {
418                        // SAFETY: args is a valid, initialized value of type #args_tuple_type
419                        // and we have exclusive access to it via &mut.
420                        #[allow(unsafe_code)]
421                        unsafe {
422                            #roam::bind_channels_callee_args(
423                                &mut args as *mut _ as *mut u8,
424                                plan,
425                                &call.channels,
426                                binder,
427                            );
428                        }
429                    }
430                }
431            }
432        }
433    } else {
434        quote! {}
435    };
436
437    // When there are channels, args must be mut for binding
438    let args_let = if has_channels {
439        quote! { let mut args: #args_tuple_type }
440    } else {
441        quote! { let args: #args_tuple_type }
442    };
443
444    let return_type = method.return_type();
445    let (ok_ty_ref, err_ty_ref) = method_ok_and_err_types(&return_type);
446    let ok_has_roam_lifetime = ok_ty_ref.has_named_lifetime("roam");
447    let is_fallible = return_type.as_result().is_some();
448    let ok_ty = ok_ty_ref.to_token_stream();
449    let err_ty = err_ty_ref
450        .map(Type::to_token_stream)
451        .unwrap_or_else(|| quote! { ::core::convert::Infallible });
452
453    let invoke_and_reply = if ok_has_roam_lifetime {
454        quote! {
455            let sink_call = #roam::SinkCall::new(reply);
456            self.handler.#method_fn(sink_call, #(#arg_names),*).await;
457        }
458    } else if is_fallible {
459        quote! {
460            let result = self.handler.#method_fn(#(#arg_names),*).await;
461            let sink_call = #roam::SinkCall::new(reply);
462            #roam::Call::<'_, #ok_ty, #err_ty>::reply(sink_call, result).await;
463        }
464    } else {
465        quote! {
466            let value = self.handler.#method_fn(#(#arg_names),*).await;
467            let sink_call = #roam::SinkCall::new(reply);
468            #roam::Call::<'_, #ok_ty, #err_ty>::ok(sink_call, value).await;
469        }
470    };
471
472    quote! {
473        if method_id == #descriptor_fn_name().methods[#idx].id {
474            #args_let = match #roam::facet_postcard::from_slice_borrowed(args_bytes) {
475                Ok(v) => v,
476                Err(_) => {
477                    reply.send_error(#roam::RoamError::<::core::convert::Infallible>::InvalidPayload).await;
478                    return;
479                }
480            };
481            #channel_binding
482            #destructure
483            #invoke_and_reply
484            return;
485        }
486    }
487}
488
489// ============================================================================
490// Client Generation
491// ============================================================================
492
493// r[impl rpc.caller]
494fn generate_client(parsed: &ServiceTrait, roam: &TokenStream2) -> TokenStream2 {
495    let client_name = format_ident!("{}Client", parsed.name());
496    let descriptor_fn_name = format_ident!("{}_service_descriptor", parsed.name().to_snake_case());
497    let service_name = parsed.name();
498
499    let client_doc = format!(
500        "Client for the `{service_name}` service.\n\n\
501        Stores a type-erased [`Caller`]({roam}::Caller) implementation.",
502    );
503
504    let client_methods: Vec<TokenStream2> = parsed
505        .methods()
506        .enumerate()
507        .map(|(i, m)| generate_client_method(m, i, &descriptor_fn_name, roam))
508        .collect();
509
510    quote! {
511        #[doc = #client_doc]
512        #[must_use = "Dropping this client may close the connection if it is the last caller."]
513        #[derive(Clone)]
514        pub struct #client_name {
515            caller: #roam::ErasedCaller,
516        }
517
518        impl #client_name {
519            /// Create a new client wrapping the given caller.
520            pub fn new(caller: impl #roam::Caller) -> Self {
521                Self {
522                    caller: #roam::ErasedCaller::new(caller),
523                }
524            }
525
526            /// Resolve when the underlying connection closes.
527            pub async fn closed(&self) {
528                #roam::Caller::closed(&self.caller).await;
529            }
530
531            /// Return whether the underlying connection is still considered connected.
532            pub fn is_connected(&self) -> bool {
533                #roam::Caller::is_connected(&self.caller)
534            }
535
536            #(#client_methods)*
537        }
538
539        impl From<#roam::DriverCaller> for #client_name {
540            fn from(caller: #roam::DriverCaller) -> Self {
541                Self::new(caller)
542            }
543        }
544    }
545}
546
547// r[impl zerocopy.send.borrowed]
548// r[impl zerocopy.send.borrowed-in-struct]
549// r[impl zerocopy.send.lifetime]
550fn generate_client_method(
551    method: &ServiceMethod,
552    method_index: usize,
553    descriptor_fn_name: &proc_macro2::Ident,
554    roam: &TokenStream2,
555) -> TokenStream2 {
556    let method_name = format_ident!("{}", method.name().to_snake_case());
557    let method_doc = method.doc().map(|d| quote! { #[doc = #d] });
558    let idx = method_index;
559
560    let params: Vec<TokenStream2> = method
561        .args()
562        .map(|arg| {
563            let name = format_ident!("{}", arg.name().to_snake_case());
564            let ty = arg.ty.to_token_stream();
565            quote! { #name: #ty }
566        })
567        .collect();
568    let arg_names: Vec<proc_macro2::Ident> = method
569        .args()
570        .map(|arg| format_ident!("{}", arg.name().to_snake_case()))
571        .collect();
572
573    // Args tuple type (for RpcPlan::for_type)
574    let arg_types: Vec<TokenStream2> = method
575        .args()
576        .map(|a| to_static_type_tokens(&a.ty))
577        .collect();
578    let args_tuple_type = match arg_types.len() {
579        0 => quote! { () },
580        1 => {
581            let t = &arg_types[0];
582            quote! { (#t,) }
583        }
584        _ => quote! { (#(#arg_types),*) },
585    };
586
587    // Args tuple value (for serialization)
588    let args_tuple = match arg_names.len() {
589        0 => quote! { () },
590        1 => {
591            let n = &arg_names[0];
592            quote! { (#n,) }
593        }
594        _ => quote! { (#(#arg_names),*) },
595    };
596
597    // r[impl rpc.fallible]
598    // r[impl rpc.fallible.caller-signature]
599    let return_type = method.return_type();
600    let (ok_type_for_lifetimes, _) = method_ok_and_err_types(&return_type);
601    let ok_uses_roam_lifetime = ok_type_for_lifetimes.has_named_lifetime("roam");
602    let (ok_ty_decode, err_ty, client_return) = if let Some((ok, err)) = return_type.as_result() {
603        let ok_t = ok.to_token_stream();
604        let ok_t_static = to_static_type_tokens(ok);
605        let err_t = err.to_token_stream();
606        (
607            if ok_uses_roam_lifetime {
608                ok_t_static.clone()
609            } else {
610                ok_t.clone()
611            },
612            err_t.clone(),
613            if ok_uses_roam_lifetime {
614                quote! { Result<#roam::SelfRef<#ok_t_static>, #roam::RoamError<#err_t>> }
615            } else {
616                quote! { Result<#ok_t, #roam::RoamError<#err_t>> }
617            },
618        )
619    } else {
620        let t = return_type.to_token_stream();
621        let t_static = to_static_type_tokens(&return_type);
622        (
623            if ok_uses_roam_lifetime {
624                t_static.clone()
625            } else {
626                t.clone()
627            },
628            quote! { ::core::convert::Infallible },
629            if ok_uses_roam_lifetime {
630                quote! { Result<#roam::SelfRef<#t_static>, #roam::RoamError> }
631            } else {
632                quote! { Result<#t, #roam::RoamError> }
633            },
634        )
635    };
636
637    let has_channels = method.args().any(|a| a.ty.contains_channel());
638
639    let (args_binding, channel_binding) = if has_channels {
640        (
641            quote! { let mut args = #args_tuple; },
642            quote! {
643                #[cfg(not(target_arch = "wasm32"))]
644                let channels = if let Some(binder) = #roam::Caller::channel_binder(&self.caller) {
645                    let plan = #roam::RpcPlan::for_type::<#args_tuple_type>();
646                    // SAFETY: args is a valid, initialized value of the args tuple type
647                    // and we have exclusive access to it via &mut.
648                    #[allow(unsafe_code)]
649                    unsafe {
650                        #roam::bind_channels_caller_args(
651                            &mut args as *mut _ as *mut u8,
652                            plan,
653                            binder,
654                        )
655                    }
656                } else {
657                    vec![]
658                };
659                #[cfg(target_arch = "wasm32")]
660                let channels: Vec<#roam::ChannelId> = vec![];
661            },
662        )
663    } else {
664        (
665            quote! { let args = #args_tuple; },
666            quote! { let channels = vec![]; },
667        )
668    };
669
670    if ok_uses_roam_lifetime {
671        quote! {
672            #method_doc
673            pub async fn #method_name(&self, #(#params),*) -> #client_return {
674                let method_id = #descriptor_fn_name().methods[#idx].id;
675                #args_binding
676                #channel_binding
677                let req = #roam::RequestCall {
678                    method_id,
679                    args: #roam::Payload::outgoing(&args),
680                    channels,
681                    metadata: Default::default(),
682                };
683                let response = #roam::Caller::call(&self.caller, req).await.map_err(|e| match e {
684                    #roam::RoamError::UnknownMethod => #roam::RoamError::<#err_ty>::UnknownMethod,
685                    #roam::RoamError::InvalidPayload => #roam::RoamError::<#err_ty>::InvalidPayload,
686                    #roam::RoamError::Cancelled => #roam::RoamError::<#err_ty>::Cancelled,
687                    #roam::RoamError::User(never) => match never {},
688                })?;
689                response.try_repack(|resp, _bytes| {
690                    let ret_bytes = match &resp.ret {
691                        #roam::Payload::Incoming(bytes) => bytes,
692                        _ => return Err(#roam::RoamError::<#err_ty>::InvalidPayload),
693                    };
694                    let result: Result<#ok_ty_decode, #roam::RoamError<#err_ty>> =
695                        #roam::facet_postcard::from_slice_borrowed(ret_bytes)
696                            .map_err(|_| #roam::RoamError::<#err_ty>::InvalidPayload)?;
697                    let ret = result?;
698                    Ok(ret)
699                })
700            }
701        }
702    } else {
703        quote! {
704            #method_doc
705            pub async fn #method_name(&self, #(#params),*) -> #client_return {
706                let method_id = #descriptor_fn_name().methods[#idx].id;
707                #args_binding
708                #channel_binding
709                let req = #roam::RequestCall {
710                    method_id,
711                    args: #roam::Payload::outgoing(&args),
712                    channels,
713                    metadata: Default::default(),
714                };
715                let response = #roam::Caller::call(&self.caller, req).await.map_err(|e| match e {
716                    #roam::RoamError::UnknownMethod => #roam::RoamError::<#err_ty>::UnknownMethod,
717                    #roam::RoamError::InvalidPayload => #roam::RoamError::<#err_ty>::InvalidPayload,
718                    #roam::RoamError::Cancelled => #roam::RoamError::<#err_ty>::Cancelled,
719                    #roam::RoamError::User(never) => match never {},
720                })?;
721                let ret_bytes = match &response.ret {
722                    #roam::Payload::Incoming(bytes) => bytes,
723                    _ => return Err(#roam::RoamError::<#err_ty>::InvalidPayload),
724                };
725                let result: Result<#ok_ty_decode, #roam::RoamError<#err_ty>> =
726                    #roam::facet_postcard::from_slice(ret_bytes)
727                        .map_err(|_| #roam::RoamError::<#err_ty>::InvalidPayload)?;
728                result
729            }
730        }
731    }
732}
733
734#[cfg(test)]
735mod tests {
736    use insta::assert_snapshot;
737    use quote::quote;
738
739    fn prettyprint(ts: proc_macro2::TokenStream) -> String {
740        use std::io::Write;
741        use std::process::{Command, Stdio};
742
743        let mut child = Command::new("rustfmt")
744            .args(["--edition", "2024"])
745            .stdin(Stdio::piped())
746            .stdout(Stdio::piped())
747            .stderr(Stdio::inherit())
748            .spawn()
749            .expect("failed to spawn rustfmt");
750
751        child
752            .stdin
753            .take()
754            .unwrap()
755            .write_all(ts.to_string().as_bytes())
756            .unwrap();
757
758        let output = child.wait_with_output().expect("rustfmt failed");
759        assert!(
760            output.status.success(),
761            "rustfmt exited with {}",
762            output.status
763        );
764        String::from_utf8(output.stdout).expect("rustfmt output not UTF-8")
765    }
766
767    fn generate(input: proc_macro2::TokenStream) -> String {
768        let parsed = roam_macros_parse::parse_trait(&input).unwrap();
769        let roam = quote! { ::roam };
770        let ts = crate::generate_service(&parsed, &roam).unwrap();
771        prettyprint(ts)
772    }
773
774    #[test]
775    fn adder_infallible() {
776        assert_snapshot!(generate(quote! {
777            pub trait Adder { async fn add(&self, a: i32, b: i32) -> i32; }
778        }));
779    }
780
781    #[test]
782    fn fallible() {
783        assert_snapshot!(generate(quote! {
784            trait Calc { async fn div(&self, a: f64, b: f64) -> Result<f64, DivError>; }
785        }));
786    }
787
788    #[test]
789    fn no_args() {
790        assert_snapshot!(generate(quote! {
791            trait Ping { async fn ping(&self) -> u64; }
792        }));
793    }
794
795    #[test]
796    fn unit_return() {
797        assert_snapshot!(generate(quote! {
798            trait Notifier { async fn notify(&self, msg: String); }
799        }));
800    }
801
802    #[test]
803    fn streaming_tx() {
804        assert_snapshot!(generate(quote! {
805            trait Streamer { async fn count_up(&self, start: i32, output: Tx<i32>) -> i32; }
806        }));
807    }
808
809    #[test]
810    fn rejects_channels_in_return_type() {
811        let parsed = roam_macros_parse::parse_trait(&quote! {
812            trait Streamer { async fn stream(&self) -> Rx<i32>; }
813        })
814        .unwrap();
815        let roam = quote! { ::roam };
816        let err = crate::generate_service(&parsed, &roam).unwrap_err();
817        assert_eq!(
818            err.message,
819            "method `stream` has Channel (Tx/Rx) in return type - channels are only allowed in method arguments"
820        );
821    }
822
823    #[test]
824    fn rejects_non_roam_return_lifetime() {
825        let parsed = roam_macros_parse::parse_trait(&quote! {
826            trait Svc { async fn bad(&self) -> &'a str; }
827        })
828        .unwrap();
829        let roam = quote! { ::roam };
830        let err = crate::generate_service(&parsed, &roam).unwrap_err();
831        assert_eq!(
832            err.message,
833            "method `bad` return type may only use lifetime `'roam` for borrowed response data"
834        );
835    }
836
837    #[test]
838    fn rejects_elided_return_lifetime() {
839        let parsed = roam_macros_parse::parse_trait(&quote! {
840            trait Svc { async fn bad(&self) -> &str; }
841        })
842        .unwrap();
843        let roam = quote! { ::roam };
844        let err = crate::generate_service(&parsed, &roam).unwrap_err();
845        assert_eq!(
846            err.message,
847            "method `bad` return type uses an elided reference lifetime; use explicit `'roam` (for example `&'roam str`)"
848        );
849    }
850
851    #[test]
852    fn rejects_borrowed_error_type() {
853        let parsed = roam_macros_parse::parse_trait(&quote! {
854            trait Svc { async fn bad(&self) -> Result<u32, &'roam str>; }
855        })
856        .unwrap();
857        let roam = quote! { ::roam };
858        let err = crate::generate_service(&parsed, &roam).unwrap_err();
859        assert_eq!(
860            err.message,
861            "method `bad` error type must be owned (no lifetimes), because client errors are not wrapped in SelfRef"
862        );
863    }
864
865    #[test]
866    fn borrowed_roam_return() {
867        assert_snapshot!(generate(quote! {
868            trait Hasher { async fn hash(&self, payload: String) -> &'roam str; }
869        }));
870    }
871
872    #[test]
873    fn borrowed_roam_return_call_style() {
874        assert_snapshot!(generate(quote! {
875            trait Hasher { async fn hash(&self, payload: String) -> &'roam str; }
876        }));
877    }
878
879    #[test]
880    fn borrowed_roam_cow_return() {
881        assert_snapshot!(generate(quote! {
882            trait TextSvc {
883                async fn normalize(&self, input: String) -> ::std::borrow::Cow<'roam, str>;
884            }
885        }));
886    }
887
888    #[test]
889    fn borrowed_return_mixed_with_borrowed_args_and_channels_compiles_to_expected_shapes() {
890        assert_snapshot!(generate(quote! {
891            trait WordLab {
892                async fn is_short(&self, word: &str) -> bool;
893                async fn classify(&self, word: String) -> &'roam str;
894                async fn transform(&self, prefix: &str, input: Rx<String>, output: Tx<String>) -> u32;
895            }
896        }));
897    }
898}