Skip to main content

mill_rpc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    braced,
5    parse::{Parse, ParseStream},
6    parse_macro_input, FnArg, Ident, Pat, ReturnType, Token, TraitItemFn, Type,
7};
8
9/// Module-level macro for defining an RPC service.
10///
11/// Generates a module containing `Service` trait, `Client` struct,
12/// `server()` wrapper function, and all request/response types.
13///
14/// By default, both server and client code are generated.
15/// Use `#[server]` or `#[client]` to generate only one side.
16///
17/// # Examples
18///
19/// ```ignore
20/// // Generate both server and client (default)
21/// mill_rpc::service! {
22///     service Calculator {
23///         fn add(a: i32, b: i32) -> i32;
24///         fn divide(a: f64, b: f64) -> f64;
25///     }
26/// }
27///
28/// // Server only
29/// mill_rpc::service! {
30///     #[server]
31///     service Calculator {
32///         fn add(a: i32, b: i32) -> i32;
33///     }
34/// }
35///
36/// // Client only (e.g. in a separate client crate)
37/// mill_rpc::service! {
38///     #[client]
39///     service Calculator {
40///         fn add(a: i32, b: i32) -> i32;
41///     }
42/// }
43/// ```
44#[proc_macro]
45pub fn service(input: TokenStream) -> TokenStream {
46    let def = parse_macro_input!(input as ServiceDef);
47    match generate_service_module(def) {
48        Ok(tokens) => tokens.into(),
49        Err(err) => err.to_compile_error().into(),
50    }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54enum GenerateMode {
55    Both,
56    ServerOnly,
57    ClientOnly,
58}
59
60struct ServiceDef {
61    mode: GenerateMode,
62    name: Ident,
63    methods: Vec<MethodDef>,
64}
65
66struct MethodDef {
67    name: Ident,
68    args: Vec<(Ident, Type)>,
69    return_type: Type,
70}
71
72impl Parse for ServiceDef {
73    fn parse(input: ParseStream) -> syn::Result<Self> {
74        // Parse optional #[server] or #[client]
75        let mode = if input.peek(Token![#]) {
76            input.parse::<Token![#]>()?;
77            let content;
78            syn::bracketed!(content in input);
79            let attr_name: Ident = content.parse()?;
80            match attr_name.to_string().as_str() {
81                "server" => GenerateMode::ServerOnly,
82                "client" => GenerateMode::ClientOnly,
83                other => {
84                    return Err(syn::Error::new_spanned(
85                        attr_name,
86                        format!(
87                            "Unknown attribute `{}`, expected `server` or `client`",
88                            other
89                        ),
90                    ))
91                }
92            }
93        } else {
94            GenerateMode::Both
95        };
96
97        // Parse `service Name { ... }`
98        let service_kw: Ident = input.parse()?;
99        if service_kw != "service" {
100            return Err(syn::Error::new_spanned(service_kw, "Expected `service`"));
101        }
102
103        let name: Ident = input.parse()?;
104
105        let content;
106        braced!(content in input);
107
108        let mut methods = Vec::new();
109        while !content.is_empty() {
110            let method: TraitItemFn = content.parse()?;
111
112            let method_name = method.sig.ident.clone();
113
114            let mut args = Vec::new();
115            for arg in &method.sig.inputs {
116                match arg {
117                    FnArg::Typed(pat_type) => {
118                        let ident = match &*pat_type.pat {
119                            Pat::Ident(pi) => pi.ident.clone(),
120                            other => {
121                                return Err(syn::Error::new_spanned(
122                                    other,
123                                    "Expected a simple identifier for argument name",
124                                ))
125                            }
126                        };
127                        args.push((ident, (*pat_type.ty).clone()));
128                    }
129                    FnArg::Receiver(_) => {
130                        return Err(syn::Error::new_spanned(
131                            arg,
132                            "Service methods should not have `self` parameter",
133                        ))
134                    }
135                }
136            }
137
138            let return_type = match &method.sig.output {
139                ReturnType::Default => syn::parse_quote!(()),
140                ReturnType::Type(_, ty) => (**ty).clone(),
141            };
142
143            methods.push(MethodDef {
144                name: method_name,
145                args,
146                return_type,
147            });
148        }
149
150        Ok(ServiceDef {
151            mode,
152            name,
153            methods,
154        })
155    }
156}
157
158fn generate_service_module(def: ServiceDef) -> syn::Result<proc_macro2::TokenStream> {
159    let mod_name = format_ident!("{}", to_snake_case(&def.name.to_string()));
160    let service_name_str = def.name.to_string();
161    let method_count = def.methods.len() as u16;
162
163    let gen_server = def.mode != GenerateMode::ClientOnly;
164    let gen_client = def.mode != GenerateMode::ServerOnly;
165
166    let method_consts: Vec<_> = def
167        .methods
168        .iter()
169        .enumerate()
170        .map(|(idx, m)| {
171            let const_name = format_ident!("{}", m.name.to_string().to_uppercase());
172            let id = idx as u16;
173            quote! { pub const #const_name: u16 = #id; }
174        })
175        .collect();
176
177    // Request / Response types
178    let type_defs: Vec<_> = def
179        .methods
180        .iter()
181        .map(|m| {
182            let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string()));
183            let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string()));
184            let ret_ty = &m.return_type;
185
186            let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect();
187            let field_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect();
188
189            let req_struct = if m.args.is_empty() {
190                quote! {
191                    #[derive(::serde::Serialize, ::serde::Deserialize, Debug)]
192                    pub(super) struct #req_name;
193                }
194            } else {
195                quote! {
196                    #[derive(::serde::Serialize, ::serde::Deserialize, Debug)]
197                    pub(super) struct #req_name {
198                        #( pub #field_names: #field_types, )*
199                    }
200                }
201            };
202
203            quote! {
204                #req_struct
205
206                #[derive(::serde::Serialize, ::serde::Deserialize, Debug)]
207                pub(super) struct #resp_name(pub #ret_ty);
208            }
209        })
210        .collect();
211
212    let server_trait = if gen_server {
213        let trait_methods: Vec<_> = def
214            .methods
215            .iter()
216            .map(|m| {
217                let name = &m.name;
218                let ret_ty = &m.return_type;
219                let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect();
220                let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect();
221                quote! {
222                    fn #name(&self, ctx: &::mill_rpc_core::RpcContext, #( #arg_names: #arg_types ),*) -> #ret_ty;
223                }
224            })
225            .collect();
226
227        let dispatch_arms: Vec<_> = def
228            .methods
229            .iter()
230            .map(|m| {
231                let name = &m.name;
232                let const_name = format_ident!("{}", m.name.to_string().to_uppercase());
233                let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string()));
234                let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string()));
235
236                let call_args = if m.args.is_empty() {
237                    quote! {}
238                } else {
239                    let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect();
240                    let args: Vec<_> = field_names.iter().map(|n| quote! { req.#n }).collect();
241                    quote! { , #( #args ),* }
242                };
243
244                quote! {
245                    methods::#const_name => {
246                        let req: types::#req_name = codec.deserialize(args)?;
247                        let result = svc.#name(ctx #call_args);
248                        codec.serialize(&types::#resp_name(result))
249                    }
250                }
251            })
252            .collect();
253
254        quote! {
255            /// Server trait — implement this to handle RPC calls for this service.
256            pub trait Service: Send + Sync + 'static {
257                #( #trait_methods )*
258            }
259
260            /// Internal dispatcher that bridges `Service` impl to `ServiceDispatch`.
261            struct Dispatcher<T: Service>(T);
262
263            impl<T: Service> ::mill_rpc_core::ServiceDispatch for Dispatcher<T> {
264                fn dispatch(
265                    &self,
266                    ctx: &::mill_rpc_core::RpcContext,
267                    method_id: u16,
268                    args: &[u8],
269                    codec: &::mill_rpc_core::Codec,
270                ) -> Result<Vec<u8>, ::mill_rpc_core::RpcError> {
271                    let svc = &self.0;
272                    match method_id {
273                        #( #dispatch_arms, )*
274                        _ => Err(::mill_rpc_core::RpcError::method_not_found(method_id)),
275                    }
276                }
277            }
278
279            /// Wrap a `Service` implementation for server registration.
280            ///
281            /// # Example
282            /// ```ignore
283            /// RpcServer::builder()
284            ///     .service(calculator::server(MyCalc))
285            ///     .build(&event_loop)?;
286            /// ```
287            pub fn server<T: Service>(implementation: T) -> impl ::mill_rpc_core::ServiceDispatch {
288                Dispatcher(implementation)
289            }
290        }
291    } else {
292        quote! {}
293    };
294
295    let client_code = if gen_client {
296        let client_methods: Vec<_> = def
297            .methods
298            .iter()
299            .map(|m| {
300                let name = &m.name;
301                let ret_ty = &m.return_type;
302                let const_name = format_ident!("{}", m.name.to_string().to_uppercase());
303                let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string()));
304                let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string()));
305
306                let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect();
307                let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect();
308
309                let req_construct = if m.args.is_empty() {
310                    quote! { types::#req_name }
311                } else {
312                    let fields: Vec<_> = arg_names.iter().map(|n| quote! { #n: #n }).collect();
313                    quote! { types::#req_name { #( #fields, )* } }
314                };
315
316                quote! {
317                    pub fn #name(&self, #( #arg_names: #arg_types ),*) -> Result<#ret_ty, ::mill_rpc_core::RpcError> {
318                        let req = #req_construct;
319                        let payload = self.codec.serialize(&req)?;
320                        let resp_bytes = self.transport.call(
321                            self.service_id,
322                            methods::#const_name,
323                            payload,
324                        )?;
325                        let resp: types::#resp_name = self.codec.deserialize(&resp_bytes)?;
326                        Ok(resp.0)
327                    }
328                }
329            })
330            .collect();
331
332        quote! {
333            /// Generated RPC client for this service.
334            pub struct Client {
335                transport: ::std::sync::Arc<dyn ::mill_rpc_core::RpcTransport>,
336                codec: ::mill_rpc_core::Codec,
337                service_id: u16,
338            }
339
340            impl Client {
341                /// Create a new client.
342                ///
343                /// - `transport`: the RPC transport (typically an `RpcClient`)
344                /// - `codec`: serialization codec (must match the server)
345                /// - `service_id`: the ID assigned to this service on the server
346                ///   (matches registration order, starting from 0)
347                pub fn new(
348                    transport: ::std::sync::Arc<dyn ::mill_rpc_core::RpcTransport>,
349                    codec: ::mill_rpc_core::Codec,
350                    service_id: u16,
351                ) -> Self {
352                    Self { transport, codec, service_id }
353                }
354
355                #( #client_methods )*
356            }
357        }
358    } else {
359        quote! {}
360    };
361
362    let output = quote! {
363        pub mod #mod_name {
364            #![allow(unused_imports)]
365            use super::*;
366
367            /// Method ID constants.
368            pub mod methods {
369                #( #method_consts )*
370            }
371
372            /// Service metadata.
373            pub const SERVICE_NAME: &str = #service_name_str;
374            pub const METHOD_COUNT: u16 = #method_count;
375
376            /// Internal request/response types (not part of the public API).
377            mod types {
378                use super::super::*;
379                #( #type_defs )*
380            }
381
382            #server_trait
383
384            #client_code
385        }
386    };
387
388    Ok(output)
389}
390
391// ---------------------------------------------------------------------------
392// Helpers
393// ---------------------------------------------------------------------------
394
395fn to_snake_case(s: &str) -> String {
396    let mut result = String::new();
397    for (i, ch) in s.chars().enumerate() {
398        if ch.is_uppercase() {
399            if i > 0 {
400                result.push('_');
401            }
402            result.push(ch.to_lowercase().next().unwrap());
403        } else {
404            result.push(ch);
405        }
406    }
407    result
408}
409
410fn to_pascal_case(s: &str) -> String {
411    s.split('_')
412        .map(|part| {
413            let mut chars = part.chars();
414            match chars.next() {
415                None => String::new(),
416                Some(c) => c.to_uppercase().to_string() + chars.as_str(),
417            }
418        })
419        .collect()
420}