1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 braced,
5 parse::{Parse, ParseStream},
6 parse_macro_input, FnArg, Ident, Pat, ReturnType, Token, TraitItemFn, Type,
7};
8
9#[proc_macro]
45pub fn service(input: TokenStream) -> TokenStream {
46 let def = parse_macro_input!(input as ServiceDef);
47 match generate_service_module(def) {
48 Ok(tokens) => tokens.into(),
49 Err(err) => err.to_compile_error().into(),
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54enum GenerateMode {
55 Both,
56 ServerOnly,
57 ClientOnly,
58}
59
60struct ServiceDef {
61 mode: GenerateMode,
62 name: Ident,
63 methods: Vec<MethodDef>,
64}
65
66struct MethodDef {
67 name: Ident,
68 args: Vec<(Ident, Type)>,
69 return_type: Type,
70}
71
72impl Parse for ServiceDef {
73 fn parse(input: ParseStream) -> syn::Result<Self> {
74 let mode = if input.peek(Token![#]) {
76 input.parse::<Token![#]>()?;
77 let content;
78 syn::bracketed!(content in input);
79 let attr_name: Ident = content.parse()?;
80 match attr_name.to_string().as_str() {
81 "server" => GenerateMode::ServerOnly,
82 "client" => GenerateMode::ClientOnly,
83 other => {
84 return Err(syn::Error::new_spanned(
85 attr_name,
86 format!(
87 "Unknown attribute `{}`, expected `server` or `client`",
88 other
89 ),
90 ))
91 }
92 }
93 } else {
94 GenerateMode::Both
95 };
96
97 let service_kw: Ident = input.parse()?;
99 if service_kw != "service" {
100 return Err(syn::Error::new_spanned(service_kw, "Expected `service`"));
101 }
102
103 let name: Ident = input.parse()?;
104
105 let content;
106 braced!(content in input);
107
108 let mut methods = Vec::new();
109 while !content.is_empty() {
110 let method: TraitItemFn = content.parse()?;
111
112 let method_name = method.sig.ident.clone();
113
114 let mut args = Vec::new();
115 for arg in &method.sig.inputs {
116 match arg {
117 FnArg::Typed(pat_type) => {
118 let ident = match &*pat_type.pat {
119 Pat::Ident(pi) => pi.ident.clone(),
120 other => {
121 return Err(syn::Error::new_spanned(
122 other,
123 "Expected a simple identifier for argument name",
124 ))
125 }
126 };
127 args.push((ident, (*pat_type.ty).clone()));
128 }
129 FnArg::Receiver(_) => {
130 return Err(syn::Error::new_spanned(
131 arg,
132 "Service methods should not have `self` parameter",
133 ))
134 }
135 }
136 }
137
138 let return_type = match &method.sig.output {
139 ReturnType::Default => syn::parse_quote!(()),
140 ReturnType::Type(_, ty) => (**ty).clone(),
141 };
142
143 methods.push(MethodDef {
144 name: method_name,
145 args,
146 return_type,
147 });
148 }
149
150 Ok(ServiceDef {
151 mode,
152 name,
153 methods,
154 })
155 }
156}
157
158fn generate_service_module(def: ServiceDef) -> syn::Result<proc_macro2::TokenStream> {
159 let mod_name = format_ident!("{}", to_snake_case(&def.name.to_string()));
160 let service_name_str = def.name.to_string();
161 let method_count = def.methods.len() as u16;
162
163 let gen_server = def.mode != GenerateMode::ClientOnly;
164 let gen_client = def.mode != GenerateMode::ServerOnly;
165
166 let method_consts: Vec<_> = def
167 .methods
168 .iter()
169 .enumerate()
170 .map(|(idx, m)| {
171 let const_name = format_ident!("{}", m.name.to_string().to_uppercase());
172 let id = idx as u16;
173 quote! { pub const #const_name: u16 = #id; }
174 })
175 .collect();
176
177 let type_defs: Vec<_> = def
179 .methods
180 .iter()
181 .map(|m| {
182 let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string()));
183 let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string()));
184 let ret_ty = &m.return_type;
185
186 let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect();
187 let field_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect();
188
189 let req_struct = if m.args.is_empty() {
190 quote! {
191 #[derive(::serde::Serialize, ::serde::Deserialize, Debug)]
192 pub(super) struct #req_name;
193 }
194 } else {
195 quote! {
196 #[derive(::serde::Serialize, ::serde::Deserialize, Debug)]
197 pub(super) struct #req_name {
198 #( pub #field_names: #field_types, )*
199 }
200 }
201 };
202
203 quote! {
204 #req_struct
205
206 #[derive(::serde::Serialize, ::serde::Deserialize, Debug)]
207 pub(super) struct #resp_name(pub #ret_ty);
208 }
209 })
210 .collect();
211
212 let server_trait = if gen_server {
213 let trait_methods: Vec<_> = def
214 .methods
215 .iter()
216 .map(|m| {
217 let name = &m.name;
218 let ret_ty = &m.return_type;
219 let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect();
220 let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect();
221 quote! {
222 fn #name(&self, ctx: &::mill_rpc_core::RpcContext, #( #arg_names: #arg_types ),*) -> #ret_ty;
223 }
224 })
225 .collect();
226
227 let dispatch_arms: Vec<_> = def
228 .methods
229 .iter()
230 .map(|m| {
231 let name = &m.name;
232 let const_name = format_ident!("{}", m.name.to_string().to_uppercase());
233 let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string()));
234 let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string()));
235
236 let call_args = if m.args.is_empty() {
237 quote! {}
238 } else {
239 let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect();
240 let args: Vec<_> = field_names.iter().map(|n| quote! { req.#n }).collect();
241 quote! { , #( #args ),* }
242 };
243
244 quote! {
245 methods::#const_name => {
246 let req: types::#req_name = codec.deserialize(args)?;
247 let result = svc.#name(ctx #call_args);
248 codec.serialize(&types::#resp_name(result))
249 }
250 }
251 })
252 .collect();
253
254 quote! {
255 pub trait Service: Send + Sync + 'static {
257 #( #trait_methods )*
258 }
259
260 struct Dispatcher<T: Service>(T);
262
263 impl<T: Service> ::mill_rpc_core::ServiceDispatch for Dispatcher<T> {
264 fn dispatch(
265 &self,
266 ctx: &::mill_rpc_core::RpcContext,
267 method_id: u16,
268 args: &[u8],
269 codec: &::mill_rpc_core::Codec,
270 ) -> Result<Vec<u8>, ::mill_rpc_core::RpcError> {
271 let svc = &self.0;
272 match method_id {
273 #( #dispatch_arms, )*
274 _ => Err(::mill_rpc_core::RpcError::method_not_found(method_id)),
275 }
276 }
277 }
278
279 pub fn server<T: Service>(implementation: T) -> impl ::mill_rpc_core::ServiceDispatch {
288 Dispatcher(implementation)
289 }
290 }
291 } else {
292 quote! {}
293 };
294
295 let client_code = if gen_client {
296 let client_methods: Vec<_> = def
297 .methods
298 .iter()
299 .map(|m| {
300 let name = &m.name;
301 let ret_ty = &m.return_type;
302 let const_name = format_ident!("{}", m.name.to_string().to_uppercase());
303 let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string()));
304 let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string()));
305
306 let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect();
307 let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect();
308
309 let req_construct = if m.args.is_empty() {
310 quote! { types::#req_name }
311 } else {
312 let fields: Vec<_> = arg_names.iter().map(|n| quote! { #n: #n }).collect();
313 quote! { types::#req_name { #( #fields, )* } }
314 };
315
316 quote! {
317 pub fn #name(&self, #( #arg_names: #arg_types ),*) -> Result<#ret_ty, ::mill_rpc_core::RpcError> {
318 let req = #req_construct;
319 let payload = self.codec.serialize(&req)?;
320 let resp_bytes = self.transport.call(
321 self.service_id,
322 methods::#const_name,
323 payload,
324 )?;
325 let resp: types::#resp_name = self.codec.deserialize(&resp_bytes)?;
326 Ok(resp.0)
327 }
328 }
329 })
330 .collect();
331
332 quote! {
333 pub struct Client {
335 transport: ::std::sync::Arc<dyn ::mill_rpc_core::RpcTransport>,
336 codec: ::mill_rpc_core::Codec,
337 service_id: u16,
338 }
339
340 impl Client {
341 pub fn new(
348 transport: ::std::sync::Arc<dyn ::mill_rpc_core::RpcTransport>,
349 codec: ::mill_rpc_core::Codec,
350 service_id: u16,
351 ) -> Self {
352 Self { transport, codec, service_id }
353 }
354
355 #( #client_methods )*
356 }
357 }
358 } else {
359 quote! {}
360 };
361
362 let output = quote! {
363 pub mod #mod_name {
364 #![allow(unused_imports)]
365 use super::*;
366
367 pub mod methods {
369 #( #method_consts )*
370 }
371
372 pub const SERVICE_NAME: &str = #service_name_str;
374 pub const METHOD_COUNT: u16 = #method_count;
375
376 mod types {
378 use super::super::*;
379 #( #type_defs )*
380 }
381
382 #server_trait
383
384 #client_code
385 }
386 };
387
388 Ok(output)
389}
390
391fn to_snake_case(s: &str) -> String {
396 let mut result = String::new();
397 for (i, ch) in s.chars().enumerate() {
398 if ch.is_uppercase() {
399 if i > 0 {
400 result.push('_');
401 }
402 result.push(ch.to_lowercase().next().unwrap());
403 } else {
404 result.push(ch);
405 }
406 }
407 result
408}
409
410fn to_pascal_case(s: &str) -> String {
411 s.split('_')
412 .map(|part| {
413 let mut chars = part.chars();
414 match chars.next() {
415 None => String::new(),
416 Some(c) => c.to_uppercase().to_string() + chars.as_str(),
417 }
418 })
419 .collect()
420}