ridicule_macros/
lib.rs

1//! Macros for Ridicule, a mocking library supporting non-static generics.
2#![deny(clippy::all, clippy::pedantic, missing_docs)]
3use proc_macro::TokenStream;
4use proc_macro2::Ident;
5use proc_macro_error::{proc_macro_error, ResultExt};
6use quote::{format_ident, quote};
7use syn::token::Brace;
8use syn::{
9    parse,
10    Block,
11    FnArg,
12    GenericArgument,
13    GenericParam,
14    Generics,
15    ImplItem,
16    ImplItemMethod,
17    ItemTrait,
18    Path,
19    PathArguments,
20    PathSegment,
21    ReturnType,
22    TraitItem,
23    Type,
24    TypeBareFn,
25    TypeParamBound,
26    Visibility,
27    WherePredicate,
28};
29
30use crate::expectation::Expectation;
31use crate::mock::Mock;
32use crate::mock_input::MockInput;
33use crate::syn_ext::{PathExt, PathSegmentExt, WithLeadingColons};
34use crate::util::create_path;
35
36mod expectation;
37mod mock;
38mod mock_input;
39mod syn_ext;
40mod util;
41
42/// Creates a mock.
43///
44/// # Examples
45/// ```
46/// use ridicule::mock;
47///
48/// trait Foo
49/// {
50///     fn bar<A, B>(&self, a: A) -> B;
51/// }
52///
53/// mock! {
54///     MockFoo {}
55///
56///     impl Foo for MockFoo
57///     {
58///         fn bar<A, B>(&self, a: A) -> B;
59///     }
60/// }
61///
62/// fn main()
63/// {
64///     let mut mock_foo = MockFoo::new();
65///
66///     unsafe {
67///         mock_foo
68///             .expect_bar()
69///             .returning(|foo, a: u32| format!("Hello {a}"));
70///     }
71///
72///     assert_eq!(mock_foo.bar::<u32, String>(123), "Hello 123");
73/// }
74/// ```
75#[proc_macro]
76#[proc_macro_error]
77pub fn mock(input_stream: TokenStream) -> TokenStream
78{
79    let input = parse::<MockInput>(input_stream.clone()).unwrap_or_abort();
80
81    let mock_ident = input.mock;
82
83    let mock_mod_ident = format_ident!("__{mock_ident}");
84
85    let method_items =
86        get_type_replaced_impl_item_methods(input.item_impl.items, &mock_ident);
87
88    let mock = Mock::new(
89        mock_ident.clone(),
90        input.mocked_trait,
91        &method_items,
92        input.item_impl.generics.clone(),
93    );
94
95    let expectations = method_items.iter().map(|item_method| {
96        Expectation::new(
97            &mock_ident,
98            item_method,
99            input.item_impl.generics.params.clone(),
100        )
101    });
102
103    quote! {
104        mod #mock_mod_ident {
105            use super::*;
106
107            #mock
108
109            #(#expectations)*
110        }
111
112        use #mock_mod_ident::#mock_ident;
113    }
114    .into()
115}
116
117/// Creates a mock automatically.
118#[proc_macro_attribute]
119#[proc_macro_error]
120pub fn automock(_: TokenStream, input_stream: TokenStream) -> TokenStream
121{
122    let item_trait = parse::<ItemTrait>(input_stream).unwrap_or_abort();
123
124    let mock_ident = format_ident!("Mock{}", item_trait.ident);
125
126    let mock_mod_ident = format_ident!("__{mock_ident}");
127
128    let method_items = get_type_replaced_impl_item_methods(
129        item_trait.items.iter().filter_map(|item| match item {
130            TraitItem::Method(item_method) => Some(ImplItem::Method(ImplItemMethod {
131                attrs: item_method.attrs.clone(),
132                vis: Visibility::Inherited,
133                defaultness: None,
134                sig: item_method.sig.clone(),
135                block: Block {
136                    brace_token: Brace::default(),
137                    stmts: vec![],
138                },
139            })),
140            _ => None,
141        }),
142        &mock_ident,
143    );
144
145    let mock = Mock::new(
146        mock_ident.clone(),
147        Path::new(
148            WithLeadingColons::No,
149            [PathSegment::new(item_trait.ident.clone(), None)],
150        ),
151        &method_items,
152        item_trait.generics.clone(),
153    );
154
155    let expectations = method_items.iter().map(|item_method| {
156        Expectation::new(&mock_ident, item_method, item_trait.generics.params.clone())
157    });
158
159    let visibility = &item_trait.vis;
160
161    quote! {
162        #item_trait
163
164        mod #mock_mod_ident {
165            use super::*;
166
167            #mock
168
169            #(#expectations)*
170        }
171
172        #visibility use #mock_mod_ident::#mock_ident;
173    }
174    .into()
175}
176
177fn get_type_replaced_impl_item_methods(
178    impl_items: impl IntoIterator<Item = ImplItem>,
179    mock_ident: &Ident,
180) -> Vec<ImplItemMethod>
181{
182    let target_path = create_path!(Self);
183
184    let replacement_path = Path::new(
185        WithLeadingColons::No,
186        [PathSegment::new(mock_ident.clone(), None)],
187    );
188
189    impl_items
190        .into_iter()
191        .filter_map(|item| match item {
192            ImplItem::Method(mut item_method) => {
193                item_method.sig.inputs = item_method
194                    .sig
195                    .inputs
196                    .into_iter()
197                    .map(|fn_arg| match fn_arg {
198                        FnArg::Typed(mut typed_arg) => {
199                            typed_arg.ty = Box::new(replace_path_in_type(
200                                *typed_arg.ty,
201                                &target_path,
202                                &replacement_path,
203                            ));
204
205                            FnArg::Typed(typed_arg)
206                        }
207
208                        FnArg::Receiver(receiver) => FnArg::Receiver(receiver),
209                    })
210                    .collect();
211
212                item_method.sig.output = match item_method.sig.output {
213                    ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
214                        r_arrow,
215                        Box::new(replace_path_in_type(
216                            *return_type,
217                            &target_path,
218                            &replacement_path,
219                        )),
220                    ),
221                    ReturnType::Default => ReturnType::Default,
222                };
223
224                item_method.sig.generics = replace_path_in_generics(
225                    item_method.sig.generics,
226                    &target_path,
227                    &replacement_path,
228                );
229
230                Some(item_method)
231            }
232            _ => None,
233        })
234        .collect()
235}
236
237fn replace_path_in_generics(
238    mut generics: Generics,
239    target_path: &Path,
240    replacement_path: &Path,
241) -> Generics
242{
243    generics.params = generics
244        .params
245        .into_iter()
246        .map(|generic_param| match generic_param {
247            GenericParam::Type(mut type_param) => {
248                type_param.bounds = type_param
249                    .bounds
250                    .into_iter()
251                    .map(|bound| {
252                        replace_type_param_bound_paths(
253                            bound,
254                            target_path,
255                            replacement_path,
256                        )
257                    })
258                    .collect();
259
260                GenericParam::Type(type_param)
261            }
262            generic_param => generic_param,
263        })
264        .collect();
265
266    generics.where_clause = generics.where_clause.map(|mut where_clause| {
267        where_clause.predicates = where_clause
268            .predicates
269            .into_iter()
270            .map(|predicate| match predicate {
271                WherePredicate::Type(mut predicate_type) => {
272                    predicate_type.bounded_ty = replace_path_in_type(
273                        predicate_type.bounded_ty,
274                        target_path,
275                        replacement_path,
276                    );
277
278                    predicate_type.bounds = predicate_type
279                        .bounds
280                        .into_iter()
281                        .map(|bound| {
282                            replace_type_param_bound_paths(
283                                bound,
284                                target_path,
285                                replacement_path,
286                            )
287                        })
288                        .collect();
289
290                    WherePredicate::Type(predicate_type)
291                }
292                predicate => predicate,
293            })
294            .collect();
295
296        where_clause
297    });
298
299    generics
300}
301
302fn replace_path_in_type(ty: Type, target_path: &Path, replacement_path: &Path) -> Type
303{
304    match ty {
305        Type::Ptr(mut type_ptr) => {
306            type_ptr.elem = Box::new(replace_path_in_type(
307                *type_ptr.elem,
308                target_path,
309                replacement_path,
310            ));
311
312            Type::Ptr(type_ptr)
313        }
314        Type::Path(mut type_path) => {
315            if &type_path.path == target_path {
316                type_path.path = replacement_path.clone();
317            } else {
318                type_path.path =
319                    replace_path_args(type_path.path, target_path, replacement_path);
320            }
321
322            Type::Path(type_path)
323        }
324        Type::Array(mut type_array) => {
325            type_array.elem = Box::new(replace_path_in_type(
326                *type_array.elem,
327                target_path,
328                replacement_path,
329            ));
330
331            Type::Array(type_array)
332        }
333        Type::Group(mut type_group) => {
334            type_group.elem = Box::new(replace_path_in_type(
335                *type_group.elem,
336                target_path,
337                replacement_path,
338            ));
339
340            Type::Group(type_group)
341        }
342        Type::BareFn(type_bare_fn) => Type::BareFn(replace_type_bare_fn_type_paths(
343            type_bare_fn,
344            target_path,
345            replacement_path,
346        )),
347        Type::Paren(mut type_paren) => {
348            type_paren.elem = Box::new(replace_path_in_type(
349                *type_paren.elem,
350                target_path,
351                replacement_path,
352            ));
353
354            Type::Paren(type_paren)
355        }
356        Type::Slice(mut type_slice) => {
357            type_slice.elem = Box::new(replace_path_in_type(
358                *type_slice.elem,
359                target_path,
360                replacement_path,
361            ));
362
363            Type::Slice(type_slice)
364        }
365        Type::Tuple(mut type_tuple) => {
366            type_tuple.elems = type_tuple
367                .elems
368                .into_iter()
369                .map(|elem_type| {
370                    replace_path_in_type(elem_type, target_path, replacement_path)
371                })
372                .collect();
373
374            Type::Tuple(type_tuple)
375        }
376        Type::Reference(mut type_reference) => {
377            type_reference.elem = Box::new(replace_path_in_type(
378                *type_reference.elem,
379                target_path,
380                replacement_path,
381            ));
382
383            Type::Reference(type_reference)
384        }
385        Type::TraitObject(mut type_trait_object) => {
386            type_trait_object.bounds = type_trait_object
387                .bounds
388                .into_iter()
389                .map(|bound| match bound {
390                    TypeParamBound::Trait(mut trait_bound) => {
391                        trait_bound.path = replace_path_args(
392                            trait_bound.path,
393                            target_path,
394                            replacement_path,
395                        );
396
397                        TypeParamBound::Trait(trait_bound)
398                    }
399                    TypeParamBound::Lifetime(lifetime) => {
400                        TypeParamBound::Lifetime(lifetime)
401                    }
402                })
403                .collect();
404
405            Type::TraitObject(type_trait_object)
406        }
407        other_type => other_type,
408    }
409}
410
411fn replace_path_args(mut path: Path, target_path: &Path, replacement_path: &Path)
412    -> Path
413{
414    path.segments = path
415        .segments
416        .into_iter()
417        .map(|mut segment| {
418            segment.arguments = match segment.arguments {
419                PathArguments::AngleBracketed(mut generic_args) => {
420                    generic_args.args = generic_args
421                        .args
422                        .into_iter()
423                        .map(|generic_arg| match generic_arg {
424                            GenericArgument::Type(ty) => GenericArgument::Type(
425                                replace_path_in_type(ty, target_path, replacement_path),
426                            ),
427                            GenericArgument::Binding(mut binding) => {
428                                binding.ty = replace_path_in_type(
429                                    binding.ty,
430                                    target_path,
431                                    replacement_path,
432                                );
433
434                                GenericArgument::Binding(binding)
435                            }
436                            generic_arg => generic_arg,
437                        })
438                        .collect();
439
440                    PathArguments::AngleBracketed(generic_args)
441                }
442                PathArguments::Parenthesized(mut generic_args) => {
443                    generic_args.inputs = generic_args
444                        .inputs
445                        .into_iter()
446                        .map(|input_ty| {
447                            replace_path_in_type(input_ty, target_path, replacement_path)
448                        })
449                        .collect();
450
451                    generic_args.output = match generic_args.output {
452                        ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
453                            r_arrow,
454                            Box::new(replace_path_in_type(
455                                *return_type,
456                                target_path,
457                                replacement_path,
458                            )),
459                        ),
460                        ReturnType::Default => ReturnType::Default,
461                    };
462
463                    PathArguments::Parenthesized(generic_args)
464                }
465                PathArguments::None => PathArguments::None,
466            };
467
468            segment
469        })
470        .collect();
471
472    path
473}
474
475fn replace_type_bare_fn_type_paths(
476    mut type_bare_fn: TypeBareFn,
477    target_path: &Path,
478    replacement_path: &Path,
479) -> TypeBareFn
480{
481    type_bare_fn.inputs = type_bare_fn
482        .inputs
483        .into_iter()
484        .map(|mut bare_fn_arg| {
485            bare_fn_arg.ty =
486                replace_path_in_type(bare_fn_arg.ty, target_path, replacement_path);
487
488            bare_fn_arg
489        })
490        .collect();
491
492    type_bare_fn.output = match type_bare_fn.output {
493        ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
494            r_arrow,
495            Box::new(replace_path_in_type(
496                *return_type,
497                target_path,
498                replacement_path,
499            )),
500        ),
501        ReturnType::Default => ReturnType::Default,
502    };
503
504    type_bare_fn
505}
506
507fn replace_type_param_bound_paths(
508    type_param_bound: TypeParamBound,
509    target_path: &Path,
510    replacement_path: &Path,
511) -> TypeParamBound
512{
513    match type_param_bound {
514        TypeParamBound::Trait(mut trait_bound) => {
515            if &trait_bound.path == target_path {
516                trait_bound.path = replacement_path.clone();
517            } else {
518                trait_bound.path =
519                    replace_path_args(trait_bound.path, target_path, replacement_path);
520            }
521
522            TypeParamBound::Trait(trait_bound)
523        }
524        TypeParamBound::Lifetime(lifetime) => TypeParamBound::Lifetime(lifetime),
525    }
526}