async_mock/
lib.rs
1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::parse_macro_input;
4
5#[allow(dead_code)]
6fn print_tokens(tokens: &dyn ToTokens) {
7 println!("{}", tokens.to_token_stream());
8}
9
10#[allow(dead_code)]
11fn print_tokens_dbg(tokens: &dyn ToTokens) {
12 println!("{:?}", tokens.to_token_stream());
13}
14
15fn contains_impl(token: &syn::Type) -> bool {
16 match token {
17 syn::Type::ImplTrait(_) => true,
18 syn::Type::Group(group) => contains_impl(group.elem.as_ref()),
19 syn::Type::Paren(paren) => contains_impl(paren.elem.as_ref()),
20 syn::Type::Reference(reference) => contains_impl(reference.elem.as_ref()),
21 _ => false,
22 }
23}
24
25fn convert_impl_to_dyn(token: &syn::Type) -> syn::Type {
26 match &token {
27 syn::Type::ImplTrait(impl_trait) => syn::Type::TraitObject(syn::TypeTraitObject {
28 dyn_token: Some(syn::token::Dyn::default()),
29 bounds: impl_trait.bounds.clone(),
30 }),
31 syn::Type::Group(group) => syn::Type::Group(syn::TypeGroup {
32 group_token: group.group_token,
33 elem: Box::new(convert_impl_to_dyn(group.elem.as_ref())),
34 }),
35 syn::Type::Paren(paren) => syn::Type::Paren(syn::TypeParen {
36 paren_token: paren.paren_token,
37 elem: Box::new(convert_impl_to_dyn(paren.elem.as_ref())),
38 }),
39 syn::Type::Reference(reference) => syn::Type::Reference(syn::TypeReference {
40 and_token: reference.and_token,
41 lifetime: reference.lifetime.clone(),
42 mutability: reference.mutability,
43 elem: Box::new(convert_impl_to_dyn(reference.elem.as_ref())),
44 }),
45 _ => token.clone(),
46 }
47}
48
49#[proc_macro_attribute]
50pub fn async_mock(_attr: TokenStream, item: TokenStream) -> TokenStream {
51 let input = parse_macro_input!(item as syn::ItemTrait);
52 let trait_name = input.ident.clone();
53 let mock_name = format_ident!("Mock{trait_name}");
54 let mut objects = Vec::new();
55 let mut expectations = Vec::new();
56 let mut expectation_validation = Vec::new();
57 let mut functions = Vec::new();
58 let mut impls = Vec::new();
59 let mut counter = 0;
60
61 for item in input.items.iter() {
62 if let syn::TraitItem::Fn(f) = item {
63 let mut fn_arg_types = Vec::new();
64 let mut fn_arg_types_dyn = Vec::new();
65 let mut fn_arg_names = Vec::new();
66 let mut has_impl_ref = false;
67
68 for arg in f.sig.inputs.iter() {
69 if let syn::FnArg::Typed(pat) = arg {
70 if let syn::Pat::Ident(ident) = pat.pat.as_ref() {
71 fn_arg_names.push(ident.ident.clone());
72 }
73
74 has_impl_ref |= contains_impl(pat.ty.as_ref());
75 fn_arg_types.push(pat.ty.clone());
76 fn_arg_types_dyn.push(convert_impl_to_dyn(pat.ty.as_ref()));
77 }
78 }
79
80 let function_name = format_ident!("{}", f.sig.ident);
81 let expect_name = format_ident!("expect_{function_name}");
82 let expectation_name = format_ident!("{function_name}_expectation");
83 let expectation_struct_name = format_ident!("__{mock_name}Expectation{counter}");
84 let expectation_struct_name_inner =
85 format_ident!("__{mock_name}ExpectationInner{counter}");
86 let fn_rt = f.sig.output.clone();
87 let function_signature = f.sig.clone();
88
89 let fn_storage_type = if has_impl_ref {
90 quote! { Box<dyn Fn(#(#fn_arg_types_dyn),*) #fn_rt + Send + Sync> }
91 } else {
92 quote! { fn(#(#fn_arg_types_dyn),*) #fn_rt }
93 };
94
95 objects.push(quote! {
96 #expectation_name: #expectation_struct_name
97 });
98
99 let returning_fn_name = if has_impl_ref {
100 format_ident!("returning_dyn")
101 } else {
102 format_ident!("returning")
103 };
104
105 expectations.push(quote! {
106 #[cfg(test)]
107 #[derive(Default)]
108 pub struct #expectation_struct_name {
109 inner: std::sync::Mutex<#expectation_struct_name_inner>,
110 }
111
112 #[cfg(test)]
113 #[derive(Default)]
114 pub struct #expectation_struct_name_inner {
115 expecting: u32,
116 called: u32,
117 returning: Option<#fn_storage_type>,
118 }
119
120 #[cfg(test)]
121 impl #expectation_struct_name {
122 pub fn once(&mut self) -> &mut Self {
123 self.inner.lock().unwrap().expecting = 1;
124 self
125 }
126
127 pub fn never(&mut self) -> &mut Self {
128 self.inner.lock().unwrap().expecting = 0;
129 self
130 }
131
132 pub fn times(&mut self, count: u32) -> &mut Self {
133 self.inner.lock().unwrap().expecting = count;
134 self
135 }
136
137 pub fn #returning_fn_name(
138 &mut self,
139 func: #fn_storage_type,
140 ) -> &mut Self {
141 self.inner.lock().unwrap().returning = Some(func);
142 self
143 }
144 }
145 });
146
147 let get_mutex_expectation = quote! {
148 let expectation = self.#expectation_name.inner.lock();
149 assert!(expectation.is_ok(), "Poisoned inner mocking state for `{}`.", stringify!(#function_name));
150 let mut expectation = expectation.unwrap();
151 };
152
153 let func_call_with_drop = if has_impl_ref {
154 quote! {
155 let func = expectation.returning.as_ref();
156
157 if let Some(func) = func {
158 func(#(#fn_arg_names),*)
159 } else {
160 drop(expectation);
161 panic!("Missing returning function for `{}`", stringify!(#function_name));
162 }
163 }
164 } else {
165 quote! {
166 let func = expectation.returning;
167
168 if let Some(func) = &func {
169 func(#(#fn_arg_names),*)
170 } else {
171 drop(expectation);
172 panic!("Missing returning function for `{}`", stringify!(#function_name));
173 }
174 }
175 };
176
177 impls.push(quote! {
178 #function_signature {
179 #get_mutex_expectation
180
181 expectation.called += 1;
182
183 #func_call_with_drop
184 }
185 });
186
187 expectation_validation.push(quote! {
188 {
189 #get_mutex_expectation
190
191 let expecting = expectation.expecting;
192 let called = expectation.called;
193
194 drop(expectation);
195
196 if !std::thread::panicking() {
197 assert_eq!(
198 expecting,
199 called,
200 "Failed expectation for `{}`. Called {} times but expecting {}.",
201 stringify!(#function_name),
202 called,
203 expecting
204 );
205 }
206 }
207 });
208
209 functions.push(quote! {
210 pub fn #expect_name(&mut self) -> &mut #expectation_struct_name {
211 &mut self.#expectation_name
212 }
213 });
214
215 counter += 1;
216 };
217 }
218
219 let code = quote! {
220 #input
221
222 #[cfg(test)]
223 #[derive(Default)]
224 #[allow(dead_code)]
225 pub struct #mock_name {
226 #(#objects),*
227 }
228
229 #[cfg(test)]
230 impl Drop for #mock_name {
231 fn drop(&mut self) {
232 #(#expectation_validation)*
233 }
234 }
235
236 #(#expectations)*
237
238 #[cfg(test)]
239 #[allow(dead_code)]
240 impl #mock_name {
241 #(#functions) *
242
243 pub fn new() -> Self {
244 Self::default()
245 }
246 }
247
248 #[cfg(test)]
249 #[async_trait::async_trait] impl #trait_name for #mock_name {
251 #(#impls) *
252 }
253 };
254
255 code.into()
258}