1#![doc = include_str!("../README.md")]
2
3use heck::{ToShoutySnakeCase, ToSnakeCase};
4use proc_macro::TokenStream;
5use proc_macro_crate::{FoundCrate, crate_name};
6use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8
9mod parser;
10
11use parser::{Error as MacroError, ParsedTrait, join_doc_lines, parse_trait};
12
13fn compute_method_id(service_name: &str, method_name: &str) -> u32 {
18 const FNV_OFFSET: u64 = 0xcbf29ce484222325;
20 const FNV_PRIME: u64 = 0x100000001b3;
21
22 let mut hash = FNV_OFFSET;
23
24 for byte in service_name.bytes() {
26 hash ^= byte as u64;
27 hash = hash.wrapping_mul(FNV_PRIME);
28 }
29 hash ^= b'.' as u64;
31 hash = hash.wrapping_mul(FNV_PRIME);
32 for byte in method_name.bytes() {
34 hash ^= byte as u64;
35 hash = hash.wrapping_mul(FNV_PRIME);
36 }
37
38 ((hash >> 32) ^ hash) as u32
40}
41
42#[proc_macro_attribute]
73pub fn service(_attr: TokenStream, item: TokenStream) -> TokenStream {
74 let trait_tokens = TokenStream2::from(item.clone());
75
76 let parsed_trait = match parse_trait(&trait_tokens) {
77 Ok(parsed) => parsed,
78 Err(err) => return err.to_compile_error().into(),
79 };
80
81 match generate_service(&parsed_trait) {
82 Ok(tokens) => tokens.into(),
83 Err(err) => err.to_compile_error().into(),
84 }
85}
86
87fn generate_service(input: &ParsedTrait) -> Result<TokenStream2, MacroError> {
88 let trait_name = &input.ident;
89 let trait_name_str = trait_name.to_string();
90 let trait_snake = trait_name_str.to_snake_case();
91 let trait_shouty = trait_name_str.to_shouty_snake_case();
92 let vis = &input.vis_tokens;
93
94 let rapace_crate = match crate_name("rapace") {
99 Ok(FoundCrate::Itself) => quote!(rapace),
100 Ok(FoundCrate::Name(name)) => {
101 let ident = Ident::new(&name, Span::call_site());
102 quote!(#ident)
103 }
104 Err(_) => {
105 if crate_name("rapace_core").is_ok() {
107 return Err(MacroError::new(
110 Span::call_site(),
111 "Internal crates using rapace_macros must add `rapace` as a dependency, \
112 or you can create a facade module. See rapace-testkit for an example.",
113 ));
114 } else {
115 return Err(MacroError::new(
116 Span::call_site(),
117 "rapace crate not found in dependencies. Add `rapace = \"...\"` to your Cargo.toml",
118 ));
119 }
120 }
121 };
122
123 let trait_doc = join_doc_lines(&input.doc_lines);
125
126 let trait_doc_attr = if trait_doc.is_empty() {
134 quote! {}
135 } else {
136 quote! { #[doc = #trait_doc] }
137 };
138 let rewritten_methods = input.methods.iter().map(|m| {
139 let method_name = &m.name;
140 let method_doc = join_doc_lines(&m.doc_lines);
141 let method_doc_attr = if method_doc.is_empty() {
142 quote! {}
143 } else {
144 quote! { #[doc = #method_doc] }
145 };
146 let args = m.args.iter().map(|a| {
147 let name = &a.name;
148 let ty = &a.ty;
149 quote! { #name: #ty }
150 });
151 let return_type = &m.return_type;
152 quote! {
153 #method_doc_attr
154 fn #method_name(&self, #(#args),*) -> impl ::std::future::Future<Output = #return_type> + Send + '_;
155 }
156 });
157 let trait_tokens = quote! {
158 #[allow(clippy::type_complexity)]
159 #trait_doc_attr
160 #vis trait #trait_name {
161 #(#rewritten_methods)*
162 }
163 };
164
165 let methods: Vec<MethodInfo> = input
166 .methods
167 .iter()
168 .map(MethodInfo::try_from_parsed)
169 .collect::<Result<_, _>>()?;
170
171 let client_name = format_ident!("{}Client", trait_name);
172 let server_name = format_ident!("{}Server", trait_name);
173
174 let client_methods_hardcoded = methods.iter().map(|m| {
176 let method_id = compute_method_id(&trait_name_str, &m.name.to_string());
177 generate_client_method(m, method_id, &trait_name_str, &rapace_crate)
178 });
179
180 let client_methods_registry = methods
182 .iter()
183 .enumerate()
184 .map(|(idx, m)| generate_client_method_registry(m, idx, &trait_name_str, &rapace_crate));
185
186 let dispatch_arms = methods.iter().map(|m| {
188 let method_id = compute_method_id(&trait_name_str, &m.name.to_string());
189 generate_dispatch_arm(m, method_id, &rapace_crate)
190 });
191
192 let streaming_dispatch_arms = methods.iter().map(|m| {
194 let method_id = compute_method_id(&trait_name_str, &m.name.to_string());
195 generate_streaming_dispatch_arm(m, method_id, &rapace_crate)
196 });
197
198 let is_streaming_method_fn = quote! {
203 fn __is_streaming_method_id(method_id: u32) -> bool {
204 ::#rapace_crate::registry::ServiceRegistry::with_global(|reg| {
205 reg.method_by_id(::#rapace_crate::registry::MethodId(method_id))
206 .map(|m| m.is_streaming)
207 .unwrap_or(false)
208 })
209 }
210 };
211
212 let method_id_consts = methods.iter().map(|m| {
214 let method_id = compute_method_id(&trait_name_str, &m.name.to_string());
215 let method_shouty = m.name.to_string().to_shouty_snake_case();
216 let const_name = format_ident!("{}_METHOD_ID_{}", trait_shouty, method_shouty);
217 quote! {
218 #vis const #const_name: u32 = #method_id;
219 }
220 });
221
222 let register_fn_name = format_ident!("{}_register", trait_snake);
224 let register_fn = generate_register_fn(
225 &trait_name_str,
226 &trait_doc,
227 &methods,
228 &rapace_crate,
229 ®ister_fn_name,
230 vis,
231 );
232
233 let registry_client_name = format_ident!("{}RegistryClient", trait_name);
235 let method_id_fields: Vec<_> = methods
236 .iter()
237 .map(|m| {
238 let field_name = format_ident!("{}_method_id", m.name);
239 quote! { #field_name: u32 }
240 })
241 .collect();
242 let method_id_lookups: Vec<_> = methods
243 .iter()
244 .map(|m| {
245 let field_name = format_ident!("{}_method_id", m.name);
246 let method_name = m.name.to_string();
247 quote! {
248 #field_name: registry.resolve_method_id(#trait_name_str, #method_name)
249 .expect(concat!("method ", #method_name, " not found in registry"))
250 .0
251 }
252 })
253 .collect();
254
255 let expanded = quote! {
256 #trait_tokens
258
259 #(#method_id_consts)*
260
261 #register_fn
262
263 #vis struct #client_name {
282 session: ::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession>,
283 }
284
285 impl #client_name {
286 pub fn new(session: ::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession>) -> Self {
294 Self { session }
295 }
296
297 pub fn session(&self) -> &::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession> {
299 &self.session
300 }
301
302 #(#client_methods_hardcoded)*
303 }
304
305 #vis struct #registry_client_name {
312 session: ::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession>,
313 #(pub #method_id_fields,)*
314 }
315
316 impl #registry_client_name {
317 pub fn new(session: ::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession>, registry: &::#rapace_crate::registry::ServiceRegistry) -> Self {
328 Self {
329 session,
330 #(#method_id_lookups,)*
331 }
332 }
333
334 pub fn session(&self) -> &::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession> {
336 &self.session
337 }
338
339 #(#client_methods_registry)*
340 }
341
342 #vis struct #server_name<S> {
348 service: S,
349 }
350
351 impl<S: #trait_name + Send + Sync + 'static> #server_name<S> {
352 fn __auto_register() {
357 use ::std::sync::OnceLock;
358 static REGISTERED: OnceLock<()> = OnceLock::new();
359
360 REGISTERED.get_or_init(|| {
361 ::#rapace_crate::registry::ServiceRegistry::with_global_mut(|registry| {
362 #register_fn_name(registry);
363 });
364 });
365 }
366
367 pub fn new(service: S) -> Self {
372 Self::__auto_register();
373 Self { service }
374 }
375
376 pub async fn serve(
388 self,
389 transport: ::#rapace_crate::rapace_core::Transport,
390 ) -> ::std::result::Result<(), ::#rapace_crate::rapace_core::RpcError> {
391 ::#rapace_crate::tracing::debug!("serve: entering loop, waiting for requests");
392 loop {
393 let request = match transport.recv_frame().await {
395 Ok(frame) => {
396 ::#rapace_crate::tracing::debug!(
397 method_id = frame.desc.method_id,
398 channel_id = frame.desc.channel_id,
399 flags = ?frame.desc.flags,
400 payload_len = frame.payload_bytes().len(),
401 "serve: received frame"
402 );
403 frame
404 }
405 Err(::#rapace_crate::rapace_core::TransportError::Closed) => {
406 ::#rapace_crate::tracing::debug!("serve: transport closed");
407 return Ok(());
409 }
410 Err(e) => {
411 ::#rapace_crate::tracing::error!(?e, "serve: transport error");
412 return Err(::#rapace_crate::rapace_core::RpcError::Transport(e));
413 }
414 };
415
416 if !request.desc.flags.contains(::#rapace_crate::rapace_core::FrameFlags::DATA) {
418 ::#rapace_crate::tracing::debug!("serve: skipping non-DATA frame");
419 continue;
420 }
421
422 ::#rapace_crate::tracing::debug!(
424 method_id = request.desc.method_id,
425 channel_id = request.desc.channel_id,
426 "serve: dispatching to dispatch_streaming"
427 );
428 if let Err(e) = self.dispatch_streaming(
429 request.desc.method_id,
430 request.desc.channel_id,
431 request.payload_bytes(),
432 &transport,
433 ).await {
434 ::#rapace_crate::tracing::error!(?e, "serve: dispatch_streaming returned error");
435 let mut desc = ::#rapace_crate::rapace_core::MsgDescHot::new();
437 desc.channel_id = request.desc.channel_id;
438 desc.flags = ::#rapace_crate::rapace_core::FrameFlags::ERROR | ::#rapace_crate::rapace_core::FrameFlags::EOS;
439
440 let (code, message): (u32, ::std::string::String) = match &e {
442 ::#rapace_crate::rapace_core::RpcError::Status { code, message } => (*code as u32, message.clone()),
443 ::#rapace_crate::rapace_core::RpcError::Transport(_) => (::#rapace_crate::rapace_core::ErrorCode::Internal as u32, "transport error".into()),
444 ::#rapace_crate::rapace_core::RpcError::Cancelled => (::#rapace_crate::rapace_core::ErrorCode::Cancelled as u32, "cancelled".into()),
445 ::#rapace_crate::rapace_core::RpcError::DeadlineExceeded => (::#rapace_crate::rapace_core::ErrorCode::DeadlineExceeded as u32, "deadline exceeded".into()),
446 };
447 let mut err_bytes = ::std::vec::Vec::with_capacity(8 + message.len());
448 err_bytes.extend_from_slice(&code.to_le_bytes());
449 err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
450 err_bytes.extend_from_slice(message.as_bytes());
451
452 let frame = ::#rapace_crate::rapace_core::Frame::with_payload(desc, err_bytes);
453 let _ = transport.send_frame(frame).await;
454 }
455 }
456 }
457
458 pub async fn serve_one(
463 &self,
464 transport: &#rapace_crate::rapace_core::Transport,
465 ) -> ::std::result::Result<(), #rapace_crate::rapace_core::RpcError> {
466 let request = transport.recv_frame().await
468 .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
469
470 if !request.desc.flags.contains(#rapace_crate::rapace_core::FrameFlags::DATA) {
472 return Ok(());
473 }
474
475 self.dispatch_streaming(
477 request.desc.method_id,
478 request.desc.channel_id,
479 request.payload_bytes(),
480 transport,
481 ).await
482 }
483
484 pub async fn dispatch(
489 &self,
490 method_id: u32,
491 request_payload: &[u8],
492 ) -> ::std::result::Result<#rapace_crate::rapace_core::Frame, #rapace_crate::rapace_core::RpcError> {
493 match method_id {
494 #(#dispatch_arms)*
495 _ => Err(#rapace_crate::rapace_core::RpcError::Status {
496 code: #rapace_crate::rapace_core::ErrorCode::Unimplemented,
497 message: ::std::format!("unknown method_id: {}", method_id),
498 }),
499 }
500 }
501
502 pub async fn dispatch_streaming(
506 &self,
507 method_id: u32,
508 channel_id: u32,
509 request_payload: &[u8],
510 transport: &#rapace_crate::rapace_core::Transport,
511 ) -> ::std::result::Result<(), #rapace_crate::rapace_core::RpcError> {
512 #rapace_crate::tracing::debug!(method_id, channel_id, "dispatch_streaming: entered");
513 match method_id {
514 #(#streaming_dispatch_arms)*
515 _ => Err(#rapace_crate::rapace_core::RpcError::Status {
516 code: #rapace_crate::rapace_core::ErrorCode::Unimplemented,
517 message: ::std::format!("unknown method_id: {}", method_id),
518 }),
519 }
520 }
521
522 #is_streaming_method_fn
523
524 pub fn into_session_dispatcher(
535 self,
536 transport: ::#rapace_crate::rapace_core::Transport,
537 ) -> impl Fn(
538 ::#rapace_crate::rapace_core::Frame,
539 ) -> ::std::pin::Pin<
540 Box<
541 dyn ::std::future::Future<
542 Output = ::std::result::Result<
543 ::#rapace_crate::rapace_core::Frame,
544 ::#rapace_crate::rapace_core::RpcError,
545 >,
546 > + Send
547 + 'static,
548 >,
549 > + Send
550 + Sync
551 + 'static {
552 use ::#rapace_crate::rapace_core::{ErrorCode, Frame, FrameFlags, MsgDescHot, RpcError};
553
554 let server = ::std::sync::Arc::new(self);
555 move |frame: Frame| {
556 let server = server.clone();
557 let transport = transport.clone();
558 Box::pin(async move {
559 let method_id = frame.desc.method_id;
560 let channel_id = frame.desc.channel_id;
561 let flags = frame.desc.flags;
562 let payload = frame.payload_bytes().to_vec();
563
564 if Self::__is_streaming_method_id(method_id) {
565 if !flags.contains(FrameFlags::NO_REPLY) {
567 return Err(RpcError::Status {
568 code: ErrorCode::InvalidArgument,
569 message: "streaming request missing NO_REPLY flag".into(),
570 });
571 }
572
573 if let Err(err) = server
575 .dispatch_streaming(method_id, channel_id, &payload, &transport)
576 .await
577 {
578 let (code, message): (u32, String) = match &err {
581 RpcError::Status { code, message } => (*code as u32, message.clone()),
582 RpcError::Transport(_) => (ErrorCode::Internal as u32, "transport error".into()),
583 RpcError::Cancelled => (ErrorCode::Cancelled as u32, "cancelled".into()),
584 RpcError::DeadlineExceeded => (ErrorCode::DeadlineExceeded as u32, "deadline exceeded".into()),
585 };
586
587 let mut err_bytes = Vec::with_capacity(8 + message.len());
588 err_bytes.extend_from_slice(&code.to_le_bytes());
589 err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
590 err_bytes.extend_from_slice(message.as_bytes());
591
592 let mut desc = MsgDescHot::new();
593 desc.channel_id = channel_id;
594 desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
595 let frame = Frame::with_payload(desc, err_bytes);
596 let _ = transport.send_frame(frame).await;
597 }
598
599 Ok(Frame::new(MsgDescHot::new()))
601 } else {
602 server.dispatch(method_id, &payload).await
603 }
604 })
605 }
606 }
607 }
608 };
609
610 Ok(expanded)
611}
612
613#[derive(Clone, Debug)]
615#[allow(clippy::large_enum_variant)]
616enum MethodKind {
617 Unary,
619 ServerStreaming {
621 item_type: TokenStream2,
623 },
624}
625
626struct MethodInfo {
627 name: Ident,
628 args: Vec<(Ident, TokenStream2)>, return_type: TokenStream2,
630 kind: MethodKind,
631 doc: String,
632}
633
634impl MethodInfo {
635 fn try_from_parsed(method: &parser::ParsedMethod) -> Result<Self, MacroError> {
636 let doc = join_doc_lines(&method.doc_lines);
637
638 let args = method
639 .args
640 .iter()
641 .map(|arg| (arg.name.clone(), arg.ty.clone()))
642 .collect();
643
644 let return_type = method.return_type.clone();
645 let kind = if let Some(item_type) = extract_streaming_return_type(&return_type) {
646 MethodKind::ServerStreaming { item_type }
647 } else {
648 MethodKind::Unary
649 };
650
651 Ok(Self {
652 name: method.name.clone(),
653 args,
654 return_type,
655 kind,
656 doc,
657 })
658 }
659}
660
661fn generate_client_method(
662 method: &MethodInfo,
663 method_id: u32,
664 service_name: &str,
665 rapace_crate: &TokenStream2,
666) -> TokenStream2 {
667 match &method.kind {
668 MethodKind::Unary => {
669 generate_client_method_unary(method, method_id, service_name, rapace_crate)
670 }
671 MethodKind::ServerStreaming { item_type } => generate_client_method_server_streaming(
672 method,
673 method_id,
674 service_name,
675 item_type,
676 rapace_crate,
677 ),
678 }
679}
680
681fn generate_client_method_registry(
682 method: &MethodInfo,
683 method_index: usize,
684 service_name: &str,
685 rapace_crate: &TokenStream2,
686) -> TokenStream2 {
687 match &method.kind {
688 MethodKind::Unary => {
689 generate_client_method_unary_registry(method, method_index, service_name, rapace_crate)
690 }
691 MethodKind::ServerStreaming { item_type } => {
692 generate_client_method_server_streaming_registry(
693 method,
694 method_index,
695 service_name,
696 item_type,
697 rapace_crate,
698 )
699 }
700 }
701}
702
703fn generate_client_method_unary(
704 method: &MethodInfo,
705 method_id: u32,
706 service_name: &str,
707 rapace_crate: &TokenStream2,
708) -> TokenStream2 {
709 let name = &method.name;
710 let method_name_str = name.to_string();
711 let return_type = &method.return_type;
712
713 let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
714 let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
715
716 let fn_args = arg_names.iter().zip(arg_types.iter()).map(|(name, ty)| {
718 quote! { #name: #ty }
719 });
720
721 let encode_expr = if arg_names.is_empty() {
723 quote! { #rapace_crate::facet_postcard::to_vec(&()).unwrap() }
724 } else if arg_names.len() == 1 {
725 let arg = &arg_names[0];
726 quote! { #rapace_crate::facet_postcard::to_vec(&#arg).unwrap() }
727 } else {
728 quote! { #rapace_crate::facet_postcard::to_vec(&(#(#arg_names.clone()),*)).unwrap() }
729 };
730
731 quote! {
732 pub async fn #name(&self, #(#fn_args),*) -> ::std::result::Result<#return_type, #rapace_crate::rapace_core::RpcError> {
734 use #rapace_crate::rapace_core::FrameFlags;
735
736 let request_bytes: ::std::vec::Vec<u8> = #encode_expr;
738
739 let channel_id = self.session.next_channel_id();
741 #rapace_crate::tracing::debug!(
742 service = #service_name,
743 method = #method_name_str,
744 method_id = #method_id,
745 channel_id,
746 "RPC call start"
747 );
748 let response = self.session.call(channel_id, #method_id, request_bytes).await?;
749 #rapace_crate::tracing::debug!(
750 service = #service_name,
751 method = #method_name_str,
752 method_id = #method_id,
753 channel_id,
754 "RPC call complete"
755 );
756
757 if response.flags().contains(FrameFlags::ERROR) {
759 return Err(#rapace_crate::rapace_core::parse_error_payload(response.payload_bytes()));
760 }
761
762 let result: #return_type = #rapace_crate::facet_postcard::from_slice(response.payload_bytes())
764 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
765 code: #rapace_crate::rapace_core::ErrorCode::Internal,
766 message: ::std::format!("decode error: {:?}", e),
767 })?;
768
769 Ok(result)
770 }
771 }
772}
773
774fn extract_streaming_return_type(ty: &TokenStream2) -> Option<TokenStream2> {
775 let tokens: Vec<TokenTree> = ty.clone().into_iter().collect();
776 let mut index = 0;
777 while index < tokens.len() {
778 match &tokens[index] {
779 TokenTree::Ident(ident) if ident == "Streaming" => {
780 let mut search = index + 1;
781 while search < tokens.len() {
782 match &tokens[search] {
783 TokenTree::Punct(p) if p.as_char() == '<' => {
784 let inner = collect_generic_tokens(&tokens, search)?;
785 return select_stream_item_type(inner);
786 }
787 TokenTree::Punct(p) if p.as_char() == ':' => {
788 search += 1;
789 continue;
790 }
791 _ => break,
792 }
793 }
794 }
795 _ => {}
796 }
797 index += 1;
798 }
799 None
800}
801
802fn collect_generic_tokens(tokens: &[TokenTree], start: usize) -> Option<TokenStream2> {
803 let mut depth = 0usize;
804 let mut inner = TokenStream2::new();
805 let mut i = start;
806 while i < tokens.len() {
807 match &tokens[i] {
808 TokenTree::Punct(p) if p.as_char() == '<' => {
809 depth += 1;
810 if depth > 1 {
811 inner.extend(std::iter::once(tokens[i].clone()));
812 }
813 }
814 TokenTree::Punct(p) if p.as_char() == '>' => {
815 if depth == 0 {
816 return None;
817 }
818 depth -= 1;
819 if depth == 0 {
820 return Some(inner);
821 }
822 inner.extend(std::iter::once(tokens[i].clone()));
823 }
824 other => {
825 if depth >= 1 {
826 inner.extend(std::iter::once(other.clone()));
827 }
828 }
829 }
830 i += 1;
831 }
832 None
833}
834
835fn select_stream_item_type(inner: TokenStream2) -> Option<TokenStream2> {
836 let segments = split_top_level(inner, ',');
837 for segment in segments.into_iter().rev() {
838 let text = segment.to_string();
839 if text.trim().is_empty() {
840 continue;
841 }
842 if text.trim_start().starts_with('\'') {
843 continue;
844 }
845 return Some(segment);
846 }
847 None
848}
849
850fn split_top_level(tokens: TokenStream2, delimiter: char) -> Vec<TokenStream2> {
851 let mut parts = Vec::new();
852 let mut current = TokenStream2::new();
853 let mut angle_depth = 0usize;
854 for tt in tokens.into_iter() {
855 match &tt {
856 TokenTree::Punct(p) if p.as_char() == '<' => {
857 angle_depth += 1;
858 current.extend(std::iter::once(tt));
859 }
860 TokenTree::Punct(p) if p.as_char() == '>' => {
861 angle_depth = angle_depth.saturating_sub(1);
862 current.extend(std::iter::once(tt));
863 }
864 TokenTree::Punct(p) if p.as_char() == delimiter && angle_depth == 0 => {
865 parts.push(current);
866 current = TokenStream2::new();
867 continue;
868 }
869 _ => current.extend(std::iter::once(tt)),
870 }
871 }
872 parts.push(current);
873 parts
874}
875
876fn generate_client_method_server_streaming(
877 method: &MethodInfo,
878 method_id: u32,
879 service_name: &str,
880 item_type: &TokenStream2,
881 rapace_crate: &TokenStream2,
882) -> TokenStream2 {
883 let name = &method.name;
884 let method_name_str = name.to_string();
885
886 let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
887 let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
888
889 let fn_args = arg_names.iter().zip(arg_types.iter()).map(|(name, ty)| {
891 quote! { #name: #ty }
892 });
893
894 let encode_expr = if arg_names.is_empty() {
896 quote! { #rapace_crate::facet_postcard::to_vec(&()).unwrap() }
897 } else if arg_names.len() == 1 {
898 let arg = &arg_names[0];
899 quote! { #rapace_crate::facet_postcard::to_vec(&#arg).unwrap() }
900 } else {
901 quote! { #rapace_crate::facet_postcard::to_vec(&(#(#arg_names.clone()),*)).unwrap() }
902 };
903
904 quote! {
905 pub async fn #name(&self, #(#fn_args),*) -> ::std::result::Result<#rapace_crate::rapace_core::Streaming<#item_type>, #rapace_crate::rapace_core::RpcError> {
911 use #rapace_crate::rapace_core::{ErrorCode, RpcError};
912
913 #rapace_crate::tracing::debug!(
914 service = #service_name,
915 method = #method_name_str,
916 method_id = #method_id,
917 "RPC streaming call start"
918 );
919
920 let request_bytes: ::std::vec::Vec<u8> = #encode_expr;
921
922 let mut rx = self.session
924 .start_streaming_call(#method_id, request_bytes)
925 .await?;
926
927 let stream = #rapace_crate::rapace_core::try_stream! {
929 while let Some(chunk) = rx.recv().await {
930 if chunk.is_error() {
932 let err = #rapace_crate::rapace_core::parse_error_payload(chunk.payload_bytes());
933 Err(err)?;
934 }
935
936 if chunk.is_eos() && chunk.payload_bytes().is_empty() {
938 break;
939 }
940
941 let item: #item_type = #rapace_crate::facet_postcard::from_slice(chunk.payload_bytes())
943 .map_err(|e| RpcError::Status {
944 code: ErrorCode::Internal,
945 message: ::std::format!("decode error: {:?}", e),
946 })?;
947
948 yield item;
949 }
950 };
951
952 Ok(::std::boxed::Box::pin(stream))
953 }
954 }
955}
956
957fn generate_dispatch_arm(
958 method: &MethodInfo,
959 method_id: u32,
960 rapace_crate: &TokenStream2,
961) -> TokenStream2 {
962 match &method.kind {
963 MethodKind::Unary => generate_dispatch_arm_unary(method, method_id, rapace_crate),
964 MethodKind::ServerStreaming { .. } => {
965 quote! {
968 #method_id => {
969 Err(#rapace_crate::rapace_core::RpcError::Status {
970 code: #rapace_crate::rapace_core::ErrorCode::Internal,
971 message: "streaming method called via unary dispatch".into(),
972 })
973 }
974 }
975 }
976 }
977}
978
979fn generate_streaming_dispatch_arm(
980 method: &MethodInfo,
981 method_id: u32,
982 rapace_crate: &TokenStream2,
983) -> TokenStream2 {
984 match &method.kind {
985 MethodKind::Unary => {
986 let name = &method.name;
988 let return_type = &method.return_type;
989 let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
990 let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
991
992 let decode_and_call = if arg_names.is_empty() {
993 quote! {
994 let result: #return_type = self.service.#name().await;
995 }
996 } else if arg_names.len() == 1 {
997 let arg = &arg_names[0];
998 let ty = &arg_types[0];
999 quote! {
1000 let #arg: #ty = #rapace_crate::facet_postcard::from_slice(request_payload)
1001 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1002 code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1003 message: ::std::format!("decode error: {:?}", e),
1004 })?;
1005 let result: #return_type = self.service.#name(#arg).await;
1006 }
1007 } else {
1008 let tuple_type = quote! { (#(#arg_types),*) };
1009 quote! {
1010 let (#(#arg_names),*): #tuple_type = #rapace_crate::facet_postcard::from_slice(request_payload)
1011 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1012 code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1013 message: ::std::format!("decode error: {:?}", e),
1014 })?;
1015 let result: #return_type = self.service.#name(#(#arg_names),*).await;
1016 }
1017 };
1018
1019 quote! {
1020 #method_id => {
1021 #decode_and_call
1022
1023 let response_bytes: ::std::vec::Vec<u8> = #rapace_crate::facet_postcard::to_vec(&result)
1025 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1026 code: #rapace_crate::rapace_core::ErrorCode::Internal,
1027 message: ::std::format!("encode error: {:?}", e),
1028 })?;
1029
1030 let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1031 desc.channel_id = channel_id;
1032 desc.flags = #rapace_crate::rapace_core::FrameFlags::DATA | #rapace_crate::rapace_core::FrameFlags::EOS;
1033
1034 let frame = if response_bytes.len() <= #rapace_crate::rapace_core::INLINE_PAYLOAD_SIZE {
1035 #rapace_crate::rapace_core::Frame::with_inline_payload(desc, &response_bytes)
1036 .expect("inline payload should fit")
1037 } else {
1038 #rapace_crate::rapace_core::Frame::with_payload(desc, response_bytes)
1039 };
1040
1041 transport.send_frame(frame).await
1042 .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
1043 Ok(())
1044 }
1045 }
1046 }
1047 MethodKind::ServerStreaming { .. } => {
1048 generate_streaming_dispatch_arm_server_streaming(method, method_id, rapace_crate)
1049 }
1050 }
1051}
1052
1053fn generate_streaming_dispatch_arm_server_streaming(
1054 method: &MethodInfo,
1055 method_id: u32,
1056 rapace_crate: &TokenStream2,
1057) -> TokenStream2 {
1058 let name = &method.name;
1059 let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
1060 let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
1061
1062 let decode_args = if arg_names.is_empty() {
1063 quote! {}
1064 } else if arg_names.len() == 1 {
1065 let arg = &arg_names[0];
1066 let ty = &arg_types[0];
1067 quote! {
1068 let #arg: #ty = #rapace_crate::facet_postcard::from_slice(request_payload)
1069 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1070 code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1071 message: ::std::format!("decode error: {:?}", e),
1072 })?;
1073 }
1074 } else {
1075 let tuple_type = quote! { (#(#arg_types),*) };
1076 quote! {
1077 let (#(#arg_names),*): #tuple_type = #rapace_crate::facet_postcard::from_slice(request_payload)
1078 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1079 code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1080 message: ::std::format!("decode error: {:?}", e),
1081 })?;
1082 }
1083 };
1084
1085 let call_args = if arg_names.is_empty() {
1086 quote! {}
1087 } else {
1088 quote! { #(#arg_names),* }
1089 };
1090
1091 quote! {
1092 #method_id => {
1093 #decode_args
1094
1095 let mut stream = self.service.#name(#call_args).await;
1097
1098 use #rapace_crate::futures::stream::StreamExt;
1100 #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: starting to iterate stream");
1101
1102 loop {
1103 #rapace_crate::tracing::trace!(channel_id, "streaming dispatch: waiting for next item");
1104 match stream.next().await {
1105 Some(Ok(item)) => {
1106 #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: got item, encoding");
1107 let item_bytes: ::std::vec::Vec<u8> = #rapace_crate::facet_postcard::to_vec(&item)
1109 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1110 code: #rapace_crate::rapace_core::ErrorCode::Internal,
1111 message: ::std::format!("encode error: {:?}", e),
1112 })?;
1113
1114 let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1116 desc.channel_id = channel_id;
1117 desc.flags = #rapace_crate::rapace_core::FrameFlags::DATA;
1118
1119 let frame = if item_bytes.len() <= #rapace_crate::rapace_core::INLINE_PAYLOAD_SIZE {
1120 #rapace_crate::rapace_core::Frame::with_inline_payload(desc, &item_bytes)
1121 .expect("inline payload should fit")
1122 } else {
1123 #rapace_crate::rapace_core::Frame::with_payload(desc, item_bytes)
1124 };
1125
1126 #rapace_crate::tracing::debug!(channel_id, payload_len = frame.payload_bytes().len(), "streaming dispatch: sending DATA frame");
1127 transport.send_frame(frame).await
1128 .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
1129 #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: DATA frame sent");
1130 }
1131 Some(Err(err)) => {
1132 #rapace_crate::tracing::warn!(channel_id, ?err, "streaming dispatch: got error from stream");
1133 let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1135 desc.channel_id = channel_id;
1136 desc.flags = #rapace_crate::rapace_core::FrameFlags::ERROR | #rapace_crate::rapace_core::FrameFlags::EOS;
1137
1138 let (code, message): (u32, &str) = match &err {
1140 #rapace_crate::rapace_core::RpcError::Status { code, message } => (*code as u32, message.as_str()),
1141 #rapace_crate::rapace_core::RpcError::Transport(_) => (#rapace_crate::rapace_core::ErrorCode::Internal as u32, "transport error"),
1142 #rapace_crate::rapace_core::RpcError::Cancelled => (#rapace_crate::rapace_core::ErrorCode::Cancelled as u32, "cancelled"),
1143 #rapace_crate::rapace_core::RpcError::DeadlineExceeded => (#rapace_crate::rapace_core::ErrorCode::DeadlineExceeded as u32, "deadline exceeded"),
1144 };
1145 let mut err_bytes = Vec::with_capacity(8 + message.len());
1146 err_bytes.extend_from_slice(&code.to_le_bytes());
1147 err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
1148 err_bytes.extend_from_slice(message.as_bytes());
1149
1150 let frame = #rapace_crate::rapace_core::Frame::with_payload(desc, err_bytes);
1151 transport.send_frame(frame).await
1152 .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
1153 return Ok(());
1154 }
1155 None => {
1156 #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: stream ended, sending EOS");
1157 let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1159 desc.channel_id = channel_id;
1160 desc.flags = #rapace_crate::rapace_core::FrameFlags::EOS;
1161 let frame = #rapace_crate::rapace_core::Frame::new(desc);
1162 transport.send_frame(frame).await
1163 .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
1164 #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: EOS sent, returning");
1165 return Ok(());
1166 }
1167 }
1168 }
1169 }
1170 }
1171}
1172
1173fn generate_dispatch_arm_unary(
1174 method: &MethodInfo,
1175 method_id: u32,
1176 rapace_crate: &TokenStream2,
1177) -> TokenStream2 {
1178 let name = &method.name;
1179 let return_type = &method.return_type;
1180 let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
1181 let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
1182
1183 let decode_and_call = if arg_names.is_empty() {
1185 quote! {
1186 let result: #return_type = self.service.#name().await;
1188 }
1189 } else if arg_names.len() == 1 {
1190 let arg = &arg_names[0];
1191 let ty = &arg_types[0];
1192 quote! {
1193 let #arg: #ty = #rapace_crate::facet_postcard::from_slice(request_payload)
1194 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1195 code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1196 message: ::std::format!("decode error: {:?}", e),
1197 })?;
1198 let result: #return_type = self.service.#name(#arg).await;
1199 }
1200 } else {
1201 let tuple_type = quote! { (#(#arg_types),*) };
1203 quote! {
1204 let (#(#arg_names),*): #tuple_type = #rapace_crate::facet_postcard::from_slice(request_payload)
1205 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1206 code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1207 message: ::std::format!("decode error: {:?}", e),
1208 })?;
1209 let result: #return_type = self.service.#name(#(#arg_names),*).await;
1210 }
1211 };
1212
1213 quote! {
1214 #method_id => {
1215 #decode_and_call
1216
1217 let response_bytes: ::std::vec::Vec<u8> = #rapace_crate::facet_postcard::to_vec(&result)
1219 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1220 code: #rapace_crate::rapace_core::ErrorCode::Internal,
1221 message: ::std::format!("encode error: {:?}", e),
1222 })?;
1223
1224 let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1226 desc.flags = #rapace_crate::rapace_core::FrameFlags::DATA | #rapace_crate::rapace_core::FrameFlags::EOS;
1227
1228 let frame = if response_bytes.len() <= #rapace_crate::rapace_core::INLINE_PAYLOAD_SIZE {
1229 #rapace_crate::rapace_core::Frame::with_inline_payload(desc, &response_bytes)
1230 .expect("inline payload should fit")
1231 } else {
1232 #rapace_crate::rapace_core::Frame::with_payload(desc, response_bytes)
1233 };
1234
1235 Ok(frame)
1236 }
1237 }
1238}
1239
1240fn generate_register_fn(
1245 service_name: &str,
1246 service_doc: &str,
1247 methods: &[MethodInfo],
1248 rapace_crate: &TokenStream2,
1249 register_fn_name: &Ident,
1250 vis: &TokenStream2,
1251) -> TokenStream2 {
1252 let method_registrations: Vec<TokenStream2> = methods
1253 .iter()
1254 .map(|m| {
1255 let method_name = m.name.to_string();
1256 let method_doc = &m.doc;
1257 let arg_types: Vec<_> = m.args.iter().map(|(_, ty)| ty).collect();
1258
1259 let arg_infos: Vec<TokenStream2> = m
1261 .args
1262 .iter()
1263 .map(|(name, ty)| {
1264 let name_str = name.to_string();
1265 let type_str = quote!(#ty).to_string();
1266 quote! {
1267 #rapace_crate::registry::ArgInfo {
1268 name: #name_str,
1269 type_name: #type_str,
1270 }
1271 }
1272 })
1273 .collect();
1274
1275 let request_shape_expr = if arg_types.is_empty() {
1277 quote! { <() as #rapace_crate::facet_core::Facet>::SHAPE }
1278 } else if arg_types.len() == 1 {
1279 let ty = &arg_types[0];
1280 quote! { <#ty as #rapace_crate::facet_core::Facet>::SHAPE }
1281 } else {
1282 quote! { <(#(#arg_types),*) as #rapace_crate::facet_core::Facet>::SHAPE }
1283 };
1284
1285 let response_shape_expr = match &m.kind {
1287 MethodKind::Unary => {
1288 let return_type = &m.return_type;
1289 quote! { <#return_type as #rapace_crate::facet_core::Facet>::SHAPE }
1290 }
1291 MethodKind::ServerStreaming { item_type } => {
1292 quote! { <#item_type as #rapace_crate::facet_core::Facet>::SHAPE }
1293 }
1294 };
1295
1296 let is_streaming = matches!(m.kind, MethodKind::ServerStreaming { .. });
1298
1299 if is_streaming {
1300 quote! {
1301 builder.add_streaming_method(
1302 #method_name,
1303 #method_doc,
1304 vec![#(#arg_infos),*],
1305 #request_shape_expr,
1306 #response_shape_expr,
1307 );
1308 }
1309 } else {
1310 quote! {
1311 builder.add_method(
1312 #method_name,
1313 #method_doc,
1314 vec![#(#arg_infos),*],
1315 #request_shape_expr,
1316 #response_shape_expr,
1317 );
1318 }
1319 }
1320 })
1321 .collect();
1322
1323 quote! {
1324 #vis fn #register_fn_name(registry: &mut #rapace_crate::registry::ServiceRegistry) {
1329 let mut builder = registry.register_service(#service_name, #service_doc);
1330 #(#method_registrations)*
1331 builder.finish();
1332 }
1333 }
1334}
1335
1336fn generate_client_method_unary_registry(
1338 method: &MethodInfo,
1339 _method_index: usize,
1340 service_name: &str,
1341 rapace_crate: &TokenStream2,
1342) -> TokenStream2 {
1343 let name = &method.name;
1344 let method_name_str = name.to_string();
1345 let return_type = &method.return_type;
1346 let method_id_field = format_ident!("{}_method_id", name);
1347
1348 let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
1349 let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
1350
1351 let fn_args = arg_names.iter().zip(arg_types.iter()).map(|(name, ty)| {
1352 quote! { #name: #ty }
1353 });
1354
1355 let encode_expr = if arg_names.is_empty() {
1356 quote! { #rapace_crate::facet_postcard::to_vec(&()).unwrap() }
1357 } else if arg_names.len() == 1 {
1358 let arg = &arg_names[0];
1359 quote! { #rapace_crate::facet_postcard::to_vec(&#arg).unwrap() }
1360 } else {
1361 quote! { #rapace_crate::facet_postcard::to_vec(&(#(#arg_names.clone()),*)).unwrap() }
1362 };
1363
1364 quote! {
1365 pub async fn #name(&self, #(#fn_args),*) -> ::std::result::Result<#return_type, #rapace_crate::rapace_core::RpcError> {
1367 use #rapace_crate::rapace_core::FrameFlags;
1368
1369 let request_bytes: ::std::vec::Vec<u8> = #encode_expr;
1370
1371 let channel_id = self.session.next_channel_id();
1373 #rapace_crate::tracing::debug!(
1374 service = #service_name,
1375 method = #method_name_str,
1376 method_id = self.#method_id_field,
1377 channel_id,
1378 "RPC call start"
1379 );
1380 let response = self.session.call(channel_id, self.#method_id_field, request_bytes).await?;
1381 #rapace_crate::tracing::debug!(
1382 service = #service_name,
1383 method = #method_name_str,
1384 method_id = self.#method_id_field,
1385 channel_id,
1386 "RPC call complete"
1387 );
1388
1389 if response.flags().contains(FrameFlags::ERROR) {
1390 return Err(#rapace_crate::rapace_core::parse_error_payload(response.payload_bytes()));
1391 }
1392
1393 let result: #return_type = #rapace_crate::facet_postcard::from_slice(response.payload_bytes())
1394 .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1395 code: #rapace_crate::rapace_core::ErrorCode::Internal,
1396 message: ::std::format!("decode error: {:?}", e),
1397 })?;
1398
1399 Ok(result)
1400 }
1401 }
1402}
1403
1404fn generate_client_method_server_streaming_registry(
1406 method: &MethodInfo,
1407 _method_index: usize,
1408 service_name: &str,
1409 item_type: &TokenStream2,
1410 rapace_crate: &TokenStream2,
1411) -> TokenStream2 {
1412 let name = &method.name;
1413 let method_name_str = name.to_string();
1414 let method_id_field = format_ident!("{}_method_id", name);
1415
1416 let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
1417 let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
1418
1419 let fn_args = arg_names.iter().zip(arg_types.iter()).map(|(name, ty)| {
1420 quote! { #name: #ty }
1421 });
1422
1423 let encode_expr = if arg_names.is_empty() {
1425 quote! { #rapace_crate::facet_postcard::to_vec(&()).unwrap() }
1426 } else if arg_names.len() == 1 {
1427 let arg = &arg_names[0];
1428 quote! { #rapace_crate::facet_postcard::to_vec(&#arg).unwrap() }
1429 } else {
1430 quote! { #rapace_crate::facet_postcard::to_vec(&(#(#arg_names.clone()),*)).unwrap() }
1431 };
1432
1433 quote! {
1434 pub async fn #name(&self, #(#fn_args),*) -> ::std::result::Result<#rapace_crate::rapace_core::Streaming<#item_type>, #rapace_crate::rapace_core::RpcError> {
1440 use #rapace_crate::rapace_core::{ErrorCode, RpcError};
1441
1442 #rapace_crate::tracing::debug!(
1443 service = #service_name,
1444 method = #method_name_str,
1445 method_id = self.#method_id_field,
1446 "RPC streaming call start"
1447 );
1448
1449 let request_bytes: ::std::vec::Vec<u8> = #encode_expr;
1450
1451 let mut rx = self.session
1453 .start_streaming_call(self.#method_id_field, request_bytes)
1454 .await?;
1455
1456 let stream = #rapace_crate::rapace_core::try_stream! {
1458 while let Some(chunk) = rx.recv().await {
1459 if chunk.is_error() {
1461 let err = #rapace_crate::rapace_core::parse_error_payload(chunk.payload_bytes());
1462 Err(err)?;
1463 }
1464
1465 if chunk.is_eos() && chunk.payload_bytes().is_empty() {
1467 break;
1468 }
1469
1470 let item: #item_type = #rapace_crate::facet_postcard::from_slice(chunk.payload_bytes())
1472 .map_err(|e| RpcError::Status {
1473 code: ErrorCode::Internal,
1474 message: ::std::format!("decode error: {:?}", e),
1475 })?;
1476
1477 yield item;
1478 }
1479 };
1480
1481 Ok(::std::boxed::Box::pin(stream))
1482 }
1483 }
1484}