async_mock/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::parse_macro_input;
4
5#[allow(dead_code)]
6fn print_tokens(tokens: &dyn ToTokens) {
7    println!("{}", tokens.to_token_stream());
8}
9
10#[allow(dead_code)]
11fn print_tokens_dbg(tokens: &dyn ToTokens) {
12    println!("{:?}", tokens.to_token_stream());
13}
14
15fn contains_impl(token: &syn::Type) -> bool {
16    match token {
17        syn::Type::ImplTrait(_) => true,
18        syn::Type::Group(group) => contains_impl(group.elem.as_ref()),
19        syn::Type::Paren(paren) => contains_impl(paren.elem.as_ref()),
20        syn::Type::Reference(reference) => contains_impl(reference.elem.as_ref()),
21        _ => false,
22    }
23}
24
25fn convert_impl_to_dyn(token: &syn::Type) -> syn::Type {
26    match &token {
27        syn::Type::ImplTrait(impl_trait) => syn::Type::TraitObject(syn::TypeTraitObject {
28            dyn_token: Some(syn::token::Dyn::default()),
29            bounds: impl_trait.bounds.clone(),
30        }),
31        syn::Type::Group(group) => syn::Type::Group(syn::TypeGroup {
32            group_token: group.group_token,
33            elem: Box::new(convert_impl_to_dyn(group.elem.as_ref())),
34        }),
35        syn::Type::Paren(paren) => syn::Type::Paren(syn::TypeParen {
36            paren_token: paren.paren_token,
37            elem: Box::new(convert_impl_to_dyn(paren.elem.as_ref())),
38        }),
39        syn::Type::Reference(reference) => syn::Type::Reference(syn::TypeReference {
40            and_token: reference.and_token,
41            lifetime: reference.lifetime.clone(),
42            mutability: reference.mutability,
43            elem: Box::new(convert_impl_to_dyn(reference.elem.as_ref())),
44        }),
45        _ => token.clone(),
46    }
47}
48
49#[proc_macro_attribute]
50pub fn async_mock(_attr: TokenStream, item: TokenStream) -> TokenStream {
51    let input = parse_macro_input!(item as syn::ItemTrait);
52    let trait_name = input.ident.clone();
53    let mock_name = format_ident!("Mock{trait_name}");
54    let mut objects = Vec::new();
55    let mut expectations = Vec::new();
56    let mut expectation_validation = Vec::new();
57    let mut functions = Vec::new();
58    let mut impls = Vec::new();
59    let mut counter = 0;
60
61    for item in input.items.iter() {
62        if let syn::TraitItem::Fn(f) = item {
63            let mut fn_arg_types = Vec::new();
64            let mut fn_arg_types_dyn = Vec::new();
65            let mut fn_arg_names = Vec::new();
66            let mut has_impl_ref = false;
67
68            for arg in f.sig.inputs.iter() {
69                if let syn::FnArg::Typed(pat) = arg {
70                    if let syn::Pat::Ident(ident) = pat.pat.as_ref() {
71                        fn_arg_names.push(ident.ident.clone());
72                    }
73
74                    has_impl_ref |= contains_impl(pat.ty.as_ref());
75                    fn_arg_types.push(pat.ty.clone());
76                    fn_arg_types_dyn.push(convert_impl_to_dyn(pat.ty.as_ref()));
77                }
78            }
79
80            let function_name = format_ident!("{}", f.sig.ident);
81            let expect_name = format_ident!("expect_{function_name}");
82            let expectation_name = format_ident!("{function_name}_expectation");
83            let expectation_struct_name = format_ident!("__{mock_name}Expectation{counter}");
84            let expectation_struct_name_inner =
85                format_ident!("__{mock_name}ExpectationInner{counter}");
86            let fn_rt = f.sig.output.clone();
87            let function_signature = f.sig.clone();
88
89            let fn_storage_type = if has_impl_ref {
90                quote! { Box<dyn Fn(#(#fn_arg_types_dyn),*) #fn_rt + Send + Sync> }
91            } else {
92                quote! { fn(#(#fn_arg_types_dyn),*) #fn_rt }
93            };
94
95            objects.push(quote! {
96                #expectation_name: #expectation_struct_name
97            });
98
99            let returning_fn_name = if has_impl_ref {
100                format_ident!("returning_dyn")
101            } else {
102                format_ident!("returning")
103            };
104
105            expectations.push(quote! {
106                #[cfg(test)]
107                #[derive(Default)]
108                pub struct #expectation_struct_name {
109                    inner: std::sync::Mutex<#expectation_struct_name_inner>,
110                }
111
112                #[cfg(test)]
113                #[derive(Default)]
114                pub struct #expectation_struct_name_inner {
115                    expecting: u32,
116                    called: u32,
117                    returning: Option<#fn_storage_type>,
118                }
119
120                #[cfg(test)]
121                impl #expectation_struct_name {
122                    pub fn once(&mut self) -> &mut Self {
123                        self.inner.lock().unwrap().expecting = 1;
124                        self
125                    }
126
127                    pub fn never(&mut self) -> &mut Self {
128                        self.inner.lock().unwrap().expecting = 0;
129                        self
130                    }
131
132                    pub fn times(&mut self, count: u32) -> &mut Self {
133                        self.inner.lock().unwrap().expecting = count;
134                        self
135                    }
136
137                    pub fn #returning_fn_name(
138                        &mut self,
139                        func: #fn_storage_type,
140                    ) -> &mut Self {
141                        self.inner.lock().unwrap().returning = Some(func);
142                        self
143                    }
144                }
145            });
146
147            let get_mutex_expectation = quote! {
148                let expectation = self.#expectation_name.inner.lock();
149                assert!(expectation.is_ok(), "Poisoned inner mocking state for `{}`.", stringify!(#function_name));
150                let mut expectation = expectation.unwrap();
151            };
152
153            let func_call_with_drop = if has_impl_ref {
154                quote! {
155                    let func = expectation.returning.as_ref();
156
157                    if let Some(func) = func {
158                        func(#(#fn_arg_names),*)
159                    } else {
160                        drop(expectation);
161                        panic!("Missing returning function for `{}`", stringify!(#function_name));
162                    }
163                }
164            } else {
165                quote! {
166                    let func = expectation.returning;
167
168                    if let Some(func) = &func {
169                        func(#(#fn_arg_names),*)
170                    } else {
171                        drop(expectation);
172                        panic!("Missing returning function for `{}`", stringify!(#function_name));
173                    }
174                }
175            };
176
177            impls.push(quote! {
178                #function_signature {
179                    #get_mutex_expectation
180
181                    expectation.called += 1;
182
183                    #func_call_with_drop
184                }
185            });
186
187            expectation_validation.push(quote! {
188                {
189                    #get_mutex_expectation
190
191                    let expecting = expectation.expecting;
192                    let called = expectation.called;
193
194                    drop(expectation);
195
196                    if !std::thread::panicking() {
197                        assert_eq!(
198                            expecting,
199                            called,
200                            "Failed expectation for `{}`. Called {} times but expecting {}.",
201                            stringify!(#function_name),
202                            called,
203                            expecting
204                        );
205                    }
206                }
207            });
208
209            functions.push(quote! {
210                pub fn #expect_name(&mut self) -> &mut #expectation_struct_name {
211                    &mut self.#expectation_name
212                }
213            });
214
215            counter += 1;
216        };
217    }
218
219    let code = quote! {
220        #input
221
222        #[cfg(test)]
223        #[derive(Default)]
224        #[allow(dead_code)]
225        pub struct #mock_name {
226            #(#objects),*
227        }
228
229        #[cfg(test)]
230        impl Drop for #mock_name {
231            fn drop(&mut self) {
232                #(#expectation_validation)*
233            }
234        }
235
236        #(#expectations)*
237
238        #[cfg(test)]
239        #[allow(dead_code)]
240        impl #mock_name {
241            #(#functions) *
242
243            pub fn new() -> Self {
244                Self::default()
245            }
246        }
247
248        #[cfg(test)]
249        #[async_trait::async_trait] // TODO: Only add this if it was used on the trait
250        impl #trait_name for #mock_name {
251            #(#impls) *
252        }
253    };
254
255    // print_tokens(&code);
256
257    code.into()
258}