client_handle_core/
lib.rs1use proc_macro2::{TokenStream};
2use proc_macro_error::{abort};
3use quote::{quote, format_ident, ToTokens};
4use syn::{self, FnArg, TraitItemMethod, Ident, ReturnType, Pat, parse2};
5use convert_case::{Case, Casing};
6
7pub fn client_handle_core(_attr: TokenStream, item: TokenStream) -> TokenStream {
8 let ast = parse2(item).unwrap();
9
10 if let syn::Item::Trait(trayt) = &ast {
11 handle_trait(&Ast { trayt })
12 } else {
13 abort!(ast, "The `async_tokio_handle` macro only works on traits");
14 }
15}
16
17struct Ast<'a> {
20 trayt: &'a syn::ItemTrait,
21}
22
23struct Method <'a> {
27 sig: &'a syn::Signature,
28}
29
30impl<'a> Ast<'a> {
31 fn trait_name(&self) -> &Ident {
33 &self.trayt.ident
34 }
35
36 fn enum_name(&self) -> Ident {
39 format_ident!("Async{}Message", self.trait_name())
40 }
41
42 fn handle_name(&self) -> Ident {
45 format_ident!("Async{}Handle", self.trait_name())
46 }
47
48 fn original_trait(&self) -> TokenStream {
51 self.trayt.to_token_stream().into()
52 }
53
54 fn methods(&self) -> Vec<Method> {
57 self.get_trait_methods()
58 .iter()
59 .map(|m| Method { sig: &m.sig })
60 .collect()
61 }
62
63 fn get_trait_methods(&'a self) -> Vec<&'a TraitItemMethod> {
65 let mut methods = Vec::new();
66 for item in &self.trayt.items {
67 match item {
68 syn::TraitItem::Method(method) => {
69 if let Some(FnArg::Receiver(_)) = method.sig.inputs.first() {
70 methods.push(method);
71 }
72 },
73 _ => { panic!("Can only handle trait methods") }
74 }
75 }
76 methods
77 }
78}
79
80
81impl<'a> Method<'a> {
82 fn name(&self) -> &Ident {
83 return &self.sig.ident
84 }
85
86 fn name_pascal_case(&self) -> Ident {
87 format_ident!("{}", self.sig.ident.to_string().to_case(Case::Pascal))
88 }
89
90 fn typed_parameter_names_only(&self) -> Vec<&Pat> {
91 let mut result = Vec::new();
92 for input in &self.sig.inputs {
93 match input {
94 FnArg::Receiver(_) => {},
95 FnArg::Typed(typed) => {
96 result.push(&*typed.pat)
97 },
98 }
99 }
100 result
101 }
102
103 fn typed_parameters(&self) -> Vec<&FnArg> {
104 self.sig.inputs
105 .iter()
106 .filter(|arg| {
107 match arg {
108 FnArg::Receiver(_) => false,
109 FnArg::Typed(_) => true,
110 }
111 })
112 .collect()
113 }
114
115 fn return_value_type(&self) -> proc_macro2::TokenStream {
116 match &self.sig.output {
117 ReturnType::Default => quote!{ () },
118 ReturnType::Type(_, tipe) => quote! { #tipe },
119 }
120 }
121}
122
123
124fn handle_trait(ast: &Ast) -> TokenStream {
125 let message_enum = generate_message_enum(ast);
126
127 let output = vec![
128 ast.original_trait(),
129 generate_struct(&ast),
130 message_enum,
131 ];
132
133 let mut gen: TokenStream = TokenStream::new();
134 gen.extend(output.into_iter());
135
136 gen
137}
138
139fn generate_struct(ast: &Ast) -> TokenStream {
203 let trait_name = &ast.trait_name();
204 let struct_name = &ast.handle_name();
205 let message_enum_name = &ast.enum_name();
206
207 let mut async_result = Vec::new();
208 let mut sync_result = Vec::new();
209 for method in ast.methods() {
210 let msg_name = method.name_pascal_case();
211 let parameters = method.typed_parameters();
212 let parameter_names = method.typed_parameter_names_only();
213 let method_name = method.name();
214 let return_type = method.return_value_type();
215
216 let create_enum_call = quote! {
217 #message_enum_name::#msg_name { return_value, #(#parameter_names),* }
218 };
219
220 async_result.push(quote! {
221 async fn #method_name (&self, #(#parameters),*) -> #return_type {
222 let (return_value, response) = tokio::sync::oneshot::channel();
223 self.handle.send(#create_enum_call).await.expect("Error when sending message to the sync code");
224 response.await.expect("Error receiving the response")
225 }
226 });
227
228 sync_result.push(quote! {
229 #message_enum_name::#msg_name { return_value, #(#parameter_names),* } => {
230 let result = sync.#method_name(#(#parameter_names),*);
231 return_value.send(result).expect("Error calling function");
232 }
233 });
234 }
235
236 quote! {
237 #[derive(Debug)]
238 struct #struct_name {
239 handle: tokio::sync::mpsc::Sender<#message_enum_name>,
240 }
241
242 trait ToAsyncHandle {
243 fn to_async_handle(self, depth: usize) -> #struct_name;
244 }
245
246 impl<T> ToAsyncHandle for T
247 where
248 T: #trait_name + Sync + Send + 'static
249 {
250 fn to_async_handle(self: T, depth: usize) -> #struct_name {
251 #struct_name::spawn(self, depth)
252 }
253 }
254
255 impl #struct_name {
256 pub fn new(handle: tokio::sync::mpsc::Sender<#message_enum_name>) -> Self {
257 Self { handle }
258 }
259
260 pub fn spawn<T>(mut sync: T, depth: usize) -> Self
261 where
262 T: #trait_name + Sync + Send + 'static
263 {
264 let (tx, mut rx) = tokio::sync::mpsc::channel(depth);
265 tokio::spawn(async move {
266 while let Some(msg) = rx.recv().await {
267 match msg {
268 #(#sync_result)*
269 }
270 }
271 });
272 Self { handle: tx }
273 }
274
275 #(#async_result)*
276 }
277 }.into()
278
279}
280
281fn generate_message_enum(ast: &Ast) -> TokenStream {
295 let enum_name = ast.enum_name();
296
297 let mut enum_variants = Vec::new();
298 for method in ast.methods() {
299 let name = method.name_pascal_case();
300 let parameters = method.typed_parameters();
301 let return_type = method.return_value_type();
302
303 enum_variants.push(quote! {
304 #name {return_value: tokio::sync::oneshot::Sender<#return_type>, #(#parameters),* }
305 });
306 }
307
308 quote!(
309 #[derive(Debug)]
310 enum #enum_name {
311 #(#enum_variants),*
312 }
313 ).into()
314}
315
316#[cfg(test)]
317mod test {
318 use super::*;
319
320 fn assert_tokens_eq(expected: &TokenStream, actual: &TokenStream) {
321 let expected = expected.to_string();
322 let actual = actual.to_string();
323
324 if expected != actual {
325 println!(
326 "{}",
327 colored_diff::PrettyDifference {
328 expected: &expected,
329 actual: &actual,
330 }
331 );
332 println!("expected: {}", &expected);
333 println!("actual : {}", &actual);
334 panic!("expected != actual");
335 }
336 }
337
338 #[test]
339 fn test_tokio_handle() {
340 let before = quote! {
341 trait MyTrait {
342 fn ignored_associated_function();
343 fn ignored_associated_function_args(input: u64) -> u64;
344 fn simple(&self);
345 fn echo(&self, input: u64) -> u64;
346 }
347 };
348 let expected = quote! {
349 trait MyTrait {
350 fn ignored_associated_function();
351 fn ignored_associated_function_args(input: u64) -> u64;
352 fn simple(&self);
353 fn echo(&self, input: u64) -> u64;
354 }
355
356 #[derive(Debug)]
357 struct AsyncMyTraitHandle { handle: tokio::sync::mpsc::Sender<AsyncMyTraitMessage>, }
358
359 trait ToAsyncHandle { fn to_async_handle (self, depth: usize) -> AsyncMyTraitHandle ; }
360
361 impl<T> ToAsyncHandle for T
362 where
363 T: MyTrait + Sync + Send + 'static
364 {
365 fn to_async_handle(self: T, depth: usize) -> AsyncMyTraitHandle {
366 AsyncMyTraitHandle::spawn(self, depth)
367 }
368 }
369
370 impl AsyncMyTraitHandle {
371 pub fn new(handle: tokio::sync::mpsc::Sender<AsyncMyTraitMessage>) -> Self {
372 Self { handle }
373 }
374
375 pub fn spawn<T>(mut sync: T, depth: usize) -> Self
376 where
377 T: MyTrait + Sync + Send + 'static
378 {
379 let (tx, mut rx) = tokio::sync::mpsc::channel(depth);
380 tokio::spawn(async move {
381 while let Some(msg) = rx.recv().await {
382 match msg {
383 AsyncMyTraitMessage::Simple { return_value, } => {
384 let result = sync.simple();
385 return_value.send(result).expect("Error calling function");
386 }
387 AsyncMyTraitMessage::Echo { return_value, input } => {
388 let result = sync.echo(input);
389 return_value.send(result).expect("Error calling function");
390 }
391 }
392 }
393 });
394 Self { handle: tx }
395 }
396
397 async fn simple(&self, ) -> () {
398 let (return_value, response) = tokio::sync::oneshot::channel();
399 self.handle.send(AsyncMyTraitMessage::Simple{ return_value, }).await.expect("Error when sending message to the sync code");
400 response.await.expect("Error receiving the response")
401 }
402
403 async fn echo(&self, input: u64) -> u64 {
404 let (return_value, response) = tokio::sync::oneshot::channel();
405 self.handle.send(AsyncMyTraitMessage::Echo{ return_value, input }).await.expect("Error when sending message to the sync code");
406 response.await.expect("Error receiving the response")
407 }
408 }
409
410 #[derive(Debug)]
411 enum AsyncMyTraitMessage {
412 Simple { return_value: tokio::sync::oneshot::Sender<()>, },
413 Echo { return_value: tokio::sync::oneshot::Sender<u64>, input: u64 }
414 }
415
416
417 };
418 let after = client_handle_core(quote!(), before);
419 assert_tokens_eq(&expected, &after);
420 }
421}