1use std::collections::HashSet;
2
3use convert_case::{Case, Casing};
4use proc_macro2::TokenStream;
5use proc_macro_error::{abort, emit_error, proc_macro_error};
6use quote::{quote, quote_spanned};
7use syn::{
8 spanned::Spanned, FnArg, GenericArgument, Ident, Pat, ReturnType, Token, TraitItem,
9 TraitItemFn, Type, VisRestricted, Visibility,
10};
11
12#[proc_macro_error]
36#[proc_macro_attribute]
37pub fn service(
38 _attr: proc_macro::TokenStream,
39 item: proc_macro::TokenStream,
40) -> proc_macro::TokenStream {
41 let tokens = proc_macro2::TokenStream::from(item);
42
43 let Ok(ast) = syn::parse2::<syn::ItemTrait>(tokens) else {
62 return proc_macro::TokenStream::new();
63 };
64
65 let module_name = format!("{}", ast.ident.to_string().to_case(Case::Snake));
66 let module_name = Ident::new(&module_name, proc_macro2::Span::call_site());
67 let original_vis = &ast.vis;
68 let vis = match ast.vis.clone() {
69 vis @ (syn::Visibility::Public(_) | syn::Visibility::Restricted(_)) => vis,
70 syn::Visibility::Inherited => syn::Visibility::Restricted(VisRestricted {
71 pub_token: Token),
72 paren_token: Default::default(),
73 in_token: None,
74 path: Box::new(syn::parse2::<syn::Path>(quote!(super)).unwrap()),
75 }),
76 };
77
78 let mut statefuls = Vec::new();
79 let mut statelesses = Vec::new();
80 let mut call_binds = Vec::new();
81 let mut name_table = HashSet::new();
82
83 let (functions, attrs): (Vec<_>, Vec<_>) = ast
84 .items
85 .iter()
86 .filter_map(|x| if let TraitItem::Fn(x) = x { Some(x) } else { None })
87 .map(|x| {
88 let mut x = x.clone();
89 let attrs = method_attrs(&mut x);
90 (x, attrs)
91 })
92 .unzip();
93
94 let non_functions = ast
95 .items
96 .iter()
97 .filter_map(|x| if let TraitItem::Fn(_) = x { None } else { Some(x) })
98 .collect::<Vec<_>>();
99
100 for item in functions.iter().zip(&attrs) {
101 if let Some(loaded) = generate_loader_item(item.0, item.1, &mut name_table) {
102 match loaded {
103 LoaderOutput::Stateful(stateful) => statefuls.push(stateful),
104 LoaderOutput::Stateless(stateless) => statelesses.push(stateless),
105 }
106 }
107
108 if let Some(caller) = generate_call_stubs(item.0, item.1, &vis) {
109 call_binds.push(caller);
110 }
111 }
112
113 let trait_signatures = generate_trait_signatures(&functions, &attrs);
114
115 let output = quote!(
116 #original_vis mod #module_name {
117 #![allow(unused_parens)]
118 #![allow(unused)]
119
120 use super::*;
121
122 use rpc_it::service as __sv;
123 use rpc_it::service::macro_utils as __mc;
124 use rpc_it::serde;
125 use rpc_it::ExtractUserData;
126
127 #vis trait Service: Send + Sync + 'static + Clone {
128 #trait_signatures
129 #(#non_functions)*
130 }
131
132 #vis fn load_service_stateful_only<T: Service, R: __sv::Router>(
133 __this: T,
134 __service: &mut __sv::ServiceBuilder<R>
135 ) -> __mc::RegisterResult {
136 #(#statefuls;)*
137 Ok(())
138 }
139
140 #vis fn load_service_stateless_only<T: Service, R: __sv::Router>(
141 __service: &mut __sv::ServiceBuilder<R>
142 ) -> __mc::RegisterResult {
143 #(#statelesses;)*
144 Ok(())
145 }
146
147 #vis fn load_service<T:Service, R: __sv::Router>(
148 __this: T,
149 __service: &mut __sv::ServiceBuilder<R>
150 ) -> __mc::RegisterResult {
151 load_service_stateful_only(__this, __service)?;
152 load_service_stateless_only::<T, _>(__service)?;
153 Ok(())
154 }
155
156 #[derive(Debug, Clone)]
157 #vis struct Client(rpc_it::Sender);
158
159 impl Client {
160 #vis fn new(inner: rpc_it::Sender) -> Self {
161 Self(inner)
162 }
163
164 #vis fn into_inner(self) -> rpc_it::Sender {
165 self.0
166 }
167
168 #vis fn inner(&self) -> &rpc_it::Sender {
169 &self.0
170 }
171
172 #(#call_binds)*
173 }
174 }
175 );
176
177 output.into()
178}
179
180enum LoaderOutput {
181 Stateful(TokenStream),
182 Stateless(TokenStream),
183}
184
185fn generate_loader_item(
186 method: &TraitItemFn,
187 attrs: &MethodAttrs,
188 used_route_table: &mut HashSet<String>,
189) -> Option<LoaderOutput> {
190 if attrs.skip {
191 return None;
192 }
193
194 let mut is_self_ref = false;
195 let mut is_stateless = false;
196
197 if let Some(receiver) = method.sig.receiver() {
198 if receiver.reference.is_some() && receiver.colon_token.is_none() {
199 if receiver.mutability.is_some() {
200 emit_error!(receiver, "Only `&self` is allowed");
201 return None;
202 }
203
204 is_self_ref = true;
205 } else if receiver.colon_token.is_some() && receiver.reference.is_none() {
206 is_self_ref = matches!(&*receiver.ty, syn::Type::Reference(_));
207 }
208 } else {
209 is_stateless = true;
210 };
211
212 let is_sync_func = attrs.sync;
214 let mut routes = Vec::with_capacity(1 + attrs.aliases.len());
215 let ident = &method.sig.ident;
216 routes.push(
217 attrs
218 .route
219 .as_ref()
220 .map(syn::LitStr::value)
221 .unwrap_or_else(|| method.sig.ident.to_string()),
222 );
223
224 for route in &attrs.aliases {
225 routes.push(route.value());
226 }
227
228 let (is_ref, inputs): (Vec<_>, Vec<_>) = method
230 .sig
231 .inputs
232 .iter()
233 .skip(if is_self_ref { 1 } else { 0 })
234 .map(|input| {
235 let syn::FnArg::Typed(pat) = input else {
236 abort!(input, "unexpected argument type");
237 };
238
239 if let Type::Reference(r) = &*pat.ty {
240 let inner = &r.elem;
241 (true, Type::Verbatim(quote!(std::borrow::Cow<#inner>)))
242 } else {
243 (false, (*pat.ty).clone())
244 }
245 })
246 .unzip();
247
248 let tup_inputs = quote!((#(#inputs),*));
249 let route_paths = quote!(&[#(#routes),*]);
250 let unpack = if inputs.len() == 1 {
251 let tok_ref = is_ref[0].then(|| quote!(&));
252 quote!(#tok_ref __req)
253 } else {
254 let vals = (0..inputs.len()).map(|x| syn::Index::from(x));
255 let tok_ref = is_ref.iter().map(|x| if *x { quote!(&) } else { quote!() });
256 quote!(#( #tok_ref __req.#vals ),*)
257 };
258
259 for r in routes {
260 if !used_route_table.insert(r.clone()) {
261 emit_error!(method, "duplicated route: {}", r);
262 }
263 }
264
265 let output = OutputType::new(&method.sig.output);
266
267 let tok_this_clone = (!is_stateless).then(|| quote!(let __this_2 = __this.clone();));
268 let tok_this_param = (!is_stateless).then(|| quote!(&__this_2,));
269
270 let strm = if output.is_notify() {
271 quote!(
272 #tok_this_clone
273 __service.register_notify_handler(#route_paths, move |__src, __req: #tup_inputs| {
274 T::#ident(#tok_this_param __src, #unpack);
275 Ok(())
276 })?
277 )
278 } else {
279 let type_out = output.typed_req();
280 if is_sync_func {
281 let rval = output.handle_sync_retval_to_response(
282 Ident::new("__src", method.sig.output.span()),
283 Ident::new("__result", method.sig.output.span()),
284 );
285
286 quote!(
287 #tok_this_clone
288 __service.register_request_handler(#route_paths, move |__src: #type_out, __req: #tup_inputs| {
289 let __result = T::#ident(#tok_this_param __src.user_data_owned(), #unpack);
290 #rval;
291 Ok(())
292 })?
293 )
294 } else {
295 quote!(
296 #tok_this_clone
297 __service.register_request_handler(#route_paths, move |__src: #type_out, __req: #tup_inputs| {
298 T::#ident(#tok_this_param __src, #unpack);
299 Ok(())
300 })?
301 )
302 }
303 };
304
305 Some(if is_stateless { LoaderOutput::Stateless(strm) } else { LoaderOutput::Stateful(strm) })
306}
307
308fn generate_call_stubs(
309 method: &TraitItemFn,
310 attrs: &MethodAttrs,
311 vis: &Visibility,
312) -> Option<TokenStream> {
313 if attrs.skip {
314 return None;
315 }
316
317 let has_receiver = method.sig.receiver().is_some();
318
319 let inputs = method
320 .sig
321 .inputs
322 .iter()
323 .skip(if has_receiver { 1 } else { 0 })
324 .map(|arg| {
325 let FnArg::Typed(pat) = arg else { abort!(arg, "unexpected argument type") };
326 if !matches!(*pat.pat, Pat::Ident(_)) {
327 abort!(arg, "Function argument pattern must be named identifier.");
328 }
329 pat
330 })
331 .collect::<Vec<_>>();
332
333 let input_ref_args = inputs.iter().map(|x| *x).cloned().map(|mut x| {
334 x.ty = match *x.ty {
335 ty @ Type::Reference(_) => ty.into(),
336 other => Type::Reference(syn::TypeReference {
337 and_token: Token),
338 lifetime: None,
339 mutability: None,
340 elem: other.into(),
341 })
342 .into(),
343 };
344 x
345 });
346 let input_ref_arg_tokens = quote!(#(#input_ref_args),*);
347
348 let input_idents = inputs
349 .iter()
350 .map(|x| *x)
351 .cloned()
352 .map(|x| {
353 let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &*x.pat else { unreachable!() };
354 ident.clone()
355 })
356 .collect::<Vec<_>>();
357
358 let method_ident = &method.sig.ident;
359 let output = OutputType::new(&method.sig.output);
360
361 let method_str =
362 attrs.route.as_ref().map(syn::LitStr::value).unwrap_or_else(|| method_ident.to_string());
363
364 let new_ident_suffixed =
365 |sfx: &str| syn::Ident::new(&format!("{0}_{1}", method_ident, sfx), method_ident.span());
366 let new_ident_prefixed =
367 |sfx: &str| syn::Ident::new(&format!("{1}_{0}", method_ident, sfx), method_ident.span());
368 let method_ident_deferred = new_ident_suffixed("deferred");
369
370 Some(if output.is_notify() {
371 let method_ident_with_reuse = new_ident_suffixed("with_reuse");
372 let method_ident_deferred_with_reuse = new_ident_suffixed("deferred_with_reuse");
373
374 let reuse_version = attrs.with_reuse.then(|| quote!(
375 #[doc(hidden)]
376 #vis async fn #method_ident_with_reuse(&self, buffer: &mut rpc_it::rpc::WriteBuffer, #input_ref_arg_tokens) -> Result<(), rpc_it::SendError> {
377 self.0.notify_with_reuse(buffer, #method_str, &(#(#input_idents),*)).await
378 }
379
380 #[doc(hidden)]
381 #vis async fn #method_ident_deferred_with_reuse(&self, buffer: &mut rpc_it::rpc::WriteBuffer, #input_ref_arg_tokens) -> Result<(), rpc_it::SendError> {
382 self.0.notify_deferred_with_reuse(buffer, #method_str, &(#(#input_idents),*))
383 }
384 ));
385
386 quote!(
387 #vis async fn #method_ident(&self, #input_ref_arg_tokens) -> Result<(), rpc_it::SendError> {
388 self.0.notify(#method_str, &(#(#input_idents),*)).await
389 }
390
391
392 #vis async fn #method_ident_deferred(&self, #input_ref_arg_tokens) -> Result<(), rpc_it::SendError> {
393 self.0.notify_deferred(#method_str, &(#(#input_idents),*))
394 }
395
396 #reuse_version
397 )
398 } else {
399 let (ok_tok, err_tok) = match &output {
400 OutputType::Response(ok, err) => (quote!(#ok), quote!(#err)),
401 OutputType::ResponseNoErr(ok) => (quote!(#ok), quote!(())),
402 OutputType::Notify => unreachable!(),
403 };
404
405 let method_ident_request = new_ident_prefixed("request");
406
407 quote!(
408 #vis async fn #method_ident(&self, #input_ref_arg_tokens)
409 -> Result<#ok_tok, rpc_it::TypedCallError<#err_tok>>
410 {
411 self.0.call_with_err(#method_str, &(#(#input_idents),*)).await
412 }
413
414 #vis async fn #method_ident_request(&self, #input_ref_arg_tokens)
415 -> Result<rpc_it::TypedResponse<#ok_tok, #err_tok>, rpc_it::SendError>
416 {
417 let resp = self.0.request(#method_str, &(#(#input_idents),*)).await?;
418 Ok(rpc_it::TypedResponse::new(resp.to_owned()))
419 }
420
421 #vis async fn #method_ident_deferred(&self, #input_ref_arg_tokens)
422 -> Result<rpc_it::TypedResponse<#ok_tok, #err_tok>, rpc_it::SendError>
423 {
424 let resp = self.0.request_deferred(#method_str, &(#(#input_idents),*))?;
425 Ok(rpc_it::TypedResponse::new(resp.to_owned()))
426 }
427 )
428 })
429}
430
431fn generate_trait_signatures(items: &[TraitItemFn], attrs: &[MethodAttrs]) -> TokenStream {
432 let tokens = items.iter().zip(attrs).map(|(method, attrs)| {
433 let mut method = method.clone();
434 let out = OutputType::new(&method.sig.output);
435
436 if attrs.skip {
437 return TraitItem::Fn(method);
439 }
440
441 let req_param_ident = if let Some(body) =
442 method.default.as_ref().filter(|_| !attrs.sync && !out.is_notify())
443 {
444 let span = body.span();
445 let id_req = Ident::new("___rq", span);
446 let payload: syn::Expr = syn::parse_quote_spanned!(span => (move || #body)());
447 let response = match out {
448 OutputType::Notify => unreachable!(),
449 OutputType::ResponseNoErr(_) => {
450 quote_spanned!(span => #id_req.ok(&#payload).ok();)
451 }
452 OutputType::Response(_, _) => {
453 quote_spanned!(
454 span => match #payload {
455 Ok(x) => #id_req.ok(&x).ok(),
456 Err(e) => #id_req.err(&e).ok(),
457 }
458 )
459 }
460 };
461
462 method.default = Some(syn::parse_quote_spanned!(
463 span =>
464 {
465 #response;
466 }
467 ));
468
469 syn::Pat::Ident(syn::PatIdent {
470 attrs: Vec::new(),
471 by_ref: None,
472 mutability: None,
473 subpat: None,
474 ident: id_req,
475 })
476 } else {
477 syn::Pat::Wild(syn::PatWild {
478 attrs: Vec::new(),
479 underscore_token: Token),
480 })
481 };
482
483 {
484 let has_receiver = method.sig.receiver().is_some();
485 let insert_at = if has_receiver { 1 } else { 0 };
486
487 if out.is_notify() {
488 method.sig.inputs.insert(
489 insert_at,
490 syn::parse_quote_spanned!(method.sig.output.span() => _: rpc_it::Notify),
491 );
492 } else if !attrs.sync {
493 method.sig.inputs.insert(
494 insert_at,
495 syn::FnArg::Typed(syn::PatType {
496 attrs: Vec::new(),
497 colon_token: Default::default(),
498 pat: req_param_ident.into(),
499 ty: out.typed_req().into(),
500 }),
501 );
502
503 method.sig.output = ReturnType::Default;
504 } else {
505 method.sig.inputs.insert(
506 insert_at,
507 syn::parse_quote_spanned!(method.sig.output.span() => _: rpc_it::OwnedUserData),
508 );
509 }
510 }
511
512 TraitItem::Fn(method)
513 });
514
515 quote!(#(#tokens)*)
516}
517
518#[derive(Default)]
519struct MethodAttrs {
520 sync: bool,
521 skip: bool,
522 aliases: Vec<syn::LitStr>,
523 with_reuse: bool,
524 route: Option<syn::LitStr>,
525}
526
527fn method_attrs(method: &mut TraitItemFn) -> MethodAttrs {
528 let mut attrs = MethodAttrs::default();
529
530 for attr in std::mem::take(&mut method.attrs) {
531 match &attr.meta {
532 syn::Meta::Path(path) => {
533 if path.is_ident("sync") {
534 if matches!(method.sig.output, ReturnType::Default) {
535 emit_error!(attr, "'sync' attribute is only allowed for requests");
536 }
537
538 attrs.sync = true;
539 } else if path.is_ident("skip") {
540 attrs.skip = true;
541 } else if path.is_ident("with_reuse") {
542 attrs.with_reuse = true;
543 } else {
544 emit_error!(attr, "unexpected attribute")
545 }
546 }
547
548 syn::Meta::List(_) => {
549 emit_error!(attr, "unexpected attribute")
550 }
551
552 syn::Meta::NameValue(kv) => {
553 let Some(ident) = kv.path.get_ident() else {
554 emit_error!(attr, "unexpected attribute");
555 continue;
556 };
557 let syn::Expr::Lit(syn::ExprLit { lit, .. }) = &kv.value else {
558 emit_error!(attr, "unexpected attribute");
559 continue;
560 };
561
562 if ident == "aliases" {
563 let syn::Lit::Str(route) = lit else {
564 emit_error!(lit, "unexpected non-string literal attribute");
565 continue;
566 };
567 attrs.aliases.push(route.clone());
568 } else if ident == "route" {
569 let syn::Lit::Str(route) = lit else {
570 emit_error!(lit, "unexpected non-string literal attribute");
571 continue;
572 };
573 attrs.route = Some(route.clone());
574 } else {
575 emit_error!(attr, "unexpected attribute")
576 }
577 }
578 }
579 }
580
581 attrs
582}
583
584enum OutputType {
585 Notify,
586 ResponseNoErr(Type),
587 Response(GenericArgument, GenericArgument),
588}
589
590impl OutputType {
591 fn new(val: &syn::ReturnType) -> Self {
592 let syn::ReturnType::Type(_, ty) = val else { return Self::Notify };
593
594 let fb = || Self::ResponseNoErr((**ty).clone());
595 let Type::Path(tp) = &**ty else { return fb() };
596 let Some(first_seg) = tp.path.segments.first() else { return fb() };
597
598 if first_seg.ident != "Result" {
599 return fb();
600 }
601
602 let syn::PathArguments::AngleBracketed(ang) = &first_seg.arguments else {
603 return fb();
604 };
605
606 let mut type_iter = ang.args.iter();
607 let [Some(ok), Some(err)] = std::array::from_fn(|_| type_iter.next()) else {
608 return fb();
609 };
610
611 Self::Response(ok.clone(), err.clone())
612 }
613
614 fn is_notify(&self) -> bool {
615 matches!(self, Self::Notify)
616 }
617
618 fn typed_req(&self) -> Type {
619 match self {
620 OutputType::Notify => unimplemented!(),
621
622 OutputType::ResponseNoErr(x) => {
623 syn::parse2(quote!(rpc_it::TypedRequest<#x, ()>)).unwrap()
624 }
625
626 OutputType::Response(r, e) => {
627 syn::parse2(quote!(rpc_it::TypedRequest<#r, #e>)).unwrap()
628 }
629 }
630 }
631
632 fn handle_sync_retval_to_response(&self, req_ident: Ident, val_ident: Ident) -> TokenStream {
633 match self {
634 OutputType::Notify => unimplemented!(),
635
636 OutputType::ResponseNoErr(_) => {
637 quote!(#req_ident.ok(&#val_ident)?;)
638 }
639
640 OutputType::Response(_, _) => {
641 quote!(
642 match #val_ident {
643 Ok(x) => #req_ident.ok(&x)?,
644 Err(e) => #req_ident.err(&e)?,
645 }
646 )
647 }
648 }
649 }
650}