ridicule_macros/
expectation.rs

1use proc_macro2::{Ident, TokenStream};
2use quote::{format_ident, quote, ToTokens};
3use syn::punctuated::Punctuated;
4use syn::token::Brace;
5use syn::{
6    AngleBracketedGenericArguments,
7    Attribute,
8    BareFnArg,
9    Field,
10    Fields,
11    FieldsNamed,
12    FnArg,
13    GenericArgument,
14    GenericParam,
15    Generics,
16    ImplItemMethod,
17    ItemStruct,
18    Lifetime,
19    Pat,
20    PatIdent,
21    PatType,
22    Path,
23    PathSegment,
24    Receiver,
25    ReturnType,
26    Token,
27    TraitBound,
28    TraitBoundModifier,
29    Type,
30    TypeBareFn,
31    TypeImplTrait,
32    TypeParamBound,
33    TypePath,
34    TypeReference,
35    Visibility,
36};
37
38use crate::syn_ext::{
39    AngleBracketedGenericArgumentsExt,
40    AttributeExt,
41    AttributeStyle,
42    BareFnArgExt,
43    GenericsExt,
44    IsMut,
45    LifetimeExt,
46    PathExt,
47    PathSegmentExt,
48    TypeBareFnExt,
49    TypePathExt,
50    TypeReferenceExt,
51    VisibilityExt,
52    WithColons,
53    WithLeadingColons,
54};
55use crate::util::{create_path, create_unit_type_tuple};
56
57pub struct Expectation
58{
59    ident: Ident,
60    method_ident: Ident,
61    method_generics: Generics,
62    generic_params: Punctuated<GenericParam, Token![,]>,
63    receiver: Option<Receiver>,
64    mock: Ident,
65    arg_types: Vec<Type>,
66    return_type: ReturnType,
67    phantom_fields: Vec<PhantomField>,
68}
69
70impl Expectation
71{
72    pub fn new(
73        mock: &Ident,
74        item_method: &ImplItemMethod,
75        generic_params: Punctuated<GenericParam, Token![,]>,
76    ) -> Self
77    {
78        let ident = create_expectation_ident(mock, &item_method.sig.ident);
79
80        let phantom_fields = Self::create_phantom_fields(
81            &item_method
82                .sig
83                .generics
84                .params
85                .clone()
86                .into_iter()
87                .chain(generic_params.clone())
88                .collect(),
89        );
90
91        let receiver =
92            item_method
93                .sig
94                .inputs
95                .first()
96                .and_then(|first_arg| match first_arg {
97                    FnArg::Receiver(receiver) => Some(receiver.clone()),
98                    FnArg::Typed(_) => None,
99                });
100
101        let arg_types = item_method
102            .sig
103            .inputs
104            .iter()
105            .filter_map(|arg| match arg {
106                FnArg::Typed(typed_arg) => Some(*typed_arg.ty.clone()),
107                FnArg::Receiver(_) => None,
108            })
109            .collect::<Vec<_>>();
110
111        let return_type = item_method.sig.output.clone();
112
113        Self {
114            ident,
115            method_ident: item_method.sig.ident.clone(),
116            method_generics: item_method.sig.generics.clone(),
117            generic_params,
118            receiver,
119            mock: mock.clone(),
120            arg_types,
121            return_type,
122            phantom_fields,
123        }
124    }
125
126    fn create_phantom_fields(
127        generic_params: &Punctuated<GenericParam, Token![,]>,
128    ) -> Vec<PhantomField>
129    {
130        generic_params
131            .iter()
132            .filter_map(|generic_param| match generic_param {
133                GenericParam::Type(type_param) => {
134                    let type_param_ident = &type_param.ident;
135
136                    let field_ident = create_phantom_field_ident(
137                        type_param_ident,
138                        &PhantomFieldKind::Type,
139                    );
140
141                    let ty = create_phantom_data_type_path([GenericArgument::Type(
142                        Type::Path(TypePath::new(Path::new(
143                            WithLeadingColons::No,
144                            [PathSegment::new(type_param_ident.clone(), None)],
145                        ))),
146                    )]);
147
148                    Some(PhantomField {
149                        field: field_ident,
150                        type_path: ty,
151                    })
152                }
153                GenericParam::Lifetime(lifetime_param) => {
154                    let lifetime = &lifetime_param.lifetime;
155
156                    let field_ident = create_phantom_field_ident(
157                        &lifetime.ident,
158                        &PhantomFieldKind::Lifetime,
159                    );
160
161                    let ty = create_phantom_data_type_path([GenericArgument::Type(
162                        Type::Reference(TypeReference::new(
163                            Some(lifetime.clone()),
164                            IsMut::No,
165                            Type::Tuple(create_unit_type_tuple()),
166                        )),
167                    )]);
168
169                    Some(PhantomField {
170                        field: field_ident,
171                        type_path: ty,
172                    })
173                }
174                GenericParam::Const(_) => None,
175            })
176            .collect()
177    }
178
179    fn create_struct(
180        ident: Ident,
181        generics: Generics,
182        phantom_fields: &[PhantomField],
183        boxed_predicate_types: &[Type],
184    ) -> ItemStruct
185    {
186        ItemStruct {
187            attrs: vec![Attribute::new(
188                AttributeStyle::Outer,
189                create_path!(allow),
190                quote! { (non_camel_case_types, non_snake_case) },
191            )],
192            vis: Visibility::new_pub_crate(),
193            struct_token: <Token![struct]>::default(),
194            ident,
195            generics: generics.strip_where_clause_and_bounds(),
196            fields: Fields::Named(FieldsNamed {
197                brace_token: Brace::default(),
198                named: [
199                    Field {
200                        attrs: vec![],
201                        vis: Visibility::Inherited,
202                        ident: Some(format_ident!("returning")),
203                        colon_token: Some(<Token![:]>::default()),
204                        ty: Type::Path(TypePath::new(Path::new(
205                            WithLeadingColons::No,
206                            [PathSegment::new(
207                                format_ident!("Option"),
208                                Some(AngleBracketedGenericArguments::new(
209                                    WithColons::No,
210                                    [GenericArgument::Type(Type::BareFn(
211                                        TypeBareFn::new([], ReturnType::Default),
212                                    ))],
213                                )),
214                            )],
215                        ))),
216                    },
217                    Field {
218                        attrs: vec![],
219                        vis: Visibility::Inherited,
220                        ident: Some(format_ident!("call_cnt")),
221                        colon_token: Some(<Token![:]>::default()),
222                        ty: Type::Path(TypePath::new(create_path!(
223                            ::std::sync::atomic::AtomicU32
224                        ))),
225                    },
226                    Field {
227                        attrs: vec![],
228                        vis: Visibility::Inherited,
229                        ident: Some(format_ident!("call_cnt_expectation")),
230                        colon_token: Some(<Token![:]>::default()),
231                        ty: Type::Path(TypePath::new(create_path!(
232                            ::ridicule::__private::CallCountExpectation
233                        ))),
234                    },
235                ]
236                .into_iter()
237                .chain(boxed_predicate_types.iter().enumerate().map(
238                    |(index, boxed_predicate_type)| Field {
239                        attrs: vec![],
240                        vis: Visibility::Inherited,
241                        ident: Some(format_ident!("predicate_{index}")),
242                        colon_token: Some(<Token![:]>::default()),
243                        ty: Type::Path(TypePath::new(Path::new(
244                            WithLeadingColons::No,
245                            [PathSegment::new(
246                                format_ident!("Option"),
247                                Some(AngleBracketedGenericArguments::new(
248                                    WithColons::No,
249                                    [GenericArgument::Type(boxed_predicate_type.clone())],
250                                )),
251                            )],
252                        ))),
253                    },
254                ))
255                .chain(phantom_fields.iter().cloned().map(Field::from))
256                .collect(),
257            }),
258            semi_token: None,
259        }
260    }
261}
262
263impl ToTokens for Expectation
264{
265    #[allow(clippy::too_many_lines)]
266    fn to_tokens(&self, tokens: &mut TokenStream)
267    {
268        let generics = {
269            let mut generics = self.method_generics.clone();
270
271            generics.params.extend(self.generic_params.clone());
272
273            generics
274        };
275
276        let generic_params = &generics.params;
277
278        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
279
280        let bogus_generics = create_bogus_generics(generic_params);
281
282        let opt_self_type = receiver_to_mock_self_type(&self.receiver, self.mock.clone());
283
284        let ident = &self.ident;
285        let phantom_fields = &self.phantom_fields;
286
287        let returning_fn = Type::BareFn(TypeBareFn::new(
288            opt_self_type
289                .iter()
290                .chain(self.arg_types.iter())
291                .map(|ty| BareFnArg::new(ty.clone())),
292            self.return_type.clone(),
293        ));
294
295        let method_ident = &self.method_ident;
296
297        let arg_types_no_refs = self
298            .arg_types
299            .iter()
300            .map(|arg_type| match arg_type {
301                Type::Reference(type_ref) => &*type_ref.elem,
302                ty => ty,
303            })
304            .collect::<Vec<_>>();
305
306        let predicate_paths = arg_types_no_refs
307            .iter()
308            .map(|arg_type| {
309                Path::new(
310                    WithLeadingColons::Yes,
311                    [
312                        PathSegment::new(format_ident!("ridicule"), None),
313                        PathSegment::new(
314                            format_ident!("Predicate"),
315                            Some(AngleBracketedGenericArguments::new(
316                                WithColons::No,
317                                [GenericArgument::Type((*arg_type).clone())],
318                            )),
319                        ),
320                    ],
321                )
322            })
323            .collect::<Vec<_>>();
324
325        let boxed_predicate_types = arg_types_no_refs
326            .iter()
327            .map(|arg_type| {
328                Type::Path(TypePath::new(Path::new(
329                    WithLeadingColons::Yes,
330                    [
331                        PathSegment::new(format_ident!("ridicule"), None),
332                        PathSegment::new(format_ident!("__private"), None),
333                        PathSegment::new(
334                            format_ident!("BoxPredicate"),
335                            Some(AngleBracketedGenericArguments::new(
336                                WithColons::No,
337                                [GenericArgument::Type((*arg_type).clone())],
338                            )),
339                        ),
340                    ],
341                )))
342            })
343            .collect::<Vec<_>>();
344
345        let expectation_struct = Self::create_struct(
346            self.ident.clone(),
347            generics.clone(),
348            phantom_fields,
349            &boxed_predicate_types,
350        );
351
352        let boundless_generics = generics.clone().strip_where_clause_and_bounds();
353
354        let (boundless_impl_generics, _, _) = boundless_generics.split_for_impl();
355
356        let do_strip_generic_params = if generic_params.is_empty() {
357            quote! { self }
358        } else {
359            quote! { unsafe { std::mem::transmute(self) } }
360        };
361
362        let with_arg_names = (0..self.arg_types.len())
363            .map(|index| format_ident!("predicate_{index}"))
364            .collect::<Vec<_>>();
365
366        let with_args =
367            predicate_paths
368                .iter()
369                .enumerate()
370                .map(|(index, predicate_path)| {
371                    FnArg::Typed(PatType {
372                        attrs: vec![],
373                        pat: Box::new(Pat::Ident(PatIdent {
374                            attrs: vec![],
375                            by_ref: None,
376                            mutability: None,
377                            ident: format_ident!("predicate_{index}"),
378                            subpat: None,
379                        })),
380                        colon_token: <Token![:]>::default(),
381                        ty: Box::new(Type::ImplTrait(TypeImplTrait {
382                            impl_token: <Token![impl]>::default(),
383                            bounds: [
384                                TypeParamBound::Trait(TraitBound {
385                                    paren_token: None,
386                                    modifier: TraitBoundModifier::None,
387                                    lifetimes: None,
388                                    path: predicate_path.clone(),
389                                }),
390                                TypeParamBound::Trait(TraitBound {
391                                    paren_token: None,
392                                    modifier: TraitBoundModifier::None,
393                                    lifetimes: None,
394                                    path: create_path!(Send),
395                                }),
396                                TypeParamBound::Trait(TraitBound {
397                                    paren_token: None,
398                                    modifier: TraitBoundModifier::None,
399                                    lifetimes: None,
400                                    path: create_path!(Sync),
401                                }),
402                                TypeParamBound::Lifetime(Lifetime::create(
403                                    format_ident!("static"),
404                                )),
405                            ]
406                            .into_iter()
407                            .collect(),
408                        })),
409                    })
410                });
411
412        let check_predicates_arg_names = (0..self.arg_types.len())
413            .map(|index| format_ident!("arg_{index}"))
414            .collect::<Vec<_>>();
415
416        let arg_types = &self.arg_types;
417
418        let predicate_field_inits = (0..boxed_predicate_types.len())
419            .map(|index| {
420                let ident = format_ident!("predicate_{index}");
421
422                quote! { #ident: None }
423            })
424            .collect::<Vec<_>>();
425
426        quote! {
427            #expectation_struct
428
429            impl #impl_generics #ident #ty_generics #where_clause
430            {
431                fn new() -> Self {
432                    Self {
433                        returning: None,
434                        call_cnt: ::std::sync::atomic::AtomicU32::new(0),
435                        call_cnt_expectation:
436                            ::ridicule::__private::CallCountExpectation::Unlimited,
437                        #(#predicate_field_inits,)*
438                        #(#phantom_fields),*
439                    }
440                }
441
442                /// Sets the function that will be called to provide the return value of
443                /// this expectation.
444                ///
445                /// # Safety
446                /// The caller must ensure that no argument or return type is outlived. They must
447                /// be treated as if they are bound to 'static.
448                #[allow(unused)]
449                pub unsafe fn returning(
450                    &mut self,
451                    func: #returning_fn
452                ) -> &mut Self
453                {
454                    self.returning =  Some(unsafe { std::mem::transmute(func) });
455
456                    self
457                }
458
459                pub fn times(&mut self, cnt: u32) -> &mut Self {
460                    self.call_cnt_expectation =
461                        ::ridicule::__private::CallCountExpectation::Times(cnt);
462
463                    self
464                }
465
466                pub fn never(&mut self) -> &mut Self {
467                    self.call_cnt_expectation =
468                        ::ridicule::__private::CallCountExpectation::Never;
469
470                    self
471                }
472
473                pub fn with(&mut self, #(#with_args),*) -> &mut Self
474                {
475                    #(
476                        self.#with_arg_names = Some(
477                            ::ridicule::__private::BoxPredicate::new(#with_arg_names)
478                        );
479                    )*
480
481                    self
482                }
483
484                fn check_predicates(&self, #(#check_predicates_arg_names: &#arg_types),*)
485                {
486                    use ::ridicule::Predicate;
487
488                    #(
489                        if let Some(predicate) = &self.#with_arg_names {
490                            if !predicate.eval(&#check_predicates_arg_names) {
491                                panic!("Predicate '{}' evaluated to false", predicate);
492                            }
493                        }
494                    )*
495                }
496
497                #[allow(unused)]
498                fn strip_generic_params(
499                    self,
500                ) -> #ident<#(#bogus_generics),*>
501                {
502                    #do_strip_generic_params
503                }
504
505                fn get_returning(&self) -> &#returning_fn
506                {
507                    let Some(returning) = &self.returning else {
508                        panic!(concat!(
509                            "Expectation for function",
510                            stringify!(#method_ident),
511                            " is missing a function to call")
512                        );
513                    };
514
515                    if matches!(
516                        self.call_cnt_expectation,
517                        ::ridicule::__private::CallCountExpectation::Never
518                    ) {
519                        panic!(
520                            "Expected function {} to never be called",
521                            stringify!(#method_ident)
522                        );
523                    }
524
525                    if let ::ridicule::__private::CallCountExpectation::Times(
526                        times
527                    ) = self.call_cnt_expectation {
528                        if times == self.call_cnt.load(
529                            ::std::sync::atomic::Ordering::Relaxed
530                        ) {
531                            panic!(
532                                concat!(
533                                    "Expected function {} to be called {} times. Was ",
534                                    "called {} times"
535                                ),
536                                stringify!(#method_ident),
537                                times,
538                                times + 1
539                            );
540                        }
541                    }
542
543                    self.call_cnt.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
544
545                    let returning_ptr: *const _ = returning;
546
547                    unsafe { &*returning_ptr.cast()}
548                }
549            }
550
551            impl #ident<#(#bogus_generics),*> {
552                #[allow(unused)]
553                fn with_generic_params<#generic_params>(
554                    &self,
555                ) -> &#ident #ty_generics
556                {
557                    // SAFETY: self is a pointer to a sane place, Rustc guarantees that
558                    // by it being a reference. The generic parameters doesn't affect
559                    // the size of self in any way, as they are only used in the function
560                    // pointer field "returning"
561                    unsafe { &*(self as *const Self).cast() }
562                }
563
564                #[allow(unused)]
565                fn with_generic_params_mut<#generic_params>(
566                    &mut self,
567                ) -> &mut #ident #ty_generics
568                {
569                    // SAFETY: self is a pointer to a sane place, Rustc guarantees that
570                    // by it being a reference. The generic parameters doesn't affect
571                    // the size of self in any way, as they are only used in the function
572                    // pointer field "returning"
573                    unsafe { &mut *(self as *mut Self).cast() }
574                }
575            }
576
577            impl #boundless_impl_generics #ident #ty_generics {
578                fn is_exhausted(&self) -> bool {
579                    if let ::ridicule::__private::CallCountExpectation::Times(times) =
580                        self.call_cnt_expectation
581                    {
582                        if times == self.call_cnt.load(
583                            ::std::sync::atomic::Ordering::Relaxed
584                        ) {
585                            return true;
586                        }
587                    }
588
589                    false
590                }
591            }
592
593            impl #boundless_impl_generics Drop for #ident #ty_generics
594            {
595                fn drop(&mut self) {
596                    let call_cnt =
597                        self.call_cnt.load(::std::sync::atomic::Ordering::Relaxed);
598
599                    if let ::ridicule::__private::CallCountExpectation::Times(
600                        times
601                    ) = self.call_cnt_expectation {
602                        if !::std::thread::panicking() && call_cnt != times {
603                            panic!(
604                                concat!(
605                                    "Expected function {} to be called {} times. Was ",
606                                    "called {} times"
607                                ),
608                                stringify!(#method_ident),
609                                times,
610                                call_cnt
611                            );
612                        }
613                    }
614                }
615            }
616        }
617        .to_tokens(tokens);
618    }
619}
620
621pub fn create_expectation_ident(mock: &Ident, method: &Ident) -> Ident
622{
623    format_ident!("{mock}Expectation_{method}")
624}
625
626#[derive(Clone)]
627struct PhantomField
628{
629    field: Ident,
630    type_path: TypePath,
631}
632
633impl ToTokens for PhantomField
634{
635    fn to_tokens(&self, tokens: &mut TokenStream)
636    {
637        self.field.to_tokens(tokens);
638
639        <Token![:]>::default().to_tokens(tokens);
640
641        self.type_path.to_tokens(tokens);
642    }
643}
644
645impl From<PhantomField> for Field
646{
647    fn from(phantom_field: PhantomField) -> Self
648    {
649        Self {
650            attrs: vec![],
651            vis: Visibility::Inherited,
652            ident: Some(phantom_field.field.clone()),
653            colon_token: Some(<Token![:]>::default()),
654            ty: Type::Path(phantom_field.type_path),
655        }
656    }
657}
658
659fn create_phantom_field_ident(ident: &Ident, kind: &PhantomFieldKind) -> Ident
660{
661    match kind {
662        PhantomFieldKind::Type => format_ident!("{ident}_phantom"),
663        PhantomFieldKind::Lifetime => format_ident!("{ident}_lt_phantom"),
664    }
665}
666
667enum PhantomFieldKind
668{
669    Type,
670    Lifetime,
671}
672
673fn create_phantom_data_type_path(
674    generic_args: impl IntoIterator<Item = GenericArgument>,
675) -> TypePath
676{
677    TypePath::new(Path::new(
678        WithLeadingColons::Yes,
679        [
680            PathSegment::new(format_ident!("std"), None),
681            PathSegment::new(format_ident!("marker"), None),
682            PathSegment::new(
683                format_ident!("PhantomData"),
684                Some(AngleBracketedGenericArguments::new(
685                    WithColons::Yes,
686                    generic_args,
687                )),
688            ),
689        ],
690    ))
691}
692
693fn create_bogus_generics(
694    generic_params: &Punctuated<GenericParam, Token![,]>,
695) -> Vec<GenericArgument>
696{
697    generic_params
698        .iter()
699        .filter_map(|generic_param| match generic_param {
700            GenericParam::Type(_) => {
701                Some(GenericArgument::Type(Type::Tuple(create_unit_type_tuple())))
702            }
703            GenericParam::Lifetime(_) => Some(GenericArgument::Lifetime(
704                Lifetime::create(format_ident!("static")),
705            )),
706            GenericParam::Const(_) => None,
707        })
708        .collect()
709}
710
711fn receiver_to_mock_self_type(receiver: &Option<Receiver>, mock: Ident) -> Option<Type>
712{
713    receiver.as_ref().map(|receiver| {
714        let self_type = Type::Path(TypePath::new(Path::new(
715            WithLeadingColons::No,
716            [PathSegment::new(mock, None)],
717        )));
718
719        if let Some((_, lifetime)) = &receiver.reference {
720            return Type::Reference(TypeReference::new(
721                lifetime.clone(),
722                receiver.mutability.into(),
723                self_type,
724            ));
725        }
726
727        self_type
728    })
729}