async_blocking_bridger/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Group;
3use proc_macro2::Ident;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{quote, ToTokens};
6use syn::Stmt;
7use syn::Token;
8
9
10fn replace_ident(from: &str, to: &str, stream: TokenStream2) -> TokenStream2 {
11    stream.into_iter()
12        .map(|token| {
13            if let proc_macro2::TokenTree::Ident(ident) = &token {
14                if ident.to_string() == from {
15                    return proc_macro2::TokenTree::Ident(
16                        syn::Ident::new(&to, ident.span())
17                    );
18                }
19            } else if let proc_macro2::TokenTree::Group(group) = &token {
20                return proc_macro2::TokenTree::Group(
21                    Group::new(
22                        group.delimiter(),
23                        replace_ident(from, to, group.stream())
24                    )
25                );
26            }
27            token
28        })
29        .collect::<TokenStream2>()
30}
31
32
33/// Uses an inner object and outer handle paradigm to allow for async and blocking implementations
34/// without writing the same code twice. There will be a Blocking____Handle and Async____Handle that
35/// have an Arc<Inner____>. The macro will take an impl block for _____.
36/// The macro will then create an impl block for both handles that simply call the
37/// inner object's methods.
38///
39/// The blocking handle will have a field called rt thats a tokio runtime that can be blocked on.
40#[proc_macro]
41pub fn async_inner(input: TokenStream) -> TokenStream {
42
43    let input: TokenStream2 = input.into();
44
45    let input = replace_ident("inner", "self", input);
46
47    let mut impl_block = match syn::parse2::<syn::ItemImpl>(input.clone()) {
48        Ok(impl_block) => impl_block,
49        _ => return input.into(),
50    };
51
52    let name_str = impl_block.self_ty.to_token_stream().to_string();
53
54    let inner_name: Ident =
55        syn::parse_str(&format!("Inner{}", &name_str)).expect("Failed to parse inner name");
56    impl_block.self_ty = Box::new(
57        syn::parse2::<syn::Type>(inner_name.to_token_stream()).expect("Failed to parse inner name"),
58    );
59
60
61
62    for item in impl_block.items.iter() {
63        if let syn::ImplItem::Fn(method) = item {
64            if method.sig.asyncness.is_none() {
65                panic!("All methods in impl block must be async")
66            }
67        }
68    }
69
70    let mut async_impl_block = impl_block.clone();
71
72    let async_name: Ident = syn::parse_str(&format!("Async{}Handle", &name_str))
73        .expect("Failed to parse async handle name");
74    async_impl_block.self_ty = Box::new(
75        syn::parse2::<syn::Type>(async_name.to_token_stream())
76            .expect("Failed to parse async handle name"),
77    );
78
79    for item in async_impl_block.items.iter_mut() {
80        if let syn::ImplItem::Fn(method) = item {
81            //make the method call self.inner.method then await it
82            let method_name = &method.sig.ident;
83            let arg_names = method.sig.inputs.iter().filter_map(|arg| {
84                if let syn::FnArg::Typed(pat_type) = arg {
85                    if let syn::Pat::Ident(ident) = &*pat_type.pat {
86                        return Some(&ident.ident);
87                    }
88                }
89                None
90            });
91
92            let new_block = quote! {
93                return self.inner.#method_name(#(#arg_names),*).await;
94            };
95
96            method.block.stmts = vec![syn::parse2::<Stmt>(new_block).unwrap()];
97
98            method.vis = syn::Visibility::Public(Token![pub](proc_macro2::Span::call_site()));
99        }
100    }
101
102    let mut blocking_impl_block = impl_block.clone();
103
104    let blocking_name: Ident = syn::parse_str(&format!("Blocking{}Handle", &name_str))
105        .expect("Failed to parse blocking handle name");
106    blocking_impl_block.self_ty = Box::new(
107        syn::parse2::<syn::Type>(blocking_name.to_token_stream())
108            .expect("Failed to parse blocking handle name"),
109    );
110
111    for item in blocking_impl_block.items.iter_mut() {
112        if let syn::ImplItem::Fn(method) = item {
113            //make the method call self.inner.method then await it
114            let method_name = &method.sig.ident;
115            let arg_names = method.sig.inputs.iter().filter_map(|arg| {
116                if let syn::FnArg::Typed(pat_type) = arg {
117                    if let syn::Pat::Ident(ident) = &*pat_type.pat {
118                        return Some(&ident.ident);
119                    }
120                }
121                None
122            });
123
124            let new_block = quote! {
125                return self.rt.block_on(self.inner.#method_name(#(#arg_names),*));
126            };
127
128            method.block.stmts = vec![syn::parse2::<Stmt>(new_block).unwrap()];
129
130            method.sig.asyncness = None;
131
132            method.vis = syn::Visibility::Public(Token![pub](proc_macro2::Span::call_site()));
133        }
134    }
135
136    let tokens = quote! {
137        #impl_block
138
139        #async_impl_block
140
141        #blocking_impl_block
142    };
143
144    tokens.into()
145}