quickcheck_arbitrary_derive/
lib.rs

1use core::panic;
2use std::collections::HashMap;
3
4use proc_macro::{self};
5use proc_macro2::{Span, TokenStream};
6use quote::{ToTokens, format_ident, quote};
7use syn::{
8    Attribute, DataEnum, DeriveInput, Field, FieldsNamed, FieldsUnnamed, Ident, LitInt,
9    MetaNameValue, Type, parse_macro_input, punctuated::Punctuated, token::Comma,
10};
11
12fn generate_product_shrink<
13    Iter: IntoIterator<Item = Field> + Clone,
14    IdentKind: Clone + ToTokens + ToString,
15>(
16    fields: &Iter,
17    constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
18    make_ident: impl Fn(&str) -> IdentKind,
19    self_helper: impl Fn(Ident, &IdentKind, usize) -> TokenStream,
20) -> TokenStream {
21    let self_copies = fields
22        .clone()
23        .into_iter()
24        .enumerate()
25        .map(|(idx, field)| {
26            let ident = field
27                .ident
28                .clone()
29                .map(|ident| ident.to_string())
30                .unwrap_or(idx.to_string());
31            let unique_self = format_ident!("self_{}", ident);
32            quote! {
33                let #unique_self = <Self as ::std::clone::Clone>::clone(&self);
34            }
35        })
36        .collect::<Vec<_>>();
37
38    let cloning_iterator_madness: TokenStream = fields
39        .clone()
40        .into_iter()
41        .enumerate()
42        .map(|(idx, field)| {
43            let ident = make_ident(
44                &field
45                    .ident
46                    .clone()
47                    .map(|ident| ident.to_string())
48                    .unwrap_or(idx.to_string()),
49            );
50            let other_idents = fields
51                .clone()
52                .into_iter()
53                .enumerate()
54                .map(|(idx, field)| {
55                    make_ident(
56                        &field
57                            .ident
58                            .clone()
59                            .map(|ident| ident.to_string())
60                            .unwrap_or(idx.to_string()),
61                    )
62                })
63                .filter(|e| e.to_string() != ident.to_string())
64                .map(|field_ident| {
65                    let unique_self_toks = self_helper(
66                        format_ident!("self_{}", ident.to_string()),
67                        &field_ident,
68                        fields.clone().into_iter().collect::<Vec<_>>().len(),
69                    );
70                    (
71                        field_ident.clone(),
72                        quote! {::core::clone::Clone::clone(#unique_self_toks)},
73                    )
74                })
75                .collect::<Vec<_>>();
76            constructor(&field.ty, &ident, &other_idents)
77        })
78        .collect::<Vec<_>>()
79        .iter()
80        .rev()
81        .cloned()
82        .reduce(|a, b| quote! {::std::iter::Iterator::chain(#a, #b)})
83        .unwrap_or(quote! {});
84
85    quote! {
86        #(#self_copies)*
87        ::std::boxed::Box::new(#cloning_iterator_madness)
88    }
89}
90fn generate_product_shrink_simple<
91    Iter: IntoIterator<Item = Field> + Clone,
92    IdentKind: Clone + ToTokens + ToString,
93>(
94    fields: &Iter,
95    constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
96    make_ident: impl Fn(&str) -> IdentKind,
97) -> TokenStream {
98    generate_product_shrink(
99        fields,
100        constructor,
101        make_ident,
102        |unique_self, field_ident, _| quote! {&#unique_self.#field_ident},
103    )
104}
105
106fn make_enum_puller(pull: usize, others: usize, variant: &Ident, source: &Ident) -> TokenStream {
107    let v_puller = [quote! {__quickcheck_derive_match_puller}];
108    let pull_pattern = (0..(pull))
109        .map(|_| quote! {_})
110        .chain(v_puller.iter().cloned())
111        .chain((pull..others).map(|_| quote! {_}));
112
113    quote! {if let Self::#variant(#(#pull_pattern),*) = &#source {
114        __quickcheck_derive_match_puller
115    } else {
116        ::core::unreachable!()
117    }}
118}
119
120struct ArbitraryImpl {
121    arbitrary: TokenStream,
122    shrink: TokenStream,
123}
124
125fn make_named_struct_arbitrary(fields_named: &FieldsNamed) -> ArbitraryImpl {
126    let field_arbitrary_generators = fields_named
127        .named
128        .iter()
129        .map(|field| {
130            let identifier = &field.ident;
131            let ty = &field.ty;
132            quote! {
133                #identifier: <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
134            }
135        })
136        .collect::<Vec<_>>();
137    ArbitraryImpl {
138        shrink: generate_product_shrink_simple(
139            &fields_named.named,
140            |ty, ident, other_idents| {
141                let other_idents_initialisers = other_idents
142                    .iter()
143                    .map(|(ident, toks)| {
144                        quote! {#ident: #toks}
145                    })
146                    .collect::<Vec<_>>();
147                quote! {
148                    ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
149                        move |__quickcheck_derive_moving| Self {#ident: __quickcheck_derive_moving, #(#other_idents_initialisers),*})
150                }
151            },
152            |ident_str| Ident::new(ident_str, Span::call_site()),
153        ),
154        arbitrary: quote! {
155            Self {
156                #(#field_arbitrary_generators),*
157            }
158        },
159    }
160}
161
162fn make_unnamed_struct_arbitrary(fields_unnamed: &FieldsUnnamed) -> ArbitraryImpl {
163    let field_arbitrary_generators = fields_unnamed
164        .unnamed
165        .iter()
166        .map(|field| {
167            let ty = &field.ty;
168            quote! {
169                <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
170            }
171        })
172        .collect::<Vec<_>>();
173    ArbitraryImpl {
174        arbitrary: quote! {
175            Self(#(#field_arbitrary_generators),*)
176        },
177        shrink: generate_product_shrink_simple::<_, LitInt>(
178            &fields_unnamed.unnamed,
179            |ty, ident, other_idents| {
180                let mut idents_all = other_idents.clone();
181                idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
182                idents_all.sort_by(|(a, _), (b, _)| {
183                    a.base10_parse::<u64>()
184                        .unwrap()
185                        .cmp(&b.base10_parse().unwrap())
186                });
187                let initialiser_list = idents_all
188                    .iter()
189                    .map(|(_, stream)| stream)
190                    .collect::<Vec<_>>();
191
192                quote! {
193                    ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
194                        move |__quickcheck_derive_moving| Self(#(#initialiser_list),*))
195                }
196            },
197            |ident_str| LitInt::new(ident_str, Span::call_site()),
198        ),
199    }
200}
201
202#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
203enum RecursiveKind {
204    None = 0,
205    Linear = 1,
206    Exponential = 2,
207}
208
209#[derive(Clone, Copy, Debug)]
210struct EnumAtrributes {
211    recursive: RecursiveKind,
212}
213
214fn get_enum_attrs(attrs: &Vec<Attribute>) -> EnumAtrributes {
215    const RECURSION_INVALID_KIND: &str =
216        "quickcheck recursive strategies must be one of None, Linear, Exponential";
217
218    dbg!(attrs);
219    let all_attrs = attrs
220        .iter()
221        .filter(|attr| attr.meta.path().is_ident("quickcheck"))
222        .map(|attr| {
223            attr.parse_args_with(Punctuated::<MetaNameValue, Comma>::parse_terminated)
224                .expect("quickcheck attribute must have comma seperated arguments")
225                .iter()
226                .map(|arg| {
227                    (
228                        arg.path
229                            .get_ident()
230                            .expect("quickcheck arguments must be of the form `ident = value`")
231                            .to_string(),
232                        match &arg.value {
233                            syn::Expr::Path(v) => v.path.require_ident().expect("quickcheck recursive strategies must be one of None, Linear, Exponential").to_string(),
234                            _ => panic!("quickcheck values must be literals"),
235                        },
236                    )
237                })
238                .collect::<HashMap<_, _>>()
239        })
240        .map(|key_values| EnumAtrributes {
241            recursive: match key_values
242                .get("recursive")
243                .cloned()
244            {
245                Some(v) => match v.as_str() {
246                    "None" => RecursiveKind::None,
247                    "Linear" => RecursiveKind::Linear,
248                    "Exponential" => RecursiveKind::Exponential,
249                    _ => panic!("{}", RECURSION_INVALID_KIND)
250                },
251                None => RecursiveKind::None,
252            },
253        })
254        .collect::<Vec<_>>();
255
256    match all_attrs.len() {
257        0 => EnumAtrributes {
258            recursive: RecursiveKind::None,
259        },
260        1 => all_attrs[0],
261        _ => panic!("quickcheck attribute may only be applied once to each field"),
262    }
263}
264
265fn make_enum_arbitrary(ident: &Ident, data_enum: &DataEnum) -> ArbitraryImpl {
266    let num_variants = data_enum.variants.len();
267
268    let mut initialisers = data_enum
269        .variants
270        .iter()
271        .map(|variant| {
272            (
273                &variant.ident,
274                match variant.fields.len() {
275                    0 => (quote! {}, RecursiveKind::None),
276                    _ => {
277                        let attrs = get_enum_attrs(&variant.attrs);
278                        let new_g = match attrs.recursive {
279                            RecursiveKind::Exponential => quote! {&mut ::quickcheck::Gen::new(::std::cmp::max(::quickcheck::Gen::size(g) / 2, 0))},
280                            RecursiveKind::Linear => quote! {&mut ::quickcheck::Gen::new(::std::cmp::max(::quickcheck::Gen::size(g) - 1, 0))},
281                            RecursiveKind::None => quote! {g}
282                        };
283                        let field_arbitrary_generators = variant
284                            .fields
285                            .iter()
286                            .map(|field| {
287                                let ty = &field.ty;
288                                quote! {<#ty as ::quickcheck::Arbitrary>::arbitrary(#new_g)}
289                            })
290                            .collect::<Vec<_>>();
291                        (quote! {(#(#field_arbitrary_generators),*)}, attrs.recursive)
292                    }
293                },
294            )
295        })
296        .map(|(ident, (initialiser_list, recursive))| {
297            (quote! {Self::#ident #initialiser_list}, recursive)
298        })
299        .enumerate()
300        .map(|(index, (constructor, recursive))| {
301            (quote! {#index => #constructor}, recursive)
302        })
303        .collect::<Vec<_>>();
304
305    initialisers.sort_by_key(|(_, recursive)| *recursive);
306    let num_recursive = initialisers
307        .iter()
308        .filter(|(_, recursive)| !matches!(recursive, RecursiveKind::None))
309        .count();
310    let initialisers = initialisers
311        .into_iter()
312        .map(|(toks, _)| toks)
313        .collect::<Vec<_>>();
314
315    let enum_name = &ident;
316    let arm_matchers = data_enum
317        .variants
318        .iter()
319        .map(|variant| {
320            let variant_ident = &variant.ident;
321            let shrinker = generate_product_shrink::<_, LitInt>(
322                &variant.fields,
323                |ty, ident, other_idents| {
324                    let mut idents_all = other_idents.clone();
325                    idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
326                    idents_all.sort_by(|(a, _), (b, _)| {
327                        a.base10_parse::<u64>()
328                            .unwrap()
329                            .cmp(&b.base10_parse().unwrap())
330                    });
331                    let initialiser_list = idents_all
332                        .iter()
333                        .map(|(_, stream)| stream)
334                        .collect::<Vec<_>>();
335
336                    let puller = make_enum_puller(
337                        ident.base10_parse().unwrap(),
338                        other_idents.len(),
339                        &variant.ident,
340                        &Ident::new("self", Span::call_site()),
341                    );
342
343                    quote! {
344                        ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(
345                            #puller
346                        ),
347                        move |__quickcheck_derive_moving| Self::#variant_ident(#(#initialiser_list),*))
348                    }
349                },
350                |ident_str| LitInt::new(ident_str, Span::call_site()),
351                |ident, field, num_fields| {
352                   make_enum_puller(
353                        field.base10_parse().unwrap(),
354                        num_fields - 1,
355                        variant_ident,
356                        &ident,
357                    )
358                },
359            );
360
361            let underscores = (0..variant.fields.len())
362                .map(|_| quote! {_})
363                .collect::<Vec<_>>();
364
365             match variant.fields.is_empty() {
366                true => quote! {#enum_name::#variant_ident => ::std::boxed::Box::new(::quickcheck::empty_shrinker())},
367                false => quote! {#enum_name::#variant_ident(#(#underscores),*) => {#shrinker}} ,
368            }
369
370        })
371        .collect::<Vec<_>>();
372
373    ArbitraryImpl {
374        arbitrary: quote! {
375            match <::core::primitive::usize as ::quickcheck::Arbitrary>::arbitrary(g) % (
376            if ::quickcheck::Gen::size(g) > 0 {
377                #num_variants
378            } else {
379                #num_variants - #num_recursive
380            }) {
381                #(#initialisers),*,
382                _ => ::core::unreachable!()
383            }
384        },
385        shrink: quote! {
386            match &self {
387                #(#arm_matchers),*
388            }
389        },
390    }
391}
392
393/// Generates an implementation of `quickcheck::Arbitrary`
394/// You can annotate an enum variant with `#[quickcheck(recursive = None | Linear | Exponential)]` to allow for testing of potentially infinitely large types
395/// ```rs
396/// #[derive(Clone, QuickCheck, Debug)]
397/// enum Tree<T> {
398///     #[quickcheck(recursive = Exponential)]
399///     Branch(Vec<Tree<T>>),
400///     Leaf(T),
401/// }
402/// ```
403/// Use exponential for types that exponentially grow with depth (like trees)
404/// Use linear for types that linearly grow with depth (like linked lists)
405#[proc_macro_derive(QuickCheck, attributes(quickcheck))]
406pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
407    let DeriveInput {
408        ident,
409        data,
410        generics,
411        ..
412    } = parse_macro_input!(input);
413    let ArbitraryImpl { arbitrary, shrink } = match data {
414        syn::Data::Struct(data_struct) => match data_struct.fields {
415            syn::Fields::Named(fields_named) => make_named_struct_arbitrary(&fields_named),
416            syn::Fields::Unnamed(fields_unnamed) => make_unnamed_struct_arbitrary(&fields_unnamed),
417            syn::Fields::Unit => ArbitraryImpl {
418                arbitrary: quote! {Self},
419                shrink: quote! {::quickcheck::empty_shrinker()},
420            },
421        },
422        syn::Data::Enum(data_enum) => make_enum_arbitrary(&ident, &data_enum),
423        syn::Data::Union(_) => ArbitraryImpl {
424            shrink: quote! {::quickcheck::empty_shrinker()},
425            arbitrary: {
426                syn::Error::new_spanned(&ident, "Cannot derive QuickCheck for a union yet")
427                    .to_compile_error()
428            },
429        },
430    };
431
432    let generics_unconstrained = generics
433        .lifetimes()
434        .map(|lifetime| lifetime.lifetime.to_token_stream())
435        .chain(
436            generics
437                .type_params()
438                .map(|type_param| type_param.ident.to_token_stream()),
439        )
440        .collect::<Vec<_>>();
441
442    let generics_arbitrary = generics
443        .lifetimes()
444        .map(|lifetime| lifetime.to_token_stream())
445        .chain(generics.type_params().map(|type_param| {
446            let colon = match type_param.bounds.len() {
447                0 => quote! {:},
448                _ => quote! {+},
449            };
450            quote! {#type_param #colon ::quickcheck::Arbitrary}
451        }))
452        .collect::<Vec<_>>();
453
454    let generics_unconstrained_tokens = match generics_unconstrained.len() {
455        0 => quote! {},
456        _ => quote! {<#(#generics_unconstrained),*>},
457    };
458    let generics_arbitrary_tokens = match generics_arbitrary.len() {
459        0 => quote! {},
460        _ => quote! {<#(#generics_arbitrary),*>},
461    };
462
463    if !generics.lifetimes().collect::<Vec<_>>().is_empty() {
464        return syn::Error::new_spanned(
465            &ident,
466            "Cannot derive QuickCheck for a type with lifetimes yet",
467        )
468        .to_compile_error()
469        .into();
470    }
471
472    let output = quote! {
473        impl #generics_arbitrary_tokens ::quickcheck::Arbitrary for #ident #generics_unconstrained_tokens
474        where
475            #ident #generics_unconstrained_tokens : ::core::clone::Clone {
476            fn arbitrary(g: &mut ::quickcheck::Gen) -> Self {
477                #arbitrary
478            }
479
480            fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
481                #shrink
482            }
483        }
484    };
485    output.into()
486}