moonpool_transport_derive/
lib.rs1use proc_macro::TokenStream;
24use quote::{format_ident, quote};
25use syn::{
26 Expr, ExprLit, FnArg, GenericArgument, Ident, ItemTrait, Lit, PathArguments, ReturnType,
27 TraitItem, Type, parse_macro_input,
28};
29
30#[proc_macro_attribute]
48pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
49 let attr = parse_macro_input!(attr as InterfaceAttr);
50 let item = parse_macro_input!(item as ItemTrait);
51
52 match service_impl(attr, item) {
53 Ok(tokens) => tokens.into(),
54 Err(err) => err.to_compile_error().into(),
55 }
56}
57
58fn service_impl(attr: InterfaceAttr, item: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
60 let mut has_ref = false;
61 let mut has_mut_ref = false;
62
63 for trait_item in &item.items {
64 if let TraitItem::Fn(method) = trait_item
65 && let Some(FnArg::Receiver(recv)) = method.sig.inputs.first()
66 {
67 if recv.mutability.is_some() {
68 has_mut_ref = true;
69 } else {
70 has_ref = true;
71 }
72 }
73 }
74
75 if has_ref && has_mut_ref {
76 return Err(syn::Error::new_spanned(
77 &item.ident,
78 "all methods must use `&self` receivers",
79 ));
80 }
81
82 if has_mut_ref {
83 return Err(syn::Error::new_spanned(
84 &item.ident,
85 "`&mut self` methods (virtual actor mode) have been removed. Use `&self` for RPC services.",
86 ));
87 }
88
89 interface_impl(attr, item)
90}
91
92struct MethodInfo {
94 index: u32,
95 name: Ident,
96 req_type: Type,
97 resp_type: Type,
98}
99
100fn interface_impl(attr: InterfaceAttr, item: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
101 let interface_id = attr.id;
102 let name = &item.ident;
103 let server_name = format_ident!("{}Server", name);
104 let client_name = format_ident!("{}Client", name);
105
106 let mut method_infos: Vec<MethodInfo> = Vec::new();
108 for (index, trait_item) in item.items.iter().enumerate() {
109 if let TraitItem::Fn(method) = trait_item {
110 let method_name = &method.sig.ident;
111
112 let (req_type, resp_type) = extract_method_types(&method.sig)?;
114
115 method_infos.push(MethodInfo {
117 index: (index + 1) as u32,
118 name: method_name.clone(),
119 req_type,
120 resp_type,
121 });
122 }
123 }
124
125 let method_count = method_infos.len() as u32;
126
127 let server_fields = method_infos.iter().map(|m| {
129 let name = &m.name;
130 let req_type = &m.req_type;
131 quote! { pub #name: moonpool_transport::RequestStream<#req_type, C> }
132 });
133
134 let server_inits: Vec<_> = method_infos
136 .iter()
137 .enumerate()
138 .map(|(i, m)| {
139 let name = &m.name;
140 let idx = m.index;
141 let is_last = i == method_infos.len() - 1;
142 if is_last {
143 quote! {
144 let (#name, _) = transport.register_handler_at(Self::INTERFACE_ID, #idx as u64, codec);
145 }
146 } else {
147 quote! {
148 let (#name, _) = transport.register_handler_at(Self::INTERFACE_ID, #idx as u64, codec.clone());
149 }
150 }
151 })
152 .collect();
153
154 let server_field_names: Vec<_> = method_infos.iter().map(|m| &m.name).collect();
155
156 let client_fields = method_infos.iter().map(|m| {
158 let name = &m.name;
159 let req_type = &m.req_type;
160 let resp_type = &m.resp_type;
161 quote! {
162 pub #name: moonpool_transport::ServiceEndpoint<#req_type, #resp_type, C>
165 }
166 });
167
168 let client_field_inits = method_infos.iter().map(|m| {
170 let name = &m.name;
171 let idx = m.index;
172 quote! {
173 #name: moonpool_transport::ServiceEndpoint::new(
174 moonpool_transport::Endpoint::new(
175 address.clone(),
176 moonpool_transport::UID::new(Self::INTERFACE_ID, #idx as u64),
177 ),
178 codec.clone(),
179 )
180 }
181 });
182
183 let first_field_name = &method_infos[0].name;
184
185 let trait_vis = &item.vis;
187 let trait_items = &item.items;
188 let trait_name_snake = to_snake_case(&name.to_string());
189
190 let serve_close_handles: Vec<_> = method_infos
192 .iter()
193 .map(|m| {
194 let method_name = &m.name;
195 quote! {
196 let queue = self.#method_name.queue();
197 close_fns.push(Box::new(move || queue.close()));
198 }
199 })
200 .collect();
201
202 let serve_spawn_tasks: Vec<_> = method_infos
203 .iter()
204 .map(|m| {
205 let method_name = &m.name;
206 let resp_type = &m.resp_type;
207 let task_name = format!("{}_{}", trait_name_snake, m.name);
208 quote! {
209 {
210 let stream = self.#method_name;
211 let t = transport.clone();
212 let h = handler.clone();
213 providers.task().spawn_task(#task_name, async move {
214 while let Some((req, reply)) = stream
215 .recv_with_transport::<_, #resp_type>(&t)
216 .await
217 {
218 match h.#method_name(req).await {
219 Ok(resp) => reply.send(resp),
220 Err(e) => {
221 tracing::warn!(error = %e, method = #task_name, "handler error");
222 reply.send_error(moonpool_transport::ReplyError::BrokenPromise);
223 }
224 }
225 }
226 });
227 }
228 }
229 })
230 .collect();
231
232 let expanded = quote! {
233 #[async_trait::async_trait(?Send)]
235 #trait_vis trait #name {
236 #(#trait_items)*
237 }
238
239 pub struct #server_name<C: moonpool_transport::MessageCodec> {
243 #(#server_fields,)*
244 }
245
246 impl<C: moonpool_transport::MessageCodec + Clone> #server_name<C> {
247 pub const INTERFACE_ID: u64 = #interface_id;
249
250 pub const METHOD_COUNT: u32 = #method_count;
252
253 pub fn init<P>(transport: &std::rc::Rc<moonpool_transport::NetTransport<P>>, codec: C) -> Self
258 where
259 P: moonpool_transport::Providers,
260 {
261 #(#server_inits)*
262 Self { #(#server_field_names,)* }
263 }
264
265 pub fn serve<P, H>(
279 self,
280 transport: std::rc::Rc<moonpool_transport::NetTransport<P>>,
281 handler: std::rc::Rc<H>,
282 providers: &P,
283 ) -> moonpool_transport::ServerHandle
284 where
285 P: moonpool_transport::Providers,
286 H: #name + 'static,
287 {
288 use moonpool_transport::TaskProvider as _;
289 let mut close_fns: Vec<Box<dyn Fn()>> = Vec::new();
290 #(#serve_close_handles)*
291 #(#serve_spawn_tasks)*
292 moonpool_transport::ServerHandle::new(close_fns)
293 }
294 }
295
296 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
316 #[serde(bound(
317 serialize = "",
318 deserialize = "C: moonpool_transport::MessageCodec + Default",
319 ))]
320 pub struct #client_name<C: moonpool_transport::MessageCodec> {
321 #(#client_fields,)*
322 }
323
324 impl<C: moonpool_transport::MessageCodec + Clone> #client_name<C> {
325 pub const INTERFACE_ID: u64 = #interface_id;
327
328 pub const METHOD_COUNT: u32 = #method_count;
330
331 pub fn new(address: moonpool_transport::NetworkAddress, codec: C) -> Self {
333 Self {
334 #(#client_field_inits,)*
335 }
336 }
337
338 pub fn address(&self) -> &moonpool_transport::NetworkAddress {
340 &self.#first_field_name.endpoint().address
342 }
343 }
344 };
345
346 Ok(expanded)
347}
348
349fn extract_method_types(sig: &syn::Signature) -> syn::Result<(Type, Type)> {
353 let mut inputs = sig.inputs.iter();
355
356 match inputs.next() {
358 Some(FnArg::Receiver(_)) => {}
359 _ => {
360 return Err(syn::Error::new_spanned(
361 sig,
362 "Interface method must have &self as first parameter",
363 ));
364 }
365 }
366
367 let req_type = match inputs.next() {
369 Some(FnArg::Typed(pat_type)) => (*pat_type.ty).clone(),
370 _ => {
371 return Err(syn::Error::new_spanned(
372 sig,
373 "Interface method must have a request parameter: async fn name(&self, req: ReqType) -> Result<RespType, RpcError>",
374 ));
375 }
376 };
377
378 let resp_type = match &sig.output {
380 ReturnType::Type(_, ty) => extract_result_ok_type(ty)?,
381 ReturnType::Default => {
382 return Err(syn::Error::new_spanned(
383 sig,
384 "Interface method must return Result<RespType, RpcError>",
385 ));
386 }
387 };
388
389 Ok((req_type, resp_type))
390}
391
392fn extract_result_ok_type(ty: &Type) -> syn::Result<Type> {
394 if let Type::Path(type_path) = ty
395 && let Some(segment) = type_path.path.segments.last()
396 && segment.ident == "Result"
397 && let PathArguments::AngleBracketed(args) = &segment.arguments
398 && let Some(GenericArgument::Type(ok_type)) = args.args.first()
399 {
400 return Ok(ok_type.clone());
401 }
402
403 Err(syn::Error::new_spanned(
404 ty,
405 "Interface method must return Result<RespType, RpcError>",
406 ))
407}
408
409fn to_snake_case(s: &str) -> String {
411 let mut result = String::new();
412 for (i, c) in s.chars().enumerate() {
413 if c.is_uppercase() {
414 if i > 0 {
415 result.push('_');
416 }
417 result.push(c.to_ascii_lowercase());
418 } else {
419 result.push(c);
420 }
421 }
422 result
423}
424
425struct InterfaceAttr {
431 id: u64,
432}
433
434impl syn::parse::Parse for InterfaceAttr {
435 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
436 let ident: Ident = input.parse()?;
437 if ident != "id" {
438 return Err(syn::Error::new_spanned(
439 ident,
440 "expected `id` in interface attribute",
441 ));
442 }
443 let _eq: syn::Token![=] = input.parse()?;
444 let value: Expr = input.parse()?;
445
446 let id = match &value {
448 Expr::Lit(ExprLit {
449 lit: Lit::Int(lit_int),
450 ..
451 }) => lit_int.base10_parse::<u64>()?,
452 _ => {
453 return Err(syn::Error::new_spanned(
454 value,
455 "expected integer literal for interface id",
456 ));
457 }
458 };
459
460 Ok(InterfaceAttr { id })
461 }
462}