arrpc_derive/
lib.rs

1#[cfg(feature = "obake")]
2mod obake;
3mod util;
4
5use convert_case::{Case, Casing};
6use itertools::Itertools;
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use proc_macro_error::proc_macro_error;
10use quote::quote;
11use syn::{
12    parse_macro_input, parse_quote, Arm, FnArg, Ident, ItemEnum, ItemImpl, ItemTrait, Meta,
13    ReturnType, TraitItem, TraitItemFn, Type, Variant,
14};
15
16type FlagProcessor = fn(ArrpcImpls) -> ArrpcImpls;
17
18const PROC_VAR: &str = "proc";
19
20#[proc_macro_error]
21#[proc_macro_attribute]
22pub fn arrpc_service(attr: TokenStream, item: TokenStream) -> TokenStream {
23    let flag_processors = {
24        let mut processors: Vec<FlagProcessor> = Vec::new();
25
26        #[cfg(feature = "obake")]
27        processors.push(obake::processor);
28
29        processors
30    };
31
32    let original_trait = parse_macro_input!(item as ItemTrait);
33    let mut svc_trait = original_trait.clone();
34    let svc_name = &svc_trait.ident;
35
36    let proc_name = Ident::new(format!("{svc_name}Proc").as_str(), Span::call_site());
37
38    let mut proc_variants = Vec::new();
39
40    for item in svc_trait.items.iter_mut() {
41        if let TraitItem::Fn(trait_fn) = item {
42            trait_fn.sig.output = wrap_with_arrpc_result(&trait_fn.sig.output);
43
44            let proc = create_proc_variant(trait_fn);
45
46            let proc_match = match_for_proc_variant(&proc, &proc_name, &trait_fn.sig.ident);
47
48            let impl_fn = create_client_impl(&proc, &trait_fn, &proc_name);
49
50            let proc_variant = ProcVariant {
51                variant: proc,
52                svc_match_stmt: proc_match,
53                client_impl: impl_fn,
54            };
55
56            proc_variants.push(proc_variant);
57        }
58    }
59
60    let procs = proc_variants.iter().map(|proc| &proc.variant).collect_vec();
61    let proc_enum: ItemEnum = parse_quote! {
62        #[derive(serde::Serialize, serde::Deserialize)]
63        enum #proc_name {
64            #(#procs),*
65        }
66    };
67
68    let proc_var = proc_var_ident();
69
70    // Create arrpc_service impl
71    let svc_impl = parse_macro_input!(attr as Type);
72    let proc_matches = proc_variants
73        .iter()
74        .map(|proc| &proc.svc_match_stmt)
75        .collect_vec();
76    let arrpc_svc_impl: ItemImpl = parse_quote! {
77        #[async_trait::async_trait]
78        impl arrpc_core::Service for #svc_impl {
79            async fn accept<R>(&self, req: R) -> Result<R::Response>
80            where
81                R: arrpc_core::Request + Send + Sync,
82            {
83                let #proc_var: #proc_name = req.proc()?;
84                match #proc_var {
85                    #(#proc_matches),*
86                }
87            }
88        }
89    };
90
91    // Impl user svc for UniversalClient
92    let async_attr = svc_trait.attrs.iter().find(|attr| match &attr.meta {
93        Meta::Path(path) => path
94            .segments
95            .iter()
96            .any(|segment| segment.ident == "async_trait"),
97        _ => false,
98    });
99
100    let fn_impls = proc_variants
101        .iter()
102        .map(|proc| &proc.client_impl)
103        .collect_vec();
104    let unv_client_impl: ItemImpl = parse_quote! {
105        #async_attr
106        impl<T> #svc_name for arrpc_core::UniversalClient<T>
107            where T: arrpc_core::ClientContract + Send + Sync
108        {
109            #(#fn_impls)*
110        }
111    };
112
113    let mut impls = ArrpcImpls {
114        updated_trait: svc_trait,
115        proc_enum,
116        svc_impl: arrpc_svc_impl,
117        client_impl: unv_client_impl,
118        extras: Vec::new(),
119    };
120
121    for processor in flag_processors {
122        impls = processor(impls);
123    }
124
125    impls.into()
126}
127
128fn proc_var_ident() -> Ident {
129    Ident::new(PROC_VAR, Span::call_site())
130}
131
132fn wrap_with_arrpc_result(old: &ReturnType) -> ReturnType {
133    let ret_type = match old {
134        ReturnType::Default => quote!(()),
135        ReturnType::Type(_, ret_type) => quote!(#ret_type),
136    };
137
138    let replacement_ret: ReturnType = parse_quote! {
139        -> arrpc_core::Result<#ret_type>
140    };
141
142    replacement_ret
143}
144
145fn create_proc_variant(trait_fn: &TraitItemFn) -> Variant {
146    let fn_name = &trait_fn.sig.ident;
147    let name = proc_name_for_fn(fn_name.to_string().as_str());
148    let name: Ident = Ident::new(name.as_str(), Span::call_site());
149    let args = trait_fn
150        .sig
151        .inputs
152        .iter()
153        .filter(|input| matches!(**input, FnArg::Typed(_)));
154    let proc: Variant = parse_quote! {
155        #name {
156             #(#args),*
157        }
158    };
159
160    proc
161}
162
163fn proc_name_for_fn(fn_name: &str) -> String {
164    fn_name.from_case(Case::Snake).to_case(Case::Pascal)
165}
166
167fn match_for_proc_variant(proc_variant: &Variant, proc_name: &Ident, fn_name: &Ident) -> Arm {
168    let args = proc_variant
169        .fields
170        .iter()
171        .filter_map(|field| field.ident.as_ref())
172        .collect_vec();
173    let name = &proc_variant.ident;
174
175    parse_quote!(#proc_name::#name{#(#args),*} => req.respond(self.#fn_name(#(#args),*).await?))
176}
177
178fn create_client_impl(
179    proc_variant: &Variant,
180    trait_fn: &TraitItemFn,
181    proc_name: &Ident,
182) -> TraitItemFn {
183    let TraitItemFn { sig, .. } = trait_fn;
184    let name = &proc_variant.ident;
185    let args = proc_variant
186        .fields
187        .iter()
188        .filter_map(|field| field.ident.as_ref())
189        .collect_vec();
190    let proc_var = proc_var_ident();
191    parse_quote! {
192        #sig {
193            let #proc_var = #proc_name::#name{#(#args),*};
194            self.0
195                .send(#proc_var)
196                .await
197        }
198    }
199}
200
201struct ArrpcImpls {
202    pub updated_trait: ItemTrait,
203    pub proc_enum: ItemEnum,
204    pub svc_impl: ItemImpl,
205    pub client_impl: ItemImpl,
206    pub extras: Vec<proc_macro2::TokenStream>,
207}
208
209impl From<ArrpcImpls> for TokenStream {
210    fn from(value: ArrpcImpls) -> Self {
211        let ArrpcImpls {
212            updated_trait,
213            proc_enum,
214            svc_impl,
215            client_impl,
216            extras,
217            ..
218        } = value;
219
220        quote! {
221            #updated_trait
222
223            #proc_enum
224
225            #svc_impl
226
227            #client_impl
228
229            #(#extras)*
230        }
231        .into()
232    }
233}
234
235struct ProcVariant {
236    variant: Variant,
237    svc_match_stmt: Arm,
238    client_impl: TraitItemFn,
239}