quickcheck_arbitrary_derive/
lib.rs

1use core::panic;
2use std::collections::HashMap;
3
4use proc_macro::{self};
5use proc_macro2::{Span, TokenStream};
6use quote::{ToTokens, format_ident, quote};
7use syn::{
8    Attribute, DataEnum, DeriveInput, Field, FieldsNamed, FieldsUnnamed, Ident, LitInt,
9    MetaNameValue, Type, parse_macro_input, punctuated::Punctuated, token::Comma,
10};
11
12fn generate_product_shrink<
13    Iter: IntoIterator<Item = Field> + Clone,
14    IdentKind: Clone + ToTokens + ToString,
15>(
16    fields: &Iter,
17    constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
18    make_ident: impl Fn(&str) -> IdentKind,
19    self_helper: impl Fn(Ident, &IdentKind, usize) -> TokenStream,
20) -> TokenStream {
21    let self_copies = fields
22        .clone()
23        .into_iter()
24        .enumerate()
25        .map(|(idx, field)| {
26            let ident = field
27                .ident
28                .clone()
29                .map(|ident| ident.to_string())
30                .unwrap_or(idx.to_string());
31            let unique_self = format_ident!("self_{}", ident);
32            quote! {
33                let #unique_self = <Self as ::std::clone::Clone>::clone(&self);
34            }
35        })
36        .collect::<Vec<_>>();
37
38    let cloning_iterator_madness: TokenStream = fields
39        .clone()
40        .into_iter()
41        .enumerate()
42        .map(|(idx, field)| {
43            let ident = make_ident(
44                &field
45                    .ident
46                    .clone()
47                    .map(|ident| ident.to_string())
48                    .unwrap_or(idx.to_string()),
49            );
50            let other_idents = fields
51                .clone()
52                .into_iter()
53                .enumerate()
54                .map(|(idx, field)| {
55                    make_ident(
56                        &field
57                            .ident
58                            .clone()
59                            .map(|ident| ident.to_string())
60                            .unwrap_or(idx.to_string()),
61                    )
62                })
63                .filter(|e| e.to_string() != ident.to_string())
64                .map(|field_ident| {
65                    let unique_self_toks = self_helper(
66                        format_ident!("self_{}", ident.to_string()),
67                        &field_ident,
68                        fields.clone().into_iter().collect::<Vec<_>>().len(),
69                    );
70                    (
71                        field_ident.clone(),
72                        quote! {::core::clone::Clone::clone(#unique_self_toks)},
73                    )
74                })
75                .collect::<Vec<_>>();
76            constructor(&field.ty, &ident, &other_idents)
77        })
78        .collect::<Vec<_>>()
79        .iter()
80        .rev()
81        .cloned()
82        .reduce(|a, b| quote! {::std::iter::Iterator::chain(#a, #b)})
83        .unwrap_or(quote! {});
84
85    quote! {
86        #(#self_copies)*
87        ::std::boxed::Box::new(#cloning_iterator_madness)
88    }
89}
90fn generate_product_shrink_simple<
91    Iter: IntoIterator<Item = Field> + Clone,
92    IdentKind: Clone + ToTokens + ToString,
93>(
94    fields: &Iter,
95    constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
96    make_ident: impl Fn(&str) -> IdentKind,
97) -> TokenStream {
98    generate_product_shrink(
99        fields,
100        constructor,
101        make_ident,
102        |unique_self, field_ident, _| quote! {&#unique_self.#field_ident},
103    )
104}
105
106fn make_enum_puller(pull: usize, others: usize, variant: &Ident, source: &Ident) -> TokenStream {
107    let v_puller = [quote! {__quickcheck_derive_match_puller}];
108    let pull_pattern = (0..(pull))
109        .map(|_| quote! {_})
110        .chain(v_puller.iter().cloned())
111        .chain((pull..others).map(|_| quote! {_}));
112
113    quote! {if let Self::#variant(#(#pull_pattern),*) = &#source {
114        __quickcheck_derive_match_puller
115    } else {
116        ::core::unreachable!()
117    }}
118}
119
120struct ArbitraryImpl {
121    arbitrary: TokenStream,
122    shrink: TokenStream,
123}
124
125fn make_named_struct_arbitrary(fields_named: &FieldsNamed) -> ArbitraryImpl {
126    let field_arbitrary_generators = fields_named
127        .named
128        .iter()
129        .map(|field| {
130            let identifier = &field.ident;
131            let ty = &field.ty;
132            quote! {
133                #identifier: <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
134            }
135        })
136        .collect::<Vec<_>>();
137    ArbitraryImpl {
138        shrink: generate_product_shrink_simple(
139            &fields_named.named,
140            |ty, ident, other_idents| {
141                let other_idents_initialisers = other_idents
142                    .iter()
143                    .map(|(ident, toks)| {
144                        quote! {#ident: #toks}
145                    })
146                    .collect::<Vec<_>>();
147                quote! {
148                    ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
149                        move |__quickcheck_derive_moving| Self {#ident: __quickcheck_derive_moving, #(#other_idents_initialisers),*})
150                }
151            },
152            |ident_str| Ident::new(ident_str, Span::call_site()),
153        ),
154        arbitrary: quote! {
155            Self {
156                #(#field_arbitrary_generators),*
157            }
158        },
159    }
160}
161
162fn make_unnamed_struct_arbitrary(fields_unnamed: &FieldsUnnamed) -> ArbitraryImpl {
163    let field_arbitrary_generators = fields_unnamed
164        .unnamed
165        .iter()
166        .map(|field| {
167            let ty = &field.ty;
168            quote! {
169                <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
170            }
171        })
172        .collect::<Vec<_>>();
173    ArbitraryImpl {
174        arbitrary: quote! {
175            Self(#(#field_arbitrary_generators),*)
176        },
177        shrink: generate_product_shrink_simple::<_, LitInt>(
178            &fields_unnamed.unnamed,
179            |ty, ident, other_idents| {
180                let mut idents_all = other_idents.clone();
181                idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
182                idents_all.sort_by(|(a, _), (b, _)| {
183                    a.base10_parse::<u64>()
184                        .unwrap()
185                        .cmp(&b.base10_parse().unwrap())
186                });
187                let initialiser_list = idents_all
188                    .iter()
189                    .map(|(_, stream)| stream)
190                    .collect::<Vec<_>>();
191
192                quote! {
193                    ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
194                        move |__quickcheck_derive_moving| Self(#(#initialiser_list),*))
195                }
196            },
197            |ident_str| LitInt::new(ident_str, Span::call_site()),
198        ),
199    }
200}
201
202#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
203enum RecursiveKind {
204    None = 0,
205    Linear = 1,
206    Exponential = 2,
207}
208
209#[derive(Clone, Copy, Debug)]
210struct EnumAtrributes {
211    recursive: RecursiveKind,
212}
213
214fn get_enum_attrs(attrs: &Vec<Attribute>) -> EnumAtrributes {
215    const RECURSION_INVALID_KIND: &str =
216        "quickcheck recursive strategies must be one of None, Linear, Exponential";
217
218    let all_attrs = attrs
219        .iter()
220        .filter(|attr| attr.meta.path().is_ident("quickcheck"))
221        .map(|attr| {
222            attr.parse_args_with(Punctuated::<MetaNameValue, Comma>::parse_terminated)
223                .expect("quickcheck attribute must have comma seperated arguments")
224                .iter()
225                .map(|arg| {
226                    (
227                        arg.path
228                            .get_ident()
229                            .expect("quickcheck arguments must be of the form `ident = value`")
230                            .to_string(),
231                        match &arg.value {
232                            syn::Expr::Path(v) => v.path.require_ident().expect("quickcheck recursive strategies must be one of None, Linear, Exponential").to_string(),
233                            _ => panic!("quickcheck values must be literals"),
234                        },
235                    )
236                })
237                .collect::<HashMap<_, _>>()
238        })
239        .map(|key_values| EnumAtrributes {
240            recursive: match key_values
241                .get("recursive")
242                .cloned()
243            {
244                Some(v) => match v.as_str() {
245                    "None" => RecursiveKind::None,
246                    "Linear" => RecursiveKind::Linear,
247                    "Exponential" => RecursiveKind::Exponential,
248                    _ => panic!("{}", RECURSION_INVALID_KIND)
249                },
250                None => RecursiveKind::None,
251            },
252        })
253        .collect::<Vec<_>>();
254
255    match all_attrs.len() {
256        0 => EnumAtrributes {
257            recursive: RecursiveKind::None,
258        },
259        1 => all_attrs[0],
260        _ => panic!("quickcheck attribute may only be applied once to each field"),
261    }
262}
263
264fn make_enum_arbitrary(ident: &Ident, data_enum: &DataEnum) -> ArbitraryImpl {
265    let num_variants = data_enum.variants.len();
266
267    let mut initialisers = data_enum
268        .variants
269        .iter()
270        .map(|variant| {
271            (
272                &variant.ident,
273                match variant.fields.len() {
274                    0 => (quote! {}, RecursiveKind::None),
275                    _ => {
276                        let attrs = get_enum_attrs(&variant.attrs);
277                        let new_g = match attrs.recursive {
278                            RecursiveKind::Exponential => quote! {&mut ::quickcheck::Gen::new(::std::cmp::max(::quickcheck::Gen::size(g) / 2, 1))},
279                            RecursiveKind::Linear => quote! {&mut ::quickcheck::Gen::new(::std::cmp::max(::quickcheck::Gen::size(g) - 1, 1))},
280                            RecursiveKind::None => quote! {g}
281                        };
282                        let field_arbitrary_generators = variant
283                            .fields
284                            .iter()
285                            .map(|field| {
286                                let ty = &field.ty;
287                                quote! {<#ty as ::quickcheck::Arbitrary>::arbitrary(#new_g)}
288                            })
289                            .collect::<Vec<_>>();
290                        (quote! {(#(#field_arbitrary_generators),*)}, attrs.recursive)
291                    }
292                },
293            )
294        })
295        .map(|(ident, (initialiser_list, recursive))| {
296            (quote! {Self::#ident #initialiser_list}, recursive)
297        })
298        .enumerate()
299        .map(|(index, (constructor, recursive))| {
300            (quote! {#index => #constructor}, recursive)
301        })
302        .collect::<Vec<_>>();
303
304    initialisers.sort_by_key(|(_, recursive)| *recursive);
305    let num_recursive = initialisers
306        .iter()
307        .filter(|(_, recursive)| !matches!(recursive, RecursiveKind::None))
308        .count();
309    let initialisers = initialisers
310        .into_iter()
311        .map(|(toks, _)| toks)
312        .collect::<Vec<_>>();
313
314    let enum_name = &ident;
315    let arm_matchers = data_enum
316        .variants
317        .iter()
318        .map(|variant| {
319            let variant_ident = &variant.ident;
320            let shrinker = generate_product_shrink::<_, LitInt>(
321                &variant.fields,
322                |ty, ident, other_idents| {
323                    let mut idents_all = other_idents.clone();
324                    idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
325                    idents_all.sort_by(|(a, _), (b, _)| {
326                        a.base10_parse::<u64>()
327                            .unwrap()
328                            .cmp(&b.base10_parse().unwrap())
329                    });
330                    let initialiser_list = idents_all
331                        .iter()
332                        .map(|(_, stream)| stream)
333                        .collect::<Vec<_>>();
334
335                    let puller = make_enum_puller(
336                        ident.base10_parse().unwrap(),
337                        other_idents.len(),
338                        &variant.ident,
339                        &Ident::new("self", Span::call_site()),
340                    );
341
342                    quote! {
343                        ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(
344                            #puller
345                        ),
346                        move |__quickcheck_derive_moving| Self::#variant_ident(#(#initialiser_list),*))
347                    }
348                },
349                |ident_str| LitInt::new(ident_str, Span::call_site()),
350                |ident, field, num_fields| {
351                   make_enum_puller(
352                        field.base10_parse().unwrap(),
353                        num_fields - 1,
354                        variant_ident,
355                        &ident,
356                    )
357                },
358            );
359
360            let underscores = (0..variant.fields.len())
361                .map(|_| quote! {_})
362                .collect::<Vec<_>>();
363
364             match variant.fields.is_empty() {
365                true => quote! {#enum_name::#variant_ident => ::std::boxed::Box::new(::quickcheck::empty_shrinker())},
366                false => quote! {#enum_name::#variant_ident(#(#underscores),*) => {#shrinker}} ,
367            }
368
369        })
370        .collect::<Vec<_>>();
371
372    ArbitraryImpl {
373        arbitrary: quote! {
374            match <::core::primitive::usize as ::quickcheck::Arbitrary>::arbitrary(g) % (
375            if ::quickcheck::Gen::size(g) > 1 {
376                #num_variants
377            } else {
378                #num_variants - #num_recursive
379            }) {
380                #(#initialisers),*,
381                _ => ::core::unreachable!()
382            }
383        },
384        shrink: quote! {
385            match &self {
386                #(#arm_matchers),*
387            }
388        },
389    }
390}
391
392/// Generates an implementation of `quickcheck::Arbitrary`.
393///
394/// You can annotate an enum variant with `#[quickcheck(recursive = None | Linear | Exponential)]` to allow for testing of potentially infinitely large types
395///
396/// ```rs
397/// #[derive(Clone, QuickCheck, Debug)]
398/// enum Tree<T> {
399///     #[quickcheck(recursive = Exponential)]
400///     Branch(Vec<Tree<T>>),
401///     Leaf(T),
402/// }
403/// ```
404///
405/// Use exponential for types that exponentially grow with depth (like trees).
406///
407/// Use linear for types that linearly grow with depth (like linked lists).
408#[proc_macro_derive(QuickCheck, attributes(quickcheck))]
409pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
410    let DeriveInput {
411        ident,
412        data,
413        generics,
414        ..
415    } = parse_macro_input!(input);
416    let ArbitraryImpl { arbitrary, shrink } = match data {
417        syn::Data::Struct(data_struct) => match data_struct.fields {
418            syn::Fields::Named(fields_named) => make_named_struct_arbitrary(&fields_named),
419            syn::Fields::Unnamed(fields_unnamed) => make_unnamed_struct_arbitrary(&fields_unnamed),
420            syn::Fields::Unit => ArbitraryImpl {
421                arbitrary: quote! {Self},
422                shrink: quote! {::quickcheck::empty_shrinker()},
423            },
424        },
425        syn::Data::Enum(data_enum) => make_enum_arbitrary(&ident, &data_enum),
426        syn::Data::Union(_) => ArbitraryImpl {
427            shrink: quote! {::quickcheck::empty_shrinker()},
428            arbitrary: {
429                syn::Error::new_spanned(&ident, "Cannot derive QuickCheck for a union yet")
430                    .to_compile_error()
431            },
432        },
433    };
434
435    let generics_unconstrained = generics
436        .lifetimes()
437        .map(|lifetime| lifetime.lifetime.to_token_stream())
438        .chain(
439            generics
440                .type_params()
441                .map(|type_param| type_param.ident.to_token_stream()),
442        )
443        .collect::<Vec<_>>();
444
445    let generics_arbitrary = generics
446        .lifetimes()
447        .map(|lifetime| lifetime.to_token_stream())
448        .chain(generics.type_params().map(|type_param| {
449            let colon = match type_param.bounds.len() {
450                0 => quote! {:},
451                _ => quote! {+},
452            };
453            quote! {#type_param #colon ::quickcheck::Arbitrary}
454        }))
455        .collect::<Vec<_>>();
456
457    let generics_unconstrained_tokens = match generics_unconstrained.len() {
458        0 => quote! {},
459        _ => quote! {<#(#generics_unconstrained),*>},
460    };
461    let generics_arbitrary_tokens = match generics_arbitrary.len() {
462        0 => quote! {},
463        _ => quote! {<#(#generics_arbitrary),*>},
464    };
465
466    if !generics.lifetimes().collect::<Vec<_>>().is_empty() {
467        return syn::Error::new_spanned(
468            &ident,
469            "Cannot derive QuickCheck for a type with lifetimes yet",
470        )
471        .to_compile_error()
472        .into();
473    }
474
475    let output = quote! {
476        impl #generics_arbitrary_tokens ::quickcheck::Arbitrary for #ident #generics_unconstrained_tokens
477        where
478            #ident #generics_unconstrained_tokens : ::core::clone::Clone {
479            fn arbitrary(g: &mut ::quickcheck::Gen) -> Self {
480                #arbitrary
481            }
482
483            fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
484                #shrink
485            }
486        }
487    };
488    output.into()
489}