client_handle_core/
lib.rs

1use proc_macro2::{TokenStream};
2use proc_macro_error::{abort};
3use quote::{quote, format_ident, ToTokens};
4use syn::{self, FnArg, TraitItemMethod, Ident, ReturnType, Pat, parse2};
5use convert_case::{Case, Casing};
6
7pub fn client_handle_core(_attr: TokenStream, item: TokenStream) -> TokenStream {
8    let ast = parse2(item).unwrap();
9
10    if let syn::Item::Trait(trayt) = &ast {
11        handle_trait(&Ast { trayt })
12    } else {
13        abort!(ast, "The `async_tokio_handle` macro only works on traits");
14    }
15}
16
17// Top Level struct that holds the AST for the trait
18// that we are trying to create an async handle for
19struct Ast<'a> {
20    trayt: &'a syn::ItemTrait,
21}
22
23// A wrapper struct for a mthod signature
24// Provides helpers to get the parameters in various
25// formats
26struct Method <'a> {
27    sig: &'a syn::Signature,
28}
29
30impl<'a> Ast<'a> {
31    // Get the name of the trait that is being wrapped
32    fn trait_name(&self) -> &Ident {
33        &self.trayt.ident
34    }
35
36    // Generate the name of the enum that will be used to send and
37    // receive messages
38    fn enum_name(&self) -> Ident {
39        format_ident!("Async{}Message", self.trait_name())
40    }
41
42    // Generate the name of the struct that will be used to represent
43    // the async handle
44    fn handle_name(&self) -> Ident {
45        format_ident!("Async{}Handle", self.trait_name())
46    }
47
48    // Gets the original trait that is being wrapped
49    // So that it can be output unmodified
50    fn original_trait(&self) -> TokenStream {
51        self.trayt.to_token_stream().into()
52    }
53
54    // Generate a vec of all the methods in the struct
55    // Returning a wrapped handle
56    fn methods(&self) -> Vec<Method> {
57        self.get_trait_methods()
58            .iter()
59            .map(|m| Method { sig: &m.sig })
60            .collect()
61    }
62
63    // Get all of the trait methods that must be wrapped
64    fn get_trait_methods(&'a self) -> Vec<&'a TraitItemMethod> {
65        let mut methods = Vec::new();
66        for item in &self.trayt.items {
67            match item {
68                syn::TraitItem::Method(method) => {
69                    if let Some(FnArg::Receiver(_)) = method.sig.inputs.first() {
70                        methods.push(method);
71                    }
72                },
73                _ => { panic!("Can only handle trait methods") }
74            }
75        }
76        methods
77    }
78}
79
80
81impl<'a> Method<'a> {
82    fn name(&self) -> &Ident {
83        return &self.sig.ident
84    }
85
86    fn name_pascal_case(&self) -> Ident {
87        format_ident!("{}", self.sig.ident.to_string().to_case(Case::Pascal))
88    }
89
90    fn typed_parameter_names_only(&self) -> Vec<&Pat> {
91        let mut result = Vec::new();
92        for input in &self.sig.inputs {
93            match input {
94                FnArg::Receiver(_) => {},
95                FnArg::Typed(typed) => {
96                    result.push(&*typed.pat)
97                },
98            }
99        }
100        result
101    }
102
103    fn typed_parameters(&self) -> Vec<&FnArg> {
104        self.sig.inputs
105            .iter()
106            .filter(|arg| {
107                match arg {
108                    FnArg::Receiver(_) => false,
109                    FnArg::Typed(_) => true,
110                }
111            })
112            .collect()
113    }
114
115    fn return_value_type(&self) -> proc_macro2::TokenStream {
116        match &self.sig.output {
117            ReturnType::Default => quote!{ () },
118            ReturnType::Type(_, tipe) => quote! { #tipe },
119        }
120    }
121}
122
123
124fn handle_trait(ast: &Ast) -> TokenStream {
125    let message_enum = generate_message_enum(ast);
126
127    let output = vec![
128        ast.original_trait(),
129        generate_struct(&ast),
130        message_enum,
131    ];
132
133    let mut gen: TokenStream = TokenStream::new();
134    gen.extend(output.into_iter());
135
136    gen
137}
138
139// Generates the async handle struct and supporting code.
140// 
141// A trait like:
142//
143//  ```
144//      trait MyTrait {
145//         fn do_something(param: u64) -> u64;
146//      }
147//  ```
148//
149// will be turned into:
150//
151//  ```
152//      // The struct for holding the handle
153//      #[derive(Debug)]
154//      struct AsyncMyTraitHandle {
155//          handle: tokio::sync::mpsc::Sender<AsyncMyTraitMessage>,
156//      }
157//
158//
159//       trait ToHandle {
160//           fn to_handle(self) -> AsyncMyTraitHandle;
161//       }
162//      
163//       impl<T> ToHandle for T
164//       where
165//           T: MyTrait + Sync + Send + 'static
166//       {
167//           fn to_handle(self: T) -> AsyncMyTraitHandle {
168//               AsyncMyTraitHandle::spawn(self)
169//           }
170//       }
171//      
172//       impl AsyncMyTraitHandle {
173//           pub fn new(handle: tokio::sync::mpsc::Sender<AsyncMyTraitMessage>) -> Self {
174//               Self { handle }
175//           }
176//      
177//           pub fn spawn<T>(mut sync: T) -> Self
178//           where
179//               T: MyTrait + Sync + Send + 'static
180//           {
181//               let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
182//               tokio::spawn(async move {
183//                   while let Some(msg) = rx.recv().await {
184//                       match msg {
185//                           AsyncMyTraitMessage::DoSomething { return_value, param } => {
186//                               let result = sync.do_something(param);
187//                               return_value.send(result).expect("Error calling function");
188//                           }
189//                       }
190//                   }
191//               });
192//               Self { handle: tx }
193//           }
194//      
195//           async fn do_something(&self, param: u64) -> u64 {
196//               let (return_value, response) = tokio::sync::oneshot::channel();
197//               self.handle.send(return_value, param).await.expect("Error when sending message to the sync code");
198//               response.await.expect("Error receiving the response")
199//           }
200//       }
201//  ```
202fn generate_struct(ast: &Ast) -> TokenStream {
203    let trait_name = &ast.trait_name();
204    let struct_name = &ast.handle_name();
205    let message_enum_name = &ast.enum_name();
206
207    let mut async_result = Vec::new();
208    let mut sync_result = Vec::new();
209    for method in ast.methods() {
210        let msg_name = method.name_pascal_case();
211        let parameters = method.typed_parameters();
212        let parameter_names = method.typed_parameter_names_only();
213        let method_name = method.name();
214        let return_type = method.return_value_type();
215
216        let create_enum_call = quote! {
217            #message_enum_name::#msg_name { return_value, #(#parameter_names),* }
218        };
219
220        async_result.push(quote! {
221            async fn #method_name (&self, #(#parameters),*) -> #return_type {
222                let (return_value, response) = tokio::sync::oneshot::channel();
223                self.handle.send(#create_enum_call).await.expect("Error when sending message to the sync code");
224                response.await.expect("Error receiving the response")
225            }
226        });
227
228        sync_result.push(quote! {
229            #message_enum_name::#msg_name { return_value, #(#parameter_names),* } => {
230                let result = sync.#method_name(#(#parameter_names),*);
231                return_value.send(result).expect("Error calling function");
232            }
233        });
234    }
235
236    quote! {
237        #[derive(Debug)]
238        struct #struct_name {
239            handle: tokio::sync::mpsc::Sender<#message_enum_name>,
240        }
241
242        trait ToAsyncHandle {
243            fn to_async_handle(self, depth: usize) -> #struct_name;
244        }
245
246        impl<T> ToAsyncHandle for T
247        where
248            T: #trait_name + Sync + Send + 'static
249        {
250            fn to_async_handle(self: T, depth: usize) -> #struct_name {
251                #struct_name::spawn(self, depth)
252            }
253        }
254
255        impl #struct_name {
256            pub fn new(handle: tokio::sync::mpsc::Sender<#message_enum_name>) -> Self {
257                Self { handle }
258            }
259
260            pub fn spawn<T>(mut sync: T, depth: usize) -> Self
261            where
262                T: #trait_name + Sync + Send + 'static
263            {
264                let (tx, mut rx) = tokio::sync::mpsc::channel(depth);
265                tokio::spawn(async move {
266                    while let Some(msg) = rx.recv().await {
267                        match msg {
268                            #(#sync_result)*
269                        }
270                    }
271                });
272                Self { handle: tx }
273            }
274
275            #(#async_result)*
276        }
277    }.into()
278
279}
280
281// Generates the enum that is responsible for [de]serialising the
282// messages between the sync and async code
283//
284// given a function like `fn do_something(param: u64) -> u64` it should
285// generate:
286//
287//  ```
288//      #[derive(Debug)]        // required by tokio mpsc
289//      enum AsyncHandleMessage {
290//          DoSomething { return_value: oneshot::Sender<u64>, param: u64 },
291//      }
292//  ```
293//
294fn generate_message_enum(ast: &Ast) -> TokenStream {
295    let enum_name = ast.enum_name();
296
297    let mut enum_variants = Vec::new();
298    for method in ast.methods() {
299        let name = method.name_pascal_case();
300        let parameters = method.typed_parameters();
301        let return_type = method.return_value_type();
302
303        enum_variants.push(quote! {
304            #name {return_value: tokio::sync::oneshot::Sender<#return_type>, #(#parameters),* }
305        });
306    }
307
308    quote!(
309        #[derive(Debug)]
310        enum #enum_name {
311            #(#enum_variants),*
312        }
313    ).into()
314}
315
316#[cfg(test)]
317mod test {
318    use super::*;
319
320    fn assert_tokens_eq(expected: &TokenStream, actual: &TokenStream) {
321        let expected = expected.to_string();
322        let actual = actual.to_string();
323    
324        if expected != actual {
325            println!(
326                "{}",
327                colored_diff::PrettyDifference {
328                    expected: &expected,
329                    actual: &actual,
330                }
331            );
332            println!("expected: {}", &expected);
333            println!("actual  : {}", &actual);
334            panic!("expected != actual");
335        }
336    }
337
338    #[test]
339    fn test_tokio_handle() {
340        let before = quote! {
341            trait MyTrait {
342                fn ignored_associated_function();
343                fn ignored_associated_function_args(input: u64) -> u64;
344                fn simple(&self);
345                fn echo(&self, input: u64) -> u64;
346            }
347        };
348        let expected = quote! {
349            trait MyTrait {
350                fn ignored_associated_function();
351                fn ignored_associated_function_args(input: u64) -> u64;
352                fn simple(&self);
353                fn echo(&self, input: u64) -> u64;
354            }
355
356            #[derive(Debug)]
357            struct AsyncMyTraitHandle { handle: tokio::sync::mpsc::Sender<AsyncMyTraitMessage>, }
358
359            trait ToAsyncHandle { fn to_async_handle (self, depth: usize) -> AsyncMyTraitHandle ; }
360
361            impl<T> ToAsyncHandle for T
362            where
363                T: MyTrait + Sync + Send + 'static
364            {
365                fn to_async_handle(self: T, depth: usize) -> AsyncMyTraitHandle {
366                    AsyncMyTraitHandle::spawn(self, depth)
367                }
368            }
369
370            impl AsyncMyTraitHandle {
371                pub fn new(handle: tokio::sync::mpsc::Sender<AsyncMyTraitMessage>) -> Self {
372                    Self { handle }
373                }
374
375                pub fn spawn<T>(mut sync: T, depth: usize) -> Self
376                where
377                    T: MyTrait + Sync + Send + 'static
378                {
379                    let (tx, mut rx) = tokio::sync::mpsc::channel(depth);
380                    tokio::spawn(async move {
381                        while let Some(msg) = rx.recv().await {
382                            match msg {
383                                AsyncMyTraitMessage::Simple { return_value, } => {
384                                    let result = sync.simple();
385                                    return_value.send(result).expect("Error calling function");
386                                }
387                                AsyncMyTraitMessage::Echo { return_value, input } => {
388                                    let result = sync.echo(input);
389                                    return_value.send(result).expect("Error calling function");
390                                }
391                            }
392                        }
393                    });
394                    Self { handle: tx }
395                }
396
397                async fn simple(&self, ) -> () {
398                    let (return_value, response) = tokio::sync::oneshot::channel();
399                    self.handle.send(AsyncMyTraitMessage::Simple{ return_value, }).await.expect("Error when sending message to the sync code");
400                    response.await.expect("Error receiving the response")
401                }
402
403                async fn echo(&self, input: u64) -> u64 {
404                    let (return_value, response) = tokio::sync::oneshot::channel();
405                    self.handle.send(AsyncMyTraitMessage::Echo{ return_value, input }).await.expect("Error when sending message to the sync code");
406                    response.await.expect("Error receiving the response")
407                }
408            }
409
410            #[derive(Debug)]
411            enum AsyncMyTraitMessage {
412                Simple { return_value: tokio::sync::oneshot::Sender<()>, },
413                Echo { return_value: tokio::sync::oneshot::Sender<u64>, input: u64 }
414            }
415
416
417        };
418        let after = client_handle_core(quote!(), before);
419        assert_tokens_eq(&expected, &after);
420    }
421}