1#[cfg(feature = "obake")]
2mod obake;
3mod util;
4
5use convert_case::{Case, Casing};
6use itertools::Itertools;
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use proc_macro_error::proc_macro_error;
10use quote::quote;
11use syn::{
12 parse_macro_input, parse_quote, Arm, FnArg, Ident, ItemEnum, ItemImpl, ItemTrait, Meta,
13 ReturnType, TraitItem, TraitItemFn, Type, Variant,
14};
15
16type FlagProcessor = fn(ArrpcImpls) -> ArrpcImpls;
17
18const PROC_VAR: &str = "proc";
19
20#[proc_macro_error]
21#[proc_macro_attribute]
22pub fn arrpc_service(attr: TokenStream, item: TokenStream) -> TokenStream {
23 let flag_processors = {
24 let mut processors: Vec<FlagProcessor> = Vec::new();
25
26 #[cfg(feature = "obake")]
27 processors.push(obake::processor);
28
29 processors
30 };
31
32 let original_trait = parse_macro_input!(item as ItemTrait);
33 let mut svc_trait = original_trait.clone();
34 let svc_name = &svc_trait.ident;
35
36 let proc_name = Ident::new(format!("{svc_name}Proc").as_str(), Span::call_site());
37
38 let mut proc_variants = Vec::new();
39
40 for item in svc_trait.items.iter_mut() {
41 if let TraitItem::Fn(trait_fn) = item {
42 trait_fn.sig.output = wrap_with_arrpc_result(&trait_fn.sig.output);
43
44 let proc = create_proc_variant(trait_fn);
45
46 let proc_match = match_for_proc_variant(&proc, &proc_name, &trait_fn.sig.ident);
47
48 let impl_fn = create_client_impl(&proc, &trait_fn, &proc_name);
49
50 let proc_variant = ProcVariant {
51 variant: proc,
52 svc_match_stmt: proc_match,
53 client_impl: impl_fn,
54 };
55
56 proc_variants.push(proc_variant);
57 }
58 }
59
60 let procs = proc_variants.iter().map(|proc| &proc.variant).collect_vec();
61 let proc_enum: ItemEnum = parse_quote! {
62 #[derive(serde::Serialize, serde::Deserialize)]
63 enum #proc_name {
64 #(#procs),*
65 }
66 };
67
68 let proc_var = proc_var_ident();
69
70 let svc_impl = parse_macro_input!(attr as Type);
72 let proc_matches = proc_variants
73 .iter()
74 .map(|proc| &proc.svc_match_stmt)
75 .collect_vec();
76 let arrpc_svc_impl: ItemImpl = parse_quote! {
77 #[async_trait::async_trait]
78 impl arrpc_core::Service for #svc_impl {
79 async fn accept<R>(&self, req: R) -> Result<R::Response>
80 where
81 R: arrpc_core::Request + Send + Sync,
82 {
83 let #proc_var: #proc_name = req.proc()?;
84 match #proc_var {
85 #(#proc_matches),*
86 }
87 }
88 }
89 };
90
91 let async_attr = svc_trait.attrs.iter().find(|attr| match &attr.meta {
93 Meta::Path(path) => path
94 .segments
95 .iter()
96 .any(|segment| segment.ident == "async_trait"),
97 _ => false,
98 });
99
100 let fn_impls = proc_variants
101 .iter()
102 .map(|proc| &proc.client_impl)
103 .collect_vec();
104 let unv_client_impl: ItemImpl = parse_quote! {
105 #async_attr
106 impl<T> #svc_name for arrpc_core::UniversalClient<T>
107 where T: arrpc_core::ClientContract + Send + Sync
108 {
109 #(#fn_impls)*
110 }
111 };
112
113 let mut impls = ArrpcImpls {
114 updated_trait: svc_trait,
115 proc_enum,
116 svc_impl: arrpc_svc_impl,
117 client_impl: unv_client_impl,
118 extras: Vec::new(),
119 };
120
121 for processor in flag_processors {
122 impls = processor(impls);
123 }
124
125 impls.into()
126}
127
128fn proc_var_ident() -> Ident {
129 Ident::new(PROC_VAR, Span::call_site())
130}
131
132fn wrap_with_arrpc_result(old: &ReturnType) -> ReturnType {
133 let ret_type = match old {
134 ReturnType::Default => quote!(()),
135 ReturnType::Type(_, ret_type) => quote!(#ret_type),
136 };
137
138 let replacement_ret: ReturnType = parse_quote! {
139 -> arrpc_core::Result<#ret_type>
140 };
141
142 replacement_ret
143}
144
145fn create_proc_variant(trait_fn: &TraitItemFn) -> Variant {
146 let fn_name = &trait_fn.sig.ident;
147 let name = proc_name_for_fn(fn_name.to_string().as_str());
148 let name: Ident = Ident::new(name.as_str(), Span::call_site());
149 let args = trait_fn
150 .sig
151 .inputs
152 .iter()
153 .filter(|input| matches!(**input, FnArg::Typed(_)));
154 let proc: Variant = parse_quote! {
155 #name {
156 #(#args),*
157 }
158 };
159
160 proc
161}
162
163fn proc_name_for_fn(fn_name: &str) -> String {
164 fn_name.from_case(Case::Snake).to_case(Case::Pascal)
165}
166
167fn match_for_proc_variant(proc_variant: &Variant, proc_name: &Ident, fn_name: &Ident) -> Arm {
168 let args = proc_variant
169 .fields
170 .iter()
171 .filter_map(|field| field.ident.as_ref())
172 .collect_vec();
173 let name = &proc_variant.ident;
174
175 parse_quote!(#proc_name::#name{#(#args),*} => req.respond(self.#fn_name(#(#args),*).await?))
176}
177
178fn create_client_impl(
179 proc_variant: &Variant,
180 trait_fn: &TraitItemFn,
181 proc_name: &Ident,
182) -> TraitItemFn {
183 let TraitItemFn { sig, .. } = trait_fn;
184 let name = &proc_variant.ident;
185 let args = proc_variant
186 .fields
187 .iter()
188 .filter_map(|field| field.ident.as_ref())
189 .collect_vec();
190 let proc_var = proc_var_ident();
191 parse_quote! {
192 #sig {
193 let #proc_var = #proc_name::#name{#(#args),*};
194 self.0
195 .send(#proc_var)
196 .await
197 }
198 }
199}
200
201struct ArrpcImpls {
202 pub updated_trait: ItemTrait,
203 pub proc_enum: ItemEnum,
204 pub svc_impl: ItemImpl,
205 pub client_impl: ItemImpl,
206 pub extras: Vec<proc_macro2::TokenStream>,
207}
208
209impl From<ArrpcImpls> for TokenStream {
210 fn from(value: ArrpcImpls) -> Self {
211 let ArrpcImpls {
212 updated_trait,
213 proc_enum,
214 svc_impl,
215 client_impl,
216 extras,
217 ..
218 } = value;
219
220 quote! {
221 #updated_trait
222
223 #proc_enum
224
225 #svc_impl
226
227 #client_impl
228
229 #(#extras)*
230 }
231 .into()
232 }
233}
234
235struct ProcVariant {
236 variant: Variant,
237 svc_match_stmt: Arm,
238 client_impl: TraitItemFn,
239}