rapace_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use heck::{ToShoutySnakeCase, ToSnakeCase};
4use proc_macro::TokenStream;
5use proc_macro_crate::{FoundCrate, crate_name};
6use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8
9mod parser;
10
11use parser::{Error as MacroError, ParsedTrait, join_doc_lines, parse_trait};
12
13/// Compute a method ID by hashing "ServiceName.method_name" using FNV-1a.
14///
15/// This generates globally unique method IDs without requiring sequential
16/// assignment or a central registry. The hash is truncated to 32 bits.
17fn compute_method_id(service_name: &str, method_name: &str) -> u32 {
18    // FNV-1a hash constants
19    const FNV_OFFSET: u64 = 0xcbf29ce484222325;
20    const FNV_PRIME: u64 = 0x100000001b3;
21
22    let mut hash = FNV_OFFSET;
23
24    // Hash "ServiceName.method_name"
25    for byte in service_name.bytes() {
26        hash ^= byte as u64;
27        hash = hash.wrapping_mul(FNV_PRIME);
28    }
29    // Hash the dot separator
30    hash ^= b'.' as u64;
31    hash = hash.wrapping_mul(FNV_PRIME);
32    // Hash method name
33    for byte in method_name.bytes() {
34        hash ^= byte as u64;
35        hash = hash.wrapping_mul(FNV_PRIME);
36    }
37
38    // Truncate to 32 bits - we XOR the high and low halves to preserve entropy
39    ((hash >> 32) ^ hash) as u32
40}
41
42/// Generates RPC client and server from a trait definition.
43///
44/// # Example
45///
46/// ```ignore
47/// #[rapace::service]
48/// trait Calculator {
49///     async fn add(&self, a: i32, b: i32) -> i32;
50/// }
51///
52/// // Generated:
53/// // - CalculatorClient<T: Transport> with async fn add(&self, a: i32, b: i32) -> Result<i32, RpcError>
54/// // - CalculatorServer<S: Calculator> with dispatch method
55/// ```
56///
57/// # Streaming RPCs
58///
59/// For server-streaming, return `Streaming<T>`:
60///
61/// ```ignore
62/// use rapace_core::Streaming;
63///
64/// #[rapace::service]
65/// trait RangeService {
66///     async fn range(&self, n: u32) -> Streaming<u32>;
67/// }
68/// ```
69///
70/// The client method becomes:
71/// `async fn range(&self, n: u32) -> Result<Streaming<u32>, RpcError>`
72#[proc_macro_attribute]
73pub fn service(_attr: TokenStream, item: TokenStream) -> TokenStream {
74    let trait_tokens = TokenStream2::from(item.clone());
75
76    let parsed_trait = match parse_trait(&trait_tokens) {
77        Ok(parsed) => parsed,
78        Err(err) => return err.to_compile_error().into(),
79    };
80
81    match generate_service(&parsed_trait) {
82        Ok(tokens) => tokens.into(),
83        Err(err) => err.to_compile_error().into(),
84    }
85}
86
87fn generate_service(input: &ParsedTrait) -> Result<TokenStream2, MacroError> {
88    let trait_name = &input.ident;
89    let trait_name_str = trait_name.to_string();
90    let trait_snake = trait_name_str.to_snake_case();
91    let trait_shouty = trait_name_str.to_shouty_snake_case();
92    let vis = &input.vis_tokens;
93
94    // Detect the rapace crate name
95    // Try to find `rapace` first (the facade crate).
96    // If not found, check if we're IN the rapace crate itself.
97    // If neither, check if rapace_core is available (for internal crates).
98    let rapace_crate = match crate_name("rapace") {
99        Ok(FoundCrate::Itself) => quote!(rapace),
100        Ok(FoundCrate::Name(name)) => {
101            let ident = Ident::new(&name, Span::call_site());
102            quote!(#ident)
103        }
104        Err(_) => {
105            // rapace not found - check if this is an internal crate with direct dependencies
106            if crate_name("rapace_core").is_ok() {
107                // We have rapace_core - this is likely an internal crate
108                // Create a local rapace module that re-exports what we need
109                return Err(MacroError::new(
110                    Span::call_site(),
111                    "Internal crates using rapace_macros must add `rapace` as a dependency, \
112                     or you can create a facade module. See rapace-testkit for an example.",
113                ));
114            } else {
115                return Err(MacroError::new(
116                    Span::call_site(),
117                    "rapace crate not found in dependencies. Add `rapace = \"...\"` to your Cargo.toml",
118                ));
119            }
120        }
121    };
122
123    // Capture trait doc comments
124    let trait_doc = join_doc_lines(&input.doc_lines);
125
126    // Rewrite the user-authored `async fn` trait into an RPITIT trait that guarantees
127    // `Send` futures.
128    //
129    // Why: `RpcSession` dispatch uses `tokio::spawn`, which requires the dispatcher
130    // future to be `Send`. Boxing would satisfy this but is ergonomically awful for
131    // implementors. RPITIT lets implementors keep writing `async fn` while we encode
132    // the `Send` requirement in the trait contract.
133    let trait_doc_attr = if trait_doc.is_empty() {
134        quote! {}
135    } else {
136        quote! { #[doc = #trait_doc] }
137    };
138    let rewritten_methods = input.methods.iter().map(|m| {
139        let method_name = &m.name;
140        let method_doc = join_doc_lines(&m.doc_lines);
141        let method_doc_attr = if method_doc.is_empty() {
142            quote! {}
143        } else {
144            quote! { #[doc = #method_doc] }
145        };
146        let args = m.args.iter().map(|a| {
147            let name = &a.name;
148            let ty = &a.ty;
149            quote! { #name: #ty }
150        });
151        let return_type = &m.return_type;
152        quote! {
153            #method_doc_attr
154            fn #method_name(&self, #(#args),*) -> impl ::std::future::Future<Output = #return_type> + Send + '_;
155        }
156    });
157    let trait_tokens = quote! {
158        #[allow(clippy::type_complexity)]
159        #trait_doc_attr
160        #vis trait #trait_name {
161            #(#rewritten_methods)*
162        }
163    };
164
165    let methods: Vec<MethodInfo> = input
166        .methods
167        .iter()
168        .map(MethodInfo::try_from_parsed)
169        .collect::<Result<_, _>>()?;
170
171    let client_name = format_ident!("{}Client", trait_name);
172    let server_name = format_ident!("{}Server", trait_name);
173
174    // Generate client methods with hashed method IDs
175    let client_methods_hardcoded = methods.iter().map(|m| {
176        let method_id = compute_method_id(&trait_name_str, &m.name.to_string());
177        generate_client_method(m, method_id, &trait_name_str, &rapace_crate)
178    });
179
180    // Generate client methods that use stored method IDs from registry
181    let client_methods_registry = methods
182        .iter()
183        .enumerate()
184        .map(|(idx, m)| generate_client_method_registry(m, idx, &trait_name_str, &rapace_crate));
185
186    // Generate server dispatch arms (for unary and error fallback)
187    let dispatch_arms = methods.iter().map(|m| {
188        let method_id = compute_method_id(&trait_name_str, &m.name.to_string());
189        generate_dispatch_arm(m, method_id, &rapace_crate)
190    });
191
192    // Generate streaming dispatch arms
193    let streaming_dispatch_arms = methods.iter().map(|m| {
194        let method_id = compute_method_id(&trait_name_str, &m.name.to_string());
195        generate_streaming_dispatch_arm(m, method_id, &rapace_crate)
196    });
197
198    // Generate a helper that identifies streaming method IDs by consulting the registry.
199    //
200    // This avoids maintaining a separate "streaming method id set" in generated code and
201    // keeps the answer consistent with the on-wire method id space.
202    let is_streaming_method_fn = quote! {
203        fn __is_streaming_method_id(method_id: u32) -> bool {
204            ::#rapace_crate::registry::ServiceRegistry::with_global(|reg| {
205                reg.method_by_id(::#rapace_crate::registry::MethodId(method_id))
206                    .map(|m| m.is_streaming)
207                    .unwrap_or(false)
208            })
209        }
210    };
211
212    // Generate method ID constants
213    let method_id_consts = methods.iter().map(|m| {
214        let method_id = compute_method_id(&trait_name_str, &m.name.to_string());
215        let method_shouty = m.name.to_string().to_shouty_snake_case();
216        let const_name = format_ident!("{}_METHOD_ID_{}", trait_shouty, method_shouty);
217        quote! {
218            #vis const #const_name: u32 = #method_id;
219        }
220    });
221
222    // Generate registry registration code
223    let register_fn_name = format_ident!("{}_register", trait_snake);
224    let register_fn = generate_register_fn(
225        &trait_name_str,
226        &trait_doc,
227        &methods,
228        &rapace_crate,
229        &register_fn_name,
230        vis,
231    );
232
233    // Generate registry-aware client struct and constructor
234    let registry_client_name = format_ident!("{}RegistryClient", trait_name);
235    let method_id_fields: Vec<_> = methods
236        .iter()
237        .map(|m| {
238            let field_name = format_ident!("{}_method_id", m.name);
239            quote! { #field_name: u32 }
240        })
241        .collect();
242    let method_id_lookups: Vec<_> = methods
243        .iter()
244        .map(|m| {
245            let field_name = format_ident!("{}_method_id", m.name);
246            let method_name = m.name.to_string();
247            quote! {
248                #field_name: registry.resolve_method_id(#trait_name_str, #method_name)
249                    .expect(concat!("method ", #method_name, " not found in registry"))
250                    .0
251            }
252        })
253        .collect();
254
255    let expanded = quote! {
256        // Rewritten Send-future trait
257        #trait_tokens
258
259        #(#method_id_consts)*
260
261        #register_fn
262
263        /// Client stub for the #trait_name service.
264        ///
265        /// This client uses hardcoded method IDs (1, 2, ...) and expects an
266        /// [`Arc<RpcSession>`](::std::sync::Arc) whose
267        /// [`run`](::#rapace_crate::rapace_core::RpcSession::run) task is already
268        /// running. Construct sessions with [`RpcSession::with_channel_start`](::#rapace_crate::rapace_core::RpcSession::with_channel_start) to
269        /// coordinate odd/even channel IDs when both peers initiate RPCs.
270        /// For multi-service scenarios where method IDs must be globally unique,
271        /// use [`#registry_client_name`] instead.
272        ///
273        /// # Usage
274        ///
275        /// ```ignore
276        /// let session = Arc::new(RpcSession::new(transport));
277        /// tokio::spawn(session.clone().run()); // Start the demux loop
278        /// let client = FooClient::new(session);
279        /// let result = client.some_method(args).await?;
280        /// ```
281        #vis struct #client_name {
282            session: ::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession>,
283        }
284
285        impl #client_name {
286            /// Create a new client with the given RPC session.
287            ///
288            /// Uses compile-time, on-wire method IDs (hashed `Service.method`).
289            /// For registry-resolved method IDs, use [`#registry_client_name::new`].
290            ///
291            /// The provided session must be shared (`Arc::clone`) with the call site
292            /// and have its demux loop (`tokio::spawn(session.clone().run())`) running.
293            pub fn new(session: ::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession>) -> Self {
294                Self { session }
295            }
296
297            /// Get a reference to the underlying session.
298            pub fn session(&self) -> &::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession> {
299                &self.session
300            }
301
302            #(#client_methods_hardcoded)*
303        }
304
305        /// Registry-aware client stub for the #trait_name service.
306        ///
307        /// This client resolves method IDs from a [`ServiceRegistry`] at construction time.
308        /// This can be useful when you want to validate the service/methods are registered
309        /// (or when building tooling around introspection).
310        /// It has the same [`RpcSession`](::#rapace_crate::rapace_core::RpcSession) requirements as [`#client_name`].
311        #vis struct #registry_client_name {
312            session: ::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession>,
313            #(pub #method_id_fields,)*
314        }
315
316        impl #registry_client_name {
317            /// Create a new registry-aware client.
318            ///
319            /// Looks up method IDs from the registry. The service must be registered
320            /// in the registry before calling this constructor.
321            ///
322            /// The session's demux loop (`session.run()`) must be running for RPC calls to work.
323            ///
324            /// # Panics
325            ///
326            /// Panics if the service or any of its methods are not found in the registry.
327            pub fn new(session: ::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession>, registry: &::#rapace_crate::registry::ServiceRegistry) -> Self {
328                Self {
329                    session,
330                    #(#method_id_lookups,)*
331                }
332            }
333
334            /// Get a reference to the underlying session.
335            pub fn session(&self) -> &::std::sync::Arc<::#rapace_crate::rapace_core::RpcSession> {
336                &self.session
337            }
338
339            #(#client_methods_registry)*
340        }
341
342        /// Server dispatcher for the #trait_name service.
343        ///
344        /// Integrate this with an [`RpcSession`](::#rapace_crate::rapace_core::RpcSession)
345        /// by calling [`RpcSession::set_dispatcher`](::#rapace_crate::rapace_core::RpcSession::set_dispatcher)
346        /// and forwarding `method_id`/`payload` into [`dispatch`] or [`dispatch_streaming`].
347        #vis struct #server_name<S> {
348            service: S,
349        }
350
351        impl<S: #trait_name + Send + Sync + 'static> #server_name<S> {
352            /// Auto-register this service in the global registry.
353            ///
354            /// Called automatically from `new()`. Uses `OnceCell` to ensure registration
355            /// happens exactly once, even if multiple server instances are created.
356            fn __auto_register() {
357                use ::std::sync::OnceLock;
358                static REGISTERED: OnceLock<()> = OnceLock::new();
359
360                REGISTERED.get_or_init(|| {
361                    ::#rapace_crate::registry::ServiceRegistry::with_global_mut(|registry| {
362                        #register_fn_name(registry);
363                    });
364                });
365            }
366
367            /// Create a new server with the given service implementation.
368            ///
369            /// This automatically registers the service in the global registry
370            /// on first invocation (subsequent calls are no-ops).
371            pub fn new(service: S) -> Self {
372                Self::__auto_register();
373                Self { service }
374            }
375
376            /// Serve requests from the transport until the connection closes.
377            ///
378            /// This is the main server loop. It reads frames from the transport,
379            /// dispatches them to the appropriate method, and sends responses.
380            ///
381            /// # Example
382            ///
383            /// ```ignore
384            /// let server = CalculatorServer::new(CalculatorImpl);
385            /// server.serve(transport).await?;
386            /// ```
387            pub async fn serve(
388                self,
389                transport: ::#rapace_crate::rapace_core::Transport,
390            ) -> ::std::result::Result<(), ::#rapace_crate::rapace_core::RpcError> {
391                ::#rapace_crate::tracing::debug!("serve: entering loop, waiting for requests");
392                loop {
393                    // Receive next request frame
394                    let request = match transport.recv_frame().await {
395                        Ok(frame) => {
396                            ::#rapace_crate::tracing::debug!(
397                                method_id = frame.desc.method_id,
398                                channel_id = frame.desc.channel_id,
399                                flags = ?frame.desc.flags,
400                                payload_len = frame.payload_bytes().len(),
401                                "serve: received frame"
402                            );
403                            frame
404                        }
405                        Err(::#rapace_crate::rapace_core::TransportError::Closed) => {
406                            ::#rapace_crate::tracing::debug!("serve: transport closed");
407                            // Connection closed gracefully
408                            return Ok(());
409                        }
410                        Err(e) => {
411                            ::#rapace_crate::tracing::error!(?e, "serve: transport error");
412                            return Err(::#rapace_crate::rapace_core::RpcError::Transport(e));
413                        }
414                    };
415
416                    // Skip non-data frames (control frames, etc.)
417                    if !request.desc.flags.contains(::#rapace_crate::rapace_core::FrameFlags::DATA) {
418                        ::#rapace_crate::tracing::debug!("serve: skipping non-DATA frame");
419                        continue;
420                    }
421
422                    // Dispatch the request
423                    ::#rapace_crate::tracing::debug!(
424                        method_id = request.desc.method_id,
425                        channel_id = request.desc.channel_id,
426                        "serve: dispatching to dispatch_streaming"
427                    );
428                    if let Err(e) = self.dispatch_streaming(
429                        request.desc.method_id,
430                        request.desc.channel_id,
431                        request.payload_bytes(),
432                        &transport,
433                    ).await {
434                        ::#rapace_crate::tracing::error!(?e, "serve: dispatch_streaming returned error");
435                        // Send error response
436                        let mut desc = ::#rapace_crate::rapace_core::MsgDescHot::new();
437                        desc.channel_id = request.desc.channel_id;
438                        desc.flags = ::#rapace_crate::rapace_core::FrameFlags::ERROR | ::#rapace_crate::rapace_core::FrameFlags::EOS;
439
440                        // Encode error: [code: u32 LE][message_len: u32 LE][message bytes]
441                        let (code, message): (u32, ::std::string::String) = match &e {
442                            ::#rapace_crate::rapace_core::RpcError::Status { code, message } => (*code as u32, message.clone()),
443                            ::#rapace_crate::rapace_core::RpcError::Transport(_) => (::#rapace_crate::rapace_core::ErrorCode::Internal as u32, "transport error".into()),
444                            ::#rapace_crate::rapace_core::RpcError::Cancelled => (::#rapace_crate::rapace_core::ErrorCode::Cancelled as u32, "cancelled".into()),
445                            ::#rapace_crate::rapace_core::RpcError::DeadlineExceeded => (::#rapace_crate::rapace_core::ErrorCode::DeadlineExceeded as u32, "deadline exceeded".into()),
446                        };
447                        let mut err_bytes = ::std::vec::Vec::with_capacity(8 + message.len());
448                        err_bytes.extend_from_slice(&code.to_le_bytes());
449                        err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
450                        err_bytes.extend_from_slice(message.as_bytes());
451
452                        let frame = ::#rapace_crate::rapace_core::Frame::with_payload(desc, err_bytes);
453                        let _ = transport.send_frame(frame).await;
454                    }
455                }
456            }
457
458            /// Serve a single request from the transport.
459            ///
460            /// This is useful for testing or when you want to handle each request
461            /// individually.
462            pub async fn serve_one(
463                &self,
464                transport: &#rapace_crate::rapace_core::Transport,
465            ) -> ::std::result::Result<(), #rapace_crate::rapace_core::RpcError> {
466                // Receive next request frame
467                let request = transport.recv_frame().await
468                    .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
469
470                // Skip non-data frames
471                if !request.desc.flags.contains(#rapace_crate::rapace_core::FrameFlags::DATA) {
472                    return Ok(());
473                }
474
475                // Dispatch the request
476                self.dispatch_streaming(
477                    request.desc.method_id,
478                    request.desc.channel_id,
479                    request.payload_bytes(),
480                    transport,
481                ).await
482            }
483
484            /// Dispatch a request frame to the appropriate method.
485            ///
486            /// Returns a response frame on success for unary methods.
487            /// For streaming methods, use `dispatch_streaming` instead.
488            pub async fn dispatch(
489                &self,
490                method_id: u32,
491                request_payload: &[u8],
492            ) -> ::std::result::Result<#rapace_crate::rapace_core::Frame, #rapace_crate::rapace_core::RpcError> {
493                match method_id {
494                    #(#dispatch_arms)*
495                    _ => Err(#rapace_crate::rapace_core::RpcError::Status {
496                        code: #rapace_crate::rapace_core::ErrorCode::Unimplemented,
497                        message: ::std::format!("unknown method_id: {}", method_id),
498                    }),
499                }
500            }
501
502            /// Dispatch a streaming request to the appropriate method.
503            ///
504            /// The method sends frames via the provided transport.
505            pub async fn dispatch_streaming(
506                &self,
507                method_id: u32,
508                channel_id: u32,
509                request_payload: &[u8],
510                transport: &#rapace_crate::rapace_core::Transport,
511            ) -> ::std::result::Result<(), #rapace_crate::rapace_core::RpcError> {
512                #rapace_crate::tracing::debug!(method_id, channel_id, "dispatch_streaming: entered");
513                match method_id {
514                    #(#streaming_dispatch_arms)*
515                    _ => Err(#rapace_crate::rapace_core::RpcError::Status {
516                        code: #rapace_crate::rapace_core::ErrorCode::Unimplemented,
517                        message: ::std::format!("unknown method_id: {}", method_id),
518                    }),
519                }
520            }
521
522            #is_streaming_method_fn
523
524            /// Create a dispatcher closure suitable for `RpcSession::set_dispatcher`.
525            ///
526            /// This handles both unary and server-streaming methods:
527            /// - Unary methods return a single response `Frame`.
528            /// - Streaming methods require the request to be flagged `NO_REPLY` and are
529            ///   served by calling `dispatch_streaming` to emit DATA/EOS frames on the
530            ///   provided transport.
531            ///
532            /// To ensure streaming calls work correctly, use rapace's generated streaming
533            /// clients (they set `NO_REPLY` automatically).
534            pub fn into_session_dispatcher(
535                self,
536                transport: ::#rapace_crate::rapace_core::Transport,
537            ) -> impl Fn(
538                ::#rapace_crate::rapace_core::Frame,
539            ) -> ::std::pin::Pin<
540                Box<
541                    dyn ::std::future::Future<
542                            Output = ::std::result::Result<
543                                ::#rapace_crate::rapace_core::Frame,
544                                ::#rapace_crate::rapace_core::RpcError,
545                            >,
546                        > + Send
547                        + 'static,
548                >,
549            > + Send
550              + Sync
551              + 'static {
552                use ::#rapace_crate::rapace_core::{ErrorCode, Frame, FrameFlags, MsgDescHot, RpcError};
553
554                let server = ::std::sync::Arc::new(self);
555                move |frame: Frame| {
556                    let server = server.clone();
557                    let transport = transport.clone();
558                    Box::pin(async move {
559                        let method_id = frame.desc.method_id;
560                        let channel_id = frame.desc.channel_id;
561                        let flags = frame.desc.flags;
562                        let payload = frame.payload_bytes().to_vec();
563
564                        if Self::__is_streaming_method_id(method_id) {
565                            // Enforce NO_REPLY: streaming methods do not produce a unary response frame.
566                            if !flags.contains(FrameFlags::NO_REPLY) {
567                                return Err(RpcError::Status {
568                                    code: ErrorCode::InvalidArgument,
569                                    message: "streaming request missing NO_REPLY flag".into(),
570                                });
571                            }
572
573                            // Serve the streaming method by sending DATA/EOS frames on the transport.
574                            if let Err(err) = server
575                                .dispatch_streaming(method_id, channel_id, &payload, &transport)
576                                .await
577                            {
578                                // If dispatch_streaming fails before it could send an ERROR frame,
579                                // send one here so streaming clients don't hang.
580                                let (code, message): (u32, String) = match &err {
581                                    RpcError::Status { code, message } => (*code as u32, message.clone()),
582                                    RpcError::Transport(_) => (ErrorCode::Internal as u32, "transport error".into()),
583                                    RpcError::Cancelled => (ErrorCode::Cancelled as u32, "cancelled".into()),
584                                    RpcError::DeadlineExceeded => (ErrorCode::DeadlineExceeded as u32, "deadline exceeded".into()),
585                                };
586
587                                let mut err_bytes = Vec::with_capacity(8 + message.len());
588                                err_bytes.extend_from_slice(&code.to_le_bytes());
589                                err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
590                                err_bytes.extend_from_slice(message.as_bytes());
591
592                                let mut desc = MsgDescHot::new();
593                                desc.channel_id = channel_id;
594                                desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
595                                let frame = Frame::with_payload(desc, err_bytes);
596                                let _ = transport.send_frame(frame).await;
597                            }
598
599                            // The session will ignore this because the request had NO_REPLY set.
600                            Ok(Frame::new(MsgDescHot::new()))
601                        } else {
602                            server.dispatch(method_id, &payload).await
603                        }
604                    })
605                }
606            }
607        }
608    };
609
610    Ok(expanded)
611}
612
613/// Method kind: unary or server-streaming.
614#[derive(Clone, Debug)]
615#[allow(clippy::large_enum_variant)]
616enum MethodKind {
617    /// Unary RPC: single request, single response.
618    Unary,
619    /// Server-streaming: single request, returns `Streaming<T>`.
620    ServerStreaming {
621        /// The type T in `Streaming<T>`.
622        item_type: TokenStream2,
623    },
624}
625
626struct MethodInfo {
627    name: Ident,
628    args: Vec<(Ident, TokenStream2)>, // (name, type) pairs, excluding &self
629    return_type: TokenStream2,
630    kind: MethodKind,
631    doc: String,
632}
633
634impl MethodInfo {
635    fn try_from_parsed(method: &parser::ParsedMethod) -> Result<Self, MacroError> {
636        let doc = join_doc_lines(&method.doc_lines);
637
638        let args = method
639            .args
640            .iter()
641            .map(|arg| (arg.name.clone(), arg.ty.clone()))
642            .collect();
643
644        let return_type = method.return_type.clone();
645        let kind = if let Some(item_type) = extract_streaming_return_type(&return_type) {
646            MethodKind::ServerStreaming { item_type }
647        } else {
648            MethodKind::Unary
649        };
650
651        Ok(Self {
652            name: method.name.clone(),
653            args,
654            return_type,
655            kind,
656            doc,
657        })
658    }
659}
660
661fn generate_client_method(
662    method: &MethodInfo,
663    method_id: u32,
664    service_name: &str,
665    rapace_crate: &TokenStream2,
666) -> TokenStream2 {
667    match &method.kind {
668        MethodKind::Unary => {
669            generate_client_method_unary(method, method_id, service_name, rapace_crate)
670        }
671        MethodKind::ServerStreaming { item_type } => generate_client_method_server_streaming(
672            method,
673            method_id,
674            service_name,
675            item_type,
676            rapace_crate,
677        ),
678    }
679}
680
681fn generate_client_method_registry(
682    method: &MethodInfo,
683    method_index: usize,
684    service_name: &str,
685    rapace_crate: &TokenStream2,
686) -> TokenStream2 {
687    match &method.kind {
688        MethodKind::Unary => {
689            generate_client_method_unary_registry(method, method_index, service_name, rapace_crate)
690        }
691        MethodKind::ServerStreaming { item_type } => {
692            generate_client_method_server_streaming_registry(
693                method,
694                method_index,
695                service_name,
696                item_type,
697                rapace_crate,
698            )
699        }
700    }
701}
702
703fn generate_client_method_unary(
704    method: &MethodInfo,
705    method_id: u32,
706    service_name: &str,
707    rapace_crate: &TokenStream2,
708) -> TokenStream2 {
709    let name = &method.name;
710    let method_name_str = name.to_string();
711    let return_type = &method.return_type;
712
713    let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
714    let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
715
716    // Generate the argument list for the function signature
717    let fn_args = arg_names.iter().zip(arg_types.iter()).map(|(name, ty)| {
718        quote! { #name: #ty }
719    });
720
721    // For encoding, serialize args as a tuple using facet_postcard
722    let encode_expr = if arg_names.is_empty() {
723        quote! { #rapace_crate::facet_postcard::to_vec(&()).unwrap() }
724    } else if arg_names.len() == 1 {
725        let arg = &arg_names[0];
726        quote! { #rapace_crate::facet_postcard::to_vec(&#arg).unwrap() }
727    } else {
728        quote! { #rapace_crate::facet_postcard::to_vec(&(#(#arg_names.clone()),*)).unwrap() }
729    };
730
731    quote! {
732        /// Call the #name method on the remote service.
733        pub async fn #name(&self, #(#fn_args),*) -> ::std::result::Result<#return_type, #rapace_crate::rapace_core::RpcError> {
734            use #rapace_crate::rapace_core::FrameFlags;
735
736            // Encode request using facet_postcard
737            let request_bytes: ::std::vec::Vec<u8> = #encode_expr;
738
739            // Call via session
740            let channel_id = self.session.next_channel_id();
741            #rapace_crate::tracing::debug!(
742                service = #service_name,
743                method = #method_name_str,
744                method_id = #method_id,
745                channel_id,
746                "RPC call start"
747            );
748            let response = self.session.call(channel_id, #method_id, request_bytes).await?;
749            #rapace_crate::tracing::debug!(
750                service = #service_name,
751                method = #method_name_str,
752                method_id = #method_id,
753                channel_id,
754                "RPC call complete"
755            );
756
757            // Check for error flag
758            if response.flags().contains(FrameFlags::ERROR) {
759                return Err(#rapace_crate::rapace_core::parse_error_payload(response.payload_bytes()));
760            }
761
762            // Decode response using facet_postcard
763            let result: #return_type = #rapace_crate::facet_postcard::from_slice(response.payload_bytes())
764                .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
765                    code: #rapace_crate::rapace_core::ErrorCode::Internal,
766                    message: ::std::format!("decode error: {:?}", e),
767                })?;
768
769            Ok(result)
770        }
771    }
772}
773
774fn extract_streaming_return_type(ty: &TokenStream2) -> Option<TokenStream2> {
775    let tokens: Vec<TokenTree> = ty.clone().into_iter().collect();
776    let mut index = 0;
777    while index < tokens.len() {
778        match &tokens[index] {
779            TokenTree::Ident(ident) if ident == "Streaming" => {
780                let mut search = index + 1;
781                while search < tokens.len() {
782                    match &tokens[search] {
783                        TokenTree::Punct(p) if p.as_char() == '<' => {
784                            let inner = collect_generic_tokens(&tokens, search)?;
785                            return select_stream_item_type(inner);
786                        }
787                        TokenTree::Punct(p) if p.as_char() == ':' => {
788                            search += 1;
789                            continue;
790                        }
791                        _ => break,
792                    }
793                }
794            }
795            _ => {}
796        }
797        index += 1;
798    }
799    None
800}
801
802fn collect_generic_tokens(tokens: &[TokenTree], start: usize) -> Option<TokenStream2> {
803    let mut depth = 0usize;
804    let mut inner = TokenStream2::new();
805    let mut i = start;
806    while i < tokens.len() {
807        match &tokens[i] {
808            TokenTree::Punct(p) if p.as_char() == '<' => {
809                depth += 1;
810                if depth > 1 {
811                    inner.extend(std::iter::once(tokens[i].clone()));
812                }
813            }
814            TokenTree::Punct(p) if p.as_char() == '>' => {
815                if depth == 0 {
816                    return None;
817                }
818                depth -= 1;
819                if depth == 0 {
820                    return Some(inner);
821                }
822                inner.extend(std::iter::once(tokens[i].clone()));
823            }
824            other => {
825                if depth >= 1 {
826                    inner.extend(std::iter::once(other.clone()));
827                }
828            }
829        }
830        i += 1;
831    }
832    None
833}
834
835fn select_stream_item_type(inner: TokenStream2) -> Option<TokenStream2> {
836    let segments = split_top_level(inner, ',');
837    for segment in segments.into_iter().rev() {
838        let text = segment.to_string();
839        if text.trim().is_empty() {
840            continue;
841        }
842        if text.trim_start().starts_with('\'') {
843            continue;
844        }
845        return Some(segment);
846    }
847    None
848}
849
850fn split_top_level(tokens: TokenStream2, delimiter: char) -> Vec<TokenStream2> {
851    let mut parts = Vec::new();
852    let mut current = TokenStream2::new();
853    let mut angle_depth = 0usize;
854    for tt in tokens.into_iter() {
855        match &tt {
856            TokenTree::Punct(p) if p.as_char() == '<' => {
857                angle_depth += 1;
858                current.extend(std::iter::once(tt));
859            }
860            TokenTree::Punct(p) if p.as_char() == '>' => {
861                angle_depth = angle_depth.saturating_sub(1);
862                current.extend(std::iter::once(tt));
863            }
864            TokenTree::Punct(p) if p.as_char() == delimiter && angle_depth == 0 => {
865                parts.push(current);
866                current = TokenStream2::new();
867                continue;
868            }
869            _ => current.extend(std::iter::once(tt)),
870        }
871    }
872    parts.push(current);
873    parts
874}
875
876fn generate_client_method_server_streaming(
877    method: &MethodInfo,
878    method_id: u32,
879    service_name: &str,
880    item_type: &TokenStream2,
881    rapace_crate: &TokenStream2,
882) -> TokenStream2 {
883    let name = &method.name;
884    let method_name_str = name.to_string();
885
886    let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
887    let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
888
889    // Generate the argument list for the function signature
890    let fn_args = arg_names.iter().zip(arg_types.iter()).map(|(name, ty)| {
891        quote! { #name: #ty }
892    });
893
894    // For encoding, serialize args as a tuple using facet_postcard
895    let encode_expr = if arg_names.is_empty() {
896        quote! { #rapace_crate::facet_postcard::to_vec(&()).unwrap() }
897    } else if arg_names.len() == 1 {
898        let arg = &arg_names[0];
899        quote! { #rapace_crate::facet_postcard::to_vec(&#arg).unwrap() }
900    } else {
901        quote! { #rapace_crate::facet_postcard::to_vec(&(#(#arg_names.clone()),*)).unwrap() }
902    };
903
904    quote! {
905        /// Call the #name server-streaming method on the remote service.
906        ///
907        /// Returns a stream that yields items as they arrive from the server.
908        /// The stream ends when the server sends EOS, or yields an error if
909        /// the server sends an ERROR frame.
910        pub async fn #name(&self, #(#fn_args),*) -> ::std::result::Result<#rapace_crate::rapace_core::Streaming<#item_type>, #rapace_crate::rapace_core::RpcError> {
911            use #rapace_crate::rapace_core::{ErrorCode, RpcError};
912
913            #rapace_crate::tracing::debug!(
914                service = #service_name,
915                method = #method_name_str,
916                method_id = #method_id,
917                "RPC streaming call start"
918            );
919
920            let request_bytes: ::std::vec::Vec<u8> = #encode_expr;
921
922            // Start the streaming call - this registers a tunnel and sends the request
923            let mut rx = self.session
924                .start_streaming_call(#method_id, request_bytes)
925                .await?;
926
927            // Build a Stream<Item = Result<#item_type, RpcError>> with explicit termination on EOS
928            let stream = #rapace_crate::rapace_core::try_stream! {
929                while let Some(chunk) = rx.recv().await {
930                    // Error chunk - parse and return as error
931                    if chunk.is_error() {
932                        let err = #rapace_crate::rapace_core::parse_error_payload(chunk.payload_bytes());
933                        Err(err)?;
934                    }
935
936                    // Empty EOS chunk - stream is done
937                    if chunk.is_eos() && chunk.payload_bytes().is_empty() {
938                        break;
939                    }
940
941                    // DATA chunk (possibly with EOS flag for final item) - deserialize
942                    let item: #item_type = #rapace_crate::facet_postcard::from_slice(chunk.payload_bytes())
943                        .map_err(|e| RpcError::Status {
944                            code: ErrorCode::Internal,
945                            message: ::std::format!("decode error: {:?}", e),
946                        })?;
947
948                    yield item;
949                }
950            };
951
952            Ok(::std::boxed::Box::pin(stream))
953        }
954    }
955}
956
957fn generate_dispatch_arm(
958    method: &MethodInfo,
959    method_id: u32,
960    rapace_crate: &TokenStream2,
961) -> TokenStream2 {
962    match &method.kind {
963        MethodKind::Unary => generate_dispatch_arm_unary(method, method_id, rapace_crate),
964        MethodKind::ServerStreaming { .. } => {
965            // Streaming methods are handled by dispatch_streaming, not dispatch
966            // For the dispatch() method, return error for streaming methods
967            quote! {
968                #method_id => {
969                    Err(#rapace_crate::rapace_core::RpcError::Status {
970                        code: #rapace_crate::rapace_core::ErrorCode::Internal,
971                        message: "streaming method called via unary dispatch".into(),
972                    })
973                }
974            }
975        }
976    }
977}
978
979fn generate_streaming_dispatch_arm(
980    method: &MethodInfo,
981    method_id: u32,
982    rapace_crate: &TokenStream2,
983) -> TokenStream2 {
984    match &method.kind {
985        MethodKind::Unary => {
986            // For unary methods in streaming dispatch, call the regular dispatch and send the frame
987            let name = &method.name;
988            let return_type = &method.return_type;
989            let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
990            let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
991
992            let decode_and_call = if arg_names.is_empty() {
993                quote! {
994                    let result: #return_type = self.service.#name().await;
995                }
996            } else if arg_names.len() == 1 {
997                let arg = &arg_names[0];
998                let ty = &arg_types[0];
999                quote! {
1000                    let #arg: #ty = #rapace_crate::facet_postcard::from_slice(request_payload)
1001                        .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1002                            code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1003                            message: ::std::format!("decode error: {:?}", e),
1004                        })?;
1005                    let result: #return_type = self.service.#name(#arg).await;
1006                }
1007            } else {
1008                let tuple_type = quote! { (#(#arg_types),*) };
1009                quote! {
1010                    let (#(#arg_names),*): #tuple_type = #rapace_crate::facet_postcard::from_slice(request_payload)
1011                        .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1012                            code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1013                            message: ::std::format!("decode error: {:?}", e),
1014                        })?;
1015                    let result: #return_type = self.service.#name(#(#arg_names),*).await;
1016                }
1017            };
1018
1019            quote! {
1020                #method_id => {
1021                    #decode_and_call
1022
1023                    // Encode and send response frame
1024                    let response_bytes: ::std::vec::Vec<u8> = #rapace_crate::facet_postcard::to_vec(&result)
1025                        .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1026                            code: #rapace_crate::rapace_core::ErrorCode::Internal,
1027                            message: ::std::format!("encode error: {:?}", e),
1028                        })?;
1029
1030                    let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1031                    desc.channel_id = channel_id;
1032                    desc.flags = #rapace_crate::rapace_core::FrameFlags::DATA | #rapace_crate::rapace_core::FrameFlags::EOS;
1033
1034                    let frame = if response_bytes.len() <= #rapace_crate::rapace_core::INLINE_PAYLOAD_SIZE {
1035                        #rapace_crate::rapace_core::Frame::with_inline_payload(desc, &response_bytes)
1036                            .expect("inline payload should fit")
1037                    } else {
1038                        #rapace_crate::rapace_core::Frame::with_payload(desc, response_bytes)
1039                    };
1040
1041                    transport.send_frame(frame).await
1042                        .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
1043                    Ok(())
1044                }
1045            }
1046        }
1047        MethodKind::ServerStreaming { .. } => {
1048            generate_streaming_dispatch_arm_server_streaming(method, method_id, rapace_crate)
1049        }
1050    }
1051}
1052
1053fn generate_streaming_dispatch_arm_server_streaming(
1054    method: &MethodInfo,
1055    method_id: u32,
1056    rapace_crate: &TokenStream2,
1057) -> TokenStream2 {
1058    let name = &method.name;
1059    let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
1060    let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
1061
1062    let decode_args = if arg_names.is_empty() {
1063        quote! {}
1064    } else if arg_names.len() == 1 {
1065        let arg = &arg_names[0];
1066        let ty = &arg_types[0];
1067        quote! {
1068            let #arg: #ty = #rapace_crate::facet_postcard::from_slice(request_payload)
1069                .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1070                    code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1071                    message: ::std::format!("decode error: {:?}", e),
1072                })?;
1073        }
1074    } else {
1075        let tuple_type = quote! { (#(#arg_types),*) };
1076        quote! {
1077            let (#(#arg_names),*): #tuple_type = #rapace_crate::facet_postcard::from_slice(request_payload)
1078                .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1079                    code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1080                    message: ::std::format!("decode error: {:?}", e),
1081                })?;
1082        }
1083    };
1084
1085    let call_args = if arg_names.is_empty() {
1086        quote! {}
1087    } else {
1088        quote! { #(#arg_names),* }
1089    };
1090
1091    quote! {
1092        #method_id => {
1093            #decode_args
1094
1095            // Call the service method to get a stream
1096            let mut stream = self.service.#name(#call_args).await;
1097
1098            // Iterate over the stream and send frames
1099            use #rapace_crate::futures::stream::StreamExt;
1100            #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: starting to iterate stream");
1101
1102            loop {
1103                #rapace_crate::tracing::trace!(channel_id, "streaming dispatch: waiting for next item");
1104                match stream.next().await {
1105                    Some(Ok(item)) => {
1106                        #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: got item, encoding");
1107                        // Encode item
1108                        let item_bytes: ::std::vec::Vec<u8> = #rapace_crate::facet_postcard::to_vec(&item)
1109                            .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1110                                code: #rapace_crate::rapace_core::ErrorCode::Internal,
1111                                message: ::std::format!("encode error: {:?}", e),
1112                            })?;
1113
1114                        // Send DATA frame (not EOS yet)
1115                        let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1116                        desc.channel_id = channel_id;
1117                        desc.flags = #rapace_crate::rapace_core::FrameFlags::DATA;
1118
1119                        let frame = if item_bytes.len() <= #rapace_crate::rapace_core::INLINE_PAYLOAD_SIZE {
1120                            #rapace_crate::rapace_core::Frame::with_inline_payload(desc, &item_bytes)
1121                                .expect("inline payload should fit")
1122                        } else {
1123                            #rapace_crate::rapace_core::Frame::with_payload(desc, item_bytes)
1124                        };
1125
1126                        #rapace_crate::tracing::debug!(channel_id, payload_len = frame.payload_bytes().len(), "streaming dispatch: sending DATA frame");
1127                        transport.send_frame(frame).await
1128                            .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
1129                        #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: DATA frame sent");
1130                    }
1131                    Some(Err(err)) => {
1132                        #rapace_crate::tracing::warn!(channel_id, ?err, "streaming dispatch: got error from stream");
1133                        // Send ERROR frame and break
1134                        let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1135                        desc.channel_id = channel_id;
1136                        desc.flags = #rapace_crate::rapace_core::FrameFlags::ERROR | #rapace_crate::rapace_core::FrameFlags::EOS;
1137
1138                        // Encode error: [code: u32 LE][message_len: u32 LE][message bytes]
1139                        let (code, message): (u32, &str) = match &err {
1140                            #rapace_crate::rapace_core::RpcError::Status { code, message } => (*code as u32, message.as_str()),
1141                            #rapace_crate::rapace_core::RpcError::Transport(_) => (#rapace_crate::rapace_core::ErrorCode::Internal as u32, "transport error"),
1142                            #rapace_crate::rapace_core::RpcError::Cancelled => (#rapace_crate::rapace_core::ErrorCode::Cancelled as u32, "cancelled"),
1143                            #rapace_crate::rapace_core::RpcError::DeadlineExceeded => (#rapace_crate::rapace_core::ErrorCode::DeadlineExceeded as u32, "deadline exceeded"),
1144                        };
1145                        let mut err_bytes = Vec::with_capacity(8 + message.len());
1146                        err_bytes.extend_from_slice(&code.to_le_bytes());
1147                        err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
1148                        err_bytes.extend_from_slice(message.as_bytes());
1149
1150                        let frame = #rapace_crate::rapace_core::Frame::with_payload(desc, err_bytes);
1151                        transport.send_frame(frame).await
1152                            .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
1153                        return Ok(());
1154                    }
1155                    None => {
1156                        #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: stream ended, sending EOS");
1157                        // Stream is complete: send EOS frame
1158                        let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1159                        desc.channel_id = channel_id;
1160                        desc.flags = #rapace_crate::rapace_core::FrameFlags::EOS;
1161                        let frame = #rapace_crate::rapace_core::Frame::new(desc);
1162                        transport.send_frame(frame).await
1163                            .map_err(#rapace_crate::rapace_core::RpcError::Transport)?;
1164                        #rapace_crate::tracing::debug!(channel_id, "streaming dispatch: EOS sent, returning");
1165                        return Ok(());
1166                    }
1167                }
1168            }
1169        }
1170    }
1171}
1172
1173fn generate_dispatch_arm_unary(
1174    method: &MethodInfo,
1175    method_id: u32,
1176    rapace_crate: &TokenStream2,
1177) -> TokenStream2 {
1178    let name = &method.name;
1179    let return_type = &method.return_type;
1180    let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
1181    let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
1182
1183    // Generate decode expression for args
1184    let decode_and_call = if arg_names.is_empty() {
1185        quote! {
1186            // No arguments to decode
1187            let result: #return_type = self.service.#name().await;
1188        }
1189    } else if arg_names.len() == 1 {
1190        let arg = &arg_names[0];
1191        let ty = &arg_types[0];
1192        quote! {
1193            let #arg: #ty = #rapace_crate::facet_postcard::from_slice(request_payload)
1194                .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1195                    code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1196                    message: ::std::format!("decode error: {:?}", e),
1197                })?;
1198            let result: #return_type = self.service.#name(#arg).await;
1199        }
1200    } else {
1201        // Multiple args - decode as tuple
1202        let tuple_type = quote! { (#(#arg_types),*) };
1203        quote! {
1204            let (#(#arg_names),*): #tuple_type = #rapace_crate::facet_postcard::from_slice(request_payload)
1205                .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1206                    code: #rapace_crate::rapace_core::ErrorCode::InvalidArgument,
1207                    message: ::std::format!("decode error: {:?}", e),
1208                })?;
1209            let result: #return_type = self.service.#name(#(#arg_names),*).await;
1210        }
1211    };
1212
1213    quote! {
1214        #method_id => {
1215            #decode_and_call
1216
1217            // Encode response using facet_postcard
1218            let response_bytes: ::std::vec::Vec<u8> = #rapace_crate::facet_postcard::to_vec(&result)
1219                .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1220                    code: #rapace_crate::rapace_core::ErrorCode::Internal,
1221                    message: ::std::format!("encode error: {:?}", e),
1222                })?;
1223
1224            // Build response frame
1225            let mut desc = #rapace_crate::rapace_core::MsgDescHot::new();
1226            desc.flags = #rapace_crate::rapace_core::FrameFlags::DATA | #rapace_crate::rapace_core::FrameFlags::EOS;
1227
1228            let frame = if response_bytes.len() <= #rapace_crate::rapace_core::INLINE_PAYLOAD_SIZE {
1229                #rapace_crate::rapace_core::Frame::with_inline_payload(desc, &response_bytes)
1230                    .expect("inline payload should fit")
1231            } else {
1232                #rapace_crate::rapace_core::Frame::with_payload(desc, response_bytes)
1233            };
1234
1235            Ok(frame)
1236        }
1237    }
1238}
1239
1240/// Generate the `register` function for service registration.
1241///
1242/// This generates a function that registers the service and its methods
1243/// with a `ServiceRegistry`, capturing request/response schemas via facet.
1244fn generate_register_fn(
1245    service_name: &str,
1246    service_doc: &str,
1247    methods: &[MethodInfo],
1248    rapace_crate: &TokenStream2,
1249    register_fn_name: &Ident,
1250    vis: &TokenStream2,
1251) -> TokenStream2 {
1252    let method_registrations: Vec<TokenStream2> = methods
1253        .iter()
1254        .map(|m| {
1255            let method_name = m.name.to_string();
1256            let method_doc = &m.doc;
1257            let arg_types: Vec<_> = m.args.iter().map(|(_, ty)| ty).collect();
1258
1259            // Generate argument info
1260            let arg_infos: Vec<TokenStream2> = m
1261                .args
1262                .iter()
1263                .map(|(name, ty)| {
1264                    let name_str = name.to_string();
1265                    let type_str = quote!(#ty).to_string();
1266                    quote! {
1267                        #rapace_crate::registry::ArgInfo {
1268                            name: #name_str,
1269                            type_name: #type_str,
1270                        }
1271                    }
1272                })
1273                .collect();
1274
1275            // Request shape: tuple of arg types, or () if no args, or single type if one arg
1276            let request_shape_expr = if arg_types.is_empty() {
1277                quote! { <() as #rapace_crate::facet_core::Facet>::SHAPE }
1278            } else if arg_types.len() == 1 {
1279                let ty = &arg_types[0];
1280                quote! { <#ty as #rapace_crate::facet_core::Facet>::SHAPE }
1281            } else {
1282                quote! { <(#(#arg_types),*) as #rapace_crate::facet_core::Facet>::SHAPE }
1283            };
1284
1285            // Response shape: the return type (or inner type for streaming)
1286            let response_shape_expr = match &m.kind {
1287                MethodKind::Unary => {
1288                    let return_type = &m.return_type;
1289                    quote! { <#return_type as #rapace_crate::facet_core::Facet>::SHAPE }
1290                }
1291                MethodKind::ServerStreaming { item_type } => {
1292                    quote! { <#item_type as #rapace_crate::facet_core::Facet>::SHAPE }
1293                }
1294            };
1295
1296            // Is this a streaming method?
1297            let is_streaming = matches!(m.kind, MethodKind::ServerStreaming { .. });
1298
1299            if is_streaming {
1300                quote! {
1301                    builder.add_streaming_method(
1302                        #method_name,
1303                        #method_doc,
1304                        vec![#(#arg_infos),*],
1305                        #request_shape_expr,
1306                        #response_shape_expr,
1307                    );
1308                }
1309            } else {
1310                quote! {
1311                    builder.add_method(
1312                        #method_name,
1313                        #method_doc,
1314                        vec![#(#arg_infos),*],
1315                        #request_shape_expr,
1316                        #response_shape_expr,
1317                    );
1318                }
1319            }
1320        })
1321        .collect();
1322
1323    quote! {
1324        /// Register this service with a registry.
1325        ///
1326        /// This function registers the service and all its methods,
1327        /// capturing request/response schemas and documentation via facet.
1328        #vis fn #register_fn_name(registry: &mut #rapace_crate::registry::ServiceRegistry) {
1329            let mut builder = registry.register_service(#service_name, #service_doc);
1330            #(#method_registrations)*
1331            builder.finish();
1332        }
1333    }
1334}
1335
1336/// Generate a unary client method that uses a stored method ID from the registry.
1337fn generate_client_method_unary_registry(
1338    method: &MethodInfo,
1339    _method_index: usize,
1340    service_name: &str,
1341    rapace_crate: &TokenStream2,
1342) -> TokenStream2 {
1343    let name = &method.name;
1344    let method_name_str = name.to_string();
1345    let return_type = &method.return_type;
1346    let method_id_field = format_ident!("{}_method_id", name);
1347
1348    let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
1349    let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
1350
1351    let fn_args = arg_names.iter().zip(arg_types.iter()).map(|(name, ty)| {
1352        quote! { #name: #ty }
1353    });
1354
1355    let encode_expr = if arg_names.is_empty() {
1356        quote! { #rapace_crate::facet_postcard::to_vec(&()).unwrap() }
1357    } else if arg_names.len() == 1 {
1358        let arg = &arg_names[0];
1359        quote! { #rapace_crate::facet_postcard::to_vec(&#arg).unwrap() }
1360    } else {
1361        quote! { #rapace_crate::facet_postcard::to_vec(&(#(#arg_names.clone()),*)).unwrap() }
1362    };
1363
1364    quote! {
1365        /// Call the #name method on the remote service.
1366        pub async fn #name(&self, #(#fn_args),*) -> ::std::result::Result<#return_type, #rapace_crate::rapace_core::RpcError> {
1367            use #rapace_crate::rapace_core::FrameFlags;
1368
1369            let request_bytes: ::std::vec::Vec<u8> = #encode_expr;
1370
1371            // Call via session with registry-assigned method ID
1372            let channel_id = self.session.next_channel_id();
1373            #rapace_crate::tracing::debug!(
1374                service = #service_name,
1375                method = #method_name_str,
1376                method_id = self.#method_id_field,
1377                channel_id,
1378                "RPC call start"
1379            );
1380            let response = self.session.call(channel_id, self.#method_id_field, request_bytes).await?;
1381            #rapace_crate::tracing::debug!(
1382                service = #service_name,
1383                method = #method_name_str,
1384                method_id = self.#method_id_field,
1385                channel_id,
1386                "RPC call complete"
1387            );
1388
1389            if response.flags().contains(FrameFlags::ERROR) {
1390                return Err(#rapace_crate::rapace_core::parse_error_payload(response.payload_bytes()));
1391            }
1392
1393            let result: #return_type = #rapace_crate::facet_postcard::from_slice(response.payload_bytes())
1394                .map_err(|e| #rapace_crate::rapace_core::RpcError::Status {
1395                    code: #rapace_crate::rapace_core::ErrorCode::Internal,
1396                    message: ::std::format!("decode error: {:?}", e),
1397                })?;
1398
1399            Ok(result)
1400        }
1401    }
1402}
1403
1404/// Generate a server-streaming client method that uses a stored method ID from the registry.
1405fn generate_client_method_server_streaming_registry(
1406    method: &MethodInfo,
1407    _method_index: usize,
1408    service_name: &str,
1409    item_type: &TokenStream2,
1410    rapace_crate: &TokenStream2,
1411) -> TokenStream2 {
1412    let name = &method.name;
1413    let method_name_str = name.to_string();
1414    let method_id_field = format_ident!("{}_method_id", name);
1415
1416    let arg_names: Vec<_> = method.args.iter().map(|(name, _)| name).collect();
1417    let arg_types: Vec<_> = method.args.iter().map(|(_, ty)| ty).collect();
1418
1419    let fn_args = arg_names.iter().zip(arg_types.iter()).map(|(name, ty)| {
1420        quote! { #name: #ty }
1421    });
1422
1423    // For encoding, serialize args as a tuple using facet_postcard
1424    let encode_expr = if arg_names.is_empty() {
1425        quote! { #rapace_crate::facet_postcard::to_vec(&()).unwrap() }
1426    } else if arg_names.len() == 1 {
1427        let arg = &arg_names[0];
1428        quote! { #rapace_crate::facet_postcard::to_vec(&#arg).unwrap() }
1429    } else {
1430        quote! { #rapace_crate::facet_postcard::to_vec(&(#(#arg_names.clone()),*)).unwrap() }
1431    };
1432
1433    quote! {
1434        /// Call the #name server-streaming method on the remote service.
1435        ///
1436        /// Returns a stream that yields items as they arrive from the server.
1437        /// The stream ends when the server sends EOS, or yields an error if
1438        /// the server sends an ERROR frame.
1439        pub async fn #name(&self, #(#fn_args),*) -> ::std::result::Result<#rapace_crate::rapace_core::Streaming<#item_type>, #rapace_crate::rapace_core::RpcError> {
1440            use #rapace_crate::rapace_core::{ErrorCode, RpcError};
1441
1442            #rapace_crate::tracing::debug!(
1443                service = #service_name,
1444                method = #method_name_str,
1445                method_id = self.#method_id_field,
1446                "RPC streaming call start"
1447            );
1448
1449            let request_bytes: ::std::vec::Vec<u8> = #encode_expr;
1450
1451            // Start the streaming call with registry-assigned method ID
1452            let mut rx = self.session
1453                .start_streaming_call(self.#method_id_field, request_bytes)
1454                .await?;
1455
1456            // Build a Stream<Item = Result<#item_type, RpcError>> with explicit termination on EOS
1457            let stream = #rapace_crate::rapace_core::try_stream! {
1458                while let Some(chunk) = rx.recv().await {
1459                    // Error chunk - parse and return as error
1460                    if chunk.is_error() {
1461                        let err = #rapace_crate::rapace_core::parse_error_payload(chunk.payload_bytes());
1462                        Err(err)?;
1463                    }
1464
1465                    // Empty EOS chunk - stream is done
1466                    if chunk.is_eos() && chunk.payload_bytes().is_empty() {
1467                        break;
1468                    }
1469
1470                    // DATA chunk (possibly with EOS flag for final item) - deserialize
1471                    let item: #item_type = #rapace_crate::facet_postcard::from_slice(chunk.payload_bytes())
1472                        .map_err(|e| RpcError::Status {
1473                            code: ErrorCode::Internal,
1474                            message: ::std::format!("decode error: {:?}", e),
1475                        })?;
1476
1477                    yield item;
1478                }
1479            };
1480
1481            Ok(::std::boxed::Box::pin(stream))
1482        }
1483    }
1484}