1use proc_macro2::{Ident, TokenStream};
2use quote::{format_ident, quote, ToTokens};
3use syn::punctuated::Punctuated;
4use syn::token::Brace;
5use syn::{
6 AngleBracketedGenericArguments,
7 Attribute,
8 BareFnArg,
9 Field,
10 Fields,
11 FieldsNamed,
12 FnArg,
13 GenericArgument,
14 GenericParam,
15 Generics,
16 ImplItemMethod,
17 ItemStruct,
18 Lifetime,
19 Pat,
20 PatIdent,
21 PatType,
22 Path,
23 PathSegment,
24 Receiver,
25 ReturnType,
26 Token,
27 TraitBound,
28 TraitBoundModifier,
29 Type,
30 TypeBareFn,
31 TypeImplTrait,
32 TypeParamBound,
33 TypePath,
34 TypeReference,
35 Visibility,
36};
37
38use crate::syn_ext::{
39 AngleBracketedGenericArgumentsExt,
40 AttributeExt,
41 AttributeStyle,
42 BareFnArgExt,
43 GenericsExt,
44 IsMut,
45 LifetimeExt,
46 PathExt,
47 PathSegmentExt,
48 TypeBareFnExt,
49 TypePathExt,
50 TypeReferenceExt,
51 VisibilityExt,
52 WithColons,
53 WithLeadingColons,
54};
55use crate::util::{create_path, create_unit_type_tuple};
56
57pub struct Expectation
58{
59 ident: Ident,
60 method_ident: Ident,
61 method_generics: Generics,
62 generic_params: Punctuated<GenericParam, Token![,]>,
63 receiver: Option<Receiver>,
64 mock: Ident,
65 arg_types: Vec<Type>,
66 return_type: ReturnType,
67 phantom_fields: Vec<PhantomField>,
68}
69
70impl Expectation
71{
72 pub fn new(
73 mock: &Ident,
74 item_method: &ImplItemMethod,
75 generic_params: Punctuated<GenericParam, Token![,]>,
76 ) -> Self
77 {
78 let ident = create_expectation_ident(mock, &item_method.sig.ident);
79
80 let phantom_fields = Self::create_phantom_fields(
81 &item_method
82 .sig
83 .generics
84 .params
85 .clone()
86 .into_iter()
87 .chain(generic_params.clone())
88 .collect(),
89 );
90
91 let receiver =
92 item_method
93 .sig
94 .inputs
95 .first()
96 .and_then(|first_arg| match first_arg {
97 FnArg::Receiver(receiver) => Some(receiver.clone()),
98 FnArg::Typed(_) => None,
99 });
100
101 let arg_types = item_method
102 .sig
103 .inputs
104 .iter()
105 .filter_map(|arg| match arg {
106 FnArg::Typed(typed_arg) => Some(*typed_arg.ty.clone()),
107 FnArg::Receiver(_) => None,
108 })
109 .collect::<Vec<_>>();
110
111 let return_type = item_method.sig.output.clone();
112
113 Self {
114 ident,
115 method_ident: item_method.sig.ident.clone(),
116 method_generics: item_method.sig.generics.clone(),
117 generic_params,
118 receiver,
119 mock: mock.clone(),
120 arg_types,
121 return_type,
122 phantom_fields,
123 }
124 }
125
126 fn create_phantom_fields(
127 generic_params: &Punctuated<GenericParam, Token![,]>,
128 ) -> Vec<PhantomField>
129 {
130 generic_params
131 .iter()
132 .filter_map(|generic_param| match generic_param {
133 GenericParam::Type(type_param) => {
134 let type_param_ident = &type_param.ident;
135
136 let field_ident = create_phantom_field_ident(
137 type_param_ident,
138 &PhantomFieldKind::Type,
139 );
140
141 let ty = create_phantom_data_type_path([GenericArgument::Type(
142 Type::Path(TypePath::new(Path::new(
143 WithLeadingColons::No,
144 [PathSegment::new(type_param_ident.clone(), None)],
145 ))),
146 )]);
147
148 Some(PhantomField {
149 field: field_ident,
150 type_path: ty,
151 })
152 }
153 GenericParam::Lifetime(lifetime_param) => {
154 let lifetime = &lifetime_param.lifetime;
155
156 let field_ident = create_phantom_field_ident(
157 &lifetime.ident,
158 &PhantomFieldKind::Lifetime,
159 );
160
161 let ty = create_phantom_data_type_path([GenericArgument::Type(
162 Type::Reference(TypeReference::new(
163 Some(lifetime.clone()),
164 IsMut::No,
165 Type::Tuple(create_unit_type_tuple()),
166 )),
167 )]);
168
169 Some(PhantomField {
170 field: field_ident,
171 type_path: ty,
172 })
173 }
174 GenericParam::Const(_) => None,
175 })
176 .collect()
177 }
178
179 fn create_struct(
180 ident: Ident,
181 generics: Generics,
182 phantom_fields: &[PhantomField],
183 boxed_predicate_types: &[Type],
184 ) -> ItemStruct
185 {
186 ItemStruct {
187 attrs: vec![Attribute::new(
188 AttributeStyle::Outer,
189 create_path!(allow),
190 quote! { (non_camel_case_types, non_snake_case) },
191 )],
192 vis: Visibility::new_pub_crate(),
193 struct_token: <Token![struct]>::default(),
194 ident,
195 generics: generics.strip_where_clause_and_bounds(),
196 fields: Fields::Named(FieldsNamed {
197 brace_token: Brace::default(),
198 named: [
199 Field {
200 attrs: vec![],
201 vis: Visibility::Inherited,
202 ident: Some(format_ident!("returning")),
203 colon_token: Some(<Token![:]>::default()),
204 ty: Type::Path(TypePath::new(Path::new(
205 WithLeadingColons::No,
206 [PathSegment::new(
207 format_ident!("Option"),
208 Some(AngleBracketedGenericArguments::new(
209 WithColons::No,
210 [GenericArgument::Type(Type::BareFn(
211 TypeBareFn::new([], ReturnType::Default),
212 ))],
213 )),
214 )],
215 ))),
216 },
217 Field {
218 attrs: vec![],
219 vis: Visibility::Inherited,
220 ident: Some(format_ident!("call_cnt")),
221 colon_token: Some(<Token![:]>::default()),
222 ty: Type::Path(TypePath::new(create_path!(
223 ::std::sync::atomic::AtomicU32
224 ))),
225 },
226 Field {
227 attrs: vec![],
228 vis: Visibility::Inherited,
229 ident: Some(format_ident!("call_cnt_expectation")),
230 colon_token: Some(<Token![:]>::default()),
231 ty: Type::Path(TypePath::new(create_path!(
232 ::ridicule::__private::CallCountExpectation
233 ))),
234 },
235 ]
236 .into_iter()
237 .chain(boxed_predicate_types.iter().enumerate().map(
238 |(index, boxed_predicate_type)| Field {
239 attrs: vec![],
240 vis: Visibility::Inherited,
241 ident: Some(format_ident!("predicate_{index}")),
242 colon_token: Some(<Token![:]>::default()),
243 ty: Type::Path(TypePath::new(Path::new(
244 WithLeadingColons::No,
245 [PathSegment::new(
246 format_ident!("Option"),
247 Some(AngleBracketedGenericArguments::new(
248 WithColons::No,
249 [GenericArgument::Type(boxed_predicate_type.clone())],
250 )),
251 )],
252 ))),
253 },
254 ))
255 .chain(phantom_fields.iter().cloned().map(Field::from))
256 .collect(),
257 }),
258 semi_token: None,
259 }
260 }
261}
262
263impl ToTokens for Expectation
264{
265 #[allow(clippy::too_many_lines)]
266 fn to_tokens(&self, tokens: &mut TokenStream)
267 {
268 let generics = {
269 let mut generics = self.method_generics.clone();
270
271 generics.params.extend(self.generic_params.clone());
272
273 generics
274 };
275
276 let generic_params = &generics.params;
277
278 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
279
280 let bogus_generics = create_bogus_generics(generic_params);
281
282 let opt_self_type = receiver_to_mock_self_type(&self.receiver, self.mock.clone());
283
284 let ident = &self.ident;
285 let phantom_fields = &self.phantom_fields;
286
287 let returning_fn = Type::BareFn(TypeBareFn::new(
288 opt_self_type
289 .iter()
290 .chain(self.arg_types.iter())
291 .map(|ty| BareFnArg::new(ty.clone())),
292 self.return_type.clone(),
293 ));
294
295 let method_ident = &self.method_ident;
296
297 let arg_types_no_refs = self
298 .arg_types
299 .iter()
300 .map(|arg_type| match arg_type {
301 Type::Reference(type_ref) => &*type_ref.elem,
302 ty => ty,
303 })
304 .collect::<Vec<_>>();
305
306 let predicate_paths = arg_types_no_refs
307 .iter()
308 .map(|arg_type| {
309 Path::new(
310 WithLeadingColons::Yes,
311 [
312 PathSegment::new(format_ident!("ridicule"), None),
313 PathSegment::new(
314 format_ident!("Predicate"),
315 Some(AngleBracketedGenericArguments::new(
316 WithColons::No,
317 [GenericArgument::Type((*arg_type).clone())],
318 )),
319 ),
320 ],
321 )
322 })
323 .collect::<Vec<_>>();
324
325 let boxed_predicate_types = arg_types_no_refs
326 .iter()
327 .map(|arg_type| {
328 Type::Path(TypePath::new(Path::new(
329 WithLeadingColons::Yes,
330 [
331 PathSegment::new(format_ident!("ridicule"), None),
332 PathSegment::new(format_ident!("__private"), None),
333 PathSegment::new(
334 format_ident!("BoxPredicate"),
335 Some(AngleBracketedGenericArguments::new(
336 WithColons::No,
337 [GenericArgument::Type((*arg_type).clone())],
338 )),
339 ),
340 ],
341 )))
342 })
343 .collect::<Vec<_>>();
344
345 let expectation_struct = Self::create_struct(
346 self.ident.clone(),
347 generics.clone(),
348 phantom_fields,
349 &boxed_predicate_types,
350 );
351
352 let boundless_generics = generics.clone().strip_where_clause_and_bounds();
353
354 let (boundless_impl_generics, _, _) = boundless_generics.split_for_impl();
355
356 let do_strip_generic_params = if generic_params.is_empty() {
357 quote! { self }
358 } else {
359 quote! { unsafe { std::mem::transmute(self) } }
360 };
361
362 let with_arg_names = (0..self.arg_types.len())
363 .map(|index| format_ident!("predicate_{index}"))
364 .collect::<Vec<_>>();
365
366 let with_args =
367 predicate_paths
368 .iter()
369 .enumerate()
370 .map(|(index, predicate_path)| {
371 FnArg::Typed(PatType {
372 attrs: vec![],
373 pat: Box::new(Pat::Ident(PatIdent {
374 attrs: vec![],
375 by_ref: None,
376 mutability: None,
377 ident: format_ident!("predicate_{index}"),
378 subpat: None,
379 })),
380 colon_token: <Token![:]>::default(),
381 ty: Box::new(Type::ImplTrait(TypeImplTrait {
382 impl_token: <Token![impl]>::default(),
383 bounds: [
384 TypeParamBound::Trait(TraitBound {
385 paren_token: None,
386 modifier: TraitBoundModifier::None,
387 lifetimes: None,
388 path: predicate_path.clone(),
389 }),
390 TypeParamBound::Trait(TraitBound {
391 paren_token: None,
392 modifier: TraitBoundModifier::None,
393 lifetimes: None,
394 path: create_path!(Send),
395 }),
396 TypeParamBound::Trait(TraitBound {
397 paren_token: None,
398 modifier: TraitBoundModifier::None,
399 lifetimes: None,
400 path: create_path!(Sync),
401 }),
402 TypeParamBound::Lifetime(Lifetime::create(
403 format_ident!("static"),
404 )),
405 ]
406 .into_iter()
407 .collect(),
408 })),
409 })
410 });
411
412 let check_predicates_arg_names = (0..self.arg_types.len())
413 .map(|index| format_ident!("arg_{index}"))
414 .collect::<Vec<_>>();
415
416 let arg_types = &self.arg_types;
417
418 let predicate_field_inits = (0..boxed_predicate_types.len())
419 .map(|index| {
420 let ident = format_ident!("predicate_{index}");
421
422 quote! { #ident: None }
423 })
424 .collect::<Vec<_>>();
425
426 quote! {
427 #expectation_struct
428
429 impl #impl_generics #ident #ty_generics #where_clause
430 {
431 fn new() -> Self {
432 Self {
433 returning: None,
434 call_cnt: ::std::sync::atomic::AtomicU32::new(0),
435 call_cnt_expectation:
436 ::ridicule::__private::CallCountExpectation::Unlimited,
437 #(#predicate_field_inits,)*
438 #(#phantom_fields),*
439 }
440 }
441
442 #[allow(unused)]
449 pub unsafe fn returning(
450 &mut self,
451 func: #returning_fn
452 ) -> &mut Self
453 {
454 self.returning = Some(unsafe { std::mem::transmute(func) });
455
456 self
457 }
458
459 pub fn times(&mut self, cnt: u32) -> &mut Self {
460 self.call_cnt_expectation =
461 ::ridicule::__private::CallCountExpectation::Times(cnt);
462
463 self
464 }
465
466 pub fn never(&mut self) -> &mut Self {
467 self.call_cnt_expectation =
468 ::ridicule::__private::CallCountExpectation::Never;
469
470 self
471 }
472
473 pub fn with(&mut self, #(#with_args),*) -> &mut Self
474 {
475 #(
476 self.#with_arg_names = Some(
477 ::ridicule::__private::BoxPredicate::new(#with_arg_names)
478 );
479 )*
480
481 self
482 }
483
484 fn check_predicates(&self, #(#check_predicates_arg_names: &#arg_types),*)
485 {
486 use ::ridicule::Predicate;
487
488 #(
489 if let Some(predicate) = &self.#with_arg_names {
490 if !predicate.eval(&#check_predicates_arg_names) {
491 panic!("Predicate '{}' evaluated to false", predicate);
492 }
493 }
494 )*
495 }
496
497 #[allow(unused)]
498 fn strip_generic_params(
499 self,
500 ) -> #ident<#(#bogus_generics),*>
501 {
502 #do_strip_generic_params
503 }
504
505 fn get_returning(&self) -> &#returning_fn
506 {
507 let Some(returning) = &self.returning else {
508 panic!(concat!(
509 "Expectation for function",
510 stringify!(#method_ident),
511 " is missing a function to call")
512 );
513 };
514
515 if matches!(
516 self.call_cnt_expectation,
517 ::ridicule::__private::CallCountExpectation::Never
518 ) {
519 panic!(
520 "Expected function {} to never be called",
521 stringify!(#method_ident)
522 );
523 }
524
525 if let ::ridicule::__private::CallCountExpectation::Times(
526 times
527 ) = self.call_cnt_expectation {
528 if times == self.call_cnt.load(
529 ::std::sync::atomic::Ordering::Relaxed
530 ) {
531 panic!(
532 concat!(
533 "Expected function {} to be called {} times. Was ",
534 "called {} times"
535 ),
536 stringify!(#method_ident),
537 times,
538 times + 1
539 );
540 }
541 }
542
543 self.call_cnt.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
544
545 let returning_ptr: *const _ = returning;
546
547 unsafe { &*returning_ptr.cast()}
548 }
549 }
550
551 impl #ident<#(#bogus_generics),*> {
552 #[allow(unused)]
553 fn with_generic_params<#generic_params>(
554 &self,
555 ) -> &#ident #ty_generics
556 {
557 unsafe { &*(self as *const Self).cast() }
562 }
563
564 #[allow(unused)]
565 fn with_generic_params_mut<#generic_params>(
566 &mut self,
567 ) -> &mut #ident #ty_generics
568 {
569 unsafe { &mut *(self as *mut Self).cast() }
574 }
575 }
576
577 impl #boundless_impl_generics #ident #ty_generics {
578 fn is_exhausted(&self) -> bool {
579 if let ::ridicule::__private::CallCountExpectation::Times(times) =
580 self.call_cnt_expectation
581 {
582 if times == self.call_cnt.load(
583 ::std::sync::atomic::Ordering::Relaxed
584 ) {
585 return true;
586 }
587 }
588
589 false
590 }
591 }
592
593 impl #boundless_impl_generics Drop for #ident #ty_generics
594 {
595 fn drop(&mut self) {
596 let call_cnt =
597 self.call_cnt.load(::std::sync::atomic::Ordering::Relaxed);
598
599 if let ::ridicule::__private::CallCountExpectation::Times(
600 times
601 ) = self.call_cnt_expectation {
602 if !::std::thread::panicking() && call_cnt != times {
603 panic!(
604 concat!(
605 "Expected function {} to be called {} times. Was ",
606 "called {} times"
607 ),
608 stringify!(#method_ident),
609 times,
610 call_cnt
611 );
612 }
613 }
614 }
615 }
616 }
617 .to_tokens(tokens);
618 }
619}
620
621pub fn create_expectation_ident(mock: &Ident, method: &Ident) -> Ident
622{
623 format_ident!("{mock}Expectation_{method}")
624}
625
626#[derive(Clone)]
627struct PhantomField
628{
629 field: Ident,
630 type_path: TypePath,
631}
632
633impl ToTokens for PhantomField
634{
635 fn to_tokens(&self, tokens: &mut TokenStream)
636 {
637 self.field.to_tokens(tokens);
638
639 <Token![:]>::default().to_tokens(tokens);
640
641 self.type_path.to_tokens(tokens);
642 }
643}
644
645impl From<PhantomField> for Field
646{
647 fn from(phantom_field: PhantomField) -> Self
648 {
649 Self {
650 attrs: vec![],
651 vis: Visibility::Inherited,
652 ident: Some(phantom_field.field.clone()),
653 colon_token: Some(<Token![:]>::default()),
654 ty: Type::Path(phantom_field.type_path),
655 }
656 }
657}
658
659fn create_phantom_field_ident(ident: &Ident, kind: &PhantomFieldKind) -> Ident
660{
661 match kind {
662 PhantomFieldKind::Type => format_ident!("{ident}_phantom"),
663 PhantomFieldKind::Lifetime => format_ident!("{ident}_lt_phantom"),
664 }
665}
666
667enum PhantomFieldKind
668{
669 Type,
670 Lifetime,
671}
672
673fn create_phantom_data_type_path(
674 generic_args: impl IntoIterator<Item = GenericArgument>,
675) -> TypePath
676{
677 TypePath::new(Path::new(
678 WithLeadingColons::Yes,
679 [
680 PathSegment::new(format_ident!("std"), None),
681 PathSegment::new(format_ident!("marker"), None),
682 PathSegment::new(
683 format_ident!("PhantomData"),
684 Some(AngleBracketedGenericArguments::new(
685 WithColons::Yes,
686 generic_args,
687 )),
688 ),
689 ],
690 ))
691}
692
693fn create_bogus_generics(
694 generic_params: &Punctuated<GenericParam, Token![,]>,
695) -> Vec<GenericArgument>
696{
697 generic_params
698 .iter()
699 .filter_map(|generic_param| match generic_param {
700 GenericParam::Type(_) => {
701 Some(GenericArgument::Type(Type::Tuple(create_unit_type_tuple())))
702 }
703 GenericParam::Lifetime(_) => Some(GenericArgument::Lifetime(
704 Lifetime::create(format_ident!("static")),
705 )),
706 GenericParam::Const(_) => None,
707 })
708 .collect()
709}
710
711fn receiver_to_mock_self_type(receiver: &Option<Receiver>, mock: Ident) -> Option<Type>
712{
713 receiver.as_ref().map(|receiver| {
714 let self_type = Type::Path(TypePath::new(Path::new(
715 WithLeadingColons::No,
716 [PathSegment::new(mock, None)],
717 )));
718
719 if let Some((_, lifetime)) = &receiver.reference {
720 return Type::Reference(TypeReference::new(
721 lifetime.clone(),
722 receiver.mutability.into(),
723 self_type,
724 ));
725 }
726
727 self_type
728 })
729}