1#![deny(unsafe_code)]
7
8use ::quote::{format_ident, quote};
9use heck::ToSnakeCase;
10use proc_macro2::TokenStream as TokenStream2;
11
12pub mod crate_name;
13
14pub use roam_macros_parse::*;
15
16use crate_name::FoundCrate;
17
18#[derive(Debug, Clone)]
20pub struct Error {
21 pub span: proc_macro2::Span,
22 pub message: String,
23}
24
25impl Error {
26 pub fn new(span: proc_macro2::Span, message: impl Into<String>) -> Self {
27 Self {
28 span,
29 message: message.into(),
30 }
31 }
32
33 pub fn to_compile_error(&self) -> TokenStream2 {
34 let msg = &self.message;
35 let span = self.span;
36 quote::quote_spanned! {span=> compile_error!(#msg); }
37 }
38}
39
40impl From<ParseError> for Error {
41 fn from(err: ParseError) -> Self {
42 Self::new(proc_macro2::Span::call_site(), err.to_string())
43 }
44}
45
46pub fn parse(tokens: &TokenStream2) -> Result<ServiceTrait, Error> {
48 parse_trait(tokens).map_err(Error::from)
49}
50
51pub fn roam_crate() -> TokenStream2 {
53 match crate_name::crate_name("roam") {
54 Ok(FoundCrate::Itself) => quote! { crate },
55 Ok(FoundCrate::Name(name)) => {
56 let ident = format_ident!("{}", name);
57 quote! { ::#ident }
58 }
59 Err(_) => quote! { ::roam },
60 }
61}
62
63fn to_static_type_tokens(ty: &Type) -> TokenStream2 {
68 match ty {
69 Type::Reference(TypeRef { mutable, inner, .. }) => {
70 let inner = to_static_type_tokens(inner);
71 if mutable.is_some() {
72 quote! { &'static mut #inner }
73 } else {
74 quote! { &'static #inner }
75 }
76 }
77 Type::Tuple(TypeTuple(group)) => {
78 let elems: Vec<TokenStream2> = group
79 .content
80 .iter()
81 .map(|entry| to_static_type_tokens(&entry.value))
82 .collect();
83 match elems.len() {
84 0 => quote! { () },
85 1 => {
86 let t = &elems[0];
87 quote! { (#t,) }
88 }
89 _ => quote! { (#(#elems),*) },
90 }
91 }
92 Type::PathWithGenerics(PathWithGenerics { path, args, .. }) => {
93 let path = path.to_token_stream();
94 let args: Vec<TokenStream2> = args
95 .iter()
96 .map(|entry| match &entry.value {
97 GenericArgument::Lifetime(_) => quote! { 'static },
98 GenericArgument::Type(inner) => to_static_type_tokens(inner),
99 })
100 .collect();
101 quote! { #path < #(#args),* > }
102 }
103 Type::Path(path) => path.to_token_stream(),
104 }
105}
106
107pub fn generate_service(parsed: &ServiceTrait, roam: &TokenStream2) -> Result<TokenStream2, Error> {
116 for method in parsed.methods() {
119 let return_type = method.return_type();
120 if return_type.contains_channel() {
121 return Err(Error::new(
122 proc_macro2::Span::call_site(),
123 format!(
124 "method `{}` has Channel (Tx/Rx) in return type - channels are only allowed in method arguments",
125 method.name()
126 ),
127 ));
128 }
129
130 let (ok_ty, err_ty) = method_ok_and_err_types(&return_type);
131 if ok_ty.has_elided_reference_lifetime() {
132 return Err(Error::new(
133 proc_macro2::Span::call_site(),
134 format!(
135 "method `{}` return type uses an elided reference lifetime; use explicit `'roam` (for example `&'roam str`)",
136 method.name()
137 ),
138 ));
139 }
140 if ok_ty.has_non_named_lifetime("roam") {
141 return Err(Error::new(
142 proc_macro2::Span::call_site(),
143 format!(
144 "method `{}` return type may only use lifetime `'roam` for borrowed response data",
145 method.name()
146 ),
147 ));
148 }
149 if let Some(err_ty) = err_ty
150 && (err_ty.has_lifetime() || err_ty.has_elided_reference_lifetime())
151 {
152 return Err(Error::new(
153 proc_macro2::Span::call_site(),
154 format!(
155 "method `{}` error type must be owned (no lifetimes), because client errors are not wrapped in SelfRef",
156 method.name()
157 ),
158 ));
159 }
160 }
161
162 let service_descriptor_fn = generate_service_descriptor_fn(parsed, roam);
163 let service_trait = generate_service_trait(parsed, roam);
164 let dispatcher = generate_dispatcher(parsed, roam);
165 let client = generate_client(parsed, roam);
166 Ok(quote! {
167 #service_descriptor_fn
168 #service_trait
169 #dispatcher
170 #client
171 })
172}
173
174fn generate_service_descriptor_fn(parsed: &ServiceTrait, roam: &TokenStream2) -> TokenStream2 {
179 let service_name = parsed.name();
180 let descriptor_fn_name = format_ident!("{}_service_descriptor", service_name.to_snake_case());
181
182 let method_descriptors: Vec<TokenStream2> = parsed
184 .methods()
185 .map(|m| {
186 let method_name_str = m.name();
187
188 let arg_types: Vec<TokenStream2> =
190 m.args().map(|arg| to_static_type_tokens(&arg.ty)).collect();
191 let args_tuple_ty = quote! { (#(#arg_types,)*) };
192 let arg_name_strs: Vec<String> = m.args().map(|arg| arg.name().to_string()).collect();
193
194 let return_type = m.return_type();
195 let return_ty_tokens = to_static_type_tokens(&return_type);
196
197 let method_doc_expr = match m.doc() {
198 Some(d) => quote! { Some(#d) },
199 None => quote! { None },
200 };
201
202 quote! {
203 #roam::hash::method_descriptor::<#args_tuple_ty, #return_ty_tokens>(
204 #service_name,
205 #method_name_str,
206 &[#(#arg_name_strs),*],
207 #method_doc_expr,
208 )
209 }
210 })
211 .collect();
212
213 let service_doc_expr = match parsed.doc() {
214 Some(d) => quote! { Some(#d) },
215 None => quote! { None },
216 };
217
218 quote! {
219 #[allow(non_snake_case, clippy::all)]
220 pub fn #descriptor_fn_name() -> &'static #roam::session::ServiceDescriptor {
221 static DESCRIPTOR: std::sync::OnceLock<&'static #roam::session::ServiceDescriptor> = std::sync::OnceLock::new();
222 DESCRIPTOR.get_or_init(|| {
223 let methods: Vec<&'static #roam::session::MethodDescriptor> = vec![
224 #(#method_descriptors),*
225 ];
226 Box::leak(Box::new(#roam::session::ServiceDescriptor {
227 service_name: #service_name,
228 methods: Box::leak(methods.into_boxed_slice()),
229 doc: #service_doc_expr,
230 }))
231 })
232 }
233 }
234}
235
236fn generate_service_trait(parsed: &ServiceTrait, roam: &TokenStream2) -> TokenStream2 {
241 let trait_name = parsed.name.clone();
242 let trait_doc = parsed.doc().map(|d| quote! { #[doc = #d] });
243
244 let methods: Vec<TokenStream2> = parsed
245 .methods()
246 .map(|m| generate_trait_method(m, roam))
247 .collect();
248
249 quote! {
250 #trait_doc
251 pub trait #trait_name
252 where
253 Self: Send + Sync,
254 {
255 #(#methods)*
256 }
257 }
258}
259
260fn generate_trait_method(method: &ServiceMethod, roam: &TokenStream2) -> TokenStream2 {
261 let method_name = format_ident!("{}", method.name().to_snake_case());
262 let method_doc = method.doc().map(|d| quote! { #[doc = #d] });
263
264 let return_type = method.return_type();
265 let (ok_ty_ref, err_ty_ref) = method_ok_and_err_types(&return_type);
266 let ok_has_roam_lifetime = ok_ty_ref.has_named_lifetime("roam");
267 let method_lifetime = if ok_has_roam_lifetime {
268 quote! { <'roam> }
269 } else {
270 quote! {}
271 };
272
273 let params: Vec<TokenStream2> = method
274 .args()
275 .map(|arg| {
276 let name = format_ident!("{}", arg.name().to_snake_case());
277 let ty = arg.ty.to_token_stream();
278 quote! { #name: #ty }
279 })
280 .collect();
281
282 if ok_has_roam_lifetime {
283 let ok_ty = ok_ty_ref.to_token_stream();
284 let err_ty = err_ty_ref
285 .map(Type::to_token_stream)
286 .unwrap_or_else(|| quote! { ::core::convert::Infallible });
287 quote! {
288 #method_doc
289 fn #method_name #method_lifetime (&self, call: impl #roam::Call<'roam, #ok_ty, #err_ty>, #(#params),*) -> impl std::future::Future<Output = ()> + Send;
290 }
291 } else {
292 let output_ty = return_type.to_token_stream();
293 quote! {
294 #method_doc
295 fn #method_name (&self, #(#params),*) -> impl std::future::Future<Output = #output_ty> + Send;
296 }
297 }
298}
299
300fn generate_dispatcher(parsed: &ServiceTrait, roam: &TokenStream2) -> TokenStream2 {
305 let trait_name = parsed.name.clone();
306 let dispatcher_name = format_ident!("{}Dispatcher", parsed.name());
307 let descriptor_fn_name = format_ident!("{}_service_descriptor", parsed.name().to_snake_case());
308
309 let dispatch_arms: Vec<TokenStream2> = parsed
311 .methods()
312 .enumerate()
313 .map(|(i, m)| generate_dispatch_arm(m, i, roam, &descriptor_fn_name))
314 .collect();
315
316 let no_methods = dispatch_arms.is_empty();
317
318 let dispatch_body = if no_methods {
319 quote! {
320 let _ = (call, reply);
321 }
322 } else {
323 quote! {
325 let method_id = call.method_id;
326 let args_bytes = match &call.args {
327 #roam::Payload::Incoming(bytes) => bytes,
328 _ => {
329 reply.send_error(#roam::RoamError::<::core::convert::Infallible>::InvalidPayload).await;
330 return;
331 }
332 };
333 #(#dispatch_arms)*
334 reply.send_error(#roam::RoamError::<::core::convert::Infallible>::UnknownMethod).await;
335 }
336 };
337
338 quote! {
339 #[derive(Clone)]
344 pub struct #dispatcher_name<H> {
345 handler: H,
346 }
347
348 impl<H> #dispatcher_name<H>
349 where
350 H: #trait_name + Clone + Send + Sync + 'static,
351 {
352 pub fn new(handler: H) -> Self {
354 Self { handler }
355 }
356 }
357
358 impl<H, R> #roam::Handler<R> for #dispatcher_name<H>
359 where
360 H: #trait_name + Clone + Send + Sync + 'static,
361 R: #roam::ReplySink,
362 {
363 async fn handle(&self, call: #roam::SelfRef<#roam::RequestCall<'static>>, reply: R) {
364 #dispatch_body
365 }
366 }
367 }
368}
369
370fn generate_dispatch_arm(
371 method: &ServiceMethod,
372 method_index: usize,
373 roam: &TokenStream2,
374 descriptor_fn_name: &proc_macro2::Ident,
375) -> TokenStream2 {
376 let method_fn = format_ident!("{}", method.name().to_snake_case());
377 let idx = method_index;
378
379 let arg_types: Vec<TokenStream2> = method
381 .args()
382 .map(|a| to_static_type_tokens(&a.ty))
383 .collect();
384 let args_tuple_type = match arg_types.len() {
385 0 => quote! { () },
386 1 => {
387 let t = &arg_types[0];
388 quote! { (#t,) }
389 }
390 _ => quote! { (#(#arg_types),*) },
391 };
392
393 let arg_names: Vec<proc_macro2::Ident> = method
395 .args()
396 .map(|a| format_ident!("{}", a.name().to_snake_case()))
397 .collect();
398 let destructure = match arg_names.len() {
399 0 => quote! { let () = args; },
400 1 => {
401 let n = &arg_names[0];
402 quote! { let (#n,) = args; }
403 }
404 _ => quote! { let (#(#arg_names),*) = args; },
405 };
406
407 let _ = idx;
408
409 let has_channels = method.args().any(|a| a.ty.contains_channel());
410
411 let channel_binding = if has_channels {
412 quote! {
413 #[cfg(not(target_arch = "wasm32"))]
414 {
415 if let Some(binder) = reply.channel_binder() {
416 let plan = #roam::RpcPlan::for_type::<#args_tuple_type>();
417 if !plan.channel_locations.is_empty() {
418 #[allow(unsafe_code)]
421 unsafe {
422 #roam::bind_channels_callee_args(
423 &mut args as *mut _ as *mut u8,
424 plan,
425 &call.channels,
426 binder,
427 );
428 }
429 }
430 }
431 }
432 }
433 } else {
434 quote! {}
435 };
436
437 let args_let = if has_channels {
439 quote! { let mut args: #args_tuple_type }
440 } else {
441 quote! { let args: #args_tuple_type }
442 };
443
444 let return_type = method.return_type();
445 let (ok_ty_ref, err_ty_ref) = method_ok_and_err_types(&return_type);
446 let ok_has_roam_lifetime = ok_ty_ref.has_named_lifetime("roam");
447 let is_fallible = return_type.as_result().is_some();
448 let ok_ty = ok_ty_ref.to_token_stream();
449 let err_ty = err_ty_ref
450 .map(Type::to_token_stream)
451 .unwrap_or_else(|| quote! { ::core::convert::Infallible });
452
453 let invoke_and_reply = if ok_has_roam_lifetime {
454 quote! {
455 let sink_call = #roam::SinkCall::new(reply);
456 self.handler.#method_fn(sink_call, #(#arg_names),*).await;
457 }
458 } else if is_fallible {
459 quote! {
460 let result = self.handler.#method_fn(#(#arg_names),*).await;
461 let sink_call = #roam::SinkCall::new(reply);
462 #roam::Call::<'_, #ok_ty, #err_ty>::reply(sink_call, result).await;
463 }
464 } else {
465 quote! {
466 let value = self.handler.#method_fn(#(#arg_names),*).await;
467 let sink_call = #roam::SinkCall::new(reply);
468 #roam::Call::<'_, #ok_ty, #err_ty>::ok(sink_call, value).await;
469 }
470 };
471
472 quote! {
473 if method_id == #descriptor_fn_name().methods[#idx].id {
474 #args_let = match #roam::facet_postcard::from_slice_borrowed(args_bytes) {
475 Ok(v) => v,
476 Err(_) => {
477 reply.send_error(#roam::RoamError::<::core::convert::Infallible>::InvalidPayload).await;
478 return;
479 }
480 };
481 #channel_binding
482 #destructure
483 #invoke_and_reply
484 return;
485 }
486 }
487}
488
489fn generate_client(parsed: &ServiceTrait, roam: &TokenStream2) -> TokenStream2 {
495 let client_name = format_ident!("{}Client", parsed.name());
496 let descriptor_fn_name = format_ident!("{}_service_descriptor", parsed.name().to_snake_case());
497 let service_name = parsed.name();
498
499 let client_doc = format!(
500 "Client for the `{service_name}` service.\n\n\
501 Stores a type-erased [`Caller`]({roam}::Caller) implementation.",
502 );
503
504 let client_methods: Vec<TokenStream2> = parsed
505 .methods()
506 .enumerate()
507 .map(|(i, m)| generate_client_method(m, i, &descriptor_fn_name, roam))
508 .collect();
509
510 quote! {
511 #[doc = #client_doc]
512 #[must_use = "Dropping this client may close the connection if it is the last caller."]
513 #[derive(Clone)]
514 pub struct #client_name {
515 caller: #roam::ErasedCaller,
516 }
517
518 impl #client_name {
519 pub fn new(caller: impl #roam::Caller) -> Self {
521 Self {
522 caller: #roam::ErasedCaller::new(caller),
523 }
524 }
525
526 pub async fn closed(&self) {
528 #roam::Caller::closed(&self.caller).await;
529 }
530
531 pub fn is_connected(&self) -> bool {
533 #roam::Caller::is_connected(&self.caller)
534 }
535
536 #(#client_methods)*
537 }
538
539 impl From<#roam::DriverCaller> for #client_name {
540 fn from(caller: #roam::DriverCaller) -> Self {
541 Self::new(caller)
542 }
543 }
544 }
545}
546
547fn generate_client_method(
551 method: &ServiceMethod,
552 method_index: usize,
553 descriptor_fn_name: &proc_macro2::Ident,
554 roam: &TokenStream2,
555) -> TokenStream2 {
556 let method_name = format_ident!("{}", method.name().to_snake_case());
557 let method_doc = method.doc().map(|d| quote! { #[doc = #d] });
558 let idx = method_index;
559
560 let params: Vec<TokenStream2> = method
561 .args()
562 .map(|arg| {
563 let name = format_ident!("{}", arg.name().to_snake_case());
564 let ty = arg.ty.to_token_stream();
565 quote! { #name: #ty }
566 })
567 .collect();
568 let arg_names: Vec<proc_macro2::Ident> = method
569 .args()
570 .map(|arg| format_ident!("{}", arg.name().to_snake_case()))
571 .collect();
572
573 let arg_types: Vec<TokenStream2> = method
575 .args()
576 .map(|a| to_static_type_tokens(&a.ty))
577 .collect();
578 let args_tuple_type = match arg_types.len() {
579 0 => quote! { () },
580 1 => {
581 let t = &arg_types[0];
582 quote! { (#t,) }
583 }
584 _ => quote! { (#(#arg_types),*) },
585 };
586
587 let args_tuple = match arg_names.len() {
589 0 => quote! { () },
590 1 => {
591 let n = &arg_names[0];
592 quote! { (#n,) }
593 }
594 _ => quote! { (#(#arg_names),*) },
595 };
596
597 let return_type = method.return_type();
600 let (ok_type_for_lifetimes, _) = method_ok_and_err_types(&return_type);
601 let ok_uses_roam_lifetime = ok_type_for_lifetimes.has_named_lifetime("roam");
602 let (ok_ty_decode, err_ty, client_return) = if let Some((ok, err)) = return_type.as_result() {
603 let ok_t = ok.to_token_stream();
604 let ok_t_static = to_static_type_tokens(ok);
605 let err_t = err.to_token_stream();
606 (
607 if ok_uses_roam_lifetime {
608 ok_t_static.clone()
609 } else {
610 ok_t.clone()
611 },
612 err_t.clone(),
613 if ok_uses_roam_lifetime {
614 quote! { Result<#roam::SelfRef<#ok_t_static>, #roam::RoamError<#err_t>> }
615 } else {
616 quote! { Result<#ok_t, #roam::RoamError<#err_t>> }
617 },
618 )
619 } else {
620 let t = return_type.to_token_stream();
621 let t_static = to_static_type_tokens(&return_type);
622 (
623 if ok_uses_roam_lifetime {
624 t_static.clone()
625 } else {
626 t.clone()
627 },
628 quote! { ::core::convert::Infallible },
629 if ok_uses_roam_lifetime {
630 quote! { Result<#roam::SelfRef<#t_static>, #roam::RoamError> }
631 } else {
632 quote! { Result<#t, #roam::RoamError> }
633 },
634 )
635 };
636
637 let has_channels = method.args().any(|a| a.ty.contains_channel());
638
639 let (args_binding, channel_binding) = if has_channels {
640 (
641 quote! { let mut args = #args_tuple; },
642 quote! {
643 #[cfg(not(target_arch = "wasm32"))]
644 let channels = if let Some(binder) = #roam::Caller::channel_binder(&self.caller) {
645 let plan = #roam::RpcPlan::for_type::<#args_tuple_type>();
646 #[allow(unsafe_code)]
649 unsafe {
650 #roam::bind_channels_caller_args(
651 &mut args as *mut _ as *mut u8,
652 plan,
653 binder,
654 )
655 }
656 } else {
657 vec![]
658 };
659 #[cfg(target_arch = "wasm32")]
660 let channels: Vec<#roam::ChannelId> = vec![];
661 },
662 )
663 } else {
664 (
665 quote! { let args = #args_tuple; },
666 quote! { let channels = vec![]; },
667 )
668 };
669
670 if ok_uses_roam_lifetime {
671 quote! {
672 #method_doc
673 pub async fn #method_name(&self, #(#params),*) -> #client_return {
674 let method_id = #descriptor_fn_name().methods[#idx].id;
675 #args_binding
676 #channel_binding
677 let req = #roam::RequestCall {
678 method_id,
679 args: #roam::Payload::outgoing(&args),
680 channels,
681 metadata: Default::default(),
682 };
683 let response = #roam::Caller::call(&self.caller, req).await.map_err(|e| match e {
684 #roam::RoamError::UnknownMethod => #roam::RoamError::<#err_ty>::UnknownMethod,
685 #roam::RoamError::InvalidPayload => #roam::RoamError::<#err_ty>::InvalidPayload,
686 #roam::RoamError::Cancelled => #roam::RoamError::<#err_ty>::Cancelled,
687 #roam::RoamError::User(never) => match never {},
688 })?;
689 response.try_repack(|resp, _bytes| {
690 let ret_bytes = match &resp.ret {
691 #roam::Payload::Incoming(bytes) => bytes,
692 _ => return Err(#roam::RoamError::<#err_ty>::InvalidPayload),
693 };
694 let result: Result<#ok_ty_decode, #roam::RoamError<#err_ty>> =
695 #roam::facet_postcard::from_slice_borrowed(ret_bytes)
696 .map_err(|_| #roam::RoamError::<#err_ty>::InvalidPayload)?;
697 let ret = result?;
698 Ok(ret)
699 })
700 }
701 }
702 } else {
703 quote! {
704 #method_doc
705 pub async fn #method_name(&self, #(#params),*) -> #client_return {
706 let method_id = #descriptor_fn_name().methods[#idx].id;
707 #args_binding
708 #channel_binding
709 let req = #roam::RequestCall {
710 method_id,
711 args: #roam::Payload::outgoing(&args),
712 channels,
713 metadata: Default::default(),
714 };
715 let response = #roam::Caller::call(&self.caller, req).await.map_err(|e| match e {
716 #roam::RoamError::UnknownMethod => #roam::RoamError::<#err_ty>::UnknownMethod,
717 #roam::RoamError::InvalidPayload => #roam::RoamError::<#err_ty>::InvalidPayload,
718 #roam::RoamError::Cancelled => #roam::RoamError::<#err_ty>::Cancelled,
719 #roam::RoamError::User(never) => match never {},
720 })?;
721 let ret_bytes = match &response.ret {
722 #roam::Payload::Incoming(bytes) => bytes,
723 _ => return Err(#roam::RoamError::<#err_ty>::InvalidPayload),
724 };
725 let result: Result<#ok_ty_decode, #roam::RoamError<#err_ty>> =
726 #roam::facet_postcard::from_slice(ret_bytes)
727 .map_err(|_| #roam::RoamError::<#err_ty>::InvalidPayload)?;
728 result
729 }
730 }
731 }
732}
733
734#[cfg(test)]
735mod tests {
736 use insta::assert_snapshot;
737 use quote::quote;
738
739 fn prettyprint(ts: proc_macro2::TokenStream) -> String {
740 use std::io::Write;
741 use std::process::{Command, Stdio};
742
743 let mut child = Command::new("rustfmt")
744 .args(["--edition", "2024"])
745 .stdin(Stdio::piped())
746 .stdout(Stdio::piped())
747 .stderr(Stdio::inherit())
748 .spawn()
749 .expect("failed to spawn rustfmt");
750
751 child
752 .stdin
753 .take()
754 .unwrap()
755 .write_all(ts.to_string().as_bytes())
756 .unwrap();
757
758 let output = child.wait_with_output().expect("rustfmt failed");
759 assert!(
760 output.status.success(),
761 "rustfmt exited with {}",
762 output.status
763 );
764 String::from_utf8(output.stdout).expect("rustfmt output not UTF-8")
765 }
766
767 fn generate(input: proc_macro2::TokenStream) -> String {
768 let parsed = roam_macros_parse::parse_trait(&input).unwrap();
769 let roam = quote! { ::roam };
770 let ts = crate::generate_service(&parsed, &roam).unwrap();
771 prettyprint(ts)
772 }
773
774 #[test]
775 fn adder_infallible() {
776 assert_snapshot!(generate(quote! {
777 pub trait Adder { async fn add(&self, a: i32, b: i32) -> i32; }
778 }));
779 }
780
781 #[test]
782 fn fallible() {
783 assert_snapshot!(generate(quote! {
784 trait Calc { async fn div(&self, a: f64, b: f64) -> Result<f64, DivError>; }
785 }));
786 }
787
788 #[test]
789 fn no_args() {
790 assert_snapshot!(generate(quote! {
791 trait Ping { async fn ping(&self) -> u64; }
792 }));
793 }
794
795 #[test]
796 fn unit_return() {
797 assert_snapshot!(generate(quote! {
798 trait Notifier { async fn notify(&self, msg: String); }
799 }));
800 }
801
802 #[test]
803 fn streaming_tx() {
804 assert_snapshot!(generate(quote! {
805 trait Streamer { async fn count_up(&self, start: i32, output: Tx<i32>) -> i32; }
806 }));
807 }
808
809 #[test]
810 fn rejects_channels_in_return_type() {
811 let parsed = roam_macros_parse::parse_trait("e! {
812 trait Streamer { async fn stream(&self) -> Rx<i32>; }
813 })
814 .unwrap();
815 let roam = quote! { ::roam };
816 let err = crate::generate_service(&parsed, &roam).unwrap_err();
817 assert_eq!(
818 err.message,
819 "method `stream` has Channel (Tx/Rx) in return type - channels are only allowed in method arguments"
820 );
821 }
822
823 #[test]
824 fn rejects_non_roam_return_lifetime() {
825 let parsed = roam_macros_parse::parse_trait("e! {
826 trait Svc { async fn bad(&self) -> &'a str; }
827 })
828 .unwrap();
829 let roam = quote! { ::roam };
830 let err = crate::generate_service(&parsed, &roam).unwrap_err();
831 assert_eq!(
832 err.message,
833 "method `bad` return type may only use lifetime `'roam` for borrowed response data"
834 );
835 }
836
837 #[test]
838 fn rejects_elided_return_lifetime() {
839 let parsed = roam_macros_parse::parse_trait("e! {
840 trait Svc { async fn bad(&self) -> &str; }
841 })
842 .unwrap();
843 let roam = quote! { ::roam };
844 let err = crate::generate_service(&parsed, &roam).unwrap_err();
845 assert_eq!(
846 err.message,
847 "method `bad` return type uses an elided reference lifetime; use explicit `'roam` (for example `&'roam str`)"
848 );
849 }
850
851 #[test]
852 fn rejects_borrowed_error_type() {
853 let parsed = roam_macros_parse::parse_trait("e! {
854 trait Svc { async fn bad(&self) -> Result<u32, &'roam str>; }
855 })
856 .unwrap();
857 let roam = quote! { ::roam };
858 let err = crate::generate_service(&parsed, &roam).unwrap_err();
859 assert_eq!(
860 err.message,
861 "method `bad` error type must be owned (no lifetimes), because client errors are not wrapped in SelfRef"
862 );
863 }
864
865 #[test]
866 fn borrowed_roam_return() {
867 assert_snapshot!(generate(quote! {
868 trait Hasher { async fn hash(&self, payload: String) -> &'roam str; }
869 }));
870 }
871
872 #[test]
873 fn borrowed_roam_return_call_style() {
874 assert_snapshot!(generate(quote! {
875 trait Hasher { async fn hash(&self, payload: String) -> &'roam str; }
876 }));
877 }
878
879 #[test]
880 fn borrowed_roam_cow_return() {
881 assert_snapshot!(generate(quote! {
882 trait TextSvc {
883 async fn normalize(&self, input: String) -> ::std::borrow::Cow<'roam, str>;
884 }
885 }));
886 }
887
888 #[test]
889 fn borrowed_return_mixed_with_borrowed_args_and_channels_compiles_to_expected_shapes() {
890 assert_snapshot!(generate(quote! {
891 trait WordLab {
892 async fn is_short(&self, word: &str) -> bool;
893 async fn classify(&self, word: String) -> &'roam str;
894 async fn transform(&self, prefix: &str, input: Rx<String>, output: Tx<String>) -> u32;
895 }
896 }));
897 }
898}