mockall_derive/
lib.rs

1// vim: tw=80
2//! Proc Macros for use with Mockall
3//!
4//! You probably don't want to use this crate directly.  Instead, you should use
5//! its reexports via the [`mockall`](https://docs.rs/mockall/latest/mockall)
6//! crate.
7
8#![cfg_attr(feature = "nightly_derive", feature(proc_macro_diagnostic))]
9#![cfg_attr(test, deny(warnings))]
10// This lint is unhelpful.  See
11// https://github.com/rust-lang/rust-clippy/discussions/14256
12#![allow(clippy::doc_overindented_list_items)]
13
14use cfg_if::cfg_if;
15use proc_macro2::{Span, TokenStream};
16use quote::{ToTokens, format_ident, quote};
17use std::{
18    env,
19    hash::BuildHasherDefault
20};
21use syn::{
22    *,
23    punctuated::Punctuated,
24    spanned::Spanned
25};
26
27mod automock;
28mod mock_function;
29mod mock_item;
30mod mock_item_struct;
31mod mock_trait;
32mod mockable_item;
33mod mockable_struct;
34use crate::automock::Attrs;
35use crate::mockable_struct::MockableStruct;
36use crate::mock_item::MockItem;
37use crate::mock_item_struct::MockItemStruct;
38use crate::mockable_item::MockableItem;
39
40// Define deterministic aliases for these common types.
41type HashMap<K, V> = std::collections::HashMap<K, V, BuildHasherDefault<std::collections::hash_map::DefaultHasher>>;
42type HashSet<K> = std::collections::HashSet<K, BuildHasherDefault<std::collections::hash_map::DefaultHasher>>;
43
44cfg_if! {
45    // proc-macro2's Span::unstable method requires the nightly feature, and it
46    // doesn't work in test mode.
47    // https://github.com/alexcrichton/proc-macro2/issues/159
48    if #[cfg(all(feature = "nightly_derive", not(test)))] {
49        fn compile_error(span: Span, msg: &str) {
50            span.unstable()
51                .error(msg)
52                .emit();
53        }
54    } else {
55        fn compile_error(_span: Span, msg: &str) {
56            panic!("{msg}.  More information may be available when mockall is built with the \"nightly\" feature.");
57        }
58    }
59}
60
61/// Does this Attribute represent Mockall's "concretize" pseudo-attribute?
62fn is_concretize(attr: &Attribute) -> bool {
63    if attr.path().segments.last().unwrap().ident == "concretize" {
64        true
65    } else if attr.path().is_ident("cfg_attr") {
66        match &attr.meta {
67            Meta::List(ml) => {
68                ml.tokens.to_string().contains("concretize")
69            },
70            // cfg_attr should always contain a list
71            _ => false,
72        }
73    } else {
74        false
75    }
76}
77
78/// replace generic arguments with concrete trait object arguments
79///
80/// # Return
81///
82/// * A Generics object with the concretized types removed
83/// * An array of transformed argument types, suitable for matchers and
84///   returners
85/// * An array of expressions that should be passed to the `call` function.
86fn concretize_args(gen: &Generics, sig: &Signature) ->
87    (Generics, Punctuated<FnArg, Token![,]>, Vec<TokenStream>, Signature)
88{
89    let args = &sig.inputs;
90    let mut hm = HashMap::default();
91    let mut needs_muts = HashMap::default();
92
93    let mut save_types = |ident: &Ident, tpb: &Punctuated<TypeParamBound, Token![+]>| {
94        if !tpb.is_empty() {
95            let mut pat = quote!(&(dyn #tpb));
96            let mut needs_mut = false;
97            if let Some(TypeParamBound::Trait(t)) = tpb.first() {
98                if t.path.segments.first().map(|seg| &seg.ident == "FnMut")
99                    .unwrap_or(false)
100                {
101                    // For FnMut arguments, the rfunc needs a mutable reference
102                    pat = quote!(&mut (dyn #tpb));
103                    needs_mut = true;
104                }
105            }
106            if let Ok(newty) = parse2::<Type>(pat) {
107                // substitute T arguments
108                let subst_ty: Type = parse2(quote!(#ident)).unwrap();
109                needs_muts.insert(subst_ty.clone(), needs_mut);
110                hm.insert(subst_ty, (newty.clone(), None));
111
112                // substitute &T arguments
113                let subst_ty: Type = parse2(quote!(&#ident)).unwrap();
114                needs_muts.insert(subst_ty.clone(), needs_mut);
115                hm.insert(subst_ty, (newty, None));
116            } else {
117                compile_error(tpb.span(),
118                    "Type cannot be made into a trait object");
119            }
120
121            if let Ok(newty) = parse2::<Type>(quote!(&mut (dyn #tpb))) {
122                // substitute &mut T arguments
123                let subst_ty: Type = parse2(quote!(&mut #ident)).unwrap();
124                needs_muts.insert(subst_ty.clone(), needs_mut);
125                hm.insert(subst_ty, (newty, None));
126            } else {
127                compile_error(tpb.span(),
128                    "Type cannot be made into a trait object");
129            }
130
131            // I wish we could substitute &[T] arguments.  But there's no way
132            // for the mock method to turn &[T] into &[&dyn T].
133            if let Ok(newty) = parse2::<Type>(quote!(&[&(dyn #tpb)])) {
134                let subst_ty: Type = parse2(quote!(&[#ident])).unwrap();
135                needs_muts.insert(subst_ty.clone(), needs_mut);
136                hm.insert(subst_ty, (newty, Some(tpb.clone())));
137            } else {
138                compile_error(tpb.span(),
139                    "Type cannot be made into a trait object");
140            }
141        }
142    };
143
144    for g in gen.params.iter() {
145        if let GenericParam::Type(tp) = g {
146            save_types(&tp.ident, &tp.bounds);
147            // else there had better be a where clause
148        }
149    }
150    if let Some(wc) = &gen.where_clause {
151        for pred in wc.predicates.iter() {
152            if let WherePredicate::Type(pt) = pred {
153                let bounded_ty = &pt.bounded_ty;
154                if let Ok(ident) = parse2::<Ident>(quote!(#bounded_ty)) {
155                    save_types(&ident, &pt.bounds);
156                } else {
157                    // We can't yet handle where clauses this complicated
158                }
159            }
160        }
161    }
162
163    let outg = Generics {
164        lt_token: None,
165        gt_token: None,
166        params: Punctuated::new(),
167        where_clause: None
168    };
169    let outargs = args.iter().map(|arg| {
170        if let FnArg::Typed(pt) = arg {
171            let mut call_pt = pt.clone();
172            demutify_arg(&mut call_pt);
173            if let Some((newty, _)) = hm.get(&pt.ty) {
174                FnArg::Typed(PatType {
175                    attrs: Vec::default(),
176                    pat: call_pt.pat,
177                    colon_token: pt.colon_token,
178                    ty: Box::new(newty.clone())
179                })
180            } else {
181                FnArg::Typed(PatType {
182                    attrs: Vec::default(),
183                    pat: call_pt.pat,
184                    colon_token: pt.colon_token,
185                    ty: pt.ty.clone()
186                })
187            }
188        } else {
189            arg.clone()
190        }
191    }).collect();
192
193    // Finally, Reference any concretizing arguments
194    // use filter_map to remove the &self argument
195    let call_exprs = args.iter().filter_map(|arg| {
196        match arg {
197            FnArg::Typed(pt) => {
198                let mut pt2 = pt.clone();
199                demutify_arg(&mut pt2);
200                let pat = &pt2.pat;
201                if pat_is_self(pat) {
202                    None
203                } else if let Some((_, newbound)) = hm.get(&pt.ty) {
204                    if let Type::Reference(tr) = &*pt.ty {
205                        if let Type::Slice(_ts) = &*tr.elem {
206                            // Assume _ts is the generic type or we wouldn't be
207                            // here
208                            Some(quote!(
209                                &(0..#pat.len())
210                                .map(|__mockall_i| &#pat[__mockall_i] as &(dyn #newbound))
211                                .collect::<Vec<_>>()
212                            ))
213                        } else {
214                            Some(quote!(#pat))
215                        }
216                    } else if needs_muts.get(&pt.ty).cloned().unwrap_or(false) {
217                        Some(quote!(&mut #pat))
218                    } else {
219                        Some(quote!(&#pat))
220                    }
221                } else {
222                    Some(quote!(#pat))
223                }
224            },
225            FnArg::Receiver(_) => None,
226        }
227    }).collect();
228
229    // Add any necessary "mut" qualifiers to the Signature
230    let mut altsig = sig.clone();
231    for arg in altsig.inputs.iter_mut() {
232        if let FnArg::Typed(pt) = arg {
233            if needs_muts.get(&pt.ty).cloned().unwrap_or(false) {
234                if let Pat::Ident(pi) = &mut *pt.pat {
235                    pi.mutability = Some(Token![mut](pi.mutability.span()));
236                } else {
237                    compile_error(pt.pat.span(),
238                                    "This Pat type is not yet supported by Mockall when used as an argument to a concretized function.")
239                }
240            }
241        }
242    }
243
244    (outg, outargs, call_exprs, altsig)
245}
246
247fn deanonymize_lifetime(lt: &mut Lifetime) {
248    if lt.ident == "_" {
249        lt.ident = format_ident!("static");
250    }
251}
252
253fn deanonymize_path(path: &mut Path) {
254    for seg in path.segments.iter_mut() {
255        match &mut seg.arguments {
256            PathArguments::None => (),
257            PathArguments::AngleBracketed(abga) => {
258                for ga in abga.args.iter_mut() {
259                    if let GenericArgument::Lifetime(lt) = ga {
260                        deanonymize_lifetime(lt)
261                    }
262                }
263            },
264            _ => compile_error(seg.arguments.span(),
265                "Methods returning functions are TODO"),
266        }
267    }
268}
269
270/// Replace any references to the anonymous lifetime `'_` with `'static`.
271fn deanonymize(literal_type: &mut Type) {
272    match literal_type {
273        Type::Array(ta) => deanonymize(ta.elem.as_mut()),
274        Type::BareFn(tbf) => {
275            if let ReturnType::Type(_, ref mut bt) = tbf.output {
276                deanonymize(bt.as_mut());
277            }
278            for input in tbf.inputs.iter_mut() {
279                deanonymize(&mut input.ty);
280            }
281        },
282        Type::Group(tg) => deanonymize(tg.elem.as_mut()),
283        Type::Infer(_) => (),
284        Type::Never(_) => (),
285        Type::Paren(tp) => deanonymize(tp.elem.as_mut()),
286        Type::Path(tp) => {
287            if let Some(ref mut qself) = tp.qself {
288                deanonymize(qself.ty.as_mut());
289            }
290            deanonymize_path(&mut tp.path);
291        },
292        Type::Ptr(tptr) => deanonymize(tptr.elem.as_mut()),
293        Type::Reference(tr) => {
294            if let Some(lt) = tr.lifetime.as_mut() {
295                deanonymize_lifetime(lt)
296            }
297            deanonymize(tr.elem.as_mut());
298        },
299        Type::Slice(s) => deanonymize(s.elem.as_mut()),
300        Type::TraitObject(tto) => {
301            for tpb in tto.bounds.iter_mut() {
302                match tpb {
303                    TypeParamBound::Trait(tb) => deanonymize_path(&mut tb.path),
304                    TypeParamBound::Lifetime(lt) => deanonymize_lifetime(lt),
305                    _ => ()
306                }
307            }
308        },
309        Type::Tuple(tt) => {
310            for ty in tt.elems.iter_mut() {
311                deanonymize(ty)
312            }
313        }
314        x => compile_error(x.span(), "Unimplemented type for deanonymize")
315    }
316}
317
318// If there are any closures in the argument list, turn them into boxed
319// functions
320fn declosurefy(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
321    (Generics, Punctuated<FnArg, Token![,]>, Vec<TokenStream>)
322{
323    let mut hm = HashMap::default();
324
325    let mut save_fn_types = |ident: &Ident, bounds: &Punctuated<TypeParamBound, Token![+]>|
326    {
327        for tpb in bounds.iter() {
328            if let TypeParamBound::Trait(tb) = tpb {
329                let fident = &tb.path.segments.last().unwrap().ident;
330                if ["Fn", "FnMut", "FnOnce"].iter().any(|s| fident == *s) {
331                    let newty: Type = parse2(quote!(Box<dyn #bounds>)).unwrap();
332                    let subst_ty: Type = parse2(quote!(#ident)).unwrap();
333                    assert!(hm.insert(subst_ty, newty).is_none(),
334                        "A generic parameter had two Fn bounds?");
335                }
336            }
337        }
338    };
339
340    // First, build a HashMap of all Fn generic types
341    for g in gen.params.iter() {
342        if let GenericParam::Type(tp) = g {
343            save_fn_types(&tp.ident, &tp.bounds);
344        }
345    }
346    if let Some(wc) = &gen.where_clause {
347        for pred in wc.predicates.iter() {
348            if let WherePredicate::Type(pt) = pred {
349                let bounded_ty = &pt.bounded_ty;
350                if let Ok(ident) = parse2::<Ident>(quote!(#bounded_ty)) {
351                    save_fn_types(&ident, &pt.bounds);
352                } else {
353                    // We can't yet handle where clauses this complicated
354                }
355            }
356        }
357    }
358
359    // Then remove those types from both the Generics' params and where clause
360    let should_remove = |ident: &Ident| {
361            let ty: Type = parse2(quote!(#ident)).unwrap();
362            hm.contains_key(&ty)
363    };
364    let params = gen.params.iter()
365        .filter(|g| {
366            if let GenericParam::Type(tp) = g {
367                !should_remove(&tp.ident)
368            } else {
369                true
370            }
371        }).cloned()
372        .collect::<Punctuated<_, _>>();
373    let mut wc2 = gen.where_clause.clone();
374    if let Some(wc) = &mut wc2 {
375        wc.predicates = wc.predicates.iter()
376            .filter(|wp| {
377                if let WherePredicate::Type(pt) = wp {
378                    let bounded_ty = &pt.bounded_ty;
379                    if let Ok(ident) = parse2::<Ident>(quote!(#bounded_ty)) {
380                        !should_remove(&ident)
381                    } else {
382                        // We can't yet handle where clauses this complicated
383                        true
384                    }
385                } else {
386                    true
387                }
388            }).cloned()
389            .collect::<Punctuated<_, _>>();
390        if wc.predicates.is_empty() {
391            wc2 = None;
392        }
393    }
394    let outg = Generics {
395        lt_token: if params.is_empty() { None } else { gen.lt_token },
396        gt_token: if params.is_empty() { None } else { gen.gt_token },
397        params,
398        where_clause: wc2
399    };
400
401    // Next substitute Box<Fn> into the arguments
402    let outargs = args.iter().map(|arg| {
403        if let FnArg::Typed(pt) = arg {
404            let mut immutable_pt = pt.clone();
405            demutify_arg(&mut immutable_pt);
406            if let Some(newty) = hm.get(&pt.ty) {
407                FnArg::Typed(PatType {
408                    attrs: Vec::default(),
409                    pat: immutable_pt.pat,
410                    colon_token: pt.colon_token,
411                    ty: Box::new(newty.clone())
412                })
413            } else {
414                FnArg::Typed(PatType {
415                    attrs: Vec::default(),
416                    pat: immutable_pt.pat,
417                    colon_token: pt.colon_token,
418                    ty: pt.ty.clone()
419                })
420            }
421        } else {
422            arg.clone()
423        }
424    }).collect();
425
426    // Finally, Box any closure arguments
427    // use filter_map to remove the &self argument
428    let callargs = args.iter().filter_map(|arg| {
429        match arg {
430            FnArg::Typed(pt) => {
431                let mut pt2 = pt.clone();
432                demutify_arg(&mut pt2);
433                let pat = &pt2.pat;
434                if pat_is_self(pat) {
435                    None
436                } else if hm.contains_key(&pt.ty) {
437                    Some(quote!(Box::new(#pat)))
438                } else {
439                    Some(quote!(#pat))
440                }
441            },
442            FnArg::Receiver(_) => None,
443        }
444    }).collect();
445    (outg, outargs, callargs)
446}
447
448/// Replace any "impl trait" types with "Box<dyn trait>" or equivalent.
449fn deimplify(rt: &mut ReturnType) {
450    if let ReturnType::Type(_, ty) = rt {
451        if let Type::ImplTrait(ref tit) = &**ty {
452            let needs_pin = tit.bounds
453                .iter()
454                .any(|tpb| {
455                    if let TypeParamBound::Trait(tb) = tpb {
456                        if let Some(seg) = tb.path.segments.last() {
457                            seg.ident == "Future" || seg.ident == "Stream"
458                        } else {
459                            // It might still be a Future, but we can't guess
460                            // what names it might be imported under.  Too bad.
461                            false
462                        }
463                    } else {
464                        false
465                    }
466                });
467            let bounds = &tit.bounds;
468            if needs_pin {
469                *ty = parse2(quote!(::std::pin::Pin<Box<dyn #bounds>>)).unwrap();
470            } else {
471                *ty = parse2(quote!(Box<dyn #bounds>)).unwrap();
472            }
473        }
474    }
475}
476
477/// Remove any generics that place constraints on Self.
478fn dewhereselfify(generics: &mut Generics) {
479    if let Some(ref mut wc) = &mut generics.where_clause {
480        let new_predicates = wc.predicates.iter()
481            .filter(|wp| match wp {
482                WherePredicate::Type(pt) => {
483                    pt.bounded_ty != parse2(quote!(Self)).unwrap()
484                },
485                _ => true
486            }).cloned()
487            .collect::<Punctuated<WherePredicate, Token![,]>>();
488        wc.predicates = new_predicates;
489    }
490    if generics.where_clause.as_ref()
491        .map(|wc| wc.predicates.is_empty())
492        .unwrap_or(false)
493    {
494        generics.where_clause = None;
495    }
496}
497
498/// Remove any mutability qualifiers from a method's argument list
499fn demutify(inputs: &mut Punctuated<FnArg, token::Comma>) {
500    for arg in inputs.iter_mut() {
501        match arg {
502            FnArg::Receiver(r) => if r.reference.is_none() {
503                r.mutability = None
504            },
505            FnArg::Typed(pt) => demutify_arg(pt),
506        }
507    }
508}
509
510/// Remove any "mut" from a method argument's binding.
511fn demutify_arg(arg: &mut PatType) {
512    match *arg.pat {
513        Pat::Wild(_) => {
514            compile_error(arg.span(),
515                "Mocked methods must have named arguments");
516        },
517        Pat::Ident(ref mut pat_ident) => {
518            if let Some(r) = &pat_ident.by_ref {
519                compile_error(r.span(),
520                    "Mockall does not support by-reference argument bindings");
521            }
522            if let Some((_at, subpat)) = &pat_ident.subpat {
523                compile_error(subpat.span(),
524                    "Mockall does not support subpattern bindings");
525            }
526            pat_ident.mutability = None;
527        },
528        _ => {
529            compile_error(arg.span(), "Unsupported argument type");
530        }
531    };
532}
533
534fn deselfify_path(path: &mut Path, actual: &Ident, generics: &Generics) {
535    for seg in path.segments.iter_mut() {
536        if seg.ident == "Self" {
537            seg.ident = actual.clone();
538            if let PathArguments::None = seg.arguments {
539                if !generics.params.is_empty() {
540                    let args = generics.params.iter()
541                        .map(|gp| {
542                            match gp {
543                                GenericParam::Type(tp) => {
544                                    let ident = tp.ident.clone();
545                                    GenericArgument::Type(
546                                        Type::Path(
547                                            TypePath {
548                                                qself: None,
549                                                path: Path::from(ident)
550                                            }
551                                        )
552                                    )
553                                },
554                                GenericParam::Lifetime(ld) =>{
555                                    GenericArgument::Lifetime(
556                                        ld.lifetime.clone()
557                                    )
558                                }
559                                _ => unimplemented!(),
560                            }
561                        }).collect::<Punctuated<_, _>>();
562                    seg.arguments = PathArguments::AngleBracketed(
563                        AngleBracketedGenericArguments {
564                            colon2_token: None,
565                            lt_token: generics.lt_token.unwrap(),
566                            args,
567                            gt_token: generics.gt_token.unwrap(),
568                        }
569                    );
570                }
571            } else {
572                compile_error(seg.arguments.span(),
573                    "Type arguments after Self are unexpected");
574            }
575        }
576        if let PathArguments::AngleBracketed(abga) = &mut seg.arguments
577        {
578            for arg in abga.args.iter_mut() {
579                match arg {
580                    GenericArgument::Type(ty) =>
581                        deselfify(ty, actual, generics),
582                    GenericArgument::AssocType(at) =>
583                        deselfify(&mut at.ty, actual, generics),
584                    _ => /* Nothing to do */(),
585                }
586            }
587        }
588    }
589}
590
591/// Replace any references to `Self` in `literal_type` with `actual`.
592/// `generics` is the Generics field of the parent struct.  Useful for
593/// constructor methods.
594fn deselfify(literal_type: &mut Type, actual: &Ident, generics: &Generics) {
595    match literal_type {
596        Type::Slice(s) => {
597            deselfify(s.elem.as_mut(), actual, generics);
598        },
599        Type::Array(a) => {
600            deselfify(a.elem.as_mut(), actual, generics);
601        },
602        Type::Ptr(p) => {
603            deselfify(p.elem.as_mut(), actual, generics);
604        },
605        Type::Reference(r) => {
606            deselfify(r.elem.as_mut(), actual, generics);
607        },
608        Type::Tuple(tuple) => {
609            for elem in tuple.elems.iter_mut() {
610                deselfify(elem, actual, generics);
611            }
612        }
613        Type::Path(type_path) => {
614            if let Some(ref mut qself) = type_path.qself {
615                deselfify(qself.ty.as_mut(), actual, generics);
616            }
617            deselfify_path(&mut type_path.path, actual, generics);
618        },
619        Type::Paren(p) => {
620            deselfify(p.elem.as_mut(), actual, generics);
621        },
622        Type::Group(g) => {
623            deselfify(g.elem.as_mut(), actual, generics);
624        },
625        Type::Macro(_) | Type::Verbatim(_) => {
626            compile_error(literal_type.span(),
627                "mockall_derive does not support this type as a return argument");
628        },
629        Type::TraitObject(tto) => {
630            // Change types like `dyn Self` into `dyn MockXXX`.
631            for bound in tto.bounds.iter_mut() {
632                if let TypeParamBound::Trait(t) = bound {
633                    deselfify_path(&mut t.path, actual, generics);
634                }
635            }
636        },
637        Type::ImplTrait(_) => {
638            /* Should've already been flagged as a compile_error */
639        },
640        Type::BareFn(_) => {
641            /* Bare functions can't have Self arguments.  Nothing to do */
642        },
643        Type::Infer(_) | Type::Never(_) =>
644        {
645            /* Nothing to do */
646        },
647        _ => compile_error(literal_type.span(), "Unsupported type"),
648    }
649}
650
651/// Change any `Self` in a method's arguments' types with `actual`.
652/// `generics` is the Generics field of the parent struct.
653fn deselfify_args(
654    args: &mut Punctuated<FnArg, Token![,]>,
655    actual: &Ident,
656    generics: &Generics)
657{
658    for arg in args.iter_mut() {
659        match arg {
660            FnArg::Receiver(r) => {
661                if r.colon_token.is_some() {
662                    deselfify(r.ty.as_mut(), actual, generics)
663                }
664            },
665            FnArg::Typed(pt) => deselfify(pt.ty.as_mut(), actual, generics)
666        }
667    }
668}
669
670fn find_ident_from_path(path: &Path) -> (Ident, PathArguments) {
671    if path.segments.len() != 1 {
672        compile_error(path.span(),
673            "mockall_derive only supports structs defined in the current module");
674        return (Ident::new("", path.span()), PathArguments::None);
675    }
676    let last_seg = path.segments.last().unwrap();
677    (last_seg.ident.clone(), last_seg.arguments.clone())
678}
679
680fn find_lifetimes_in_tpb(bound: &TypeParamBound) -> HashSet<Lifetime> {
681    let mut ret = HashSet::default();
682    match bound {
683        TypeParamBound::Lifetime(lt) => {
684            ret.insert(lt.clone());
685        },
686        TypeParamBound::Trait(tb) => {
687            ret.extend(find_lifetimes_in_path(&tb.path));
688        },
689        _ => ()
690    };
691    ret
692}
693
694fn find_lifetimes_in_path(path: &Path) -> HashSet<Lifetime> {
695    let mut ret = HashSet::default();
696    for seg in path.segments.iter() {
697        if let PathArguments::AngleBracketed(abga) = &seg.arguments {
698            for arg in abga.args.iter() {
699                match arg {
700                    GenericArgument::Lifetime(lt) => {
701                        ret.insert(lt.clone());
702                    },
703                    GenericArgument::Type(ty) => {
704                        ret.extend(find_lifetimes(ty));
705                    },
706                    GenericArgument::AssocType(at) => {
707                        ret.extend(find_lifetimes(&at.ty));
708                    },
709                    GenericArgument::Constraint(c) => {
710                        for bound in c.bounds.iter() {
711                            ret.extend(find_lifetimes_in_tpb(bound));
712                        }
713                    },
714                    GenericArgument::Const(_) => (),
715                    _ => ()
716                }
717            }
718        }
719    }
720    ret
721}
722
723fn find_lifetimes(ty: &Type) -> HashSet<Lifetime> {
724    match ty {
725        Type::Array(ta) => find_lifetimes(ta.elem.as_ref()),
726        Type::Group(tg) => find_lifetimes(tg.elem.as_ref()),
727        Type::Infer(_ti) => HashSet::default(),
728        Type::Never(_tn) => HashSet::default(),
729        Type::Paren(tp) => find_lifetimes(tp.elem.as_ref()),
730        Type::Path(tp) => {
731            let mut ret = find_lifetimes_in_path(&tp.path);
732            if let Some(qs) = &tp.qself {
733                ret.extend(find_lifetimes(qs.ty.as_ref()));
734            }
735            ret
736        },
737        Type::Ptr(tp) => find_lifetimes(tp.elem.as_ref()),
738        Type::Reference(tr) => {
739            let mut ret = find_lifetimes(tr.elem.as_ref());
740            if let Some(lt) = &tr.lifetime {
741                ret.insert(lt.clone());
742            }
743            ret
744        },
745        Type::Slice(ts) => find_lifetimes(ts.elem.as_ref()),
746        Type::TraitObject(tto) => {
747            let mut ret = HashSet::default();
748            for bound in tto.bounds.iter() {
749                ret.extend(find_lifetimes_in_tpb(bound));
750            }
751            ret
752        }
753        Type::Tuple(tt) => {
754            let mut ret = HashSet::default();
755            for ty in tt.elems.iter() {
756                ret.extend(find_lifetimes(ty));
757            }
758            ret
759        },
760        Type::ImplTrait(tit) => {
761            let mut ret = HashSet::default();
762            for tpb in tit.bounds.iter() {
763                ret.extend(find_lifetimes_in_tpb(tpb));
764            }
765            ret
766        },
767        _ => {
768            compile_error(ty.span(), "unsupported type in this context");
769            HashSet::default()
770        }
771    }
772}
773
774/// If the mocked signature contains any variadic parts, they need a pattern.
775/// The pattern is required for the signature of the mock function, a Rust
776/// function, even though it's not required in the signature of the foreign
777/// function.
778fn fix_elipses(sig: &mut Signature) {
779    if let Some(variadic) = &mut sig.variadic {
780        if variadic.pat.is_none() {
781            let pat = PatIdent {
782                attrs: vec![],
783                by_ref: None,
784                mutability: None,
785                ident: format_ident!("_"),
786                subpat: None
787            };
788            let colon = Token![:](variadic.span());
789            variadic.pat = Some((Box::new(Pat::Ident(pat)), colon));
790        }
791    }
792}
793
794struct AttrFormatter<'a>{
795    attrs: &'a [Attribute],
796    async_trait: bool,
797    trait_variant: bool,
798    doc: bool,
799    must_use: bool,
800}
801
802impl<'a> AttrFormatter<'a> {
803    fn new(attrs: &'a [Attribute]) -> AttrFormatter<'a> {
804        Self {
805            attrs,
806            async_trait: true,
807            trait_variant: false,
808            doc: true,
809            must_use: false,
810        }
811    }
812
813    fn async_trait(&mut self, allowed: bool) -> &mut Self {
814        self.async_trait = allowed;
815        self
816    }
817
818    fn trait_variant(&mut self, allowed: bool) -> &mut Self {
819        self.trait_variant = allowed;
820        self
821    }
822
823    fn doc(&mut self, allowed: bool) -> &mut Self {
824        self.doc = allowed;
825        self
826    }
827
828    fn must_use(&mut self, allowed: bool) -> &mut Self {
829        self.must_use = allowed;
830        self
831    }
832
833    // XXX This logic requires that attributes are imported with their
834    // standard names.
835    #[allow(clippy::needless_bool)]
836    #[allow(clippy::if_same_then_else)]
837    fn format(&mut self) -> Vec<Attribute> {
838        self.attrs.iter()
839            .filter(|attr| {
840                let i = attr.path().segments.last().map(|ps| &ps.ident);
841                if is_concretize(attr) {
842                    // Internally used attribute.  Never emit.
843                    false
844                } else if i.is_none() {
845                    false
846                } else if *i.as_ref().unwrap() == "derive" {
847                    // We can't usefully derive any traits.  Ignore them
848                    false
849                } else if *i.as_ref().unwrap() == "doc" {
850                    self.doc
851                } else if *i.as_ref().unwrap() == "async_trait" {
852                    self.async_trait
853                } else if *i.as_ref().unwrap() == "make" {
854                    self.trait_variant
855                } else if *i.as_ref().unwrap() == "expect" {
856                    // This probably means that there's a lint that needs to be
857                    // surpressed for the real code, but not for the mock code.
858                    // Skip it.
859                    false
860                } else if *i.as_ref().unwrap() == "inline" {
861                    // No need to inline mock functions.
862                    false
863                } else if *i.as_ref().unwrap() == "cold" {
864                    // No need for such hints on mock functions.
865                    false
866                } else if *i.as_ref().unwrap() == "instrument" {
867                    // We can't usefully instrument the mock method, so just
868                    // ignore this attribute.
869                    // https://docs.rs/tracing/0.1.23/tracing/attr.instrument.html
870                    false
871                } else if *i.as_ref().unwrap() == "link_name" {
872                    // This shows up sometimes when mocking ffi functions.  We
873                    // must not emit it on anything that isn't an ffi definition
874                    false
875                } else if *i.as_ref().unwrap() == "must_use" {
876                    self.must_use
877                } else if *i.as_ref().unwrap() == "auto_enum" {
878                    // Ignore auto_enum, because we transform the return value
879                    // into a trait object.
880                    false
881                } else {
882                    true
883                }
884            }).cloned()
885            .collect()
886    }
887}
888
889/// Determine if this Pat is any kind of `self` binding
890fn pat_is_self(pat: &Pat) -> bool {
891    if let Pat::Ident(pi) = pat {
892        pi.ident == "self"
893    } else {
894        false
895    }
896}
897
898/// Add `levels` `super::` to the path.  Return the number of levels added.
899fn supersuperfy_path(path: &mut Path, levels: usize) -> usize {
900    if let Some(t) = path.segments.last_mut() {
901        match &mut t.arguments {
902            PathArguments::None => (),
903            PathArguments::AngleBracketed(ref mut abga) => {
904                for arg in abga.args.iter_mut() {
905                    match arg {
906                        GenericArgument::Type(ref mut ty) => {
907                            *ty = supersuperfy(ty, levels);
908                        },
909                        GenericArgument::AssocType(ref mut at) => {
910                            at.ty = supersuperfy(&at.ty, levels);
911                        },
912                        GenericArgument::Constraint(ref mut constraint) => {
913                            supersuperfy_bounds(&mut constraint.bounds, levels);
914                        },
915                        _ => (),
916                    }
917                }
918            },
919            PathArguments::Parenthesized(ref mut pga) => {
920                for input in pga.inputs.iter_mut() {
921                    *input = supersuperfy(input, levels);
922                }
923                if let ReturnType::Type(_, ref mut ty) = pga.output {
924                    **ty = supersuperfy(ty, levels);
925                }
926            },
927        }
928    }
929    if let Some(t) = path.segments.first() {
930        if t.ident == "super" {
931            let mut ident = format_ident!("super");
932            ident.set_span(path.segments.span());
933            let ps = PathSegment {
934                ident,
935                arguments: PathArguments::None
936            };
937            for _ in 0..levels {
938                path.segments.insert(0, ps.clone());
939            }
940            levels
941        } else {
942            0
943        }
944    } else {
945        0
946    }
947}
948
949/// Replace any references to `super::X` in `original` with `super::super::X`.
950fn supersuperfy(original: &Type, levels: usize) -> Type {
951    let mut output = original.clone();
952    fn recurse(t: &mut Type, levels: usize) {
953        match t {
954            Type::Slice(s) => {
955                recurse(s.elem.as_mut(), levels);
956            },
957            Type::Array(a) => {
958                recurse(a.elem.as_mut(), levels);
959            },
960            Type::Ptr(p) => {
961                recurse(p.elem.as_mut(), levels);
962            },
963            Type::Reference(r) => {
964                recurse(r.elem.as_mut(), levels);
965            },
966            Type::BareFn(bfn) => {
967                if let ReturnType::Type(_, ref mut bt) = bfn.output {
968                    recurse(bt.as_mut(), levels);
969                }
970                for input in bfn.inputs.iter_mut() {
971                    recurse(&mut input.ty, levels);
972                }
973            },
974            Type::Tuple(tuple) => {
975                for elem in tuple.elems.iter_mut() {
976                    recurse(elem, levels);
977                }
978            }
979            Type::Path(type_path) => {
980                let added = supersuperfy_path(&mut type_path.path, levels);
981                if let Some(ref mut qself) = type_path.qself {
982                    recurse(qself.ty.as_mut(), levels);
983                    qself.position += added;
984                }
985            },
986            Type::Paren(p) => {
987                recurse(p.elem.as_mut(), levels);
988            },
989            Type::Group(g) => {
990                recurse(g.elem.as_mut(), levels);
991            },
992            Type::Macro(_) | Type::Verbatim(_) => {
993                compile_error(t.span(),
994                    "mockall_derive does not support this type in this position");
995            },
996            Type::TraitObject(tto) => {
997                for bound in tto.bounds.iter_mut() {
998                    if let TypeParamBound::Trait(tb) = bound {
999                        supersuperfy_path(&mut tb.path, levels);
1000                    }
1001                }
1002            },
1003            Type::ImplTrait(_) => {
1004                /* Should've already been flagged as a compile error */
1005            },
1006            Type::Infer(_) | Type::Never(_) =>
1007            {
1008                /* Nothing to do */
1009            },
1010            _ => compile_error(t.span(), "Unsupported type"),
1011        }
1012    }
1013    recurse(&mut output, levels);
1014    output
1015}
1016
1017fn supersuperfy_generics(generics: &mut Generics, levels: usize) {
1018    for param in generics.params.iter_mut() {
1019        if let GenericParam::Type(tp) = param {
1020            supersuperfy_bounds(&mut tp.bounds, levels);
1021            if let Some(ty) = tp.default.as_mut() {
1022                *ty = supersuperfy(ty, levels);
1023            }
1024        }
1025    }
1026    if let Some(wc) = generics.where_clause.as_mut() {
1027        for wp in wc.predicates.iter_mut() {
1028            if let WherePredicate::Type(pt) = wp {
1029                pt.bounded_ty = supersuperfy(&pt.bounded_ty, levels);
1030                supersuperfy_bounds(&mut pt.bounds, levels);
1031            }
1032        }
1033    }
1034}
1035
1036fn supersuperfy_bounds(
1037    bounds: &mut Punctuated<TypeParamBound, Token![+]>,
1038    levels: usize)
1039{
1040    for bound in bounds.iter_mut() {
1041        if let TypeParamBound::Trait(tb) = bound {
1042            supersuperfy_path(&mut tb.path, levels);
1043        }
1044    }
1045}
1046
1047/// Generate a suitable mockall::Key generic paramter from any Generics
1048fn gen_keyid(g: &Generics) -> impl ToTokens {
1049    match g.params.len() {
1050        0 => quote!(<()>),
1051        1 if matches!(g.params.first(), Some(GenericParam::Type(_))) => {
1052            let (_, tg, _) = g.split_for_impl();
1053            quote!(#tg)
1054        },
1055        _ => {
1056            // Rust doesn't support variadic Generics, so mockall::Key must
1057            // always have exactly one generic type.  We need to add parentheses
1058            // around whatever type generics the caller passes.
1059            let tps = g.type_params()
1060            .map(|tp| tp.ident.clone())
1061            .collect::<Punctuated::<Ident, Token![,]>>();
1062            quote!(<(#tps)>)
1063        }
1064    }
1065}
1066
1067/// Generate a mock identifier from the regular one: eg "Foo" => "MockFoo"
1068fn gen_mock_ident(ident: &Ident) -> Ident {
1069    format_ident!("Mock{}", ident)
1070}
1071
1072/// Generate an identifier for the mock struct's private module: eg "Foo" =>
1073/// "__mock_Foo"
1074fn gen_mod_ident(struct_: &Ident, trait_: Option<&Ident>) -> Ident {
1075    if let Some(t) = trait_ {
1076        format_ident!("__mock_{struct_}_{}", t)
1077    } else {
1078        format_ident!("__mock_{struct_}")
1079    }
1080}
1081
1082/// Combine two Generics structs, producing a new one that has the union of
1083/// their parameters.
1084fn merge_generics(x: &Generics, y: &Generics) -> Generics {
1085    /// Compare only the identifiers of two GenericParams
1086    fn cmp_gp_idents(x: &GenericParam, y: &GenericParam) -> bool {
1087        use GenericParam::*;
1088
1089        match (x, y) {
1090            (Type(xtp), Type(ytp)) => xtp.ident == ytp.ident,
1091            (Lifetime(xld), Lifetime(yld)) => xld.lifetime == yld.lifetime,
1092            (Const(xc), Const(yc)) => xc.ident == yc.ident,
1093            _ => false
1094        }
1095    }
1096
1097    /// Compare only the identifiers of two WherePredicates
1098    fn cmp_wp_idents(x: &WherePredicate, y: &WherePredicate) -> bool {
1099        use WherePredicate::*;
1100
1101        match (x, y) {
1102            (Type(xpt), Type(ypt)) => xpt.bounded_ty == ypt.bounded_ty,
1103            (Lifetime(xpl), Lifetime(ypl)) => xpl.lifetime == ypl.lifetime,
1104            _ => false
1105        }
1106    }
1107
1108    let mut out = if x.lt_token.is_none() && x.where_clause.is_none() {
1109        y.clone()
1110    } else if y.lt_token.is_none() && y.where_clause.is_none() {
1111        x.clone()
1112    } else {
1113        let mut out = x.clone();
1114        // First merge the params
1115        'outer_param: for yparam in y.params.iter() {
1116            // XXX: O(n^2) loop
1117            for outparam in out.params.iter_mut() {
1118                if cmp_gp_idents(outparam, yparam) {
1119                    if let (GenericParam::Type(ref mut ot),
1120                            GenericParam::Type(yt)) = (outparam, yparam)
1121                    {
1122                        ot.attrs.extend(yt.attrs.iter().cloned());
1123                        ot.colon_token = ot.colon_token.or(yt.colon_token);
1124                        ot.eq_token = ot.eq_token.or(yt.eq_token);
1125                        if ot.default.is_none() {
1126                            ot.default.clone_from(&yt.default);
1127                        }
1128                        // XXX this might result in duplicate bounds
1129                        if ot.bounds != yt.bounds {
1130                            ot.bounds.extend(yt.bounds.iter().cloned());
1131                        }
1132                    }
1133                    continue 'outer_param;
1134                }
1135            }
1136            out.params.push(yparam.clone());
1137        }
1138        out
1139    };
1140    // Then merge the where clauses
1141    match (&mut out.where_clause, &y.where_clause) {
1142        (_, None) => (),
1143        (None, Some(wc)) => out.where_clause = Some(wc.clone()),
1144        (Some(out_wc), Some(y_wc)) => {
1145            'outer_wc: for ypred in y_wc.predicates.iter() {
1146                // XXX: O(n^2) loop
1147                for outpred in out_wc.predicates.iter_mut() {
1148                    if cmp_wp_idents(outpred, ypred) {
1149                        if let (WherePredicate::Type(ref mut ot),
1150                                WherePredicate::Type(yt)) = (outpred, ypred)
1151                        {
1152                            match (&mut ot.lifetimes, &yt.lifetimes) {
1153                                (_, None) => (),
1154                                (None, Some(bl)) =>
1155                                    ot.lifetimes = Some(bl.clone()),
1156                                (Some(obl), Some(ybl)) =>
1157                                    // XXX: might result in duplicates
1158                                    obl.lifetimes.extend(
1159                                        ybl.lifetimes.iter().cloned()),
1160                            };
1161                            // XXX: might result in duplicate bounds
1162                            if ot.bounds != yt.bounds {
1163                                ot.bounds.extend(yt.bounds.iter().cloned())
1164                            }
1165                        }
1166                        continue 'outer_wc;
1167                    }
1168                }
1169                out_wc.predicates.push(ypred.clone());
1170            }
1171        }
1172    }
1173    out
1174}
1175
1176fn lifetimes_to_generic_params(lv: &Punctuated<LifetimeParam, Token![,]>)
1177    -> Punctuated<GenericParam, Token![,]>
1178{
1179    lv.iter()
1180        .map(|lt| GenericParam::Lifetime(lt.clone()))
1181        .collect()
1182}
1183
1184/// Transform a Vec of lifetimes into a Generics
1185fn lifetimes_to_generics(lv: &Punctuated<LifetimeParam, Token![,]>)-> Generics {
1186    if lv.is_empty() {
1187            Generics::default()
1188    } else {
1189        let params = lifetimes_to_generic_params(lv);
1190        Generics {
1191            lt_token: Some(Token![<](lv[0].span())),
1192            gt_token: Some(Token![>](lv[0].span())),
1193            params,
1194            where_clause: None
1195        }
1196    }
1197}
1198
1199/// Split a generics list into three: one for type generics and where predicates
1200/// that relate to the signature, one for lifetimes that relate to the arguments
1201/// only, and one for lifetimes that relate to the return type only.
1202fn split_lifetimes(
1203    generics: Generics,
1204    args: &Punctuated<FnArg, Token![,]>,
1205    rt: &ReturnType)
1206    -> (Generics,
1207        Punctuated<LifetimeParam, token::Comma>,
1208        Punctuated<LifetimeParam, token::Comma>)
1209{
1210    if generics.lt_token.is_none() {
1211        return (generics, Default::default(), Default::default());
1212    }
1213
1214    // Check which types and lifetimes are referenced by the arguments
1215    let mut alts = HashSet::<Lifetime>::default();
1216    let mut rlts = HashSet::<Lifetime>::default();
1217    for arg in args {
1218        match arg {
1219            FnArg::Receiver(r) => {
1220                if let Some((_, Some(lt))) = &r.reference {
1221                    alts.insert(lt.clone());
1222                }
1223            },
1224            FnArg::Typed(pt) => {
1225                alts.extend(find_lifetimes(pt.ty.as_ref()));
1226            },
1227        };
1228    };
1229
1230    if let ReturnType::Type(_, ty) = rt {
1231        rlts.extend(find_lifetimes(ty));
1232    }
1233
1234    let mut tv = Punctuated::new();
1235    let mut alv = Punctuated::new();
1236    let mut rlv = Punctuated::new();
1237    for p in generics.params.into_iter() {
1238        match p {
1239            GenericParam::Lifetime(ltd) if rlts.contains(&ltd.lifetime) =>
1240                rlv.push(ltd),
1241            GenericParam::Lifetime(ltd) if alts.contains(&ltd.lifetime) =>
1242                alv.push(ltd),
1243            GenericParam::Lifetime(_) => {
1244                // Probably a lifetime parameter from the impl block that isn't
1245                // used by this particular method
1246            },
1247            GenericParam::Type(_) | GenericParam::Const(_) => tv.push(p),
1248        }
1249    }
1250
1251    let tg = if tv.is_empty() {
1252        Generics::default()
1253    } else {
1254        Generics {
1255            lt_token: generics.lt_token,
1256            gt_token: generics.gt_token,
1257            params: tv,
1258            where_clause: generics.where_clause
1259        }
1260    };
1261
1262    (tg, alv, rlv)
1263}
1264
1265/// Return the visibility that should be used for expectation!, given the
1266/// original method's visibility.
1267///
1268/// # Arguments
1269/// - `vis`:    Original visibility of the item
1270/// - `levels`: How many modules will the mock item be nested in?
1271fn expectation_visibility(vis: &Visibility, levels: usize)
1272    -> Visibility
1273{
1274    if levels == 0 {
1275        return vis.clone();
1276    }
1277
1278    let in_token = Token![in](vis.span());
1279    let super_token = Token![super](vis.span());
1280    match vis {
1281        Visibility::Inherited => {
1282            // Private items need pub(in super::[...]) for each level
1283            let mut path = Path::from(super_token);
1284            for _ in 1..levels {
1285                path.segments.push(super_token.into());
1286            }
1287            Visibility::Restricted(VisRestricted{
1288                pub_token: Token![pub](vis.span()),
1289                paren_token: token::Paren::default(),
1290                in_token: Some(in_token),
1291                path: Box::new(path)
1292            })
1293        },
1294        Visibility::Restricted(vr) => {
1295            // crate => don't change
1296            // in crate::* => don't change
1297            // super => in super::super::super
1298            // self => in super::super
1299            // in anything_else => super::super::anything_else
1300            if vr.path.segments.first().unwrap().ident == "crate" {
1301                Visibility::Restricted(vr.clone())
1302            } else {
1303                let mut out = vr.clone();
1304                out.in_token = Some(in_token);
1305                for _ in 0..levels {
1306                    out.path.segments.insert(0, super_token.into());
1307                }
1308                Visibility::Restricted(out)
1309            }
1310        },
1311        _ => vis.clone()
1312    }
1313}
1314
1315fn staticize(generics: &Generics) -> Generics {
1316    let mut ret = generics.clone();
1317    for lt in ret.lifetimes_mut() {
1318        lt.lifetime = Lifetime::new("'static", Span::call_site());
1319    };
1320    ret
1321}
1322
1323fn mock_it<M: Into<MockableItem>>(inputs: M) -> TokenStream
1324{
1325    let mockable: MockableItem = inputs.into();
1326    let mock = MockItem::from(mockable);
1327    let ts = mock.into_token_stream();
1328    if env::var("MOCKALL_DEBUG").is_ok() {
1329        println!("{ts}");
1330    }
1331    ts
1332}
1333
1334fn do_mock_once(input: TokenStream) -> TokenStream
1335{
1336    let item: MockableStruct = match syn::parse2(input) {
1337        Ok(mock) => mock,
1338        Err(err) => {
1339            return err.to_compile_error();
1340        }
1341    };
1342    mock_it(item)
1343}
1344
1345fn do_mock(input: TokenStream) -> TokenStream
1346{
1347    cfg_if! {
1348        if #[cfg(reprocheck)] {
1349            let ts_a = do_mock_once(input.clone());
1350            let ts_b = do_mock_once(input.clone());
1351            assert_eq!(ts_a.to_string(), ts_b.to_string());
1352        }
1353    }
1354    do_mock_once(input)
1355}
1356
1357#[proc_macro_attribute]
1358pub fn concretize(
1359    _attrs: proc_macro::TokenStream,
1360    input: proc_macro::TokenStream) -> proc_macro::TokenStream
1361{
1362    // Do nothing.  This "attribute" is processed as text by the real proc
1363    // macros.
1364    input
1365}
1366
1367#[proc_macro]
1368pub fn mock(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1369    do_mock(input.into()).into()
1370}
1371
1372#[proc_macro_attribute]
1373pub fn automock(attrs: proc_macro::TokenStream, input: proc_macro::TokenStream)
1374    -> proc_macro::TokenStream
1375{
1376    let attrs: proc_macro2::TokenStream = attrs.into();
1377    let input: proc_macro2::TokenStream = input.into();
1378    do_automock(attrs, input).into()
1379}
1380
1381fn do_automock_once(attrs: TokenStream, input: TokenStream) -> TokenStream {
1382    let mut output = input.clone();
1383    let attrs: Attrs = match parse2(attrs) {
1384        Ok(a) => a,
1385        Err(err) => {
1386            return err.to_compile_error();
1387        }
1388    };
1389    let item: Item = match parse2(input) {
1390        Ok(item) => item,
1391        Err(err) => {
1392            return err.to_compile_error();
1393        }
1394    };
1395    output.extend(mock_it((attrs, item)));
1396    output
1397}
1398
1399fn do_automock(attrs: TokenStream, input: TokenStream) -> TokenStream {
1400    cfg_if! {
1401        if #[cfg(reprocheck)] {
1402            let ts_a = do_automock_once(attrs.clone(), input.clone());
1403            let ts_b = do_automock_once(attrs.clone(), input.clone());
1404            assert_eq!(ts_a.to_string(), ts_b.to_string());
1405        }
1406    }
1407    do_automock_once(attrs, input)
1408}
1409
1410#[cfg(test)]
1411mod t {
1412    use super::*;
1413
1414fn assert_contains(output: &str, tokens: TokenStream) {
1415    let s = tokens.to_string();
1416    assert!(output.contains(&s), "output does not contain {:?}", &s);
1417}
1418
1419fn assert_not_contains(output: &str, tokens: TokenStream) {
1420    let s = tokens.to_string();
1421    assert!(!output.contains(&s), "output contains {:?}", &s);
1422}
1423
1424/// Various tests for overall code generation that are hard or impossible to
1425/// write as integration tests
1426mod mock {
1427    use std::str::FromStr;
1428    use super::super::*;
1429    use super::*;
1430
1431    #[test]
1432    fn inherent_method_visibility() {
1433        let code = "
1434            Foo {
1435                fn foo(&self);
1436                pub fn bar(&self);
1437                pub(crate) fn baz(&self);
1438                pub(super) fn bean(&self);
1439                pub(in crate::outer) fn boom(&self);
1440            }
1441        ";
1442        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1443        let output = do_mock(ts).to_string();
1444        assert_not_contains(&output, quote!(pub fn foo));
1445        assert!(!output.contains(") fn foo"));
1446        assert_contains(&output, quote!(pub fn bar));
1447        assert_contains(&output, quote!(pub(crate) fn baz));
1448        assert_contains(&output, quote!(pub(super) fn bean));
1449        assert_contains(&output, quote!(pub(in crate::outer) fn boom));
1450
1451        assert_not_contains(&output, quote!(pub fn expect_foo));
1452        assert!(!output.contains("pub fn expect_foo"));
1453        assert!(!output.contains(") fn expect_foo"));
1454        assert_contains(&output, quote!(pub fn expect_bar));
1455        assert_contains(&output, quote!(pub(crate) fn expect_baz));
1456        assert_contains(&output, quote!(pub(super) fn expect_bean));
1457        assert_contains(&output, quote!(pub(in crate::outer) fn expect_boom));
1458    }
1459
1460    #[test]
1461    fn must_use_struct() {
1462        let code = "
1463            #[must_use]
1464            pub Foo {}
1465        ";
1466        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1467        let output = do_mock(ts).to_string();
1468        assert_contains(&output, quote!(#[must_use] pub struct MockFoo));
1469    }
1470
1471    #[test]
1472    fn specific_impl() {
1473        let code = "
1474            pub Foo<T: 'static> {}
1475            impl Bar for Foo<u32> {
1476                fn bar(&self);
1477            }
1478            impl Bar for Foo<i32> {
1479                fn bar(&self);
1480            }
1481        ";
1482        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1483        let output = do_mock(ts).to_string();
1484        assert_contains(&output, quote!(impl Bar for MockFoo<u32>));
1485        assert_contains(&output, quote!(impl Bar for MockFoo<i32>));
1486        // Ensure we don't duplicate the checkpoint function
1487        assert_not_contains(&output, quote!(
1488            self.Bar_expectations.checkpoint();
1489            self.Bar_expectations.checkpoint();
1490        ));
1491        // The expect methods should return specific types, not generic ones
1492        assert_contains(&output, quote!(
1493            pub fn expect_bar(&mut self) -> &mut __mock_MockFoo_Bar::__bar::Expectation<u32>
1494        ));
1495        assert_contains(&output, quote!(
1496            pub fn expect_bar(&mut self) -> &mut __mock_MockFoo_Bar::__bar::Expectation<i32>
1497        ));
1498    }
1499}
1500
1501/// Various tests for overall code generation that are hard or impossible to
1502/// write as integration tests
1503mod automock {
1504    use std::str::FromStr;
1505    use super::super::*;
1506    use super::*;
1507
1508    #[test]
1509    fn doc_comments() {
1510        let code = "
1511            mod foo {
1512                /// Function docs
1513                pub fn bar() { unimplemented!() }
1514            }
1515        ";
1516        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1517        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1518        let output = do_automock(attrs_ts, ts).to_string();
1519        assert_contains(&output, quote!(#[doc=" Function docs"] pub fn bar));
1520    }
1521
1522    #[test]
1523    fn method_visibility() {
1524        let code = "
1525        impl Foo {
1526            fn foo(&self) {}
1527            pub fn bar(&self) {}
1528            pub(super) fn baz(&self) {}
1529            pub(crate) fn bang(&self) {}
1530            pub(in super::x) fn bean(&self) {}
1531        }";
1532        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1533        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1534        let output = do_automock(attrs_ts, ts).to_string();
1535        assert_not_contains(&output, quote!(pub fn foo));
1536        assert!(!output.contains(") fn foo"));
1537        assert_not_contains(&output, quote!(pub fn expect_foo));
1538        assert!(!output.contains(") fn expect_foo"));
1539        assert_contains(&output, quote!(pub fn bar));
1540        assert_contains(&output, quote!(pub fn expect_bar));
1541        assert_contains(&output, quote!(pub(super) fn baz));
1542        assert_contains(&output, quote!(pub(super) fn expect_baz));
1543        assert_contains(&output, quote!(pub ( crate ) fn bang));
1544        assert_contains(&output, quote!(pub ( crate ) fn expect_bang));
1545        assert_contains(&output, quote!(pub ( in super :: x ) fn bean));
1546        assert_contains(&output, quote!(pub ( in super :: x ) fn expect_bean));
1547    }
1548
1549    #[test]
1550    fn must_use_method() {
1551        let code = "
1552        impl Foo {
1553            #[must_use]
1554            fn foo(&self) -> i32 {42}
1555        }";
1556        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1557        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1558        let output = do_automock(attrs_ts, ts).to_string();
1559        assert_not_contains(&output, quote!(#[must_use] fn expect_foo));
1560        assert_contains(&output, quote!(#[must_use] #[allow(dead_code)] fn foo));
1561    }
1562
1563    #[test]
1564    fn must_use_static_method() {
1565        let code = "
1566        impl Foo {
1567            #[must_use]
1568            fn foo() -> i32 {42}
1569        }";
1570        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1571        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1572        let output = do_automock(attrs_ts, ts).to_string();
1573        assert_not_contains(&output, quote!(#[must_use] fn expect));
1574        assert_not_contains(&output, quote!(#[must_use] fn foo_context));
1575        assert_contains(&output, quote!(#[must_use] #[allow(dead_code)] fn foo));
1576    }
1577
1578    #[test]
1579    fn must_use_trait() {
1580        let code = "
1581        #[must_use]
1582        trait Foo {}
1583        ";
1584        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1585        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1586        let output = do_automock(attrs_ts, ts).to_string();
1587        assert_not_contains(&output, quote!(#[must_use] struct MockFoo));
1588    }
1589
1590    #[test]
1591    #[should_panic(expected = "automock does not currently support structs with elided lifetimes")]
1592    fn elided_lifetimes() {
1593        let code = "impl X<'_> {}";
1594        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1595        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1596        do_automock(attrs_ts, ts).to_string();
1597    }
1598
1599    #[test]
1600    #[should_panic(expected = "can only mock inline modules")]
1601    fn external_module() {
1602        let code = "mod foo;";
1603        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1604        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1605        do_automock(attrs_ts, ts).to_string();
1606    }
1607
1608    #[test]
1609    fn trait_visibility() {
1610        let code = "
1611        pub(super) trait Foo {}
1612        ";
1613        let attrs_ts = proc_macro2::TokenStream::from_str("").unwrap();
1614        let ts = proc_macro2::TokenStream::from_str(code).unwrap();
1615        let output = do_automock(attrs_ts, ts).to_string();
1616        assert_contains(&output, quote!(pub ( super ) struct MockFoo));
1617    }
1618}
1619
1620mod concretize_args {
1621    use super::*;
1622
1623    #[allow(clippy::needless_range_loop)] // Clippy's suggestion is worse
1624    fn check_concretize(
1625        sig: TokenStream,
1626        expected_inputs: &[TokenStream],
1627        expected_call_exprs: &[TokenStream],
1628        expected_sig_inputs: &[TokenStream])
1629    {
1630        let f: Signature = parse2(sig).unwrap();
1631        let (generics, inputs, call_exprs, altsig) = concretize_args(&f.generics, &f);
1632        assert!(generics.params.is_empty());
1633        assert_eq!(inputs.len(), expected_inputs.len());
1634        assert_eq!(call_exprs.len(), expected_call_exprs.len());
1635        for i in 0..inputs.len() {
1636            let actual = &inputs[i];
1637            let exp = &expected_inputs[i];
1638            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1639        }
1640        for i in 0..call_exprs.len() {
1641            let actual = &call_exprs[i];
1642            let exp = &expected_call_exprs[i];
1643            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1644        }
1645        for i in 0..altsig.inputs.len() {
1646            let actual = &altsig.inputs[i];
1647            let exp = &expected_sig_inputs[i];
1648            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1649        }
1650    }
1651
1652    #[test]
1653    fn bystanders() {
1654        check_concretize(
1655            quote!(fn foo<P: AsRef<Path>>(x: i32, p: P, y: &f64)),
1656            &[quote!(x: i32), quote!(p: &(dyn AsRef<Path>)), quote!(y: &f64)],
1657            &[quote!(x), quote!(&p), quote!(y)],
1658            &[quote!(x: i32), quote!(p: P), quote!(y: &f64)]
1659        );
1660    }
1661
1662    #[test]
1663    fn function_args() {
1664        check_concretize(
1665            quote!(fn foo<F1: Fn(u32) -> u32,
1666                          F2: FnMut(&mut u32) -> u32,
1667                          F3: FnOnce(u32) -> u32,
1668                          F4: Fn() + Send>(f1: F1, f2: F2, f3: F3, f4: F4)),
1669            &[quote!(f1: &(dyn Fn(u32) -> u32)),
1670              quote!(f2: &mut(dyn FnMut(&mut u32) -> u32)),
1671              quote!(f3: &(dyn FnOnce(u32) -> u32)),
1672              quote!(f4: &(dyn Fn() + Send))],
1673            &[quote!(&f1), quote!(&mut f2), quote!(&f3), quote!(&f4)],
1674            &[quote!(f1: F1), quote!(mut f2: F2), quote!(f3: F3), quote!(f4: F4)]
1675        );
1676    }
1677
1678    #[test]
1679    fn multi_bounds() {
1680        check_concretize(
1681            quote!(fn foo<P: AsRef<String> + AsMut<String>>(p: P)),
1682            &[quote!(p: &(dyn AsRef<String> + AsMut<String>))],
1683            &[quote!(&p)],
1684            &[quote!(p: P)],
1685        );
1686    }
1687
1688    #[test]
1689    fn mutable_reference_arg() {
1690        check_concretize(
1691            quote!(fn foo<P: AsMut<Path>>(p: &mut P)),
1692            &[quote!(p: &mut (dyn AsMut<Path>))],
1693            &[quote!(p)],
1694            &[quote!(p: &mut P)],
1695        );
1696    }
1697
1698    #[test]
1699    fn mutable_reference_multi_bounds() {
1700        check_concretize(
1701            quote!(fn foo<P: AsRef<String> + AsMut<String>>(p: &mut P)),
1702            &[quote!(p: &mut (dyn AsRef<String> + AsMut<String>))],
1703            &[quote!(p)],
1704            &[quote!(p: &mut P)]
1705        );
1706    }
1707
1708    #[test]
1709    fn reference_arg() {
1710        check_concretize(
1711            quote!(fn foo<P: AsRef<Path>>(p: &P)),
1712            &[quote!(p: &(dyn AsRef<Path>))],
1713            &[quote!(p)],
1714            &[quote!(p: &P)]
1715        );
1716    }
1717
1718    #[test]
1719    fn simple() {
1720        check_concretize(
1721            quote!(fn foo<P: AsRef<Path>>(p: P)),
1722            &[quote!(p: &(dyn AsRef<Path>))],
1723            &[quote!(&p)],
1724            &[quote!(p: P)],
1725        );
1726    }
1727
1728    #[test]
1729    fn slice() {
1730        check_concretize(
1731            quote!(fn foo<P: AsRef<Path>>(p: &[P])),
1732            &[quote!(p: &[&(dyn AsRef<Path>)])],
1733            &[quote!(&(0..p.len()).map(|__mockall_i| &p[__mockall_i] as &(dyn AsRef<Path>)).collect::<Vec<_>>())],
1734            &[quote!(p: &[P])]
1735        );
1736    }
1737
1738    #[test]
1739    fn slice_with_multi_bounds() {
1740        check_concretize(
1741            quote!(fn foo<P: AsRef<Path> + AsMut<String>>(p: &[P])),
1742            &[quote!(p: &[&(dyn AsRef<Path> + AsMut<String>)])],
1743            &[quote!(&(0..p.len()).map(|__mockall_i| &p[__mockall_i] as &(dyn AsRef<Path> + AsMut<String>)).collect::<Vec<_>>())],
1744            &[quote!(p: &[P])]
1745        );
1746    }
1747
1748    #[test]
1749    fn where_clause() {
1750        check_concretize(
1751            quote!(fn foo<P>(p: P) where P: AsRef<Path>),
1752            &[quote!(p: &(dyn AsRef<Path>))],
1753            &[quote!(&p)],
1754            &[quote!(p: P)]
1755        );
1756    }
1757}
1758
1759mod declosurefy {
1760    use super::*;
1761
1762    fn check_declosurefy(
1763        sig: TokenStream,
1764        expected_inputs: &[TokenStream],
1765        expected_call_exprs: &[TokenStream])
1766    {
1767        let f: Signature = parse2(sig).unwrap();
1768        let (generics, inputs, call_exprs) =
1769            declosurefy(&f.generics, &f.inputs);
1770        assert!(generics.params.is_empty());
1771        assert_eq!(inputs.len(), expected_inputs.len());
1772        assert_eq!(call_exprs.len(), expected_call_exprs.len());
1773        for i in 0..inputs.len() {
1774            let actual = &inputs[i];
1775            let exp = &expected_inputs[i];
1776            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1777        }
1778        for i in 0..call_exprs.len() {
1779            let actual = &call_exprs[i];
1780            let exp = &expected_call_exprs[i];
1781            assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
1782        }
1783    }
1784
1785    #[test]
1786    fn bounds() {
1787        check_declosurefy(
1788            quote!(fn foo<F: Fn(u32) -> u32 + Send>(f: F)),
1789            &[quote!(f: Box<dyn Fn(u32) -> u32 + Send>)],
1790            &[quote!(Box::new(f))]
1791        );
1792    }
1793
1794    #[test]
1795    fn r#fn() {
1796        check_declosurefy(
1797            quote!(fn foo<F: Fn(u32) -> u32>(f: F)),
1798            &[quote!(f: Box<dyn Fn(u32) -> u32>)],
1799            &[quote!(Box::new(f))]
1800        );
1801    }
1802
1803    #[test]
1804    fn fn_mut() {
1805        check_declosurefy(
1806            quote!(fn foo<F: FnMut(u32) -> u32>(f: F)),
1807            &[quote!(f: Box<dyn FnMut(u32) -> u32>)],
1808            &[quote!(Box::new(f))]
1809        );
1810    }
1811
1812    #[test]
1813    fn fn_once() {
1814        check_declosurefy(
1815            quote!(fn foo<F: FnOnce(u32) -> u32>(f: F)),
1816            &[quote!(f: Box<dyn FnOnce(u32) -> u32>)],
1817            &[quote!(Box::new(f))]
1818        );
1819    }
1820
1821    #[test]
1822    fn mutable_pattern() {
1823        check_declosurefy(
1824            quote!(fn foo<F: FnMut(u32) -> u32>(mut f: F)),
1825            &[quote!(f: Box<dyn FnMut(u32) -> u32>)],
1826            &[quote!(Box::new(f))]
1827        );
1828    }
1829
1830    #[test]
1831    fn where_clause() {
1832        check_declosurefy(
1833            quote!(fn foo<F>(f: F) where F: Fn(u32) -> u32),
1834            &[quote!(f: Box<dyn Fn(u32) -> u32>)],
1835            &[quote!(Box::new(f))]
1836        );
1837    }
1838
1839    #[test]
1840    fn where_clause_with_bounds() {
1841        check_declosurefy(
1842            quote!(fn foo<F>(f: F) where F: Fn(u32) -> u32 + Send),
1843            &[quote!(f: Box<dyn Fn(u32) -> u32 + Send>)],
1844            &[quote!(Box::new(f))]
1845        );
1846    }
1847}
1848
1849mod fix_elipses {
1850    use super::*;
1851
1852    fn check_fix_elipses(orig_ts: TokenStream, expected_ts: TokenStream) {
1853        let mut orig: Signature = parse2(orig_ts).unwrap();
1854        let expected: Signature = parse2(expected_ts).unwrap();
1855        fix_elipses(&mut orig);
1856        assert_eq!(quote!(#orig).to_string(), quote!(#expected).to_string());
1857    }
1858
1859    #[test]
1860    fn nopat() {
1861        check_fix_elipses(
1862            quote!(fn foo(x: i32, ...) -> i32),
1863            quote!(fn foo(x: i32, _: ...) -> i32)
1864        );
1865    }
1866
1867    #[test]
1868    fn pat() {
1869        check_fix_elipses(
1870            quote!(fn foo(x: i32, y: ...) -> i32),
1871            quote!(fn foo(x: i32, y: ...) -> i32)
1872        );
1873    }
1874}
1875
1876mod deimplify {
1877    use super::*;
1878
1879    fn check_deimplify(orig_ts: TokenStream, expected_ts: TokenStream) {
1880        let mut orig: ReturnType = parse2(orig_ts).unwrap();
1881        let expected: ReturnType = parse2(expected_ts).unwrap();
1882        deimplify(&mut orig);
1883        assert_eq!(quote!(#orig).to_string(), quote!(#expected).to_string());
1884    }
1885
1886    // Future is a special case
1887    #[test]
1888    fn impl_future() {
1889        check_deimplify(
1890            quote!(-> impl Future<Output=i32>),
1891            quote!(-> ::std::pin::Pin<Box<dyn Future<Output=i32>>>)
1892        );
1893    }
1894
1895    // Future is a special case, wherever it appears
1896    #[test]
1897    fn impl_future_reverse() {
1898        check_deimplify(
1899            quote!(-> impl Send + Future<Output=i32>),
1900            quote!(-> ::std::pin::Pin<Box<dyn Send + Future<Output=i32>>>)
1901        );
1902    }
1903
1904    // Stream is a special case
1905    #[test]
1906    fn impl_stream() {
1907        check_deimplify(
1908            quote!(-> impl Stream<Item=i32>),
1909            quote!(-> ::std::pin::Pin<Box<dyn Stream<Item=i32>>>)
1910        );
1911    }
1912
1913    #[test]
1914    fn impl_trait() {
1915        check_deimplify(
1916            quote!(-> impl Foo),
1917            quote!(-> Box<dyn Foo>)
1918        );
1919    }
1920
1921    // With extra bounds
1922    #[test]
1923    fn impl_trait2() {
1924        check_deimplify(
1925            quote!(-> impl Foo + Send),
1926            quote!(-> Box<dyn Foo + Send>)
1927        );
1928    }
1929}
1930
1931mod deselfify {
1932    use super::*;
1933
1934    fn check_deselfify(
1935        orig_ts: TokenStream,
1936        actual_ts: TokenStream,
1937        generics_ts: TokenStream,
1938        expected_ts: TokenStream)
1939    {
1940        let mut ty: Type = parse2(orig_ts).unwrap();
1941        let actual: Ident = parse2(actual_ts).unwrap();
1942        let generics: Generics = parse2(generics_ts).unwrap();
1943        let expected: Type = parse2(expected_ts).unwrap();
1944        deselfify(&mut ty, &actual, &generics);
1945        assert_eq!(quote!(#ty).to_string(),
1946                   quote!(#expected).to_string());
1947    }
1948
1949    #[test]
1950    fn arc() {
1951        check_deselfify(
1952            quote!(Arc<Self>),
1953            quote!(Foo),
1954            quote!(),
1955            quote!(Arc<Foo>)
1956        );
1957    }
1958    #[test]
1959    fn future() {
1960        check_deselfify(
1961            quote!(Box<dyn Future<Output=Self>>),
1962            quote!(Foo),
1963            quote!(),
1964            quote!(Box<dyn Future<Output=Foo>>)
1965        );
1966    }
1967
1968    #[test]
1969    fn qself() {
1970        check_deselfify(
1971            quote!(<Self as Self>::Self),
1972            quote!(Foo),
1973            quote!(),
1974            quote!(<Foo as Foo>::Foo)
1975        );
1976    }
1977
1978    #[test]
1979    fn trait_object() {
1980        check_deselfify(
1981            quote!(Box<dyn Self>),
1982            quote!(Foo),
1983            quote!(),
1984            quote!(Box<dyn Foo>)
1985        );
1986    }
1987
1988    // A trait object with multiple bounds
1989    #[test]
1990    fn trait_object2() {
1991        check_deselfify(
1992            quote!(Box<dyn Self + Send>),
1993            quote!(Foo),
1994            quote!(),
1995            quote!(Box<dyn Foo + Send>)
1996        );
1997    }
1998}
1999
2000mod dewhereselfify {
2001    use super::*;
2002
2003    #[test]
2004    fn lifetime() {
2005        let mut meth: ImplItemFn = parse2(quote!(
2006                fn foo<'a>(&self) where 'a: 'static, Self: Sized {}
2007        )).unwrap();
2008        let expected: ImplItemFn = parse2(quote!(
2009                fn foo<'a>(&self) where 'a: 'static {}
2010        )).unwrap();
2011        dewhereselfify(&mut meth.sig.generics);
2012        assert_eq!(meth, expected);
2013    }
2014
2015    #[test]
2016    fn normal_method() {
2017        let mut meth: ImplItemFn = parse2(quote!(
2018                fn foo(&self) where Self: Sized {}
2019        )).unwrap();
2020        let expected: ImplItemFn = parse2(quote!(
2021                fn foo(&self) {}
2022        )).unwrap();
2023        dewhereselfify(&mut meth.sig.generics);
2024        assert_eq!(meth, expected);
2025    }
2026
2027    #[test]
2028    fn with_real_generics() {
2029        let mut meth: ImplItemFn = parse2(quote!(
2030                fn foo<T>(&self, t: T) where Self: Sized, T: Copy {}
2031        )).unwrap();
2032        let expected: ImplItemFn = parse2(quote!(
2033                fn foo<T>(&self, t: T) where T: Copy {}
2034        )).unwrap();
2035        dewhereselfify(&mut meth.sig.generics);
2036        assert_eq!(meth, expected);
2037    }
2038}
2039
2040mod gen_keyid {
2041    use super::*;
2042
2043    fn check_gen_keyid(orig: TokenStream, expected: TokenStream) {
2044        let g: Generics = parse2(orig).unwrap();
2045        let keyid = gen_keyid(&g);
2046        assert_eq!(quote!(#keyid).to_string(), quote!(#expected).to_string());
2047    }
2048
2049    #[test]
2050    fn empty() {
2051        check_gen_keyid(quote!(), quote!(<()>));
2052    }
2053
2054    #[test]
2055    fn onetype() {
2056        check_gen_keyid(quote!(<T>), quote!(<T>));
2057    }
2058
2059    #[test]
2060    fn twotypes() {
2061        check_gen_keyid(quote!(<T, V>), quote!(<(T, V)>));
2062    }
2063}
2064
2065mod merge_generics {
2066    use super::*;
2067
2068    #[test]
2069    fn both() {
2070        let mut g1: Generics = parse2(quote!(<T: 'static, V: Copy> )).unwrap();
2071        let wc1: WhereClause = parse2(quote!(where T: Default)).unwrap();
2072        g1.where_clause = Some(wc1);
2073
2074        let mut g2: Generics = parse2(quote!(<Q: Send, V: Clone>)).unwrap();
2075        let wc2: WhereClause = parse2(quote!(where T: Sync, Q: Debug)).unwrap();
2076        g2.where_clause = Some(wc2);
2077
2078        let gm = super::merge_generics(&g1, &g2);
2079        let gm_wc = &gm.where_clause;
2080
2081        let ge: Generics = parse2(quote!(
2082                <T: 'static, V: Copy + Clone, Q: Send>
2083        )).unwrap();
2084        let wce: WhereClause = parse2(quote!(
2085            where T: Default + Sync, Q: Debug
2086        )).unwrap();
2087
2088        assert_eq!(quote!(#ge #wce).to_string(),
2089                   quote!(#gm #gm_wc).to_string());
2090    }
2091
2092    #[test]
2093    fn eq() {
2094        let mut g1: Generics = parse2(quote!(<T: 'static, V: Copy> )).unwrap();
2095        let wc1: WhereClause = parse2(quote!(where T: Default)).unwrap();
2096        g1.where_clause = Some(wc1.clone());
2097
2098        let gm = super::merge_generics(&g1, &g1);
2099        let gm_wc = &gm.where_clause;
2100
2101        assert_eq!(quote!(#g1 #wc1).to_string(),
2102                   quote!(#gm #gm_wc).to_string());
2103    }
2104
2105    #[test]
2106    fn lhs_only() {
2107        let mut g1: Generics = parse2(quote!(<T: 'static, V: Copy> )).unwrap();
2108        let wc1: WhereClause = parse2(quote!(where T: Default)).unwrap();
2109        g1.where_clause = Some(wc1.clone());
2110
2111        let g2 = Generics::default();
2112
2113        let gm = super::merge_generics(&g1, &g2);
2114        let gm_wc = &gm.where_clause;
2115
2116        assert_eq!(quote!(#g1 #wc1).to_string(),
2117                   quote!(#gm #gm_wc).to_string());
2118    }
2119
2120    #[test]
2121    fn lhs_wc_only() {
2122        let mut g1 = Generics::default();
2123        let wc1: WhereClause = parse2(quote!(where T: Default)).unwrap();
2124        g1.where_clause = Some(wc1.clone());
2125
2126        let g2 = Generics::default();
2127
2128        let gm = super::merge_generics(&g1, &g2);
2129        let gm_wc = &gm.where_clause;
2130
2131        assert_eq!(quote!(#g1 #wc1).to_string(),
2132                   quote!(#gm #gm_wc).to_string());
2133    }
2134
2135    #[test]
2136    fn rhs_only() {
2137        let g1 = Generics::default();
2138        let mut g2: Generics = parse2(quote!(<Q: Send, V: Clone>)).unwrap();
2139        let wc2: WhereClause = parse2(quote!(where T: Sync, Q: Debug)).unwrap();
2140        g2.where_clause = Some(wc2.clone());
2141
2142        let gm = super::merge_generics(&g1, &g2);
2143        let gm_wc = &gm.where_clause;
2144
2145        assert_eq!(quote!(#g2 #wc2).to_string(),
2146                   quote!(#gm #gm_wc).to_string());
2147    }
2148}
2149
2150mod supersuperfy {
2151    use super::*;
2152
2153    fn check_supersuperfy(orig: TokenStream, expected: TokenStream) {
2154        let orig_ty: Type = parse2(orig).unwrap();
2155        let expected_ty: Type = parse2(expected).unwrap();
2156        let output = supersuperfy(&orig_ty, 1);
2157        assert_eq!(quote!(#output).to_string(),
2158                   quote!(#expected_ty).to_string());
2159    }
2160
2161    #[test]
2162    fn array() {
2163        check_supersuperfy(
2164            quote!([super::X; n]),
2165            quote!([super::super::X; n])
2166        );
2167    }
2168
2169    #[test]
2170    fn barefn() {
2171        check_supersuperfy(
2172            quote!(fn(super::A) -> super::B),
2173            quote!(fn(super::super::A) -> super::super::B)
2174        );
2175    }
2176
2177    #[test]
2178    fn group() {
2179        let orig = TypeGroup {
2180            group_token: token::Group::default(),
2181            elem: Box::new(parse2(quote!(super::T)).unwrap())
2182        };
2183        let expected = TypeGroup {
2184            group_token: token::Group::default(),
2185            elem: Box::new(parse2(quote!(super::super::T)).unwrap())
2186        };
2187        let output = supersuperfy(&Type::Group(orig), 1);
2188        assert_eq!(quote!(#output).to_string(),
2189                   quote!(#expected).to_string());
2190    }
2191
2192    // Just check that it doesn't panic
2193    #[test]
2194    fn infer() {
2195        check_supersuperfy( quote!(_), quote!(_));
2196    }
2197
2198    // Just check that it doesn't panic
2199    #[test]
2200    fn never() {
2201        check_supersuperfy( quote!(!), quote!(!));
2202    }
2203
2204    #[test]
2205    fn paren() {
2206        check_supersuperfy(
2207            quote!((super::X)),
2208            quote!((super::super::X))
2209        );
2210    }
2211
2212    #[test]
2213    fn path() {
2214        check_supersuperfy(
2215            quote!(::super::SuperT<u32>),
2216            quote!(::super::super::SuperT<u32>)
2217        );
2218    }
2219
2220    #[test]
2221    fn path_with_qself() {
2222        check_supersuperfy(
2223            quote!(<super::X as super::Y>::Foo<u32>),
2224            quote!(<super::super::X as super::super::Y>::Foo<u32>),
2225        );
2226    }
2227
2228    #[test]
2229    fn angle_bracketed_generic_arguments() {
2230        check_supersuperfy(
2231            quote!(mod_::T<super::X>),
2232            quote!(mod_::T<super::super::X>)
2233        );
2234    }
2235
2236    #[test]
2237    fn ptr() {
2238        check_supersuperfy(
2239            quote!(*const super::X),
2240            quote!(*const super::super::X)
2241        );
2242    }
2243
2244    #[test]
2245    fn reference() {
2246        check_supersuperfy(
2247            quote!(&'a mut super::X),
2248            quote!(&'a mut super::super::X)
2249        );
2250    }
2251
2252    #[test]
2253    fn slice() {
2254        check_supersuperfy(
2255            quote!([super::X]),
2256            quote!([super::super::X])
2257        );
2258    }
2259
2260    #[test]
2261    fn trait_object() {
2262        check_supersuperfy(
2263            quote!(dyn super::X + super::Y),
2264            quote!(dyn super::super::X + super::super::Y)
2265        );
2266    }
2267
2268    #[test]
2269    fn tuple() {
2270        check_supersuperfy(
2271            quote!((super::A, super::B)),
2272            quote!((super::super::A, super::super::B))
2273        );
2274    }
2275}
2276
2277mod supersuperfy_generics {
2278    use super::*;
2279
2280    fn check_supersuperfy_generics(
2281        orig: TokenStream,
2282        orig_wc: TokenStream,
2283        expected: TokenStream,
2284        expected_wc: TokenStream)
2285    {
2286        let mut orig_g: Generics = parse2(orig).unwrap();
2287        orig_g.where_clause = parse2(orig_wc).unwrap();
2288        let mut expected_g: Generics = parse2(expected).unwrap();
2289        expected_g.where_clause = parse2(expected_wc).unwrap();
2290        let mut output: Generics = orig_g;
2291        supersuperfy_generics(&mut output, 1);
2292        let (o_ig, o_tg, o_wc) = output.split_for_impl();
2293        let (e_ig, e_tg, e_wc) = expected_g.split_for_impl();
2294        assert_eq!(quote!(#o_ig).to_string(), quote!(#e_ig).to_string());
2295        assert_eq!(quote!(#o_tg).to_string(), quote!(#e_tg).to_string());
2296        assert_eq!(quote!(#o_wc).to_string(), quote!(#e_wc).to_string());
2297    }
2298
2299    #[test]
2300    fn default() {
2301        check_supersuperfy_generics(
2302            quote!(<T: X = super::Y>), quote!(),
2303            quote!(<T: X = super::super::Y>), quote!(),
2304        );
2305    }
2306
2307    #[test]
2308    fn empty() {
2309        check_supersuperfy_generics(quote!(), quote!(), quote!(), quote!());
2310    }
2311
2312    #[test]
2313    fn everything() {
2314        check_supersuperfy_generics(
2315            quote!(<T: super::A = super::B>),
2316            quote!(where super::C: super::D),
2317            quote!(<T: super::super::A = super::super::B>),
2318            quote!(where super::super::C: super::super::D),
2319        );
2320    }
2321
2322    #[test]
2323    fn bound() {
2324        check_supersuperfy_generics(
2325            quote!(<T: super::A>), quote!(),
2326            quote!(<T: super::super::A>), quote!(),
2327        );
2328    }
2329
2330    #[test]
2331    fn closure() {
2332        check_supersuperfy_generics(
2333            quote!(<F: Fn(u32) -> super::SuperT>), quote!(),
2334            quote!(<F: Fn(u32) -> super::super::SuperT>), quote!(),
2335        );
2336    }
2337
2338    #[test]
2339    fn wc_bounded_ty() {
2340        check_supersuperfy_generics(
2341            quote!(), quote!(where super::T: X),
2342            quote!(), quote!(where super::super::T: X),
2343        );
2344    }
2345
2346    #[test]
2347    fn wc_bounds() {
2348        check_supersuperfy_generics(
2349            quote!(), quote!(where T: super::X),
2350            quote!(), quote!(where T: super::super::X),
2351        );
2352    }
2353}
2354}