Skip to main content

moonpool_transport_derive/
lib.rs

1//! Proc-macros for moonpool RPC interfaces.
2//!
3//! This crate provides the `#[service]` attribute macro for generating
4//! RPC server/client boilerplate from a trait definition.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use moonpool_transport::{service, RpcError};
10//!
11//! #[service(id = 0xCA1C_0000)]
12//! trait Calculator {
13//!     async fn add(&self, req: AddRequest) -> Result<AddResponse, RpcError>;
14//!     async fn sub(&self, req: SubRequest) -> Result<SubResponse, RpcError>;
15//! }
16//! ```
17//!
18//! This generates:
19//! - `CalculatorServer<C>` with `RequestStream` fields, `init()`, and `serve()`
20//! - `CalculatorClient` with `ServiceEndpoint` fields for each method
21//! - The trait itself with `#[async_trait(?Send)]`
22
23use proc_macro::TokenStream;
24use quote::{format_ident, quote};
25use syn::{
26    Expr, ExprLit, FnArg, GenericArgument, Ident, ItemTrait, Lit, PathArguments, ReturnType,
27    TraitItem, Type, parse_macro_input,
28};
29
30/// Attribute macro for defining RPC service interfaces.
31///
32/// Generates server and client types from a trait definition.
33/// All methods must use `&self` receivers.
34///
35/// # Attributes
36///
37/// - `#[service(id = 0x...)]` - Required. Sets the interface ID (u64).
38///
39/// # Example
40///
41/// ```rust,ignore
42/// #[service(id = 0x5049_4E47)]
43/// trait PingPong {
44///     async fn ping(&self, req: PingRequest) -> Result<PingResponse, RpcError>;
45/// }
46/// ```
47#[proc_macro_attribute]
48pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
49    let attr = parse_macro_input!(attr as InterfaceAttr);
50    let item = parse_macro_input!(item as ItemTrait);
51
52    match service_impl(attr, item) {
53        Ok(tokens) => tokens.into(),
54        Err(err) => err.to_compile_error().into(),
55    }
56}
57
58/// Auto-detect mode from method receivers and delegate.
59fn service_impl(attr: InterfaceAttr, item: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
60    let mut has_ref = false;
61    let mut has_mut_ref = false;
62
63    for trait_item in &item.items {
64        if let TraitItem::Fn(method) = trait_item
65            && let Some(FnArg::Receiver(recv)) = method.sig.inputs.first()
66        {
67            if recv.mutability.is_some() {
68                has_mut_ref = true;
69            } else {
70                has_ref = true;
71            }
72        }
73    }
74
75    if has_ref && has_mut_ref {
76        return Err(syn::Error::new_spanned(
77            &item.ident,
78            "all methods must use `&self` receivers",
79        ));
80    }
81
82    if has_mut_ref {
83        return Err(syn::Error::new_spanned(
84            &item.ident,
85            "`&mut self` methods (virtual actor mode) have been removed. Use `&self` for RPC services.",
86        ));
87    }
88
89    interface_impl(attr, item)
90}
91
92/// Method info extracted from trait methods.
93struct MethodInfo {
94    index: u32,
95    name: Ident,
96    req_type: Type,
97    resp_type: Type,
98}
99
100fn interface_impl(attr: InterfaceAttr, item: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
101    let interface_id = attr.id;
102    let name = &item.ident;
103    let server_name = format_ident!("{}Server", name);
104    let client_name = format_ident!("{}Client", name);
105
106    // Parse trait methods
107    let mut method_infos: Vec<MethodInfo> = Vec::new();
108    for (index, trait_item) in item.items.iter().enumerate() {
109        if let TraitItem::Fn(method) = trait_item {
110            let method_name = &method.sig.ident;
111
112            // Extract request and response types from method signature
113            let (req_type, resp_type) = extract_method_types(&method.sig)?;
114
115            // Method indices start at 1; index 0 is reserved.
116            method_infos.push(MethodInfo {
117                index: (index + 1) as u32,
118                name: method_name.clone(),
119                req_type,
120                resp_type,
121            });
122        }
123    }
124
125    let method_count = method_infos.len() as u32;
126
127    // Generate server fields
128    let server_fields = method_infos.iter().map(|m| {
129        let name = &m.name;
130        let req_type = &m.req_type;
131        quote! { pub #name: moonpool_transport::RequestStream<#req_type, C> }
132    });
133
134    // Generate server init - clone codec for all but the last field
135    let server_inits: Vec<_> = method_infos
136        .iter()
137        .enumerate()
138        .map(|(i, m)| {
139            let name = &m.name;
140            let idx = m.index;
141            let is_last = i == method_infos.len() - 1;
142            if is_last {
143                quote! {
144                    let (#name, _) = transport.register_handler_at(Self::INTERFACE_ID, #idx as u64, codec);
145                }
146            } else {
147                quote! {
148                    let (#name, _) = transport.register_handler_at(Self::INTERFACE_ID, #idx as u64, codec.clone());
149                }
150            }
151        })
152        .collect();
153
154    let server_field_names: Vec<_> = method_infos.iter().map(|m| &m.name).collect();
155
156    // Generate client fields — typed ServiceEndpoint per method
157    let client_fields = method_infos.iter().map(|m| {
158        let name = &m.name;
159        let req_type = &m.req_type;
160        let resp_type = &m.resp_type;
161        quote! {
162            /// Typed endpoint for this method. Call delivery methods directly:
163            /// `.get_reply()`, `.try_get_reply()`, `.send()`, `.get_reply_unless_failed_for()`.
164            pub #name: moonpool_transport::ServiceEndpoint<#req_type, #resp_type, C>
165        }
166    });
167
168    // Generate client field constructors
169    let client_field_inits = method_infos.iter().map(|m| {
170        let name = &m.name;
171        let idx = m.index;
172        quote! {
173            #name: moonpool_transport::ServiceEndpoint::new(
174                moonpool_transport::Endpoint::new(
175                    address.clone(),
176                    moonpool_transport::UID::new(Self::INTERFACE_ID, #idx as u64),
177                ),
178                codec.clone(),
179            )
180        }
181    });
182
183    let first_field_name = &method_infos[0].name;
184
185    // Generate the trait with async_trait attribute
186    let trait_vis = &item.vis;
187    let trait_items = &item.items;
188    let trait_name_snake = to_snake_case(&name.to_string());
189
190    // Generate serve() method blocks — one close handle + one spawned task per method
191    let serve_close_handles: Vec<_> = method_infos
192        .iter()
193        .map(|m| {
194            let method_name = &m.name;
195            quote! {
196                let queue = self.#method_name.queue();
197                close_fns.push(Box::new(move || queue.close()));
198            }
199        })
200        .collect();
201
202    let serve_spawn_tasks: Vec<_> = method_infos
203        .iter()
204        .map(|m| {
205            let method_name = &m.name;
206            let resp_type = &m.resp_type;
207            let task_name = format!("{}_{}", trait_name_snake, m.name);
208            quote! {
209                {
210                    let stream = self.#method_name;
211                    let t = transport.clone();
212                    let h = handler.clone();
213                    providers.task().spawn_task(#task_name, async move {
214                        while let Some((req, reply)) = stream
215                            .recv_with_transport::<_, #resp_type>(&t)
216                            .await
217                        {
218                            match h.#method_name(req).await {
219                                Ok(resp) => reply.send(resp),
220                                Err(e) => {
221                                    tracing::warn!(error = %e, method = #task_name, "handler error");
222                                    reply.send_error(moonpool_transport::ReplyError::BrokenPromise);
223                                }
224                            }
225                        }
226                    });
227                }
228            }
229        })
230        .collect();
231
232    let expanded = quote! {
233        // Emit the original trait with async_trait(?Send)
234        #[async_trait::async_trait(?Send)]
235        #trait_vis trait #name {
236            #(#trait_items)*
237        }
238
239        /// Server-side interface with RequestStreams.
240        ///
241        /// Generated by `#[service]`.
242        pub struct #server_name<C: moonpool_transport::MessageCodec> {
243            #(#server_fields,)*
244        }
245
246        impl<C: moonpool_transport::MessageCodec + Clone> #server_name<C> {
247            /// Interface identifier.
248            pub const INTERFACE_ID: u64 = #interface_id;
249
250            /// Number of methods in this interface.
251            pub const METHOD_COUNT: u32 = #method_count;
252
253            /// Initialize the server interface, registering all handlers.
254            ///
255            /// Returns the server with individual `RequestStream` fields for
256            /// manual control. For a simpler pattern, use [`serve()`](Self::serve).
257            pub fn init<P>(transport: &std::rc::Rc<moonpool_transport::NetTransport<P>>, codec: C) -> Self
258            where
259                P: moonpool_transport::Providers,
260            {
261                #(#server_inits)*
262                Self { #(#server_field_names,)* }
263            }
264
265            /// Consume this server and spawn handler tasks for all methods.
266            ///
267            /// Each method gets its own task that loops on `recv_with_transport`
268            /// and dispatches to the handler. Returns a [`ServerHandle`](moonpool_transport::ServerHandle)
269            /// that stops all tasks when dropped.
270            ///
271            /// # Example
272            ///
273            /// ```rust,ignore
274            /// let server = MyServer::init(&transport, JsonCodec);
275            /// let handle = server.serve(transport.clone(), Rc::new(handler), &providers);
276            /// // Tasks run until handle is dropped or stop() is called
277            /// ```
278            pub fn serve<P, H>(
279                self,
280                transport: std::rc::Rc<moonpool_transport::NetTransport<P>>,
281                handler: std::rc::Rc<H>,
282                providers: &P,
283            ) -> moonpool_transport::ServerHandle
284            where
285                P: moonpool_transport::Providers,
286                H: #name + 'static,
287            {
288                use moonpool_transport::TaskProvider as _;
289                let mut close_fns: Vec<Box<dyn Fn()>> = Vec::new();
290                #(#serve_close_handles)*
291                #(#serve_spawn_tasks)*
292                moonpool_transport::ServerHandle::new(close_fns)
293            }
294        }
295
296        /// Client-side interface with typed [`ServiceEndpoint`](moonpool_transport::ServiceEndpoint)
297        /// fields.
298        ///
299        /// Generated by `#[service]`. Each field provides delivery mode methods
300        /// directly: `.get_reply()`, `.try_get_reply()`, `.send()`,
301        /// `.get_reply_unless_failed_for()`.
302        ///
303        /// FDB equivalent: interface structs like `StorageServerInterface`.
304        ///
305        /// # Example
306        ///
307        /// ```rust,ignore
308        /// let calc = CalculatorClient::new(server_addr, JsonCodec);
309        ///
310        /// // Choose delivery mode at call site:
311        /// let resp = calc.add.get_reply(&transport, req).await?;
312        /// let resp = calc.add.try_get_reply(&transport, req).await?;
313        /// calc.add.send(&transport, req)?;
314        /// ```
315        #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
316        #[serde(bound(
317            serialize = "",
318            deserialize = "C: moonpool_transport::MessageCodec + Default",
319        ))]
320        pub struct #client_name<C: moonpool_transport::MessageCodec> {
321            #(#client_fields,)*
322        }
323
324        impl<C: moonpool_transport::MessageCodec + Clone> #client_name<C> {
325            /// Interface identifier.
326            pub const INTERFACE_ID: u64 = #interface_id;
327
328            /// Number of methods in this interface.
329            pub const METHOD_COUNT: u32 = #method_count;
330
331            /// Create a new client interface from a network address and codec.
332            pub fn new(address: moonpool_transport::NetworkAddress, codec: C) -> Self {
333                Self {
334                    #(#client_field_inits,)*
335                }
336            }
337
338            /// Get the address this client points to.
339            pub fn address(&self) -> &moonpool_transport::NetworkAddress {
340                // All fields share the same address; use the first one.
341                &self.#first_field_name.endpoint().address
342            }
343        }
344    };
345
346    Ok(expanded)
347}
348
349/// Extract request and response types from method signature.
350///
351/// Expected signature: `async fn name(&self, req: ReqType) -> Result<RespType, RpcError>`
352fn extract_method_types(sig: &syn::Signature) -> syn::Result<(Type, Type)> {
353    // Skip &self, get the second argument
354    let mut inputs = sig.inputs.iter();
355
356    // First should be &self
357    match inputs.next() {
358        Some(FnArg::Receiver(_)) => {}
359        _ => {
360            return Err(syn::Error::new_spanned(
361                sig,
362                "Interface method must have &self as first parameter",
363            ));
364        }
365    }
366
367    // Second should be the request parameter
368    let req_type = match inputs.next() {
369        Some(FnArg::Typed(pat_type)) => (*pat_type.ty).clone(),
370        _ => {
371            return Err(syn::Error::new_spanned(
372                sig,
373                "Interface method must have a request parameter: async fn name(&self, req: ReqType) -> Result<RespType, RpcError>",
374            ));
375        }
376    };
377
378    // Extract response type from return type: Result<RespType, RpcError>
379    let resp_type = match &sig.output {
380        ReturnType::Type(_, ty) => extract_result_ok_type(ty)?,
381        ReturnType::Default => {
382            return Err(syn::Error::new_spanned(
383                sig,
384                "Interface method must return Result<RespType, RpcError>",
385            ));
386        }
387    };
388
389    Ok((req_type, resp_type))
390}
391
392/// Extract the Ok type from `Result<T, E>`.
393fn extract_result_ok_type(ty: &Type) -> syn::Result<Type> {
394    if let Type::Path(type_path) = ty
395        && let Some(segment) = type_path.path.segments.last()
396        && segment.ident == "Result"
397        && let PathArguments::AngleBracketed(args) = &segment.arguments
398        && let Some(GenericArgument::Type(ok_type)) = args.args.first()
399    {
400        return Ok(ok_type.clone());
401    }
402
403    Err(syn::Error::new_spanned(
404        ty,
405        "Interface method must return Result<RespType, RpcError>",
406    ))
407}
408
409/// Convert a PascalCase name to snake_case.
410fn to_snake_case(s: &str) -> String {
411    let mut result = String::new();
412    for (i, c) in s.chars().enumerate() {
413        if c.is_uppercase() {
414            if i > 0 {
415                result.push('_');
416            }
417            result.push(c.to_ascii_lowercase());
418        } else {
419            result.push(c);
420        }
421    }
422    result
423}
424
425// ============================================================================
426// Shared Attribute Parsing
427// ============================================================================
428
429/// Parsed interface attribute.
430struct InterfaceAttr {
431    id: u64,
432}
433
434impl syn::parse::Parse for InterfaceAttr {
435    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
436        let ident: Ident = input.parse()?;
437        if ident != "id" {
438            return Err(syn::Error::new_spanned(
439                ident,
440                "expected `id` in interface attribute",
441            ));
442        }
443        let _eq: syn::Token![=] = input.parse()?;
444        let value: Expr = input.parse()?;
445
446        // Extract the numeric value
447        let id = match &value {
448            Expr::Lit(ExprLit {
449                lit: Lit::Int(lit_int),
450                ..
451            }) => lit_int.base10_parse::<u64>()?,
452            _ => {
453                return Err(syn::Error::new_spanned(
454                    value,
455                    "expected integer literal for interface id",
456                ));
457            }
458        };
459
460        Ok(InterfaceAttr { id })
461    }
462}