1#![recursion_limit = "512"]
9
10extern crate proc_macro;
11extern crate proc_macro2;
12extern crate quote;
13extern crate syn;
14
15use std::env;
16
17use proc_macro::TokenStream;
18use proc_macro2::TokenStream as TokenStream2;
19use quote::{format_ident, quote, ToTokens};
20use syn::ext::IdentExt;
21use syn::parse::{Parse, ParseStream};
22use syn::spanned::Spanned;
23use syn::token::Comma;
24use syn::{
25 braced, parenthesized, parse_macro_input, parse_quote, AttrStyle, Attribute, Expr, FnArg,
26 Ident, Lit, LitBool, MetaNameValue, Pat, PatType, Path, ReturnType, Token, Type, Visibility,
27};
28
29const ENV_LRCALL_MACRO_PRINT: &'static str = "LRCALL_MACRO_PRINT";
30
31macro_rules! extend_errors {
35 ($errors: ident, $e: expr) => {
36 match $errors {
37 Ok(_) => $errors = Err($e),
38 Err(ref mut errors) => errors.extend($e),
39 }
40 };
41}
42
43struct Service {
44 attrs: Vec<Attribute>,
45 vis: Visibility,
46 ident: Ident,
47 rpcs: Vec<RpcMethod>,
48}
49
50struct RpcMethod {
51 attrs: Vec<Attribute>,
52 ident: Ident,
53 args: Vec<PatType>,
54 output: ReturnType,
55}
56
57impl Parse for Service {
58 fn parse(input: ParseStream) -> syn::Result<Self> {
59 let attrs = input.call(Attribute::parse_outer)?;
60 let vis = input.parse()?;
61 input.parse::<Token![trait]>()?;
62 let ident: Ident = input.parse()?;
63 let content;
64 braced!(content in input);
65 let mut rpcs = Vec::<RpcMethod>::new();
66 while !content.is_empty() {
67 rpcs.push(content.parse()?);
68 }
69 let mut ident_errors = Ok(());
70 for rpc in &rpcs {
71 if rpc.ident == "new" {
72 extend_errors!(
73 ident_errors,
74 syn::Error::new(
75 rpc.ident.span(),
76 format!(
77 "method name conflicts with generated fn `{}Client::new`",
78 ident.unraw()
79 )
80 )
81 );
82 }
83 if rpc.ident == "serve" {
84 extend_errors!(
85 ident_errors,
86 syn::Error::new(
87 rpc.ident.span(),
88 format!("method name conflicts with generated fn `{ident}::serve`")
89 )
90 );
91 }
92 }
93 ident_errors?;
94
95 Ok(Self {
96 attrs,
97 vis,
98 ident,
99 rpcs,
100 })
101 }
102}
103
104impl Parse for RpcMethod {
105 fn parse(input: ParseStream) -> syn::Result<Self> {
106 let attrs = input.call(Attribute::parse_outer)?;
107 input.parse::<Token![async]>()?;
108 input.parse::<Token![fn]>()?;
109 let ident = input.parse()?;
110 let content;
111 parenthesized!(content in input);
112 let mut args = Vec::new();
113 let mut errors = Ok(());
114 for arg in content.parse_terminated(FnArg::parse, Comma)? {
115 match arg {
116 FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => {
117 args.push(captured);
118 }
119 FnArg::Typed(captured) => {
120 extend_errors!(
121 errors,
122 syn::Error::new(captured.pat.span(), "patterns aren't allowed in RPC args")
123 );
124 }
125 FnArg::Receiver(_) => {
126 extend_errors!(
127 errors,
128 syn::Error::new(arg.span(), "method args cannot start with self")
129 );
130 }
131 }
132 }
133 errors?;
134 let output = input.parse()?;
135 input.parse::<Token![;]>()?;
136
137 Ok(Self {
138 attrs,
139 ident,
140 args,
141 output,
142 })
143 }
144}
145
146#[derive(Default)]
147struct DeriveMeta {
148 derive: Option<Derive>,
149 warnings: Vec<TokenStream2>,
150}
151
152impl DeriveMeta {
153 fn with_derives(mut self, new: Vec<Path>) -> Self {
154 match self.derive.as_mut() {
155 Some(Derive::Explicit(old)) => old.extend(new),
156 _ => self.derive = Some(Derive::Explicit(new)),
157 }
158
159 self
160 }
161}
162
163enum Derive {
164 Explicit(Vec<Path>),
165 Serde(bool),
166}
167
168impl Parse for DeriveMeta {
169 fn parse(input: ParseStream) -> syn::Result<Self> {
170 let mut result = Ok(DeriveMeta::default());
171
172 let mut derives = Vec::new();
173 let mut derive_serde = Vec::new();
174 let mut has_derive_serde = false;
175 let mut has_explicit_derives = false;
176
177 let meta_items = input.parse_terminated(MetaNameValue::parse, Comma)?;
178 for meta in meta_items {
179 if meta.path.segments.len() != 1 {
180 extend_errors!(
181 result,
182 syn::Error::new(
183 meta.span(),
184 "lrcall::service does not support this meta item"
185 )
186 );
187 continue;
188 }
189 let segment = meta.path.segments.first().unwrap();
190 if segment.ident == "derive" {
191 has_explicit_derives = true;
192 let Expr::Array(ref array) = meta.value else {
193 extend_errors!(
194 result,
195 syn::Error::new(
196 meta.span(),
197 "lrcall::service does not support this meta item"
198 )
199 );
200 continue;
201 };
202
203 let paths = array
204 .elems
205 .iter()
206 .filter_map(|e| {
207 if let Expr::Path(path) = e {
208 Some(path.path.clone())
209 } else {
210 extend_errors!(
211 result,
212 syn::Error::new(e.span(), "Expected Path or Type")
213 );
214 None
215 }
216 })
217 .collect::<Vec<_>>();
218
219 result = result.map(|d| d.with_derives(paths));
220 derives.push(meta);
221 } else if segment.ident == "derive_serde" {
222 has_derive_serde = true;
223 let Expr::Lit(expr_lit) = &meta.value else {
224 extend_errors!(
225 result,
226 syn::Error::new(meta.value.span(), "expected literal")
227 );
228 continue;
229 };
230 match expr_lit.lit {
231 Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => {
232 result = result.map(|d| DeriveMeta {
233 derive: Some(Derive::Serde(true)),
234 ..d
235 })
236 }
237 Lit::Bool(LitBool { value: true, .. }) => {
238 extend_errors!(
239 result,
240 syn::Error::new(
241 meta.span(),
242 "To enable serde, first enable the `serde1` feature of lrcall"
243 )
244 );
245 }
246 Lit::Bool(LitBool { value: false, .. }) => {
247 result = result.map(|d| DeriveMeta {
248 derive: Some(Derive::Serde(false)),
249 ..d
250 })
251 }
252 _ => extend_errors!(
253 result,
254 syn::Error::new(
255 expr_lit.lit.span(),
256 "`derive_serde` expects a value of type `bool`"
257 )
258 ),
259 }
260 derive_serde.push(meta);
261 } else {
262 extend_errors!(
263 result,
264 syn::Error::new(
265 meta.span(),
266 "lrcall::service does not support this meta item"
267 )
268 );
269 continue;
270 }
271 }
272
273 if has_derive_serde {
274 let deprecation_hack = quote! {
275 const _: () = {
276 #[deprecated(
277 note = "\nThe form `lrcall::service(derive_serde = true)` is deprecated.\
278 \nUse `lrcall::service(derive = [Serialize, Deserialize])`."
279 )]
280 const DEPRECATED_SYNTAX: () = ();
281 let _ = DEPRECATED_SYNTAX;
282 };
283 };
284
285 result = result.map(|mut d| {
286 d.warnings.push(deprecation_hack.to_token_stream());
287 d
288 });
289 }
290
291 if has_explicit_derives & has_derive_serde {
292 extend_errors!(
293 result,
294 syn::Error::new(
295 input.span(),
296 "lrcall does not support `derive_serde` and `derive` at the same time"
297 )
298 );
299 }
300
301 if derive_serde.len() > 1 {
302 for (i, derive_serde) in derive_serde.iter().enumerate() {
303 extend_errors!(
304 result,
305 syn::Error::new(
306 derive_serde.span(),
307 format!(
308 "`derive_serde` appears more than once (occurrence #{})",
309 i + 1
310 )
311 )
312 );
313 }
314 }
315
316 if derives.len() > 1 {
317 for (i, derive) in derives.iter().enumerate() {
318 extend_errors!(
319 result,
320 syn::Error::new(
321 derive.span(),
322 format!("`derive` appears more than once (occurrence #{})", i + 1)
323 )
324 );
325 }
326 }
327
328 result
329 }
330}
331
332#[proc_macro_attribute]
342#[cfg(feature = "serde1")]
343pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
344 let mut gen: proc_macro2::TokenStream = quote! {
345 #[derive(::lrcall::serde::Serialize, ::lrcall::serde::Deserialize)]
346 #[serde(crate = "::lrcall::serde")]
347 };
348 gen.extend(proc_macro2::TokenStream::from(item));
349 proc_macro::TokenStream::from(gen)
350}
351
352fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
353 rpcs.iter()
354 .map(|rpc| {
355 rpc.attrs
356 .iter()
357 .filter(|att| {
358 att.style == AttrStyle::Outer
359 && match &att.meta {
360 syn::Meta::List(syn::MetaList { path, .. }) => {
361 path.get_ident() == Some(&Ident::new("cfg", rpc.ident.span()))
362 }
363 _ => false,
364 }
365 })
366 .collect::<Vec<_>>()
367 })
368 .collect::<Vec<_>>()
369}
370
371#[proc_macro_attribute]
417pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
418 let derive_meta = parse_macro_input!(attr as DeriveMeta);
419 let unit_type: &Type = &parse_quote!(());
420 let Service {
421 ref attrs,
422 ref vis,
423 ref ident,
424 ref rpcs,
425 } = parse_macro_input!(input as Service);
426
427 let camel_case_fn_names: &Vec<_> = &rpcs
428 .iter()
429 .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
430 .collect();
431 let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
432
433 let derives = match derive_meta.derive.as_ref() {
434 Some(Derive::Explicit(paths)) => {
435 if !paths.is_empty() {
436 Some(quote! {
437 #[derive(
438 #(
439 #paths
440 ),*
441 )]
442 })
443 } else {
444 None
445 }
446 }
447 Some(Derive::Serde(serde)) => {
448 if *serde {
449 Some(quote! {
450 #[derive(::lrcall::serde::Serialize, ::lrcall::serde::Deserialize)]
451 #[serde(crate = "::lrcall::serde")]
452 })
453 } else {
454 None
455 }
456 }
457 None => {
458 if cfg!(feature = "serde1") {
459 Some(quote! {
460 #[derive(::lrcall::serde::Serialize, ::lrcall::serde::Deserialize)]
461 #[serde(crate = "::lrcall::serde")]
462 })
463 } else {
464 None
465 }
466 }
467 };
468
469 let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
470 let request_names = methods
471 .iter()
472 .map(|m| format!("{ident}.{m}"))
473 .collect::<Vec<_>>();
474
475 let code = ServiceGenerator {
476 service_ident: ident,
477 service_unimplemented_ident: &format_ident!("Unimpl{}", ident),
478 client_stub_ident: &format_ident!("{}RpcStub", ident),
479 channel_ident: &format_ident!("{}Channel", ident),
480 server_ident: &format_ident!("Serve{}", ident),
481 client_ident: &format_ident!("{}Client", ident),
482 request_ident: &format_ident!("{}Request", ident),
483 response_ident: &format_ident!("{}Response", ident),
484 vis,
485 args,
486 method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
487 method_cfgs: &collect_cfg_attrs(rpcs),
488 method_idents: &methods,
489 request_names: &request_names,
490 attrs,
491 rpcs,
492 return_types: &rpcs
493 .iter()
494 .map(|rpc| match rpc.output {
495 ReturnType::Type(_, ref ty) => ty.as_ref(),
496 ReturnType::Default => unit_type,
497 })
498 .collect::<Vec<_>>(),
499 arg_pats: &args
500 .iter()
501 .map(|args| args.iter().map(|arg| &*arg.pat).collect())
502 .collect::<Vec<_>>(),
503 camel_case_idents: &rpcs
504 .iter()
505 .zip(camel_case_fn_names.iter())
506 .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
507 .collect::<Vec<_>>(),
508 derives: derives.as_ref(),
509 warnings: &derive_meta.warnings,
510 }
511 .into_token_stream();
512 if env::var(ENV_LRCALL_MACRO_PRINT).map_or(false, |v| v == "1") {
513 println!("{}", code.to_string());
514 }
515 code.into()
516}
517
518struct ServiceGenerator<'a> {
521 service_ident: &'a Ident,
522 service_unimplemented_ident: &'a Ident,
523 client_stub_ident: &'a Ident,
524 channel_ident: &'a Ident,
525 server_ident: &'a Ident,
526 client_ident: &'a Ident,
527 request_ident: &'a Ident,
528 response_ident: &'a Ident,
529 vis: &'a Visibility,
530 attrs: &'a [Attribute],
531 rpcs: &'a [RpcMethod],
532 camel_case_idents: &'a [Ident],
533 method_idents: &'a [&'a Ident],
534 request_names: &'a [String],
535 method_attrs: &'a [&'a [Attribute]],
536 method_cfgs: &'a [Vec<&'a Attribute>],
537 args: &'a [&'a [PatType]],
538 return_types: &'a [&'a Type],
539 arg_pats: &'a [Vec<&'a Pat>],
540 derives: Option<&'a TokenStream2>,
541 warnings: &'a [TokenStream2],
542}
543
544impl<'a> ServiceGenerator<'a> {
545 fn trait_service(&self) -> TokenStream2 {
546 let &Self {
547 service_unimplemented_ident,
548 channel_ident,
549 attrs,
550 rpcs,
551 vis,
552 return_types,
553 service_ident,
554 client_stub_ident,
555 request_ident,
556 response_ident,
557 server_ident,
558 ..
559 } = self;
560
561 let rpc_fns = rpcs.iter().zip(return_types.iter()).map(|(RpcMethod { attrs, ident, args, .. }, output)| {
562 quote! {
563 #( #attrs )*
564 async fn #ident(self, context: ::lrcall::context::Context, #( #args ),*) -> #output;
565 }
566 });
567
568 let unimplemented_rpc_fns = rpcs.iter().zip(return_types.iter()).map(|(RpcMethod { attrs, ident, args, .. }, output)| {
569 quote! {
570 #( #attrs )*
571 #[allow(unused_variables)]
572 async fn #ident(self, context: ::lrcall::context::Context, #( #args ),*) -> #output {
573 unimplemented!()
574 }
575 }
576 });
577
578 let stub_doc = format!("The stub trait for service [`{service_ident}`].");
579 let channel_doc = format!("The default {client_stub_ident} implementation.\nUsage: `{channel_ident}::spawn(config, transport)`");
580 quote! {
581 #( #attrs )*
582 #[allow(async_fn_in_trait)]
583 #vis trait #service_ident: ::core::marker::Sized {
584 #( #rpc_fns )*
585
586 fn serve(self) -> #server_ident<Self> {
589 #server_ident { service: self }
590 }
591 }
592
593 #[derive(Debug,Clone,Copy)]
594 #vis struct #service_unimplemented_ident;
595
596 impl #service_ident for #service_unimplemented_ident {
597 #( #unimplemented_rpc_fns )*
598 }
599
600 #[doc = #stub_doc]
601 #vis trait #client_stub_ident: ::lrcall::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
602 }
603
604 impl<S> #client_stub_ident for S
605 where S: ::lrcall::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
606 {
607 }
608
609 #[doc = #channel_doc]
610 #vis type #channel_ident = ::lrcall::client::Channel<#request_ident, #response_ident>;
611 }
612 }
613
614 fn struct_server(&self) -> TokenStream2 {
615 let &Self {
616 vis, server_ident, ..
617 } = self;
618
619 quote! {
620 #[derive(Clone)]
622 #vis struct #server_ident<S> {
623 service: S,
624 }
625 }
626 }
627
628 fn impl_serve_for_server(&self) -> TokenStream2 {
629 let &Self {
630 request_ident,
631 server_ident,
632 service_ident,
633 response_ident,
634 camel_case_idents,
635 arg_pats,
636 method_idents,
637 method_cfgs,
638 ..
639 } = self;
640
641 quote! {
642 impl<S> ::lrcall::server::Serve for #server_ident<S>
643 where S: #service_ident
644 {
645 type Req = #request_ident;
646 type Resp = #response_ident;
647
648
649 async fn serve(self, ctx: ::lrcall::context::Context, req: #request_ident)
650 -> ::core::result::Result<#response_ident, ::lrcall::ServerError> {
651 match req {
652 #(
653 #( #method_cfgs )*
654 #request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
655 ::core::result::Result::Ok(#response_ident::#camel_case_idents(
656 #service_ident::#method_idents(
657 self.service, ctx, #( #arg_pats ),*
658 ).await
659 ))
660 }
661 )*
662 }
663 }
664 }
665 }
666 }
667
668 fn enum_request(&self) -> TokenStream2 {
669 let &Self {
670 derives,
671 vis,
672 request_ident,
673 camel_case_idents,
674 args,
675 request_names,
676 method_cfgs,
677 ..
678 } = self;
679
680 quote! {
681 #[allow(missing_docs)]
683 #[derive(Debug)]
684 #derives
685 #vis enum #request_ident {
686 #(
687 #( #method_cfgs )*
688 #camel_case_idents{ #( #args ),* }
689 ),*
690 }
691 impl ::lrcall::RequestName for #request_ident {
692 fn name(&self) -> &'static str {
693 match self {
694 #(
695 #( #method_cfgs )*
696 #request_ident::#camel_case_idents{..} => {
697 #request_names
698 }
699 )*
700 }
701 }
702 }
703 }
704 }
705
706 fn enum_response(&self) -> TokenStream2 {
707 let &Self {
708 derives,
709 vis,
710 response_ident,
711 camel_case_idents,
712 return_types,
713 ..
714 } = self;
715
716 quote! {
717 #[allow(missing_docs)]
719 #[derive(Debug)]
720 #derives
721 #vis enum #response_ident {
722 #( #camel_case_idents(#return_types) ),*
723 }
724 }
725 }
726
727 fn struct_client(&self) -> TokenStream2 {
728 let &Self {
729 service_unimplemented_ident,
730 channel_ident,
731 vis,
732 client_ident,
733 ..
734 } = self;
735
736 quote! {
737 #[allow(unused, private_interfaces)]
738 #[derive(Clone, Debug)]
739 #vis struct #client_ident<L=#service_unimplemented_ident, R=#channel_ident> {
742 lpc_service: ::core::option::Option<L>,
743 rpc_stub: ::core::option::Option<R>,
744 }
745 }
746 }
747
748 fn impl_client_new(&self) -> TokenStream2 {
749 let &Self {
750 service_ident,
751 client_stub_ident,
752 client_ident,
753 vis,
754 ..
755 } = self;
756
757 let code = quote! {
758 impl<L, R> #client_ident<L, R>
759 where
760 L: #service_ident + ::core::clone::Clone,
761 R: #client_stub_ident,
762 {
763 #vis fn full_client(lpc_service: L, rpc_stub: R) -> Self {
765 Self {
766 lpc_service: ::core::option::Option::Some(lpc_service),
767 rpc_stub: ::core::option::Option::Some(rpc_stub),
768 }
769 }
770
771 #vis fn lpc_client(lpc_service: L) -> Self {
773 Self {
774 lpc_service: ::core::option::Option::Some(lpc_service),
775 rpc_stub: ::core::option::Option::None,
776 }
777 }
778
779 #vis fn rpc_client(rpc_stub: R) -> Self {
781 Self {
782 lpc_service: ::core::option::Option::None,
783 rpc_stub: ::core::option::Option::Some(rpc_stub),
784 }
785 }
786 }
787 };
788 code
789 }
790
791 fn impl_client_methods(&self) -> TokenStream2 {
792 let &Self {
793 service_ident,
794 client_stub_ident,
795 client_ident,
796 request_ident,
797 response_ident,
798 method_attrs,
799 vis,
800 method_idents,
801 args,
802 return_types,
803 arg_pats,
804 camel_case_idents,
805 ..
806 } = self;
807
808 let code = quote! {
809 impl<L, R> #client_ident<L, R>
810 where
811 L: #service_ident + ::core::clone::Clone,
812 R: #client_stub_ident,
813 {
814 #(
815 #[allow(unused)]
816 #( #method_attrs )*
817 #vis async fn #method_idents(&self, ctx: ::lrcall::context::Context, #( #args ),*)
818 -> ::core::result::Result<#return_types, ::lrcall::client::RpcError> {
819 match ctx.call_type {
820 ::lrcall::context::CallType::LPC => {
821 if let ::core::option::Option::Some(lpc_service) = &self.lpc_service {
822 return ::core::result::Result::Ok(lpc_service.clone().#method_idents(ctx, #( #arg_pats ),*).await);
823 }
824 },
825 ::lrcall::context::CallType::RPC => {
826 if let ::core::option::Option::Some(rpc_stub) = &self.rpc_stub {
827 let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
828 let resp = rpc_stub.call(ctx, request);
829 return match resp.await? {
830 #response_ident::#camel_case_idents(msg) => ::core::result::Result::Ok(msg),
831 _ => ::core::unreachable!(),
832 };
833 }
834 },
835 }
836 return ::core::result::Result::Err(::lrcall::client::RpcError::ClientUnconfigured(ctx.call_type));
837 }
838 )*
839 }
840 };
841 code
842 }
843
844 fn emit_warnings(&self) -> TokenStream2 {
845 self.warnings.iter().map(|w| w.to_token_stream()).collect()
846 }
847}
848
849impl<'a> ToTokens for ServiceGenerator<'a> {
850 fn to_tokens(&self, output: &mut TokenStream2) {
851 output.extend(vec![
852 self.trait_service(),
853 self.struct_server(),
854 self.impl_serve_for_server(),
855 self.enum_request(),
856 self.enum_response(),
857 self.struct_client(),
858 self.impl_client_new(),
859 self.impl_client_methods(),
860 self.emit_warnings(),
861 ]);
862 }
863}
864
865fn snake_to_camel(ident_str: &str) -> String {
866 let mut camel_ty = String::with_capacity(ident_str.len());
867
868 let mut last_char_was_underscore = true;
869 for c in ident_str.chars() {
870 match c {
871 '_' => last_char_was_underscore = true,
872 c if last_char_was_underscore => {
873 camel_ty.extend(c.to_uppercase());
874 last_char_was_underscore = false;
875 }
876 c => camel_ty.extend(c.to_lowercase()),
877 }
878 }
879
880 camel_ty.shrink_to_fit();
881 camel_ty
882}
883
884#[test]
885fn snake_to_camel_basic() {
886 assert_eq!(snake_to_camel("abc_def"), "AbcDef");
887}
888
889#[test]
890fn snake_to_camel_underscore_suffix() {
891 assert_eq!(snake_to_camel("abc_def_"), "AbcDef");
892}
893
894#[test]
895fn snake_to_camel_underscore_prefix() {
896 assert_eq!(snake_to_camel("_abc_def"), "AbcDef");
897}
898
899#[test]
900fn snake_to_camel_underscore_consecutive() {
901 assert_eq!(snake_to_camel("abc__def"), "AbcDef");
902}
903
904#[test]
905fn snake_to_camel_capital_in_middle() {
906 assert_eq!(snake_to_camel("aBc_dEf"), "AbcDef");
907}