mockall_derive/
lib.rs

1// vim: tw=80
2//! Proc Macros for use with Mockall
3//!
4//! You probably don't want to use this crate directly.  Instead, you should use
5//! its reexports via the [`mockall`](https://docs.rs/mockall/latest/mockall)
6//! crate.
7
8#![cfg_attr(feature = "nightly_derive", feature(proc_macro_diagnostic))]
9#![cfg_attr(test, deny(warnings))]
10
11use cfg_if::cfg_if;
12use proc_macro2::{Span, TokenStream};
13use quote::{ToTokens, format_ident, quote};
14use std::{
15    env,
16    hash::BuildHasherDefault
17};
18use syn::{
19    *,
20    punctuated::Punctuated,
21    spanned::Spanned
22};
23
24mod automock;
25mod mock_function;
26mod mock_item;
27mod mock_item_struct;
28mod mock_trait;
29mod mockable_item;
30mod mockable_struct;
31use crate::automock::Attrs;
32use crate::mockable_struct::MockableStruct;
33use crate::mock_item::MockItem;
34use crate::mock_item_struct::MockItemStruct;
35use crate::mockable_item::MockableItem;
36
37// Define deterministic aliases for these common types.
38type HashMap<K, V> = std::collections::HashMap<K, V, BuildHasherDefault<std::collections::hash_map::DefaultHasher>>;
39type HashSet<K> = std::collections::HashSet<K, BuildHasherDefault<std::collections::hash_map::DefaultHasher>>;
40
41cfg_if! {
42    // proc-macro2's Span::unstable method requires the nightly feature, and it
43    // doesn't work in test mode.
44    // https://github.com/alexcrichton/proc-macro2/issues/159
45    if #[cfg(all(feature = "nightly_derive", not(test)))] {
46        fn compile_error(span: Span, msg: &str) {
47            span.unstable()
48                .error(msg)
49                .emit();
50        }
51    } else {
52        fn compile_error(_span: Span, msg: &str) {
53            panic!("{msg}.  More information may be available when mockall is built with the \"nightly\" feature.");
54        }
55    }
56}
57
58/// Does this Attribute represent Mockall's "concretize" pseudo-attribute?
59fn is_concretize(attr: &Attribute) -> bool {
60    if attr.path().segments.last().unwrap().ident == "concretize" {
61        true
62    } else if attr.path().is_ident("cfg_attr") {
63        match &attr.meta {
64            Meta::List(ml) => {
65                ml.tokens.to_string().contains("concretize")
66            },
67            // cfg_attr should always contain a list
68            _ => false,
69        }
70    } else {
71        false
72    }
73}
74
75/// replace generic arguments with concrete trait object arguments
76///
77/// # Return
78///
79/// * A Generics object with the concretized types removed
80/// * An array of transformed argument types, suitable for matchers and
81///   returners
82/// * An array of expressions that should be passed to the `call` function.
83fn concretize_args(gen: &Generics, sig: &Signature) ->
84    (Generics, Punctuated<FnArg, Token![,]>, Vec<TokenStream>, Signature)
85{
86    let args = &sig.inputs;
87    let mut hm = HashMap::default();
88    let mut needs_muts = HashMap::default();
89
90    let mut save_types = |ident: &Ident, tpb: &Punctuated<TypeParamBound, Token![+]>| {
91        if !tpb.is_empty() {
92            let mut pat = quote!(&(dyn #tpb));
93            let mut needs_mut = false;
94            if let Some(TypeParamBound::Trait(t)) = tpb.first() {
95                if t.path.segments.first().map(|seg| &seg.ident == "FnMut")
96                    .unwrap_or(false)
97                {
98                    // For FnMut arguments, the rfunc needs a mutable reference
99                    pat = quote!(&mut (dyn #tpb));
100                    needs_mut = true;
101                }
102            }
103            if let Ok(newty) = parse2::<Type>(pat) {
104                // substitute T arguments
105                let subst_ty: Type = parse2(quote!(#ident)).unwrap();
106                needs_muts.insert(subst_ty.clone(), needs_mut);
107                hm.insert(subst_ty, (newty.clone(), None));
108
109                // substitute &T arguments
110                let subst_ty: Type = parse2(quote!(&#ident)).unwrap();
111                needs_muts.insert(subst_ty.clone(), needs_mut);
112                hm.insert(subst_ty, (newty, None));
113            } else {
114                compile_error(tpb.span(),
115                    "Type cannot be made into a trait object");
116            }
117
118            if let Ok(newty) = parse2::<Type>(quote!(&mut (dyn #tpb))) {
119                // substitute &mut T arguments
120                let subst_ty: Type = parse2(quote!(&mut #ident)).unwrap();
121                needs_muts.insert(subst_ty.clone(), needs_mut);
122                hm.insert(subst_ty, (newty, None));
123            } else {
124                compile_error(tpb.span(),
125                    "Type cannot be made into a trait object");
126            }
127
128            // I wish we could substitute &[T] arguments.  But there's no way
129            // for the mock method to turn &[T] into &[&dyn T].
130            if let Ok(newty) = parse2::<Type>(quote!(&[&(dyn #tpb)])) {
131                let subst_ty: Type = parse2(quote!(&[#ident])).unwrap();
132                needs_muts.insert(subst_ty.clone(), needs_mut);
133                hm.insert(subst_ty, (newty, Some(tpb.clone())));
134            } else {
135                compile_error(tpb.span(),
136                    "Type cannot be made into a trait object");
137            }
138        }
139    };
140
141    for g in gen.params.iter() {
142        if let GenericParam::Type(tp) = g {
143            save_types(&tp.ident, &tp.bounds);
144            // else there had better be a where clause
145        }
146    }
147    if let Some(wc) = &gen.where_clause {
148        for pred in wc.predicates.iter() {
149            if let WherePredicate::Type(pt) = pred {
150                let bounded_ty = &pt.bounded_ty;
151                if let Ok(ident) = parse2::<Ident>(quote!(#bounded_ty)) {
152                    save_types(&ident, &pt.bounds);
153                } else {
154                    // We can't yet handle where clauses this complicated
155                }
156            }
157        }
158    }
159
160    let outg = Generics {
161        lt_token: None,
162        gt_token: None,
163        params: Punctuated::new(),
164        where_clause: None
165    };
166    let outargs = args.iter().map(|arg| {
167        if let FnArg::Typed(pt) = arg {
168            let mut call_pt = pt.clone();
169            demutify_arg(&mut call_pt);
170            if let Some((newty, _)) = hm.get(&pt.ty) {
171                FnArg::Typed(PatType {
172                    attrs: Vec::default(),
173                    pat: call_pt.pat,
174                    colon_token: pt.colon_token,
175                    ty: Box::new(newty.clone())
176                })
177            } else {
178                FnArg::Typed(PatType {
179                    attrs: Vec::default(),
180                    pat: call_pt.pat,
181                    colon_token: pt.colon_token,
182                    ty: pt.ty.clone()
183                })
184            }
185        } else {
186            arg.clone()
187        }
188    }).collect();
189
190    // Finally, Reference any concretizing arguments
191    // use filter_map to remove the &self argument
192    let call_exprs = args.iter().filter_map(|arg| {
193        match arg {
194            FnArg::Typed(pt) => {
195                let mut pt2 = pt.clone();
196                demutify_arg(&mut pt2);
197                let pat = &pt2.pat;
198                if pat_is_self(pat) {
199                    None
200                } else if let Some((_, newbound)) = hm.get(&pt.ty) {
201                    if let Type::Reference(tr) = &*pt.ty {
202                        if let Type::Slice(_ts) = &*tr.elem {
203                            // Assume _ts is the generic type or we wouldn't be
204                            // here
205                            Some(quote!(
206                                &(0..#pat.len())
207                                .map(|__mockall_i| &#pat[__mockall_i] as &(dyn #newbound))
208                                .collect::<Vec<_>>()
209                            ))
210                        } else {
211                            Some(quote!(#pat))
212                        }
213                    } else if needs_muts.get(&pt.ty).cloned().unwrap_or(false) {
214                        Some(quote!(&mut #pat))
215                    } else {
216                        Some(quote!(&#pat))
217                    }
218                } else {
219                    Some(quote!(#pat))
220                }
221            },
222            FnArg::Receiver(_) => None,
223        }
224    }).collect();
225
226    // Add any necessary "mut" qualifiers to the Signature
227    let mut altsig = sig.clone();
228    for arg in altsig.inputs.iter_mut() {
229        if let FnArg::Typed(pt) = arg {
230            if needs_muts.get(&pt.ty).cloned().unwrap_or(false) {
231                if let Pat::Ident(pi) = &mut *pt.pat {
232                    pi.mutability = Some(Token![mut](pi.mutability.span()));
233                } else {
234                    compile_error(pt.pat.span(),
235                                    "This Pat type is not yet supported by Mockall when used as an argument to a concretized function.")
236                }
237            }
238        }
239    }
240
241    (outg, outargs, call_exprs, altsig)
242}
243
244fn deanonymize_lifetime(lt: &mut Lifetime) {
245    if lt.ident == "_" {
246        lt.ident = format_ident!("static");
247    }
248}
249
250fn deanonymize_path(path: &mut Path) {
251    for seg in path.segments.iter_mut() {
252        match &mut seg.arguments {
253            PathArguments::None => (),
254            PathArguments::AngleBracketed(abga) => {
255                for ga in abga.args.iter_mut() {
256                    if let GenericArgument::Lifetime(lt) = ga {
257                        deanonymize_lifetime(lt)
258                    }
259                }
260            },
261            _ => compile_error(seg.arguments.span(),
262                "Methods returning functions are TODO"),
263        }
264    }
265}
266
267/// Replace any references to the anonymous lifetime `'_` with `'static`.
268fn deanonymize(literal_type: &mut Type) {
269    match literal_type {
270        Type::Array(ta) => deanonymize(ta.elem.as_mut()),
271        Type::BareFn(tbf) => {
272            if let ReturnType::Type(_, ref mut bt) = tbf.output {
273                deanonymize(bt.as_mut());
274            }
275            for input in tbf.inputs.iter_mut() {
276                deanonymize(&mut input.ty);
277            }
278        },
279        Type::Group(tg) => deanonymize(tg.elem.as_mut()),
280        Type::Infer(_) => (),
281        Type::Never(_) => (),
282        Type::Paren(tp) => deanonymize(tp.elem.as_mut()),
283        Type::Path(tp) => {
284            if let Some(ref mut qself) = tp.qself {
285                deanonymize(qself.ty.as_mut());
286            }
287            deanonymize_path(&mut tp.path);
288        },
289        Type::Ptr(tptr) => deanonymize(tptr.elem.as_mut()),
290        Type::Reference(tr) => {
291            if let Some(lt) = tr.lifetime.as_mut() {
292                deanonymize_lifetime(lt)
293            }
294            deanonymize(tr.elem.as_mut());
295        },
296        Type::Slice(s) => deanonymize(s.elem.as_mut()),
297        Type::TraitObject(tto) => {
298            for tpb in tto.bounds.iter_mut() {
299                match tpb {
300                    TypeParamBound::Trait(tb) => deanonymize_path(&mut tb.path),
301                    TypeParamBound::Lifetime(lt) => deanonymize_lifetime(lt),
302                    _ => ()
303                }
304            }
305        },
306        Type::Tuple(tt) => {
307            for ty in tt.elems.iter_mut() {
308                deanonymize(ty)
309            }
310        }
311        x => compile_error(x.span(), "Unimplemented type for deanonymize")
312    }
313}
314
315// If there are any closures in the argument list, turn them into boxed
316// functions
317fn declosurefy(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
318    (Generics, Punctuated<FnArg, Token![,]>, Vec<TokenStream>)
319{
320    let mut hm = HashMap::default();
321
322    let mut save_fn_types = |ident: &Ident, bounds: &Punctuated<TypeParamBound, Token![+]>|
323    {
324        for tpb in bounds.iter() {
325            if let TypeParamBound::Trait(tb) = tpb {
326                let fident = &tb.path.segments.last().unwrap().ident;
327                if ["Fn", "FnMut", "FnOnce"].iter().any(|s| fident == *s) {
328                    let newty: Type = parse2(quote!(Box<dyn #bounds>)).unwrap();
329                    let subst_ty: Type = parse2(quote!(#ident)).unwrap();
330                    assert!(hm.insert(subst_ty, newty).is_none(),
331                        "A generic parameter had two Fn bounds?");
332                }
333            }
334        }
335    };
336
337    // First, build a HashMap of all Fn generic types
338    for g in gen.params.iter() {
339        if let GenericParam::Type(tp) = g {
340            save_fn_types(&tp.ident, &tp.bounds);
341        }
342    }
343    if let Some(wc) = &gen.where_clause {
344        for pred in wc.predicates.iter() {
345            if let WherePredicate::Type(pt) = pred {
346                let bounded_ty = &pt.bounded_ty;
347                if let Ok(ident) = parse2::<Ident>(quote!(#bounded_ty)) {
348                    save_fn_types(&ident, &pt.bounds);
349                } else {
350                    // We can't yet handle where clauses this complicated
351                }
352            }
353        }
354    }
355
356    // Then remove those types from both the Generics' params and where clause
357    let should_remove = |ident: &Ident| {
358            let ty: Type = parse2(quote!(#ident)).unwrap();
359            hm.contains_key(&ty)
360    };
361    let params = gen.params.iter()
362        .filter(|g| {
363            if let GenericParam::Type(tp) = g {
364                !should_remove(&tp.ident)
365            } else {
366                true
367            }
368        }).cloned()
369        .collect::<Punctuated<_, _>>();
370    let mut wc2 = gen.where_clause.clone();
371    if let Some(wc) = &mut wc2 {
372        wc.predicates = wc.predicates.iter()
373            .filter(|wp| {
374                if let WherePredicate::Type(pt) = wp {
375                    let bounded_ty = &pt.bounded_ty;
376                    if let Ok(ident) = parse2::<Ident>(quote!(#bounded_ty)) {
377                        !should_remove(&ident)
378                    } else {
379                        // We can't yet handle where clauses this complicated
380                        true
381                    }
382                } else {
383                    true
384                }
385            }).cloned()
386            .collect::<Punctuated<_, _>>();
387        if wc.predicates.is_empty() {
388            wc2 = None;
389        }
390    }
391    let outg = Generics {
392        lt_token: if params.is_empty() { None } else { gen.lt_token },
393        gt_token: if params.is_empty() { None } else { gen.gt_token },
394        params,
395        where_clause: wc2
396    };
397
398    // Next substitute Box<Fn> into the arguments
399    let outargs = args.iter().map(|arg| {
400        if let FnArg::Typed(pt) = arg {
401            let mut immutable_pt = pt.clone();
402            demutify_arg(&mut immutable_pt);
403            if let Some(newty) = hm.get(&pt.ty) {
404                FnArg::Typed(PatType {
405                    attrs: Vec::default(),
406                    pat: immutable_pt.pat,
407                    colon_token: pt.colon_token,
408                    ty: Box::new(newty.clone())
409                })
410            } else {
411                FnArg::Typed(PatType {
412                    attrs: Vec::default(),
413                    pat: immutable_pt.pat,
414                    colon_token: pt.colon_token,
415                    ty: pt.ty.clone()
416                })
417            }
418        } else {
419            arg.clone()
420        }
421    }).collect();
422
423    // Finally, Box any closure arguments
424    // use filter_map to remove the &self argument
425    let callargs = args.iter().filter_map(|arg| {
426        match arg {
427            FnArg::Typed(pt) => {
428                let mut pt2 = pt.clone();
429                demutify_arg(&mut pt2);
430                let pat = &pt2.pat;
431                if pat_is_self(pat) {
432                    None
433                } else if hm.contains_key(&pt.ty) {
434                    Some(quote!(Box::new(#pat)))
435                } else {
436                    Some(quote!(#pat))
437                }
438            },
439            FnArg::Receiver(_) => None,
440        }
441    }).collect();
442    (outg, outargs, callargs)
443}
444
445/// Replace any "impl trait" types with "Box<dyn trait>" or equivalent.
446fn deimplify(rt: &mut ReturnType) {
447    if let ReturnType::Type(_, ty) = rt {
448        if let Type::ImplTrait(ref tit) = &**ty {
449            let needs_pin = tit.bounds
450                .iter()
451                .any(|tpb| {
452                    if let TypeParamBound::Trait(tb) = tpb {
453                        if let Some(seg) = tb.path.segments.last() {
454                            seg.ident == "Future" || seg.ident == "Stream"
455                        } else {
456                            // It might still be a Future, but we can't guess
457                            // what names it might be imported under.  Too bad.
458                            false
459                        }
460                    } else {
461                        false
462                    }
463                });
464            let bounds = &tit.bounds;
465            if needs_pin {
466                *ty = parse2(quote!(::std::pin::Pin<Box<dyn #bounds>>)).unwrap();
467            } else {
468                *ty = parse2(quote!(Box<dyn #bounds>)).unwrap();
469            }
470        }
471    }
472}
473
474/// Remove any generics that place constraints on Self.
475fn dewhereselfify(generics: &mut Generics) {
476    if let Some(ref mut wc) = &mut generics.where_clause {
477        let new_predicates = wc.predicates.iter()
478            .filter(|wp| match wp {
479                WherePredicate::Type(pt) => {
480                    pt.bounded_ty != parse2(quote!(Self)).unwrap()
481                },
482                _ => true
483            }).cloned()
484            .collect::<Punctuated<WherePredicate, Token![,]>>();
485        wc.predicates = new_predicates;
486    }
487    if generics.where_clause.as_ref()
488        .map(|wc| wc.predicates.is_empty())
489        .unwrap_or(false)
490    {
491        generics.where_clause = None;
492    }
493}
494
495/// Remove any mutability qualifiers from a method's argument list
496fn demutify(inputs: &mut Punctuated<FnArg, token::Comma>) {
497    for arg in inputs.iter_mut() {
498        match arg {
499            FnArg::Receiver(r) => if r.reference.is_none() {
500                r.mutability = None
501            },
502            FnArg::Typed(pt) => demutify_arg(pt),
503        }
504    }
505}
506
507/// Remove any "mut" from a method argument's binding.
508fn demutify_arg(arg: &mut PatType) {
509    match *arg.pat {
510        Pat::Wild(_) => {
511            compile_error(arg.span(),
512                "Mocked methods must have named arguments");
513        },
514        Pat::Ident(ref mut pat_ident) => {
515            if let Some(r) = &pat_ident.by_ref {
516                compile_error(r.span(),
517                    "Mockall does not support by-reference argument bindings");
518            }
519            if let Some((_at, subpat)) = &pat_ident.subpat {
520                compile_error(subpat.span(),
521                    "Mockall does not support subpattern bindings");
522            }
523            pat_ident.mutability = None;
524        },
525        _ => {
526            compile_error(arg.span(), "Unsupported argument type");
527        }
528    };
529}
530
531fn deselfify_path(path: &mut Path, actual: &Ident, generics: &Generics) {
532    for seg in path.segments.iter_mut() {
533        if seg.ident == "Self" {
534            seg.ident = actual.clone();
535            if let PathArguments::None = seg.arguments {
536                if !generics.params.is_empty() {
537                    let args = generics.params.iter()
538                        .map(|gp| {
539                            match gp {
540                                GenericParam::Type(tp) => {
541                                    let ident = tp.ident.clone();
542                                    GenericArgument::Type(
543                                        Type::Path(
544                                            TypePath {
545                                                qself: None,
546                                                path: Path::from(ident)
547                                            }
548                                        )
549                                    )
550                                },
551                                GenericParam::Lifetime(ld) =>{
552                                    GenericArgument::Lifetime(
553                                        ld.lifetime.clone()
554                                    )
555                                }
556                                _ => unimplemented!(),
557                            }
558                        }).collect::<Punctuated<_, _>>();
559                    seg.arguments = PathArguments::AngleBracketed(
560                        AngleBracketedGenericArguments {
561                            colon2_token: None,
562                            lt_token: generics.lt_token.unwrap(),
563                            args,
564                            gt_token: generics.gt_token.unwrap(),
565                        }
566                    );
567                }
568            } else {
569                compile_error(seg.arguments.span(),
570                    "Type arguments after Self are unexpected");
571            }
572        }
573        if let PathArguments::AngleBracketed(abga) = &mut seg.arguments
574        {
575            for arg in abga.args.iter_mut() {
576                match arg {
577                    GenericArgument::Type(ty) =>
578                        deselfify(ty, actual, generics),
579                    GenericArgument::AssocType(at) =>
580                        deselfify(&mut at.ty, actual, generics),
581                    _ => /* Nothing to do */(),
582                }
583            }
584        }
585    }
586}
587
588/// Replace any references to `Self` in `literal_type` with `actual`.
589/// `generics` is the Generics field of the parent struct.  Useful for
590/// constructor methods.
591fn deselfify(literal_type: &mut Type, actual: &Ident, generics: &Generics) {
592    match literal_type {
593        Type::Slice(s) => {
594            deselfify(s.elem.as_mut(), actual, generics);
595        },
596        Type::Array(a) => {
597            deselfify(a.elem.as_mut(), actual, generics);
598        },
599        Type::Ptr(p) => {
600            deselfify(p.elem.as_mut(), actual, generics);
601        },
602        Type::Reference(r) => {
603            deselfify(r.elem.as_mut(), actual, generics);
604        },
605        Type::Tuple(tuple) => {
606            for elem in tuple.elems.iter_mut() {
607                deselfify(elem, actual, generics);
608            }
609        }
610        Type::Path(type_path) => {
611            if let Some(ref mut qself) = type_path.qself {
612                deselfify(qself.ty.as_mut(), actual, generics);
613            }
614            deselfify_path(&mut type_path.path, actual, generics);
615        },
616        Type::Paren(p) => {
617            deselfify(p.elem.as_mut(), actual, generics);
618        },
619        Type::Group(g) => {
620            deselfify(g.elem.as_mut(), actual, generics);
621        },
622        Type::Macro(_) | Type::Verbatim(_) => {
623            compile_error(literal_type.span(),
624                "mockall_derive does not support this type as a return argument");
625        },
626        Type::TraitObject(tto) => {
627            // Change types like `dyn Self` into `dyn MockXXX`.
628            for bound in tto.bounds.iter_mut() {
629                if let TypeParamBound::Trait(t) = bound {
630                    deselfify_path(&mut t.path, actual, generics);
631                }
632            }
633        },
634        Type::ImplTrait(_) => {
635            /* Should've already been flagged as a compile_error */
636        },
637        Type::BareFn(_) => {
638            /* Bare functions can't have Self arguments.  Nothing to do */
639        },
640        Type::Infer(_) | Type::Never(_) =>
641        {
642            /* Nothing to do */
643        },
644        _ => compile_error(literal_type.span(), "Unsupported type"),
645    }
646}
647
648/// Change any `Self` in a method's arguments' types with `actual`.
649/// `generics` is the Generics field of the parent struct.
650fn deselfify_args(
651    args: &mut Punctuated<FnArg, Token![,]>,
652    actual: &Ident,
653    generics: &Generics)
654{
655    for arg in args.iter_mut() {
656        match arg {
657            FnArg::Receiver(r) => {
658                if r.colon_token.is_some() {
659                    deselfify(r.ty.as_mut(), actual, generics)
660                }
661            },
662            FnArg::Typed(pt) => deselfify(pt.ty.as_mut(), actual, generics)
663        }
664    }
665}
666
667fn find_ident_from_path(path: &Path) -> (Ident, PathArguments) {
668    if path.segments.len() != 1 {
669        compile_error(path.span(),
670            "mockall_derive only supports structs defined in the current module");
671        return (Ident::new("", path.span()), PathArguments::None);
672    }
673    let last_seg = path.segments.last().unwrap();
674    (last_seg.ident.clone(), last_seg.arguments.clone())
675}
676
677fn find_lifetimes_in_tpb(bound: &TypeParamBound) -> HashSet<Lifetime> {
678    let mut ret = HashSet::default();
679    match bound {
680        TypeParamBound::Lifetime(lt) => {
681            ret.insert(lt.clone());
682        },
683        TypeParamBound::Trait(tb) => {
684            ret.extend(find_lifetimes_in_path(&tb.path));
685        },
686        _ => ()
687    };
688    ret
689}
690
691fn find_lifetimes_in_path(path: &Path) -> HashSet<Lifetime> {
692    let mut ret = HashSet::default();
693    for seg in path.segments.iter() {
694        if let PathArguments::AngleBracketed(abga) = &seg.arguments {
695            for arg in abga.args.iter() {
696                match arg {
697                    GenericArgument::Lifetime(lt) => {
698                        ret.insert(lt.clone());
699                    },
700                    GenericArgument::Type(ty) => {
701                        ret.extend(find_lifetimes(ty));
702                    },
703                    GenericArgument::AssocType(at) => {
704                        ret.extend(find_lifetimes(&at.ty));
705                    },
706                    GenericArgument::Constraint(c) => {
707                        for bound in c.bounds.iter() {
708                            ret.extend(find_lifetimes_in_tpb(bound));
709                        }
710                    },
711                    GenericArgument::Const(_) => (),
712                    _ => ()
713                }
714            }
715        }
716    }
717    ret
718}
719
720fn find_lifetimes(ty: &Type) -> HashSet<Lifetime> {
721    match ty {
722        Type::Array(ta) => find_lifetimes(ta.elem.as_ref()),
723        Type::Group(tg) => find_lifetimes(tg.elem.as_ref()),
724        Type::Infer(_ti) => HashSet::default(),
725        Type::Never(_tn) => HashSet::default(),
726        Type::Paren(tp) => find_lifetimes(tp.elem.as_ref()),
727        Type::Path(tp) => {
728            let mut ret = find_lifetimes_in_path(&tp.path);
729            if let Some(qs) = &tp.qself {
730                ret.extend(find_lifetimes(qs.ty.as_ref()));
731            }
732            ret
733        },
734        Type::Ptr(tp) => find_lifetimes(tp.elem.as_ref()),
735        Type::Reference(tr) => {
736            let mut ret = find_lifetimes(tr.elem.as_ref());
737            if let Some(lt) = &tr.lifetime {
738                ret.insert(lt.clone());
739            }
740            ret
741        },
742        Type::Slice(ts) => find_lifetimes(ts.elem.as_ref()),
743        Type::TraitObject(tto) => {
744            let mut ret = HashSet::default();
745            for bound in tto.bounds.iter() {
746                ret.extend(find_lifetimes_in_tpb(bound));
747            }
748            ret
749        }
750        Type::Tuple(tt) => {
751            let mut ret = HashSet::default();
752            for ty in tt.elems.iter() {
753                ret.extend(find_lifetimes(ty));
754            }
755            ret
756        },
757        Type::ImplTrait(tit) => {
758            let mut ret = HashSet::default();
759            for tpb in tit.bounds.iter() {
760                ret.extend(find_lifetimes_in_tpb(tpb));
761            }
762            ret
763        },
764        _ => {
765            compile_error(ty.span(), "unsupported type in this context");
766            HashSet::default()
767        }
768    }
769}
770
771struct AttrFormatter<'a>{
772    attrs: &'a [Attribute],
773    async_trait: bool,
774    doc: bool,
775    must_use: bool,
776}
777
778impl<'a> AttrFormatter<'a> {
779    fn new(attrs: &'a [Attribute]) -> AttrFormatter<'a> {
780        Self {
781            attrs,
782            async_trait: true,
783            doc: true,
784            must_use: false,
785        }
786    }
787
788    fn async_trait(&mut self, allowed: bool) -> &mut Self {
789        self.async_trait = allowed;
790        self
791    }
792
793    fn doc(&mut self, allowed: bool) -> &mut Self {
794        self.doc = allowed;
795        self
796    }
797
798    fn must_use(&mut self, allowed: bool) -> &mut Self {
799        self.must_use = allowed;
800        self
801    }
802
803    // XXX This logic requires that attributes are imported with their
804    // standard names.
805    #[allow(clippy::needless_bool)]
806    #[allow(clippy::if_same_then_else)]
807    fn format(&mut self) -> Vec<Attribute> {
808        self.attrs.iter()
809            .filter(|attr| {
810                let i = attr.path().segments.last().map(|ps| &ps.ident);
811                if is_concretize(attr) {
812                    // Internally used attribute.  Never emit.
813                    false
814                } else if i.is_none() {
815                    false
816                } else if *i.as_ref().unwrap() == "derive" {
817                    // We can't usefully derive any traits.  Ignore them
818                    false
819                } else if *i.as_ref().unwrap() == "doc" {
820                    self.doc
821                } else if *i.as_ref().unwrap() == "async_trait" {
822                    self.async_trait
823                } else if *i.as_ref().unwrap() == "expect" {
824                    // This probably means that there's a lint that needs to be
825                    // surpressed for the real code, but not for the mock code.
826                    // Skip it.
827                    false
828                } else if *i.as_ref().unwrap() == "inline" {
829                    // No need to inline mock functions.
830                    false
831                } else if *i.as_ref().unwrap() == "cold" {
832                    // No need for such hints on mock functions.
833                    false
834                } else if *i.as_ref().unwrap() == "instrument" {
835                    // We can't usefully instrument the mock method, so just
836                    // ignore this attribute.
837                    // https://docs.rs/tracing/0.1.23/tracing/attr.instrument.html
838                    false
839                } else if *i.as_ref().unwrap() == "link_name" {
840                    // This shows up sometimes when mocking ffi functions.  We
841                    // must not emit it on anything that isn't an ffi definition
842                    false
843                } else if *i.as_ref().unwrap() == "must_use" {
844                    self.must_use
845                } else if *i.as_ref().unwrap() == "auto_enum" {
846                    // Ignore auto_enum, because we transform the return value
847                    // into a trait object.
848                    false
849                } else {
850                    true
851                }
852            }).cloned()
853            .collect()
854    }
855}
856
857/// Determine if this Pat is any kind of `self` binding
858fn pat_is_self(pat: &Pat) -> bool {
859    if let Pat::Ident(pi) = pat {
860        pi.ident == "self"
861    } else {
862        false
863    }
864}
865
866/// Add `levels` `super::` to the path.  Return the number of levels added.
867fn supersuperfy_path(path: &mut Path, levels: usize) -> usize {
868    if let Some(t) = path.segments.last_mut() {
869        match &mut t.arguments {
870            PathArguments::None => (),
871            PathArguments::AngleBracketed(ref mut abga) => {
872                for arg in abga.args.iter_mut() {
873                    match arg {
874                        GenericArgument::Type(ref mut ty) => {
875                            *ty = supersuperfy(ty, levels);
876                        },
877                        GenericArgument::AssocType(ref mut at) => {
878                            at.ty = supersuperfy(&at.ty, levels);
879                        },
880                        GenericArgument::Constraint(ref mut constraint) => {
881                            supersuperfy_bounds(&mut constraint.bounds, levels);
882                        },
883                        _ => (),
884                    }
885                }
886            },
887            PathArguments::Parenthesized(ref mut pga) => {
888                for input in pga.inputs.iter_mut() {
889                    *input = supersuperfy(input, levels);
890                }
891                if let ReturnType::Type(_, ref mut ty) = pga.output {
892                    *ty = Box::new(supersuperfy(ty, levels));
893                }
894            },
895        }
896    }
897    if let Some(t) = path.segments.first() {
898        if t.ident == "super" {
899            let mut ident = format_ident!("super");
900            ident.set_span(path.segments.span());
901            let ps = PathSegment {
902                ident,
903                arguments: PathArguments::None
904            };
905            for _ in 0..levels {
906                path.segments.insert(0, ps.clone());
907            }
908            levels
909        } else {
910            0
911        }
912    } else {
913        0
914    }
915}
916
917/// Replace any references to `super::X` in `original` with `super::super::X`.
918fn supersuperfy(original: &Type, levels: usize) -> Type {
919    let mut output = original.clone();
920    fn recurse(t: &mut Type, levels: usize) {
921        match t {
922            Type::Slice(s) => {
923                recurse(s.elem.as_mut(), levels);
924            },
925            Type::Array(a) => {
926                recurse(a.elem.as_mut(), levels);
927            },
928            Type::Ptr(p) => {
929                recurse(p.elem.as_mut(), levels);
930            },
931            Type::Reference(r) => {
932                recurse(r.elem.as_mut(), levels);
933            },
934            Type::BareFn(bfn) => {
935                if let ReturnType::Type(_, ref mut bt) = bfn.output {
936                    recurse(bt.as_mut(), levels);
937                }
938                for input in bfn.inputs.iter_mut() {
939                    recurse(&mut input.ty, levels);
940                }
941            },
942            Type::Tuple(tuple) => {
943                for elem in tuple.elems.iter_mut() {
944                    recurse(elem, levels);
945                }
946            }
947            Type::Path(type_path) => {
948                let added = supersuperfy_path(&mut type_path.path, levels);
949                if let Some(ref mut qself) = type_path.qself {
950                    recurse(qself.ty.as_mut(), levels);
951                    qself.position += added;
952                }
953            },
954            Type::Paren(p) => {
955                recurse(p.elem.as_mut(), levels);
956            },
957            Type::Group(g) => {
958                recurse(g.elem.as_mut(), levels);
959            },
960            Type::Macro(_) | Type::Verbatim(_) => {
961                compile_error(t.span(),
962                    "mockall_derive does not support this type in this position");
963            },
964            Type::TraitObject(tto) => {
965                for bound in tto.bounds.iter_mut() {
966                    if let TypeParamBound::Trait(tb) = bound {
967                        supersuperfy_path(&mut tb.path, levels);
968                    }
969                }
970            },
971            Type::ImplTrait(_) => {
972                /* Should've already been flagged as a compile error */
973            },
974            Type::Infer(_) | Type::Never(_) =>
975            {
976                /* Nothing to do */
977            },
978            _ => compile_error(t.span(), "Unsupported type"),
979        }
980    }
981    recurse(&mut output, levels);
982    output
983}
984
985fn supersuperfy_generics(generics: &mut Generics, levels: usize) {
986    for param in generics.params.iter_mut() {
987        if let GenericParam::Type(tp) = param {
988            supersuperfy_bounds(&mut tp.bounds, levels);
989            if let Some(ty) = tp.default.as_mut() {
990                *ty = supersuperfy(ty, levels);
991            }
992        }
993    }
994    if let Some(wc) = generics.where_clause.as_mut() {
995        for wp in wc.predicates.iter_mut() {
996            if let WherePredicate::Type(pt) = wp {
997                pt.bounded_ty = supersuperfy(&pt.bounded_ty, levels);
998                supersuperfy_bounds(&mut pt.bounds, levels);
999            }
1000        }
1001    }
1002}
1003
1004fn supersuperfy_bounds(
1005    bounds: &mut Punctuated<TypeParamBound, Token![+]>,
1006    levels: usize)
1007{
1008    for bound in bounds.iter_mut() {
1009        if let TypeParamBound::Trait(tb) = bound {
1010            supersuperfy_path(&mut tb.path, levels);
1011        }
1012    }
1013}
1014
1015/// Generate a suitable mockall::Key generic paramter from any Generics
1016fn gen_keyid(g: &Generics) -> impl ToTokens {
1017    match g.params.len() {
1018        0 => quote!(<()>),
1019        1 => {
1020            let (_, tg, _) = g.split_for_impl();
1021            quote!(#tg)
1022        },
1023        _ => {
1024            // Rust doesn't support variadic Generics, so mockall::Key must
1025            // always have exactly one generic type.  We need to add parentheses
1026            // around whatever type generics the caller passes.
1027            let tps = g.type_params()
1028            .map(|tp| tp.ident.clone())
1029            .collect::<Punctuated::<Ident, Token![,]>>();
1030            quote!(<(#tps)>)
1031        }
1032    }
1033}
1034
1035/// Generate a mock identifier from the regular one: eg "Foo" => "MockFoo"
1036fn gen_mock_ident(ident: &Ident) -> Ident {
1037    format_ident!("Mock{}", ident)
1038}
1039
1040/// Generate an identifier for the mock struct's private module: eg "Foo" =>
1041/// "__mock_Foo"
1042fn gen_mod_ident(struct_: &Ident, trait_: Option<&Ident>) -> Ident {
1043    if let Some(t) = trait_ {
1044        format_ident!("__mock_{struct_}_{}", t)
1045    } else {
1046        format_ident!("__mock_{struct_}")
1047    }
1048}
1049
1050/// Combine two Generics structs, producing a new one that has the union of
1051/// their parameters.
1052fn merge_generics(x: &Generics, y: &Generics) -> Generics {
1053    /// Compare only the identifiers of two GenericParams
1054    fn cmp_gp_idents(x: &GenericParam, y: &GenericParam) -> bool {
1055        use GenericParam::*;
1056
1057        match (x, y) {
1058            (Type(xtp), Type(ytp)) => xtp.ident == ytp.ident,
1059            (Lifetime(xld), Lifetime(yld)) => xld.lifetime == yld.lifetime,
1060            (Const(xc), Const(yc)) => xc.ident == yc.ident,
1061            _ => false
1062        }
1063    }
1064
1065    /// Compare only the identifiers of two WherePredicates
1066    fn cmp_wp_idents(x: &WherePredicate, y: &WherePredicate) -> bool {
1067        use WherePredicate::*;
1068
1069        match (x, y) {
1070            (Type(xpt), Type(ypt)) => xpt.bounded_ty == ypt.bounded_ty,
1071            (Lifetime(xpl), Lifetime(ypl)) => xpl.lifetime == ypl.lifetime,
1072            _ => false
1073        }
1074    }
1075
1076    let mut out = if x.lt_token.is_none() && x.where_clause.is_none() {
1077        y.clone()
1078    } else if y.lt_token.is_none() && y.where_clause.is_none() {
1079        x.clone()
1080    } else {
1081        let mut out = x.clone();
1082        // First merge the params
1083        'outer_param: for yparam in y.params.iter() {
1084            // XXX: O(n^2) loop
1085            for outparam in out.params.iter_mut() {
1086                if cmp_gp_idents(outparam, yparam) {
1087                    if let (GenericParam::Type(ref mut ot),
1088                            GenericParam::Type(yt)) = (outparam, yparam)
1089                    {
1090                        ot.attrs.extend(yt.attrs.iter().cloned());
1091                        ot.colon_token = ot.colon_token.or(yt.colon_token);
1092                        ot.eq_token = ot.eq_token.or(yt.eq_token);
1093                        if ot.default.is_none() {
1094                            ot.default.clone_from(&yt.default);
1095                        }
1096                        // XXX this might result in duplicate bounds
1097                        if ot.bounds != yt.bounds {
1098                            ot.bounds.extend(yt.bounds.iter().cloned());
1099                        }
1100                    }
1101                    continue 'outer_param;
1102                }
1103            }
1104            out.params.push(yparam.clone());
1105        }
1106        out
1107    };
1108    // Then merge the where clauses
1109    match (&mut out.where_clause, &y.where_clause) {
1110        (_, None) => (),
1111        (None, Some(wc)) => out.where_clause = Some(wc.clone()),
1112        (Some(out_wc), Some(y_wc)) => {
1113            'outer_wc: for ypred in y_wc.predicates.iter() {
1114                // XXX: O(n^2) loop
1115                for outpred in out_wc.predicates.iter_mut() {
1116                    if cmp_wp_idents(outpred, ypred) {
1117                        if let (WherePredicate::Type(ref mut ot),
1118                                WherePredicate::Type(yt)) = (outpred, ypred)
1119                        {
1120                            match (&mut ot.lifetimes, &yt.lifetimes) {
1121                                (_, None) => (),
1122                                (None, Some(bl)) =>
1123                                    ot.lifetimes = Some(bl.clone()),
1124                                (Some(obl), Some(ybl)) =>
1125                                    // XXX: might result in duplicates
1126                                    obl.lifetimes.extend(
1127                                        ybl.lifetimes.iter().cloned()),
1128                            };
1129                            // XXX: might result in duplicate bounds
1130                            if ot.bounds != yt.bounds {
1131                                ot.bounds.extend(yt.bounds.iter().cloned())
1132                            }
1133                        }
1134                        continue 'outer_wc;
1135                    }
1136                }
1137                out_wc.predicates.push(ypred.clone());
1138            }
1139        }
1140    }
1141    out
1142}
1143
1144fn lifetimes_to_generic_params(lv: &Punctuated<LifetimeParam, Token![,]>)
1145    -> Punctuated<GenericParam, Token![,]>
1146{
1147    lv.iter()
1148        .map(|lt| GenericParam::Lifetime(lt.clone()))
1149        .collect()
1150}
1151
1152/// Transform a Vec of lifetimes into a Generics
1153fn lifetimes_to_generics(lv: &Punctuated<LifetimeParam, Token![,]>)-> Generics {
1154    if lv.is_empty() {
1155            Generics::default()
1156    } else {
1157        let params = lifetimes_to_generic_params(lv);
1158        Generics {
1159            lt_token: Some(Token![<](lv[0].span())),
1160            gt_token: Some(Token![>](lv[0].span())),
1161            params,
1162            where_clause: None
1163        }
1164    }
1165}
1166
1167/// Split a generics list into three: one for type generics and where predicates
1168/// that relate to the signature, one for lifetimes that relate to the arguments
1169/// only, and one for lifetimes that relate to the return type only.
1170fn split_lifetimes(
1171    generics: Generics,
1172    args: &Punctuated<FnArg, Token![,]>,
1173    rt: &ReturnType)
1174    -> (Generics,
1175        Punctuated<LifetimeParam, token::Comma>,
1176        Punctuated<LifetimeParam, token::Comma>)
1177{
1178    if generics.lt_token.is_none() {
1179        return (generics, Default::default(), Default::default());
1180    }
1181
1182    // Check which types and lifetimes are referenced by the arguments
1183    let mut alts = HashSet::<Lifetime>::default();
1184    let mut rlts = HashSet::<Lifetime>::default();
1185    for arg in args {
1186        match arg {
1187            FnArg::Receiver(r) => {
1188                if let Some((_, Some(lt))) = &r.reference {
1189                    alts.insert(lt.clone());
1190                }
1191            },
1192            FnArg::Typed(pt) => {
1193                alts.extend(find_lifetimes(pt.ty.as_ref()));
1194            },
1195        };
1196    };
1197
1198    if let ReturnType::Type(_, ty) = rt {
1199        rlts.extend(find_lifetimes(ty));
1200    }
1201
1202    let mut tv = Punctuated::new();
1203    let mut alv = Punctuated::new();
1204    let mut rlv = Punctuated::new();
1205    for p in generics.params.into_iter() {
1206        match p {
1207            GenericParam::Lifetime(ltd) if rlts.contains(&ltd.lifetime) =>
1208                rlv.push(ltd),
1209            GenericParam::Lifetime(ltd) if alts.contains(&ltd.lifetime) =>
1210                alv.push(ltd),
1211            GenericParam::Lifetime(_) => {
1212                // Probably a lifetime parameter from the impl block that isn't
1213                // used by this particular method
1214            },
1215            GenericParam::Type(_) => tv.push(p),
1216            _ => (),
1217        }
1218    }
1219
1220    let tg = if tv.is_empty() {
1221        Generics::default()
1222    } else {
1223        Generics {
1224            lt_token: generics.lt_token,
1225            gt_token: generics.gt_token,
1226            params: tv,
1227            where_clause: generics.where_clause
1228        }
1229    };
1230
1231    (tg, alv, rlv)
1232}
1233
1234/// Return the visibility that should be used for expectation!, given the
1235/// original method's visibility.
1236///
1237/// # Arguments
1238/// - `vis`:    Original visibility of the item
1239/// - `levels`: How many modules will the mock item be nested in?
1240fn expectation_visibility(vis: &Visibility, levels: usize)
1241    -> Visibility
1242{
1243    if levels == 0 {
1244        return vis.clone();
1245    }
1246
1247    let in_token = Token![in](vis.span());
1248    let super_token = Token![super](vis.span());
1249    match vis {
1250        Visibility::Inherited => {
1251            // Private items need pub(in super::[...]) for each level
1252            let mut path = Path::from(super_token);
1253            for _ in 1..levels {
1254                path.segments.push(super_token.into());
1255            }
1256            Visibility::Restricted(VisRestricted{
1257                pub_token: Token![pub](vis.span()),
1258                paren_token: token::Paren::default(),
1259                in_token: Some(in_token),
1260                path: Box::new(path)
1261            })
1262        },
1263        Visibility::Restricted(vr) => {
1264            // crate => don't change
1265            // in crate::* => don't change
1266            // super => in super::super::super
1267            // self => in super::super
1268            // in anything_else => super::super::anything_else
1269            if vr.path.segments.first().unwrap().ident == "crate" {
1270                Visibility::Restricted(vr.clone())
1271            } else {
1272                let mut out = vr.clone();
1273                out.in_token = Some(in_token);
1274                for _ in 0..levels {
1275                    out.path.segments.insert(0, super_token.into());
1276                }
1277                Visibility::Restricted(out)
1278            }
1279        },
1280        _ => vis.clone()
1281    }
1282}
1283
1284fn staticize(generics: &Generics) -> Generics {
1285    let mut ret = generics.clone();
1286    for lt in ret.lifetimes_mut() {
1287        lt.lifetime = Lifetime::new("'static", Span::call_site());
1288    };
1289    ret
1290}
1291
1292fn mock_it<M: Into<MockableItem>>(inputs: M) -> TokenStream
1293{
1294    let mockable: MockableItem = inputs.into();
1295    let mock = MockItem::from(mockable);
1296    let ts = mock.into_token_stream();
1297    if env::var("MOCKALL_DEBUG").is_ok() {
1298        println!("{ts}");
1299    }
1300    ts
1301}
1302
1303fn do_mock_once(input: TokenStream) -> TokenStream
1304{
1305    let item: MockableStruct = match syn::parse2(input) {
1306        Ok(mock) => mock,
1307        Err(err) => {
1308            return err.to_compile_error();
1309        }
1310    };
1311    mock_it(item)
1312}
1313
1314fn do_mock(input: TokenStream) -> TokenStream
1315{
1316    cfg_if! {
1317        if #[cfg(reprocheck)] {
1318            let ts_a = do_mock_once(input.clone());
1319            let ts_b = do_mock_once(input.clone());
1320            assert_eq!(ts_a.to_string(), ts_b.to_string());
1321        }
1322    }
1323    do_mock_once(input)
1324}
1325
1326#[proc_macro_attribute]
1327pub fn concretize(
1328    _attrs: proc_macro::TokenStream,
1329    input: proc_macro::TokenStream) -> proc_macro::TokenStream
1330{
1331    // Do nothing.  This "attribute" is processed as text by the real proc
1332    // macros.
1333    input
1334}
1335
1336#[proc_macro]
1337pub fn mock(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1338    do_mock(input.into()).into()
1339}
1340
1341#[proc_macro_attribute]
1342pub fn automock(attrs: proc_macro::TokenStream, input: proc_macro::TokenStream)
1343    -> proc_macro::TokenStream
1344{
1345    let attrs: proc_macro2::TokenStream = attrs.into();
1346    let input: proc_macro2::TokenStream = input.into();
1347    do_automock(attrs, input).into()
1348}
1349
1350fn do_automock_once(attrs: TokenStream, input: TokenStream) -> TokenStream {
1351    let mut output = input.clone();
1352    let attrs: Attrs = match parse2(attrs) {
1353        Ok(a) => a,
1354        Err(err) => {
1355            return err.to_compile_error();
1356        }
1357    };
1358    let item: Item = match parse2(input) {
1359        Ok(item) => item,
1360        Err(err) => {
1361            return err.to_compile_error();
1362        }
1363    };
1364    output.extend(mock_it((attrs, item)));
1365    output
1366}
1367
1368fn do_automock(attrs: TokenStream, input: TokenStream) -> TokenStream {
1369    cfg_if! {
1370        if #[cfg(reprocheck)] {
1371            let ts_a = do_automock_once(attrs.clone(), input.clone());
1372            let ts_b = do_automock_once(attrs.clone(), input.clone());
1373            assert_eq!(ts_a.to_string(), ts_b.to_string());
1374        }
1375    }
1376    do_automock_once(attrs, input)
1377}
1378
1379#[cfg(test)]
1380mod t {
1381    use super::*;
1382
1383fn assert_contains(output: &str, tokens: TokenStream) {
1384    let s = tokens.to_string();
1385    assert!(output.contains(&s), "output does not contain {:?}", &s);
1386}
1387
1388fn assert_not_contains(output: &str, tokens: TokenStream) {
1389    let s = tokens.to_string();
1390    assert!(!output.contains(&s), "output contains {:?}", &s);
1391}
1392
1393/// Various tests for overall code generation that are hard or impossible to
1394/// write as integration tests
1395mod mock {
1396    use std::str::FromStr;
1397    use super::super::*;
1398    use super::*;
1399
1400    #[test]
1401    fn inherent_method_visibility() {
1402        let code = "
1403            Foo {
1404                fn foo(&self);
1405                pub fn bar(&self);
1406                pub(crate) fn baz(&self);
1407                pub(super) fn bean(&self);
1408                pub(in crate::outer) fn boom(&self);
1409            }
1410        ";
1411        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1412        let output = do_mock(ts).to_string();
1413        assert_not_contains(&output, quote!(pub fn foo));
1414        assert!(!output.contains(") fn foo"));
1415        assert_contains(&output, quote!(pub fn bar));
1416        assert_contains(&output, quote!(pub(crate) fn baz));
1417        assert_contains(&output, quote!(pub(super) fn bean));
1418        assert_contains(&output, quote!(pub(in crate::outer) fn boom));
1419
1420        assert_not_contains(&output, quote!(pub fn expect_foo));
1421        assert!(!output.contains("pub fn expect_foo"));
1422        assert!(!output.contains(") fn expect_foo"));
1423        assert_contains(&output, quote!(pub fn expect_bar));
1424        assert_contains(&output, quote!(pub(crate) fn expect_baz));
1425        assert_contains(&output, quote!(pub(super) fn expect_bean));
1426        assert_contains(&output, quote!(pub(in crate::outer) fn expect_boom));
1427    }
1428
1429    #[test]
1430    fn must_use_struct() {
1431        let code = "
1432            #[must_use]
1433            pub Foo {}
1434        ";
1435        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1436        let output = do_mock(ts).to_string();
1437        assert_contains(&output, quote!(#[must_use] pub struct MockFoo));
1438    }
1439
1440    #[test]
1441    fn specific_impl() {
1442        let code = "
1443            pub Foo<T: 'static> {}
1444            impl Bar for Foo<u32> {
1445                fn bar(&self);
1446            }
1447            impl Bar for Foo<i32> {
1448                fn bar(&self);
1449            }
1450        ";
1451        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1452        let output = do_mock(ts).to_string();
1453        assert_contains(&output, quote!(impl Bar for MockFoo<u32>));
1454        assert_contains(&output, quote!(impl Bar for MockFoo<i32>));
1455        // Ensure we don't duplicate the checkpoint function
1456        assert_not_contains(&output, quote!(
1457            self.Bar_expectations.checkpoint();
1458            self.Bar_expectations.checkpoint();
1459        ));
1460        // The expect methods should return specific types, not generic ones
1461        assert_contains(&output, quote!(
1462            pub fn expect_bar(&mut self) -> &mut __mock_MockFoo_Bar::__bar::Expectation<u32>
1463        ));
1464        assert_contains(&output, quote!(
1465            pub fn expect_bar(&mut self) -> &mut __mock_MockFoo_Bar::__bar::Expectation<i32>
1466        ));
1467    }
1468}
1469
1470/// Various tests for overall code generation that are hard or impossible to
1471/// write as integration tests
1472mod automock {
1473    use std::str::FromStr;
1474    use super::super::*;
1475    use super::*;
1476
1477    #[test]
1478    fn doc_comments() {
1479        let code = "
1480            mod foo {
1481                /// Function docs
1482                pub fn bar() { unimplemented!() }
1483            }
1484        ";
1485        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1486        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1487        let output = do_automock(attrs_ts, ts).to_string();
1488        assert_contains(&output, quote!(#[doc=" Function docs"] pub fn bar));
1489    }
1490
1491    #[test]
1492    fn method_visibility() {
1493        let code = "
1494        impl Foo {
1495            fn foo(&self) {}
1496            pub fn bar(&self) {}
1497            pub(super) fn baz(&self) {}
1498            pub(crate) fn bang(&self) {}
1499            pub(in super::x) fn bean(&self) {}
1500        }";
1501        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1502        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1503        let output = do_automock(attrs_ts, ts).to_string();
1504        assert_not_contains(&output, quote!(pub fn foo));
1505        assert!(!output.contains(") fn foo"));
1506        assert_not_contains(&output, quote!(pub fn expect_foo));
1507        assert!(!output.contains(") fn expect_foo"));
1508        assert_contains(&output, quote!(pub fn bar));
1509        assert_contains(&output, quote!(pub fn expect_bar));
1510        assert_contains(&output, quote!(pub(super) fn baz));
1511        assert_contains(&output, quote!(pub(super) fn expect_baz));
1512        assert_contains(&output, quote!(pub ( crate ) fn bang));
1513        assert_contains(&output, quote!(pub ( crate ) fn expect_bang));
1514        assert_contains(&output, quote!(pub ( in super :: x ) fn bean));
1515        assert_contains(&output, quote!(pub ( in super :: x ) fn expect_bean));
1516    }
1517
1518    #[test]
1519    fn must_use_method() {
1520        let code = "
1521        impl Foo {
1522            #[must_use]
1523            fn foo(&self) -> i32 {42}
1524        }";
1525        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1526        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1527        let output = do_automock(attrs_ts, ts).to_string();
1528        assert_not_contains(&output, quote!(#[must_use] fn expect_foo));
1529        assert_contains(&output, quote!(#[must_use] #[allow(dead_code)] fn foo));
1530    }
1531
1532    #[test]
1533    fn must_use_static_method() {
1534        let code = "
1535        impl Foo {
1536            #[must_use]
1537            fn foo() -> i32 {42}
1538        }";
1539        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1540        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1541        let output = do_automock(attrs_ts, ts).to_string();
1542        assert_not_contains(&output, quote!(#[must_use] fn expect));
1543        assert_not_contains(&output, quote!(#[must_use] fn foo_context));
1544        assert_contains(&output, quote!(#[must_use] #[allow(dead_code)] fn foo));
1545    }
1546
1547    #[test]
1548    fn must_use_trait() {
1549        let code = "
1550        #[must_use]
1551        trait Foo {}
1552        ";
1553        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1554        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1555        let output = do_automock(attrs_ts, ts).to_string();
1556        assert_not_contains(&output, quote!(#[must_use] struct MockFoo));
1557    }
1558
1559    #[test]
1560    #[should_panic(expected = "automock does not currently support structs with elided lifetimes")]
1561    fn elided_lifetimes() {
1562        let code = "impl X<'_> {}";
1563        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1564        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1565        do_automock(attrs_ts, ts).to_string();
1566    }
1567
1568    #[test]
1569    #[should_panic(expected = "can only mock inline modules")]
1570    fn external_module() {
1571        let code = "mod foo;";
1572        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1573        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1574        do_automock(attrs_ts, ts).to_string();
1575    }
1576
1577    #[test]
1578    fn trait_visibility() {
1579        let code = "
1580        pub(super) trait Foo {}
1581        ";
1582        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1583        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1584        let output = do_automock(attrs_ts, ts).to_string();
1585        assert_contains(&output, quote!(pub ( super ) struct MockFoo));
1586    }
1587}
1588
1589mod concretize_args {
1590    use super::*;
1591
1592    #[allow(clippy::needless_range_loop)] // Clippy's suggestion is worse
1593    fn check_concretize(
1594        sig: TokenStream,
1595        expected_inputs: &[TokenStream],
1596        expected_call_exprs: &[TokenStream],
1597        expected_sig_inputs: &[TokenStream])
1598    {
1599        let f: Signature = parse2(sig).unwrap();
1600        let (generics, inputs, call_exprs, altsig) = concretize_args(&f.generics, &f);
1601        assert!(generics.params.is_empty());
1602        assert_eq!(inputs.len(), expected_inputs.len());
1603        assert_eq!(call_exprs.len(), expected_call_exprs.len());
1604        for i in 0..inputs.len() {
1605            let actual = &inputs[i];
1606            let exp = &expected_inputs[i];
1607            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1608        }
1609        for i in 0..call_exprs.len() {
1610            let actual = &call_exprs[i];
1611            let exp = &expected_call_exprs[i];
1612            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1613        }
1614        for i in 0..altsig.inputs.len() {
1615            let actual = &altsig.inputs[i];
1616            let exp = &expected_sig_inputs[i];
1617            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1618        }
1619    }
1620
1621    #[test]
1622    fn bystanders() {
1623        check_concretize(
1624            quote!(fn foo<P: AsRef<Path>>(x: i32, p: P, y: &f64)),
1625            &[quote!(x: i32), quote!(p: &(dyn AsRef<Path>)), quote!(y: &f64)],
1626            &[quote!(x), quote!(&p), quote!(y)],
1627            &[quote!(x: i32), quote!(p: P), quote!(y: &f64)]
1628        );
1629    }
1630
1631    #[test]
1632    fn function_args() {
1633        check_concretize(
1634            quote!(fn foo<F1: Fn(u32) -> u32,
1635                          F2: FnMut(&mut u32) -> u32,
1636                          F3: FnOnce(u32) -> u32,
1637                          F4: Fn() + Send>(f1: F1, f2: F2, f3: F3, f4: F4)),
1638            &[quote!(f1: &(dyn Fn(u32) -> u32)),
1639              quote!(f2: &mut(dyn FnMut(&mut u32) -> u32)),
1640              quote!(f3: &(dyn FnOnce(u32) -> u32)),
1641              quote!(f4: &(dyn Fn() + Send))],
1642            &[quote!(&f1), quote!(&mut f2), quote!(&f3), quote!(&f4)],
1643            &[quote!(f1: F1), quote!(mut f2: F2), quote!(f3: F3), quote!(f4: F4)]
1644        );
1645    }
1646
1647    #[test]
1648    fn multi_bounds() {
1649        check_concretize(
1650            quote!(fn foo<P: AsRef<String> + AsMut<String>>(p: P)),
1651            &[quote!(p: &(dyn AsRef<String> + AsMut<String>))],
1652            &[quote!(&p)],
1653            &[quote!(p: P)],
1654        );
1655    }
1656
1657    #[test]
1658    fn mutable_reference_arg() {
1659        check_concretize(
1660            quote!(fn foo<P: AsMut<Path>>(p: &mut P)),
1661            &[quote!(p: &mut (dyn AsMut<Path>))],
1662            &[quote!(p)],
1663            &[quote!(p: &mut P)],
1664        );
1665    }
1666
1667    #[test]
1668    fn mutable_reference_multi_bounds() {
1669        check_concretize(
1670            quote!(fn foo<P: AsRef<String> + AsMut<String>>(p: &mut P)),
1671            &[quote!(p: &mut (dyn AsRef<String> + AsMut<String>))],
1672            &[quote!(p)],
1673            &[quote!(p: &mut P)]
1674        );
1675    }
1676
1677    #[test]
1678    fn reference_arg() {
1679        check_concretize(
1680            quote!(fn foo<P: AsRef<Path>>(p: &P)),
1681            &[quote!(p: &(dyn AsRef<Path>))],
1682            &[quote!(p)],
1683            &[quote!(p: &P)]
1684        );
1685    }
1686
1687    #[test]
1688    fn simple() {
1689        check_concretize(
1690            quote!(fn foo<P: AsRef<Path>>(p: P)),
1691            &[quote!(p: &(dyn AsRef<Path>))],
1692            &[quote!(&p)],
1693            &[quote!(p: P)],
1694        );
1695    }
1696
1697    #[test]
1698    fn slice() {
1699        check_concretize(
1700            quote!(fn foo<P: AsRef<Path>>(p: &[P])),
1701            &[quote!(p: &[&(dyn AsRef<Path>)])],
1702            &[quote!(&(0..p.len()).map(|__mockall_i| &p[__mockall_i] as &(dyn AsRef<Path>)).collect::<Vec<_>>())],
1703            &[quote!(p: &[P])]
1704        );
1705    }
1706
1707    #[test]
1708    fn slice_with_multi_bounds() {
1709        check_concretize(
1710            quote!(fn foo<P: AsRef<Path> + AsMut<String>>(p: &[P])),
1711            &[quote!(p: &[&(dyn AsRef<Path> + AsMut<String>)])],
1712            &[quote!(&(0..p.len()).map(|__mockall_i| &p[__mockall_i] as &(dyn AsRef<Path> + AsMut<String>)).collect::<Vec<_>>())],
1713            &[quote!(p: &[P])]
1714        );
1715    }
1716
1717    #[test]
1718    fn where_clause() {
1719        check_concretize(
1720            quote!(fn foo<P>(p: P) where P: AsRef<Path>),
1721            &[quote!(p: &(dyn AsRef<Path>))],
1722            &[quote!(&p)],
1723            &[quote!(p: P)]
1724        );
1725    }
1726}
1727
1728mod declosurefy {
1729    use super::*;
1730
1731    fn check_declosurefy(
1732        sig: TokenStream,
1733        expected_inputs: &[TokenStream],
1734        expected_call_exprs: &[TokenStream])
1735    {
1736        let f: Signature = parse2(sig).unwrap();
1737        let (generics, inputs, call_exprs) =
1738            declosurefy(&f.generics, &f.inputs);
1739        assert!(generics.params.is_empty());
1740        assert_eq!(inputs.len(), expected_inputs.len());
1741        assert_eq!(call_exprs.len(), expected_call_exprs.len());
1742        for i in 0..inputs.len() {
1743            let actual = &inputs[i];
1744            let exp = &expected_inputs[i];
1745            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1746        }
1747        for i in 0..call_exprs.len() {
1748            let actual = &call_exprs[i];
1749            let exp = &expected_call_exprs[i];
1750            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1751        }
1752    }
1753
1754    #[test]
1755    fn bounds() {
1756        check_declosurefy(
1757            quote!(fn foo<F: Fn(u32) -> u32 + Send>(f: F)),
1758            &[quote!(f: Box<dyn Fn(u32) -> u32 + Send>)],
1759            &[quote!(Box::new(f))]
1760        );
1761    }
1762
1763    #[test]
1764    fn r#fn() {
1765        check_declosurefy(
1766            quote!(fn foo<F: Fn(u32) -> u32>(f: F)),
1767            &[quote!(f: Box<dyn Fn(u32) -> u32>)],
1768            &[quote!(Box::new(f))]
1769        );
1770    }
1771
1772    #[test]
1773    fn fn_mut() {
1774        check_declosurefy(
1775            quote!(fn foo<F: FnMut(u32) -> u32>(f: F)),
1776            &[quote!(f: Box<dyn FnMut(u32) -> u32>)],
1777            &[quote!(Box::new(f))]
1778        );
1779    }
1780
1781    #[test]
1782    fn fn_once() {
1783        check_declosurefy(
1784            quote!(fn foo<F: FnOnce(u32) -> u32>(f: F)),
1785            &[quote!(f: Box<dyn FnOnce(u32) -> u32>)],
1786            &[quote!(Box::new(f))]
1787        );
1788    }
1789
1790    #[test]
1791    fn mutable_pattern() {
1792        check_declosurefy(
1793            quote!(fn foo<F: FnMut(u32) -> u32>(mut f: F)),
1794            &[quote!(f: Box<dyn FnMut(u32) -> u32>)],
1795            &[quote!(Box::new(f))]
1796        );
1797    }
1798
1799    #[test]
1800    fn where_clause() {
1801        check_declosurefy(
1802            quote!(fn foo<F>(f: F) where F: Fn(u32) -> u32),
1803            &[quote!(f: Box<dyn Fn(u32) -> u32>)],
1804            &[quote!(Box::new(f))]
1805        );
1806    }
1807
1808    #[test]
1809    fn where_clause_with_bounds() {
1810        check_declosurefy(
1811            quote!(fn foo<F>(f: F) where F: Fn(u32) -> u32 + Send),
1812            &[quote!(f: Box<dyn Fn(u32) -> u32 + Send>)],
1813            &[quote!(Box::new(f))]
1814        );
1815    }
1816}
1817
1818mod deimplify {
1819    use super::*;
1820
1821    fn check_deimplify(orig_ts: TokenStream, expected_ts: TokenStream) {
1822        let mut orig: ReturnType = parse2(orig_ts).unwrap();
1823        let expected: ReturnType = parse2(expected_ts).unwrap();
1824        deimplify(&mut orig);
1825        assert_eq!(quote!(#orig).to_string(), quote!(#expected).to_string());
1826    }
1827
1828    // Future is a special case
1829    #[test]
1830    fn impl_future() {
1831        check_deimplify(
1832            quote!(-> impl Future<Output=i32>),
1833            quote!(-> ::std::pin::Pin<Box<dyn Future<Output=i32>>>)
1834        );
1835    }
1836
1837    // Future is a special case, wherever it appears
1838    #[test]
1839    fn impl_future_reverse() {
1840        check_deimplify(
1841            quote!(-> impl Send + Future<Output=i32>),
1842            quote!(-> ::std::pin::Pin<Box<dyn Send + Future<Output=i32>>>)
1843        );
1844    }
1845
1846    // Stream is a special case
1847    #[test]
1848    fn impl_stream() {
1849        check_deimplify(
1850            quote!(-> impl Stream<Item=i32>),
1851            quote!(-> ::std::pin::Pin<Box<dyn Stream<Item=i32>>>)
1852        );
1853    }
1854
1855    #[test]
1856    fn impl_trait() {
1857        check_deimplify(
1858            quote!(-> impl Foo),
1859            quote!(-> Box<dyn Foo>)
1860        );
1861    }
1862
1863    // With extra bounds
1864    #[test]
1865    fn impl_trait2() {
1866        check_deimplify(
1867            quote!(-> impl Foo + Send),
1868            quote!(-> Box<dyn Foo + Send>)
1869        );
1870    }
1871}
1872
1873mod deselfify {
1874    use super::*;
1875
1876    fn check_deselfify(
1877        orig_ts: TokenStream,
1878        actual_ts: TokenStream,
1879        generics_ts: TokenStream,
1880        expected_ts: TokenStream)
1881    {
1882        let mut ty: Type = parse2(orig_ts).unwrap();
1883        let actual: Ident = parse2(actual_ts).unwrap();
1884        let generics: Generics = parse2(generics_ts).unwrap();
1885        let expected: Type = parse2(expected_ts).unwrap();
1886        deselfify(&mut ty, &actual, &generics);
1887        assert_eq!(quote!(#ty).to_string(),
1888                   quote!(#expected).to_string());
1889    }
1890
1891    #[test]
1892    fn arc() {
1893        check_deselfify(
1894            quote!(Arc<Self>),
1895            quote!(Foo),
1896            quote!(),
1897            quote!(Arc<Foo>)
1898        );
1899    }
1900    #[test]
1901    fn future() {
1902        check_deselfify(
1903            quote!(Box<dyn Future<Output=Self>>),
1904            quote!(Foo),
1905            quote!(),
1906            quote!(Box<dyn Future<Output=Foo>>)
1907        );
1908    }
1909
1910    #[test]
1911    fn qself() {
1912        check_deselfify(
1913            quote!(<Self as Self>::Self),
1914            quote!(Foo),
1915            quote!(),
1916            quote!(<Foo as Foo>::Foo)
1917        );
1918    }
1919
1920    #[test]
1921    fn trait_object() {
1922        check_deselfify(
1923            quote!(Box<dyn Self>),
1924            quote!(Foo),
1925            quote!(),
1926            quote!(Box<dyn Foo>)
1927        );
1928    }
1929
1930    // A trait object with multiple bounds
1931    #[test]
1932    fn trait_object2() {
1933        check_deselfify(
1934            quote!(Box<dyn Self + Send>),
1935            quote!(Foo),
1936            quote!(),
1937            quote!(Box<dyn Foo + Send>)
1938        );
1939    }
1940}
1941
1942mod dewhereselfify {
1943    use super::*;
1944
1945    #[test]
1946    fn lifetime() {
1947        let mut meth: ImplItemFn = parse2(quote!(
1948                fn foo<'a>(&self) where 'a: 'static, Self: Sized {}
1949        )).unwrap();
1950        let expected: ImplItemFn = parse2(quote!(
1951                fn foo<'a>(&self) where 'a: 'static {}
1952        )).unwrap();
1953        dewhereselfify(&mut meth.sig.generics);
1954        assert_eq!(meth, expected);
1955    }
1956
1957    #[test]
1958    fn normal_method() {
1959        let mut meth: ImplItemFn = parse2(quote!(
1960                fn foo(&self) where Self: Sized {}
1961        )).unwrap();
1962        let expected: ImplItemFn = parse2(quote!(
1963                fn foo(&self) {}
1964        )).unwrap();
1965        dewhereselfify(&mut meth.sig.generics);
1966        assert_eq!(meth, expected);
1967    }
1968
1969    #[test]
1970    fn with_real_generics() {
1971        let mut meth: ImplItemFn = parse2(quote!(
1972                fn foo<T>(&self, t: T) where Self: Sized, T: Copy {}
1973        )).unwrap();
1974        let expected: ImplItemFn = parse2(quote!(
1975                fn foo<T>(&self, t: T) where T: Copy {}
1976        )).unwrap();
1977        dewhereselfify(&mut meth.sig.generics);
1978        assert_eq!(meth, expected);
1979    }
1980}
1981
1982mod gen_keyid {
1983    use super::*;
1984
1985    fn check_gen_keyid(orig: TokenStream, expected: TokenStream) {
1986        let g: Generics = parse2(orig).unwrap();
1987        let keyid = gen_keyid(&g);
1988        assert_eq!(quote!(#keyid).to_string(), quote!(#expected).to_string());
1989    }
1990
1991    #[test]
1992    fn empty() {
1993        check_gen_keyid(quote!(), quote!(<()>));
1994    }
1995
1996    #[test]
1997    fn onetype() {
1998        check_gen_keyid(quote!(<T>), quote!(<T>));
1999    }
2000
2001    #[test]
2002    fn twotypes() {
2003        check_gen_keyid(quote!(<T, V>), quote!(<(T, V)>));
2004    }
2005}
2006
2007mod merge_generics {
2008    use super::*;
2009
2010    #[test]
2011    fn both() {
2012        let mut g1: Generics = parse2(quote!(<T: 'static, V: Copy> )).unwrap();
2013        let wc1: WhereClause = parse2(quote!(where T: Default)).unwrap();
2014        g1.where_clause = Some(wc1);
2015
2016        let mut g2: Generics = parse2(quote!(<Q: Send, V: Clone>)).unwrap();
2017        let wc2: WhereClause = parse2(quote!(where T: Sync, Q: Debug)).unwrap();
2018        g2.where_clause = Some(wc2);
2019
2020        let gm = super::merge_generics(&g1, &g2);
2021        let gm_wc = &gm.where_clause;
2022
2023        let ge: Generics = parse2(quote!(
2024                <T: 'static, V: Copy + Clone, Q: Send>
2025        )).unwrap();
2026        let wce: WhereClause = parse2(quote!(
2027            where T: Default + Sync, Q: Debug
2028        )).unwrap();
2029
2030        assert_eq!(quote!(#ge #wce).to_string(),
2031                   quote!(#gm #gm_wc).to_string());
2032    }
2033
2034    #[test]
2035    fn eq() {
2036        let mut g1: Generics = parse2(quote!(<T: 'static, V: Copy> )).unwrap();
2037        let wc1: WhereClause = parse2(quote!(where T: Default)).unwrap();
2038        g1.where_clause = Some(wc1.clone());
2039
2040        let gm = super::merge_generics(&g1, &g1);
2041        let gm_wc = &gm.where_clause;
2042
2043        assert_eq!(quote!(#g1 #wc1).to_string(),
2044                   quote!(#gm #gm_wc).to_string());
2045    }
2046
2047    #[test]
2048    fn lhs_only() {
2049        let mut g1: Generics = parse2(quote!(<T: 'static, V: Copy> )).unwrap();
2050        let wc1: WhereClause = parse2(quote!(where T: Default)).unwrap();
2051        g1.where_clause = Some(wc1.clone());
2052
2053        let g2 = Generics::default();
2054
2055        let gm = super::merge_generics(&g1, &g2);
2056        let gm_wc = &gm.where_clause;
2057
2058        assert_eq!(quote!(#g1 #wc1).to_string(),
2059                   quote!(#gm #gm_wc).to_string());
2060    }
2061
2062    #[test]
2063    fn lhs_wc_only() {
2064        let mut g1 = Generics::default();
2065        let wc1: WhereClause = parse2(quote!(where T: Default)).unwrap();
2066        g1.where_clause = Some(wc1.clone());
2067
2068        let g2 = Generics::default();
2069
2070        let gm = super::merge_generics(&g1, &g2);
2071        let gm_wc = &gm.where_clause;
2072
2073        assert_eq!(quote!(#g1 #wc1).to_string(),
2074                   quote!(#gm #gm_wc).to_string());
2075    }
2076
2077    #[test]
2078    fn rhs_only() {
2079        let g1 = Generics::default();
2080        let mut g2: Generics = parse2(quote!(<Q: Send, V: Clone>)).unwrap();
2081        let wc2: WhereClause = parse2(quote!(where T: Sync, Q: Debug)).unwrap();
2082        g2.where_clause = Some(wc2.clone());
2083
2084        let gm = super::merge_generics(&g1, &g2);
2085        let gm_wc = &gm.where_clause;
2086
2087        assert_eq!(quote!(#g2 #wc2).to_string(),
2088                   quote!(#gm #gm_wc).to_string());
2089    }
2090}
2091
2092mod supersuperfy {
2093    use super::*;
2094
2095    fn check_supersuperfy(orig: TokenStream, expected: TokenStream) {
2096        let orig_ty: Type = parse2(orig).unwrap();
2097        let expected_ty: Type = parse2(expected).unwrap();
2098        let output = supersuperfy(&orig_ty, 1);
2099        assert_eq!(quote!(#output).to_string(),
2100                   quote!(#expected_ty).to_string());
2101    }
2102
2103    #[test]
2104    fn array() {
2105        check_supersuperfy(
2106            quote!([super::X; n]),
2107            quote!([super::super::X; n])
2108        );
2109    }
2110
2111    #[test]
2112    fn barefn() {
2113        check_supersuperfy(
2114            quote!(fn(super::A) -> super::B),
2115            quote!(fn(super::super::A) -> super::super::B)
2116        );
2117    }
2118
2119    #[test]
2120    fn group() {
2121        let orig = TypeGroup {
2122            group_token: token::Group::default(),
2123            elem: Box::new(parse2(quote!(super::T)).unwrap())
2124        };
2125        let expected = TypeGroup {
2126            group_token: token::Group::default(),
2127            elem: Box::new(parse2(quote!(super::super::T)).unwrap())
2128        };
2129        let output = supersuperfy(&Type::Group(orig), 1);
2130        assert_eq!(quote!(#output).to_string(),
2131                   quote!(#expected).to_string());
2132    }
2133
2134    // Just check that it doesn't panic
2135    #[test]
2136    fn infer() {
2137        check_supersuperfy( quote!(_), quote!(_));
2138    }
2139
2140    // Just check that it doesn't panic
2141    #[test]
2142    fn never() {
2143        check_supersuperfy( quote!(!), quote!(!));
2144    }
2145
2146    #[test]
2147    fn paren() {
2148        check_supersuperfy(
2149            quote!((super::X)),
2150            quote!((super::super::X))
2151        );
2152    }
2153
2154    #[test]
2155    fn path() {
2156        check_supersuperfy(
2157            quote!(::super::SuperT<u32>),
2158            quote!(::super::super::SuperT<u32>)
2159        );
2160    }
2161
2162    #[test]
2163    fn path_with_qself() {
2164        check_supersuperfy(
2165            quote!(<super::X as super::Y>::Foo<u32>),
2166            quote!(<super::super::X as super::super::Y>::Foo<u32>),
2167        );
2168    }
2169
2170    #[test]
2171    fn angle_bracketed_generic_arguments() {
2172        check_supersuperfy(
2173            quote!(mod_::T<super::X>),
2174            quote!(mod_::T<super::super::X>)
2175        );
2176    }
2177
2178    #[test]
2179    fn ptr() {
2180        check_supersuperfy(
2181            quote!(*const super::X),
2182            quote!(*const super::super::X)
2183        );
2184    }
2185
2186    #[test]
2187    fn reference() {
2188        check_supersuperfy(
2189            quote!(&'a mut super::X),
2190            quote!(&'a mut super::super::X)
2191        );
2192    }
2193
2194    #[test]
2195    fn slice() {
2196        check_supersuperfy(
2197            quote!([super::X]),
2198            quote!([super::super::X])
2199        );
2200    }
2201
2202    #[test]
2203    fn trait_object() {
2204        check_supersuperfy(
2205            quote!(dyn super::X + super::Y),
2206            quote!(dyn super::super::X + super::super::Y)
2207        );
2208    }
2209
2210    #[test]
2211    fn tuple() {
2212        check_supersuperfy(
2213            quote!((super::A, super::B)),
2214            quote!((super::super::A, super::super::B))
2215        );
2216    }
2217}
2218
2219mod supersuperfy_generics {
2220    use super::*;
2221
2222    fn check_supersuperfy_generics(
2223        orig: TokenStream,
2224        orig_wc: TokenStream,
2225        expected: TokenStream,
2226        expected_wc: TokenStream)
2227    {
2228        let mut orig_g: Generics = parse2(orig).unwrap();
2229        orig_g.where_clause = parse2(orig_wc).unwrap();
2230        let mut expected_g: Generics = parse2(expected).unwrap();
2231        expected_g.where_clause = parse2(expected_wc).unwrap();
2232        let mut output: Generics = orig_g;
2233        supersuperfy_generics(&mut output, 1);
2234        let (o_ig, o_tg, o_wc) = output.split_for_impl();
2235        let (e_ig, e_tg, e_wc) = expected_g.split_for_impl();
2236        assert_eq!(quote!(#o_ig).to_string(), quote!(#e_ig).to_string());
2237        assert_eq!(quote!(#o_tg).to_string(), quote!(#e_tg).to_string());
2238        assert_eq!(quote!(#o_wc).to_string(), quote!(#e_wc).to_string());
2239    }
2240
2241    #[test]
2242    fn default() {
2243        check_supersuperfy_generics(
2244            quote!(<T: X = super::Y>), quote!(),
2245            quote!(<T: X = super::super::Y>), quote!(),
2246        );
2247    }
2248
2249    #[test]
2250    fn empty() {
2251        check_supersuperfy_generics(quote!(), quote!(), quote!(), quote!());
2252    }
2253
2254    #[test]
2255    fn everything() {
2256        check_supersuperfy_generics(
2257            quote!(<T: super::A = super::B>),
2258            quote!(where super::C: super::D),
2259            quote!(<T: super::super::A = super::super::B>),
2260            quote!(where super::super::C: super::super::D),
2261        );
2262    }
2263
2264    #[test]
2265    fn bound() {
2266        check_supersuperfy_generics(
2267            quote!(<T: super::A>), quote!(),
2268            quote!(<T: super::super::A>), quote!(),
2269        );
2270    }
2271
2272    #[test]
2273    fn closure() {
2274        check_supersuperfy_generics(
2275            quote!(<F: Fn(u32) -> super::SuperT>), quote!(),
2276            quote!(<F: Fn(u32) -> super::super::SuperT>), quote!(),
2277        );
2278    }
2279
2280    #[test]
2281    fn wc_bounded_ty() {
2282        check_supersuperfy_generics(
2283            quote!(), quote!(where super::T: X),
2284            quote!(), quote!(where super::super::T: X),
2285        );
2286    }
2287
2288    #[test]
2289    fn wc_bounds() {
2290        check_supersuperfy_generics(
2291            quote!(), quote!(where T: super::X),
2292            quote!(), quote!(where T: super::super::X),
2293        );
2294    }
2295}
2296}