crows_service/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use syn::parse::{Parse, ParseStream};
4
5extern crate proc_macro;
6extern crate proc_macro2;
7extern crate quote;
8extern crate syn;
9
10use quote::quote;
11use syn::spanned::Spanned;
12use syn::token::{Comma, Mut};
13use syn::{
14    braced, parenthesized, parse_macro_input, parse_quote, Attribute, FnArg, Ident, LitStr, Pat,
15    PatType, Result, ReturnType, Token, Type, Visibility,
16};
17
18/// Accumulates multiple errors into a result.
19/// Only use this for recoverable errors, i.e. non-parse errors. Fatal errors should early exit to
20/// avoid further complications.
21macro_rules! extend_errors {
22    ($errors: ident, $e: expr) => {
23        match $errors {
24            Ok(_) => $errors = Err($e),
25            Err(ref mut errors) => errors.extend($e),
26        }
27    };
28}
29
30#[allow(dead_code)]
31#[derive(Debug)]
32struct ServiceMacroInput {
33    attrs: Vec<Attribute>,
34    vis: Visibility,
35    ident: Ident,
36    methods: Vec<Method>,
37}
38
39#[allow(dead_code)]
40#[derive(Debug)]
41struct Method {
42    attrs: Vec<Attribute>,
43    ident: Ident,
44    args: Vec<PatType>,
45    output: ReturnType,
46    receiver: bool,
47    receiver_mutability: Option<Mut>,
48}
49
50impl Parse for Method {
51    fn parse(input: ParseStream) -> syn::Result<Self> {
52        let attrs = input.call(Attribute::parse_outer)?;
53
54        input.parse::<Token![async]>()?;
55
56        input.parse::<Token![fn]>()?;
57        let ident = input.parse()?;
58        let content;
59        parenthesized!(content in input);
60        let mut args = Vec::new();
61        let mut errors = Ok(());
62        let mut receiver = false;
63        let mut receiver_mutability = None;
64        for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
65            match arg {
66                FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => {
67                    args.push(captured);
68                }
69                FnArg::Typed(captured) => {
70                    extend_errors!(
71                        errors,
72                        syn::Error::new(captured.pat.span(), "patterns aren't allowed in RPC args")
73                    );
74                }
75                FnArg::Receiver(r) => {
76                    receiver = true;
77                    receiver_mutability = r.mutability.clone();
78                }
79            }
80        }
81        errors?;
82        let output = input.parse()?;
83        input.parse::<Token![;]>()?;
84
85        Ok(Self {
86            attrs,
87            ident,
88            args,
89            output,
90            receiver,
91            receiver_mutability,
92        })
93    }
94}
95
96impl Parse for ServiceMacroInput {
97    fn parse(input: ParseStream) -> Result<Self> {
98        let attrs = input.call(Attribute::parse_outer)?;
99        let vis: Visibility = input.parse()?;
100        input.parse::<Token![trait]>()?;
101        let ident: Ident = input.parse()?;
102
103        let mut methods = Vec::<Method>::new();
104
105        let content;
106        braced!(content in input);
107        while !content.is_empty() {
108            methods.push(content.parse()?);
109        }
110
111        Ok(Self {
112            attrs,
113            vis,
114            ident,
115            methods,
116        })
117    }
118}
119
120struct AttrsInput {
121    other_side: Ident,
122    variant: String,
123}
124
125impl Parse for AttrsInput {
126    fn parse(input: ParseStream) -> Result<Self> {
127        let mut other_side: Option<Ident> = None;
128        let mut variant: Option<String> = None;
129
130        while !input.is_empty() {
131            let ident: Ident = input.parse()?;
132            if ident == "other_side" {
133                input.parse::<Token![=]>()?;
134                other_side = Some(input.parse()?);
135            } else if ident == "variant" {
136                input.parse::<Token![=]>()?;
137                let lit: LitStr = input.parse()?;
138                variant = Some(lit.value());
139            } else {
140                return Err(syn::Error::new_spanned(ident, "Unexpected identifier"));
141            }
142
143            // Allow multiple attrs separated by comma.
144            if !input.is_empty() {
145                input.parse::<Token![,]>()?;
146            }
147        }
148
149        Ok(AttrsInput {
150            other_side: other_side.unwrap(),
151            variant: variant.unwrap(),
152        })
153    }
154}
155
156#[proc_macro_attribute]
157pub fn service(attr: TokenStream, original_input: TokenStream) -> TokenStream {
158    let attrs = parse_macro_input!(attr as AttrsInput);
159
160    let derive = quote! {
161        #[derive(Debug, utils::serde::Serialize, utils::serde::Deserialize, Clone)]
162    };
163
164    let cloned = original_input.clone();
165    let input = parse_macro_input!(cloned as ServiceMacroInput);
166    let unit_type: &Type = &parse_quote!(());
167
168    let ident = input.ident;
169    let request_ident = Ident::new(&format!("{}Request", ident), ident.span());
170    let response_ident = Ident::new(&format!("{}Response", ident), ident.span());
171    let message_ident = Ident::new(&format!("{}Message", ident), ident.span());
172    let dummy_ident = Ident::new(&format!("Dummy{}Service", ident), ident.span());
173    let client_ident = Ident::new(&format!("{}Client", ident), ident.span());
174    let mut requests_variants = Vec::new();
175    let mut requests_structs = Vec::new();
176    let mut response_variants = Vec::new();
177    let mut client_methods = Vec::new();
178    let mut service_match_arms = Vec::new();
179
180    let snake_ident = ident.to_string().to_case(Case::Snake);
181    let variant = attrs.variant;
182    #[allow(unused)]
183    let create_named_variant_ident =
184        Ident::new(&format!("create_{snake_ident}_{variant}"), ident.span());
185    // let create_variant_ident = Ident::new(&format!("create_{variant}"), ident.span());
186    let other_side = attrs.other_side;
187    let other_side_client_ident = Ident::new(
188        &format!("{}Client", other_side.to_string()),
189        other_side.span(),
190    );
191
192    let server_or_client_fn = if &variant == "server" {
193        let server_ident = Ident::new(&format!("{}Server", ident), ident.span());
194        quote! {
195            pub struct #server_ident {
196                server: utils::Server
197            }
198
199            impl #server_ident {
200                pub async fn accept<T>(&self, service: T) -> Option<<T as utils::Service<#dummy_ident>>::Client>
201                where T: utils::Service<#dummy_ident> + Clone + 'static {
202                    let (sender, receiver, close_receiver) = self.server.accept().await?;
203                    let client = utils::Client::new(sender, receiver, service, Some(close_receiver));
204                    Some(client)
205                }
206            }
207
208            pub async fn #create_named_variant_ident<A>(addr: A)
209                -> Result<#server_ident, std::io::Error>
210                where
211                    A: utils::tokio::net::ToSocketAddrs
212            {
213                let server = utils::create_server(addr).await?;
214                Ok(#server_ident { server })
215            }
216        }
217    } else {
218        let other_side_snake = other_side.to_string().to_case(Case::Snake);
219        let connect_to_ident = Ident::new(&format!("connect_to_{other_side_snake}"), ident.span());
220        quote! {
221            pub async fn #connect_to_ident<A, T>(addr: A, service: T)
222                -> Result<<T as utils::Service<#dummy_ident>>::Client, std::io::Error>
223                where
224                    A: utils::tokio::net::ToSocketAddrs,
225                    T: utils::Service<#dummy_ident> + Clone + 'static,
226            {
227                let (sender, mut receiver) = utils::create_client(addr).await?;
228                let client = utils::Client::new(sender, receiver, service, None);
229                Ok(client)
230            }
231        }
232    };
233
234    let mut trait_methods = Vec::new();
235
236    for method in input.methods {
237        let receiver = quote! { &self };
238
239        let pascal = method.ident.to_string().to_case(Case::Pascal);
240        let method_ident = method.ident.clone();
241        let method_request_ident =
242            Ident::new(&format!("{}{}Request", ident, pascal), method.ident.span());
243        let request_variant = quote! {
244            #method_request_ident(#method_request_ident)
245        };
246        requests_variants.push(request_variant);
247
248        let method_response_ident =
249            Ident::new(&format!("{}{}Response", ident, pascal), method.ident.span());
250        let return_ty = match method.output {
251            ReturnType::Default => unit_type,
252            ReturnType::Type(_, ref ty) => ty,
253        };
254        response_variants.push(quote! {
255            #method_response_ident(#return_ty)
256        });
257
258        let mut args = Vec::new();
259        let mut arg_names: Vec<Ident> = Vec::new();
260        for arg in method.args.clone() {
261            let ident = match *arg.pat {
262                Pat::Ident(ident) => ident.ident,
263                _ => unreachable!(),
264            };
265            arg_names.push(ident.clone());
266            let ty = arg.ty;
267            args.push(quote! {
268                #ident: #ty
269            });
270        }
271        requests_structs.push(quote! {
272            #derive
273            pub struct #method_request_ident {
274                #(#args),*
275            }
276        });
277
278        let args = method.args;
279        let output_ty = match method.output {
280            ReturnType::Type(_, ref t) => t,
281            ReturnType::Default => unit_type,
282        };
283
284        client_methods.push(quote! {
285            // TODO: this should not be anyhow, but rather io::error or sth along the lines
286            pub async fn #method_ident(#receiver, #(#args),*) -> anyhow::Result<#output_ty> {
287                let response = self.client
288                    .request::<#request_ident, #response_ident>(#request_ident::#method_request_ident(
289                        #method_request_ident { #(#arg_names),* },
290                    )).await?;
291
292                Ok(match response {
293                    #response_ident::#method_response_ident(r) => r,
294                    _ => unreachable!()
295                })
296            }
297        });
298
299        // TODO: I don't have too much time at the moment and I want to finish a few other things,
300        // so I made a simplification in the services code at a cost of a slightly worse experience
301        // when implementing services. Now we *always* pass a client to service handler methods as
302        // a first argument. Ideally we would check if the client is defined as an argument and
303        // pass it only when needed.
304        trait_methods.push(quote! {
305            fn #method_ident(#receiver, client: #other_side_client_ident, #(#args),*) -> impl std::future::Future<Output = #output_ty> + Send;
306        });
307
308        service_match_arms.push(quote! {
309            #request_ident::#method_request_ident(request) => #response_ident::#method_response_ident(self.#method_ident(client, #(request.#arg_names),*).await),
310        });
311    }
312
313    let mut result = quote! {
314        pub trait #ident {
315            #(#trait_methods)*
316        }
317    };
318
319    let impl_service = quote! {
320        impl<T> utils::Service<#dummy_ident> for T
321        where T: #ident + Send + Sync {
322            type Request = #request_ident;
323            type Response = #response_ident;
324            type Client = #other_side_client_ident;
325
326            fn handle_request(
327                &self,
328                client: Self::Client,
329                message: Self::Request,
330            ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self::Response> + Send + '_>> {
331                Box::pin(async {
332                    match message {
333                        #(#service_match_arms)*
334                    }
335                })
336            }
337        }
338    };
339
340    let get_close_receiver = quote! {
341        pub async fn get_close_receiver(&self) -> Option<tokio::sync::oneshot::Receiver<()>> {
342                self.client.get_close_receiver().await
343            }
344    };
345
346    let generated = quote! {
347        pub struct #dummy_ident;
348
349        #derive
350        pub enum #request_ident {
351            #(#requests_variants),*
352        }
353
354        #derive
355        pub enum #response_ident {
356            #(#response_variants),*
357        }
358
359        #(#requests_structs)*
360
361        #derive
362        pub enum #message_ident {
363            Request(#request_ident),
364            Response(#response_ident),
365        }
366
367        #[derive(Clone)]
368        pub struct #client_ident {
369            // TODO: this should be prefixed
370            client: utils::Client
371        }
372
373        impl utils::ClientTrait for #client_ident {
374            fn new(client: utils::Client) -> Self {
375                Self {
376                    client
377                }
378            }
379        }
380
381        impl #client_ident {
382            pub async fn wait(&self) {
383                self.client.wait().await;
384            }
385
386            #get_close_receiver
387
388            #(#client_methods)*
389        }
390
391        #impl_service
392
393        #server_or_client_fn
394    };
395    result.extend(generated);
396    TokenStream::from(result)
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn it_works() {
405        // let result = add(2, 2);
406        // assert_eq!(result, 4);
407    }
408}