1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use proc_macro2::{Ident, Span, TokenStream};
use synstructure::*;

decl_derive!([Arbitrary] => arbitrary_derive);

fn gen_arbitrary_method(variants: &[VariantInfo]) -> TokenStream {
    if variants.len() == 1 {
        // struct
        let con = variants[0].construct(|_, _| quote! { Arbitrary::arbitrary(u)? });
        let con_rest = construct_take_rest(&variants[0]);
        quote! {
            fn arbitrary(u: &mut Unstructured<'_>) -> Result<Self> {
                Ok(#con)
            }

            fn arbitrary_take_rest(mut u: Unstructured<'_>) -> Result<Self> {
                Ok(#con_rest)
            }
        }
    } else {
        // enum
        let mut variant_tokens = TokenStream::new();
        let mut variant_tokens_take_rest = TokenStream::new();

        for (count, variant) in variants.iter().enumerate() {
            let count = count as u64;
            let constructor = variant.construct(|_, _| quote! { Arbitrary::arbitrary(u)? });
            variant_tokens.extend(quote! { #count => #constructor, });

            let constructor_take_rest = construct_take_rest(&variant);
            variant_tokens_take_rest.extend(quote! { #count => #constructor_take_rest, });
        }

        let count = variants.len() as u64;

        quote! {
            fn arbitrary(u: &mut Unstructured<'_>) -> Result<Self> {
                // Use a multiply + shift to generate a ranged random number
                // with slight bias. For details, see:
                // https://lemire.me/blog/2016/06/30/fast-random-shuffling
                Ok(match (u64::from(<u32 as Arbitrary>::arbitrary(u)?) * #count) >> 32 {
                    #variant_tokens
                    _ => unreachable!()
                })
            }

            fn arbitrary_take_rest(mut u: Unstructured<'_>) -> Result<Self> {
                // Use a multiply + shift to generate a ranged random number
                // with slight bias. For details, see:
                // https://lemire.me/blog/2016/06/30/fast-random-shuffling
                Ok(match (u64::from(<u32 as Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
                    #variant_tokens_take_rest
                    _ => unreachable!()
                })
            }
        }
    }
}

fn construct_take_rest(v: &VariantInfo) -> TokenStream {
    let len = v.bindings().len();
    v.construct(|_, idx| {
        if idx == len - 1 {
            quote! { Arbitrary::arbitrary_take_rest(u)? }
        } else {
            quote! { Arbitrary::arbitrary(&mut u)? }
        }
    })
}

fn gen_size_hint_method(s: &Structure) -> TokenStream {
    let mut sizes = Vec::with_capacity(s.variants().len());
    for v in s.variants().iter() {
        let tys = v.ast().fields.iter().map(|f| &f.ty);
        sizes.push(quote! {
            arbitrary::size_hint::and_all(&[
                #( <#tys as Arbitrary>::size_hint() ),*
            ])
        });
    }
    quote! {
        fn size_hint() -> (usize, Option<usize>) {
            arbitrary::size_hint::or_all(&[ #( #sizes ),* ])
        }
    }
}

fn gen_shrink_method(s: &Structure) -> TokenStream {
    let variants = s.each_variant(|v| {
        if v.bindings().is_empty() {
            return quote! {
                Box::new(None.into_iter())
            };
        }

        let mut value_idents = vec![];
        let mut shrinker_idents = vec![];
        let mut shrinker_exprs = vec![];
        for (i, b) in v.bindings().iter().enumerate() {
            value_idents.push(Ident::new(&format!("value{}", i), Span::call_site()));
            shrinker_idents.push(Ident::new(&format!("shrinker{}", i), Span::call_site()));
            shrinker_exprs.push(quote! { Arbitrary::shrink(#b) });
        }
        let cons = v.construct(|_, i| &value_idents[i]);
        let shrinker_idents = &shrinker_idents;
        quote! {
            #( let mut #shrinker_idents = #shrinker_exprs; )*
            Box::new(std::iter::from_fn(move || {
                #( let #value_idents = #shrinker_idents.next()?; )*
                Some(#cons)
            }))
        }
    });

    quote! {
        fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
            match self {
                #variants
            }
        }
    }
}

fn arbitrary_derive(s: Structure) -> TokenStream {
    let arbitrary_method = gen_arbitrary_method(s.variants());
    let size_hint_method = gen_size_hint_method(&s);
    let shrink_method = gen_shrink_method(&s);
    s.gen_impl(quote! {
        use arbitrary::{Arbitrary, Unstructured, Result};
        gen impl Arbitrary for @Self {
            #arbitrary_method
            #size_hint_method
            #shrink_method
        }
    })
}