1#![deny(clippy::all, clippy::pedantic, missing_docs)]
3use proc_macro::TokenStream;
4use proc_macro2::Ident;
5use proc_macro_error::{proc_macro_error, ResultExt};
6use quote::{format_ident, quote};
7use syn::token::Brace;
8use syn::{
9 parse,
10 Block,
11 FnArg,
12 GenericArgument,
13 GenericParam,
14 Generics,
15 ImplItem,
16 ImplItemMethod,
17 ItemTrait,
18 Path,
19 PathArguments,
20 PathSegment,
21 ReturnType,
22 TraitItem,
23 Type,
24 TypeBareFn,
25 TypeParamBound,
26 Visibility,
27 WherePredicate,
28};
29
30use crate::expectation::Expectation;
31use crate::mock::Mock;
32use crate::mock_input::MockInput;
33use crate::syn_ext::{PathExt, PathSegmentExt, WithLeadingColons};
34use crate::util::create_path;
35
36mod expectation;
37mod mock;
38mod mock_input;
39mod syn_ext;
40mod util;
41
42#[proc_macro]
76#[proc_macro_error]
77pub fn mock(input_stream: TokenStream) -> TokenStream
78{
79 let input = parse::<MockInput>(input_stream.clone()).unwrap_or_abort();
80
81 let mock_ident = input.mock;
82
83 let mock_mod_ident = format_ident!("__{mock_ident}");
84
85 let method_items =
86 get_type_replaced_impl_item_methods(input.item_impl.items, &mock_ident);
87
88 let mock = Mock::new(
89 mock_ident.clone(),
90 input.mocked_trait,
91 &method_items,
92 input.item_impl.generics.clone(),
93 );
94
95 let expectations = method_items.iter().map(|item_method| {
96 Expectation::new(
97 &mock_ident,
98 item_method,
99 input.item_impl.generics.params.clone(),
100 )
101 });
102
103 quote! {
104 mod #mock_mod_ident {
105 use super::*;
106
107 #mock
108
109 #(#expectations)*
110 }
111
112 use #mock_mod_ident::#mock_ident;
113 }
114 .into()
115}
116
117#[proc_macro_attribute]
119#[proc_macro_error]
120pub fn automock(_: TokenStream, input_stream: TokenStream) -> TokenStream
121{
122 let item_trait = parse::<ItemTrait>(input_stream).unwrap_or_abort();
123
124 let mock_ident = format_ident!("Mock{}", item_trait.ident);
125
126 let mock_mod_ident = format_ident!("__{mock_ident}");
127
128 let method_items = get_type_replaced_impl_item_methods(
129 item_trait.items.iter().filter_map(|item| match item {
130 TraitItem::Method(item_method) => Some(ImplItem::Method(ImplItemMethod {
131 attrs: item_method.attrs.clone(),
132 vis: Visibility::Inherited,
133 defaultness: None,
134 sig: item_method.sig.clone(),
135 block: Block {
136 brace_token: Brace::default(),
137 stmts: vec![],
138 },
139 })),
140 _ => None,
141 }),
142 &mock_ident,
143 );
144
145 let mock = Mock::new(
146 mock_ident.clone(),
147 Path::new(
148 WithLeadingColons::No,
149 [PathSegment::new(item_trait.ident.clone(), None)],
150 ),
151 &method_items,
152 item_trait.generics.clone(),
153 );
154
155 let expectations = method_items.iter().map(|item_method| {
156 Expectation::new(&mock_ident, item_method, item_trait.generics.params.clone())
157 });
158
159 let visibility = &item_trait.vis;
160
161 quote! {
162 #item_trait
163
164 mod #mock_mod_ident {
165 use super::*;
166
167 #mock
168
169 #(#expectations)*
170 }
171
172 #visibility use #mock_mod_ident::#mock_ident;
173 }
174 .into()
175}
176
177fn get_type_replaced_impl_item_methods(
178 impl_items: impl IntoIterator<Item = ImplItem>,
179 mock_ident: &Ident,
180) -> Vec<ImplItemMethod>
181{
182 let target_path = create_path!(Self);
183
184 let replacement_path = Path::new(
185 WithLeadingColons::No,
186 [PathSegment::new(mock_ident.clone(), None)],
187 );
188
189 impl_items
190 .into_iter()
191 .filter_map(|item| match item {
192 ImplItem::Method(mut item_method) => {
193 item_method.sig.inputs = item_method
194 .sig
195 .inputs
196 .into_iter()
197 .map(|fn_arg| match fn_arg {
198 FnArg::Typed(mut typed_arg) => {
199 typed_arg.ty = Box::new(replace_path_in_type(
200 *typed_arg.ty,
201 &target_path,
202 &replacement_path,
203 ));
204
205 FnArg::Typed(typed_arg)
206 }
207
208 FnArg::Receiver(receiver) => FnArg::Receiver(receiver),
209 })
210 .collect();
211
212 item_method.sig.output = match item_method.sig.output {
213 ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
214 r_arrow,
215 Box::new(replace_path_in_type(
216 *return_type,
217 &target_path,
218 &replacement_path,
219 )),
220 ),
221 ReturnType::Default => ReturnType::Default,
222 };
223
224 item_method.sig.generics = replace_path_in_generics(
225 item_method.sig.generics,
226 &target_path,
227 &replacement_path,
228 );
229
230 Some(item_method)
231 }
232 _ => None,
233 })
234 .collect()
235}
236
237fn replace_path_in_generics(
238 mut generics: Generics,
239 target_path: &Path,
240 replacement_path: &Path,
241) -> Generics
242{
243 generics.params = generics
244 .params
245 .into_iter()
246 .map(|generic_param| match generic_param {
247 GenericParam::Type(mut type_param) => {
248 type_param.bounds = type_param
249 .bounds
250 .into_iter()
251 .map(|bound| {
252 replace_type_param_bound_paths(
253 bound,
254 target_path,
255 replacement_path,
256 )
257 })
258 .collect();
259
260 GenericParam::Type(type_param)
261 }
262 generic_param => generic_param,
263 })
264 .collect();
265
266 generics.where_clause = generics.where_clause.map(|mut where_clause| {
267 where_clause.predicates = where_clause
268 .predicates
269 .into_iter()
270 .map(|predicate| match predicate {
271 WherePredicate::Type(mut predicate_type) => {
272 predicate_type.bounded_ty = replace_path_in_type(
273 predicate_type.bounded_ty,
274 target_path,
275 replacement_path,
276 );
277
278 predicate_type.bounds = predicate_type
279 .bounds
280 .into_iter()
281 .map(|bound| {
282 replace_type_param_bound_paths(
283 bound,
284 target_path,
285 replacement_path,
286 )
287 })
288 .collect();
289
290 WherePredicate::Type(predicate_type)
291 }
292 predicate => predicate,
293 })
294 .collect();
295
296 where_clause
297 });
298
299 generics
300}
301
302fn replace_path_in_type(ty: Type, target_path: &Path, replacement_path: &Path) -> Type
303{
304 match ty {
305 Type::Ptr(mut type_ptr) => {
306 type_ptr.elem = Box::new(replace_path_in_type(
307 *type_ptr.elem,
308 target_path,
309 replacement_path,
310 ));
311
312 Type::Ptr(type_ptr)
313 }
314 Type::Path(mut type_path) => {
315 if &type_path.path == target_path {
316 type_path.path = replacement_path.clone();
317 } else {
318 type_path.path =
319 replace_path_args(type_path.path, target_path, replacement_path);
320 }
321
322 Type::Path(type_path)
323 }
324 Type::Array(mut type_array) => {
325 type_array.elem = Box::new(replace_path_in_type(
326 *type_array.elem,
327 target_path,
328 replacement_path,
329 ));
330
331 Type::Array(type_array)
332 }
333 Type::Group(mut type_group) => {
334 type_group.elem = Box::new(replace_path_in_type(
335 *type_group.elem,
336 target_path,
337 replacement_path,
338 ));
339
340 Type::Group(type_group)
341 }
342 Type::BareFn(type_bare_fn) => Type::BareFn(replace_type_bare_fn_type_paths(
343 type_bare_fn,
344 target_path,
345 replacement_path,
346 )),
347 Type::Paren(mut type_paren) => {
348 type_paren.elem = Box::new(replace_path_in_type(
349 *type_paren.elem,
350 target_path,
351 replacement_path,
352 ));
353
354 Type::Paren(type_paren)
355 }
356 Type::Slice(mut type_slice) => {
357 type_slice.elem = Box::new(replace_path_in_type(
358 *type_slice.elem,
359 target_path,
360 replacement_path,
361 ));
362
363 Type::Slice(type_slice)
364 }
365 Type::Tuple(mut type_tuple) => {
366 type_tuple.elems = type_tuple
367 .elems
368 .into_iter()
369 .map(|elem_type| {
370 replace_path_in_type(elem_type, target_path, replacement_path)
371 })
372 .collect();
373
374 Type::Tuple(type_tuple)
375 }
376 Type::Reference(mut type_reference) => {
377 type_reference.elem = Box::new(replace_path_in_type(
378 *type_reference.elem,
379 target_path,
380 replacement_path,
381 ));
382
383 Type::Reference(type_reference)
384 }
385 Type::TraitObject(mut type_trait_object) => {
386 type_trait_object.bounds = type_trait_object
387 .bounds
388 .into_iter()
389 .map(|bound| match bound {
390 TypeParamBound::Trait(mut trait_bound) => {
391 trait_bound.path = replace_path_args(
392 trait_bound.path,
393 target_path,
394 replacement_path,
395 );
396
397 TypeParamBound::Trait(trait_bound)
398 }
399 TypeParamBound::Lifetime(lifetime) => {
400 TypeParamBound::Lifetime(lifetime)
401 }
402 })
403 .collect();
404
405 Type::TraitObject(type_trait_object)
406 }
407 other_type => other_type,
408 }
409}
410
411fn replace_path_args(mut path: Path, target_path: &Path, replacement_path: &Path)
412 -> Path
413{
414 path.segments = path
415 .segments
416 .into_iter()
417 .map(|mut segment| {
418 segment.arguments = match segment.arguments {
419 PathArguments::AngleBracketed(mut generic_args) => {
420 generic_args.args = generic_args
421 .args
422 .into_iter()
423 .map(|generic_arg| match generic_arg {
424 GenericArgument::Type(ty) => GenericArgument::Type(
425 replace_path_in_type(ty, target_path, replacement_path),
426 ),
427 GenericArgument::Binding(mut binding) => {
428 binding.ty = replace_path_in_type(
429 binding.ty,
430 target_path,
431 replacement_path,
432 );
433
434 GenericArgument::Binding(binding)
435 }
436 generic_arg => generic_arg,
437 })
438 .collect();
439
440 PathArguments::AngleBracketed(generic_args)
441 }
442 PathArguments::Parenthesized(mut generic_args) => {
443 generic_args.inputs = generic_args
444 .inputs
445 .into_iter()
446 .map(|input_ty| {
447 replace_path_in_type(input_ty, target_path, replacement_path)
448 })
449 .collect();
450
451 generic_args.output = match generic_args.output {
452 ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
453 r_arrow,
454 Box::new(replace_path_in_type(
455 *return_type,
456 target_path,
457 replacement_path,
458 )),
459 ),
460 ReturnType::Default => ReturnType::Default,
461 };
462
463 PathArguments::Parenthesized(generic_args)
464 }
465 PathArguments::None => PathArguments::None,
466 };
467
468 segment
469 })
470 .collect();
471
472 path
473}
474
475fn replace_type_bare_fn_type_paths(
476 mut type_bare_fn: TypeBareFn,
477 target_path: &Path,
478 replacement_path: &Path,
479) -> TypeBareFn
480{
481 type_bare_fn.inputs = type_bare_fn
482 .inputs
483 .into_iter()
484 .map(|mut bare_fn_arg| {
485 bare_fn_arg.ty =
486 replace_path_in_type(bare_fn_arg.ty, target_path, replacement_path);
487
488 bare_fn_arg
489 })
490 .collect();
491
492 type_bare_fn.output = match type_bare_fn.output {
493 ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
494 r_arrow,
495 Box::new(replace_path_in_type(
496 *return_type,
497 target_path,
498 replacement_path,
499 )),
500 ),
501 ReturnType::Default => ReturnType::Default,
502 };
503
504 type_bare_fn
505}
506
507fn replace_type_param_bound_paths(
508 type_param_bound: TypeParamBound,
509 target_path: &Path,
510 replacement_path: &Path,
511) -> TypeParamBound
512{
513 match type_param_bound {
514 TypeParamBound::Trait(mut trait_bound) => {
515 if &trait_bound.path == target_path {
516 trait_bound.path = replacement_path.clone();
517 } else {
518 trait_bound.path =
519 replace_path_args(trait_bound.path, target_path, replacement_path);
520 }
521
522 TypeParamBound::Trait(trait_bound)
523 }
524 TypeParamBound::Lifetime(lifetime) => TypeParamBound::Lifetime(lifetime),
525 }
526}