quickcheck_arbitrary_derive/
lib.rs

1use proc_macro::{self};
2use proc_macro2::{Span, TokenStream};
3use quote::{ToTokens, format_ident, quote};
4use syn::{
5    DataEnum, DeriveInput, Field, FieldsNamed, FieldsUnnamed, Ident, LitInt, Type,
6    parse_macro_input,
7};
8
9fn generate_product_shrink<
10    Iter: IntoIterator<Item = Field> + Clone,
11    IdentKind: Clone + ToTokens + ToString,
12>(
13    fields: &Iter,
14    constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
15    make_ident: impl Fn(&str) -> IdentKind,
16    self_helper: impl Fn(Ident, &IdentKind, usize) -> TokenStream,
17) -> TokenStream {
18    let self_copies = fields
19        .clone()
20        .into_iter()
21        .enumerate()
22        .map(|(idx, field)| {
23            let ident = field
24                .ident
25                .clone()
26                .map(|ident| ident.to_string())
27                .unwrap_or(idx.to_string());
28            let unique_self = format_ident!("self_{}", ident);
29            quote! {
30                let #unique_self = <Self as ::std::clone::Clone>::clone(&self);
31            }
32        })
33        .collect::<Vec<_>>();
34
35    let cloning_iterator_madness: TokenStream = fields
36        .clone()
37        .into_iter()
38        .enumerate()
39        .map(|(idx, field)| {
40            let ident = make_ident(
41                &field
42                    .ident
43                    .clone()
44                    .map(|ident| ident.to_string())
45                    .unwrap_or(idx.to_string()),
46            );
47            let other_idents = fields
48                .clone()
49                .into_iter()
50                .enumerate()
51                .map(|(idx, field)| {
52                    make_ident(
53                        &field
54                            .ident
55                            .clone()
56                            .map(|ident| ident.to_string())
57                            .unwrap_or(idx.to_string()),
58                    )
59                })
60                .filter(|e| e.to_string() != ident.to_string())
61                .map(|field_ident| {
62                    let unique_self_toks = self_helper(
63                        format_ident!("self_{}", ident.to_string()),
64                        &field_ident,
65                        fields.clone().into_iter().collect::<Vec<_>>().len(),
66                    );
67                    (
68                        field_ident.clone(),
69                        quote! {::core::clone::Clone::clone(#unique_self_toks)},
70                    )
71                })
72                .collect::<Vec<_>>();
73            constructor(&field.ty, &ident, &other_idents)
74        })
75        .collect::<Vec<_>>()
76        .iter()
77        .rev()
78        .cloned()
79        .reduce(|a, b| quote! {::std::iter::Iterator::chain(#a, #b)})
80        .unwrap_or(quote! {});
81
82    quote! {
83        #(#self_copies)*
84        ::std::boxed::Box::new(#cloning_iterator_madness)
85    }
86}
87fn generate_product_shrink_simple<
88    Iter: IntoIterator<Item = Field> + Clone,
89    IdentKind: Clone + ToTokens + ToString,
90>(
91    fields: &Iter,
92    constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
93    make_ident: impl Fn(&str) -> IdentKind,
94) -> TokenStream {
95    generate_product_shrink(
96        fields,
97        constructor,
98        make_ident,
99        |unique_self, field_ident, _| quote! {&#unique_self.#field_ident},
100    )
101}
102
103fn make_enum_puller(pull: usize, others: usize, variant: &Ident, source: &Ident) -> TokenStream {
104    let v_puller = [quote! {__quickcheck_derive_match_puller}];
105    let pull_pattern = (0..(pull))
106        .map(|_| quote! {_})
107        .chain(v_puller.iter().cloned())
108        .chain((pull..others).map(|_| quote! {_}));
109
110    quote! {if let Self::#variant(#(#pull_pattern),*) = &#source {
111        __quickcheck_derive_match_puller
112    } else {
113        ::core::unreachable!()
114    }}
115}
116
117struct ArbitraryImpl {
118    arbitrary: TokenStream,
119    shrink: TokenStream,
120}
121
122fn make_named_struct_arbitrary(fields_named: &FieldsNamed) -> ArbitraryImpl {
123    let field_arbitrary_generators = fields_named
124        .named
125        .iter()
126        .map(|field| {
127            let identifier = &field.ident;
128            let ty = &field.ty;
129            quote! {
130                #identifier: <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
131            }
132        })
133        .collect::<Vec<_>>();
134    ArbitraryImpl {
135        shrink: generate_product_shrink_simple(
136            &fields_named.named,
137            |ty, ident, other_idents| {
138                let other_idents_initialisers = other_idents
139                    .iter()
140                    .map(|(ident, toks)| {
141                        quote! {#ident: #toks}
142                    })
143                    .collect::<Vec<_>>();
144                quote! {
145                    ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
146                        move |__quickcheck_derive_moving| Self {#ident: __quickcheck_derive_moving, #(#other_idents_initialisers),*})
147                }
148            },
149            |ident_str| Ident::new(ident_str, Span::call_site()),
150        ),
151        arbitrary: quote! {
152            Self {
153                #(#field_arbitrary_generators),*
154            }
155        },
156    }
157}
158
159fn make_unnamed_struct_arbitrary(fields_unnamed: &FieldsUnnamed) -> ArbitraryImpl {
160    let field_arbitrary_generators = fields_unnamed
161        .unnamed
162        .iter()
163        .map(|field| {
164            let ty = &field.ty;
165            quote! {
166                <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
167            }
168        })
169        .collect::<Vec<_>>();
170    ArbitraryImpl {
171        arbitrary: quote! {
172            Self(#(#field_arbitrary_generators),*)
173        },
174        shrink: generate_product_shrink_simple::<_, LitInt>(
175            &fields_unnamed.unnamed,
176            |ty, ident, other_idents| {
177                let mut idents_all = other_idents.clone();
178                idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
179                idents_all.sort_by(|(a, _), (b, _)| {
180                    a.base10_parse::<u64>()
181                        .unwrap()
182                        .cmp(&b.base10_parse().unwrap())
183                });
184                let initialiser_list = idents_all
185                    .iter()
186                    .map(|(_, stream)| stream)
187                    .collect::<Vec<_>>();
188
189                quote! {
190                    ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
191                        move |__quickcheck_derive_moving| Self(#(#initialiser_list),*))
192                }
193            },
194            |ident_str| LitInt::new(ident_str, Span::call_site()),
195        ),
196    }
197}
198
199fn make_enum_arbitrary(ident: &Ident, data_enum: &DataEnum) -> ArbitraryImpl {
200    let num_variants = data_enum.variants.len();
201    let initialisers = data_enum
202        .variants
203        .iter()
204        .map(|variant| {
205            (
206                &variant.ident,
207                match variant.fields.len() {
208                    0 => quote! {},
209                    _ => {
210                        let field_arbitrary_generators = variant
211                            .fields
212                            .iter()
213                            .map(|field| {
214                                let ty = &field.ty;
215                                quote! {<#ty as ::quickcheck::Arbitrary>::arbitrary(g)}
216                            })
217                            .collect::<Vec<_>>();
218                        quote! {(#(#field_arbitrary_generators),*)}
219                    }
220                },
221            )
222        })
223        .map(|initialiser| {
224            let ident = initialiser.0;
225            let initialiser_list = initialiser.1;
226            quote! {Self::#ident #initialiser_list}
227        })
228        .enumerate()
229        .map(|(index, constructor)| {
230            quote! {#index => #constructor}
231        })
232        .collect::<Vec<_>>();
233
234    let enum_name = &ident;
235    let arm_matchers = data_enum
236        .variants
237        .iter()
238        .map(|variant| {
239            let variant_ident = &variant.ident;
240            let shrinker = generate_product_shrink::<_, LitInt>(
241                &variant.fields,
242                |ty, ident, other_idents| {
243                    let mut idents_all = other_idents.clone();
244                    idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
245                    idents_all.sort_by(|(a, _), (b, _)| {
246                        a.base10_parse::<u64>()
247                            .unwrap()
248                            .cmp(&b.base10_parse().unwrap())
249                    });
250                    let initialiser_list = idents_all
251                        .iter()
252                        .map(|(_, stream)| stream)
253                        .collect::<Vec<_>>();
254
255                    let puller = make_enum_puller(
256                        ident.base10_parse().unwrap(),
257                        other_idents.len(),
258                        &variant.ident,
259                        &Ident::new("self", Span::call_site()),
260                    );
261
262                    quote! {
263                        ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(
264                            #puller
265                        ),
266                        move |__quickcheck_derive_moving| Self::#variant_ident(#(#initialiser_list),*))
267                    }
268                },
269                |ident_str| LitInt::new(ident_str, Span::call_site()),
270                |ident, field, num_fields| {
271                   make_enum_puller(
272                        field.base10_parse().unwrap(),
273                        num_fields - 1,
274                        variant_ident,
275                        &ident,
276                    )
277                },
278            );
279
280            let underscores = (0..variant.fields.len())
281                .map(|_| quote! {_})
282                .collect::<Vec<_>>();
283            quote! {#enum_name::#variant_ident(#(#underscores),*) => {#shrinker}}
284        })
285        .collect::<Vec<_>>();
286
287    ArbitraryImpl {
288        arbitrary: quote! {
289            match <::core::primitive::usize as ::quickcheck::Arbitrary>::arbitrary(g) % #num_variants {
290                #(#initialisers),*,
291                _ => ::core::unreachable!()
292            }
293        },
294        shrink: quote! {
295            match &self {
296                #(#arm_matchers),*
297            }
298        },
299    }
300}
301
302#[proc_macro_derive(QuickCheck)]
303pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
304    let DeriveInput {
305        ident,
306        data,
307        generics,
308        ..
309    } = parse_macro_input!(input);
310    let ArbitraryImpl { arbitrary, shrink } = match data {
311        syn::Data::Struct(data_struct) => match data_struct.fields {
312            syn::Fields::Named(fields_named) => make_named_struct_arbitrary(&fields_named),
313            syn::Fields::Unnamed(fields_unnamed) => make_unnamed_struct_arbitrary(&fields_unnamed),
314            syn::Fields::Unit => ArbitraryImpl {
315                arbitrary: quote! {Self},
316                shrink: quote! {::quickcheck::empty_shrinker()},
317            },
318        },
319        syn::Data::Enum(data_enum) => make_enum_arbitrary(&ident, &data_enum),
320        syn::Data::Union(_) => ArbitraryImpl {
321            shrink: quote! {::quickcheck::empty_shrinker()},
322            arbitrary: {
323                syn::Error::new_spanned(&ident, "Cannot derive QuickCheck for a union yet")
324                    .to_compile_error()
325            },
326        },
327    };
328
329    let generics_unconstrained = generics
330        .lifetimes()
331        .map(|lifetime| lifetime.lifetime.to_token_stream())
332        .chain(
333            generics
334                .type_params()
335                .map(|type_param| type_param.ident.to_token_stream()),
336        )
337        .collect::<Vec<_>>();
338
339    let generics_arbitrary = generics
340        .lifetimes()
341        .map(|lifetime| lifetime.to_token_stream())
342        .chain(generics.type_params().map(|type_param| {
343            let colon = match type_param.bounds.len() {
344                0 => quote! {:},
345                _ => quote! {+},
346            };
347            quote! {#type_param #colon ::quickcheck::Arbitrary}
348        }))
349        .collect::<Vec<_>>();
350
351    let generics_unconstrained_tokens = match generics_unconstrained.len() {
352        0 => quote! {},
353        _ => quote! {<#(#generics_unconstrained),*>},
354    };
355    let generics_arbitrary_tokens = match generics_arbitrary.len() {
356        0 => quote! {},
357        _ => quote! {<#(#generics_arbitrary),*>},
358    };
359
360    if !generics.lifetimes().collect::<Vec<_>>().is_empty() {
361        return syn::Error::new_spanned(
362            &ident,
363            "Cannot derive QuickCheck for a type with lifetimes yet",
364        )
365        .to_compile_error()
366        .into();
367    }
368
369    let output = quote! {
370        impl #generics_arbitrary_tokens ::quickcheck::Arbitrary for #ident #generics_unconstrained_tokens
371        where
372            #ident #generics_unconstrained_tokens : ::core::clone::Clone {
373            fn arbitrary(g: &mut ::quickcheck::Gen) -> Self {
374                #arbitrary
375            }
376
377            fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
378                #shrink
379            }
380        }
381    };
382    output.into()
383}