1use core::panic;
2use std::collections::HashMap;
3
4use proc_macro::{self};
5use proc_macro2::{Span, TokenStream};
6use quote::{ToTokens, format_ident, quote};
7use syn::{
8 Attribute, DataEnum, DeriveInput, Field, FieldsNamed, FieldsUnnamed, Ident, LitInt,
9 MetaNameValue, Type, parse_macro_input, punctuated::Punctuated, token::Comma,
10};
11
12fn generate_product_shrink<
13 Iter: IntoIterator<Item = Field> + Clone,
14 IdentKind: Clone + ToTokens + ToString,
15>(
16 fields: &Iter,
17 constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
18 make_ident: impl Fn(&str) -> IdentKind,
19 self_helper: impl Fn(Ident, &IdentKind, usize) -> TokenStream,
20) -> TokenStream {
21 let self_copies = fields
22 .clone()
23 .into_iter()
24 .enumerate()
25 .map(|(idx, field)| {
26 let ident = field
27 .ident
28 .clone()
29 .map(|ident| ident.to_string())
30 .unwrap_or(idx.to_string());
31 let unique_self = format_ident!("self_{}", ident);
32 quote! {
33 let #unique_self = <Self as ::std::clone::Clone>::clone(&self);
34 }
35 })
36 .collect::<Vec<_>>();
37
38 let cloning_iterator_madness: TokenStream = fields
39 .clone()
40 .into_iter()
41 .enumerate()
42 .map(|(idx, field)| {
43 let ident = make_ident(
44 &field
45 .ident
46 .clone()
47 .map(|ident| ident.to_string())
48 .unwrap_or(idx.to_string()),
49 );
50 let other_idents = fields
51 .clone()
52 .into_iter()
53 .enumerate()
54 .map(|(idx, field)| {
55 make_ident(
56 &field
57 .ident
58 .clone()
59 .map(|ident| ident.to_string())
60 .unwrap_or(idx.to_string()),
61 )
62 })
63 .filter(|e| e.to_string() != ident.to_string())
64 .map(|field_ident| {
65 let unique_self_toks = self_helper(
66 format_ident!("self_{}", ident.to_string()),
67 &field_ident,
68 fields.clone().into_iter().collect::<Vec<_>>().len(),
69 );
70 (
71 field_ident.clone(),
72 quote! {::core::clone::Clone::clone(#unique_self_toks)},
73 )
74 })
75 .collect::<Vec<_>>();
76 constructor(&field.ty, &ident, &other_idents)
77 })
78 .collect::<Vec<_>>()
79 .iter()
80 .rev()
81 .cloned()
82 .reduce(|a, b| quote! {::std::iter::Iterator::chain(#a, #b)})
83 .unwrap_or(quote! {});
84
85 quote! {
86 #(#self_copies)*
87 ::std::boxed::Box::new(#cloning_iterator_madness)
88 }
89}
90fn generate_product_shrink_simple<
91 Iter: IntoIterator<Item = Field> + Clone,
92 IdentKind: Clone + ToTokens + ToString,
93>(
94 fields: &Iter,
95 constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
96 make_ident: impl Fn(&str) -> IdentKind,
97) -> TokenStream {
98 generate_product_shrink(
99 fields,
100 constructor,
101 make_ident,
102 |unique_self, field_ident, _| quote! {&#unique_self.#field_ident},
103 )
104}
105
106fn make_enum_puller(pull: usize, others: usize, variant: &Ident, source: &Ident) -> TokenStream {
107 let v_puller = [quote! {__quickcheck_derive_match_puller}];
108 let pull_pattern = (0..(pull))
109 .map(|_| quote! {_})
110 .chain(v_puller.iter().cloned())
111 .chain((pull..others).map(|_| quote! {_}));
112
113 quote! {if let Self::#variant(#(#pull_pattern),*) = &#source {
114 __quickcheck_derive_match_puller
115 } else {
116 ::core::unreachable!()
117 }}
118}
119
120struct ArbitraryImpl {
121 arbitrary: TokenStream,
122 shrink: TokenStream,
123}
124
125fn make_named_struct_arbitrary(fields_named: &FieldsNamed) -> ArbitraryImpl {
126 let field_arbitrary_generators = fields_named
127 .named
128 .iter()
129 .map(|field| {
130 let identifier = &field.ident;
131 let ty = &field.ty;
132 quote! {
133 #identifier: <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
134 }
135 })
136 .collect::<Vec<_>>();
137 ArbitraryImpl {
138 shrink: generate_product_shrink_simple(
139 &fields_named.named,
140 |ty, ident, other_idents| {
141 let other_idents_initialisers = other_idents
142 .iter()
143 .map(|(ident, toks)| {
144 quote! {#ident: #toks}
145 })
146 .collect::<Vec<_>>();
147 quote! {
148 ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
149 move |__quickcheck_derive_moving| Self {#ident: __quickcheck_derive_moving, #(#other_idents_initialisers),*})
150 }
151 },
152 |ident_str| Ident::new(ident_str, Span::call_site()),
153 ),
154 arbitrary: quote! {
155 Self {
156 #(#field_arbitrary_generators),*
157 }
158 },
159 }
160}
161
162fn make_unnamed_struct_arbitrary(fields_unnamed: &FieldsUnnamed) -> ArbitraryImpl {
163 let field_arbitrary_generators = fields_unnamed
164 .unnamed
165 .iter()
166 .map(|field| {
167 let ty = &field.ty;
168 quote! {
169 <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
170 }
171 })
172 .collect::<Vec<_>>();
173 ArbitraryImpl {
174 arbitrary: quote! {
175 Self(#(#field_arbitrary_generators),*)
176 },
177 shrink: generate_product_shrink_simple::<_, LitInt>(
178 &fields_unnamed.unnamed,
179 |ty, ident, other_idents| {
180 let mut idents_all = other_idents.clone();
181 idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
182 idents_all.sort_by(|(a, _), (b, _)| {
183 a.base10_parse::<u64>()
184 .unwrap()
185 .cmp(&b.base10_parse().unwrap())
186 });
187 let initialiser_list = idents_all
188 .iter()
189 .map(|(_, stream)| stream)
190 .collect::<Vec<_>>();
191
192 quote! {
193 ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
194 move |__quickcheck_derive_moving| Self(#(#initialiser_list),*))
195 }
196 },
197 |ident_str| LitInt::new(ident_str, Span::call_site()),
198 ),
199 }
200}
201
202#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
203enum RecursiveKind {
204 None = 0,
205 Linear = 1,
206 Exponential = 2,
207}
208
209#[derive(Clone, Copy, Debug)]
210struct EnumAtrributes {
211 recursive: RecursiveKind,
212}
213
214fn get_enum_attrs(attrs: &Vec<Attribute>) -> EnumAtrributes {
215 const RECURSION_INVALID_KIND: &str =
216 "quickcheck recursive strategies must be one of None, Linear, Exponential";
217
218 let all_attrs = attrs
219 .iter()
220 .filter(|attr| attr.meta.path().is_ident("quickcheck"))
221 .map(|attr| {
222 attr.parse_args_with(Punctuated::<MetaNameValue, Comma>::parse_terminated)
223 .expect("quickcheck attribute must have comma seperated arguments")
224 .iter()
225 .map(|arg| {
226 (
227 arg.path
228 .get_ident()
229 .expect("quickcheck arguments must be of the form `ident = value`")
230 .to_string(),
231 match &arg.value {
232 syn::Expr::Path(v) => v.path.require_ident().expect("quickcheck recursive strategies must be one of None, Linear, Exponential").to_string(),
233 _ => panic!("quickcheck values must be literals"),
234 },
235 )
236 })
237 .collect::<HashMap<_, _>>()
238 })
239 .map(|key_values| EnumAtrributes {
240 recursive: match key_values
241 .get("recursive")
242 .cloned()
243 {
244 Some(v) => match v.as_str() {
245 "None" => RecursiveKind::None,
246 "Linear" => RecursiveKind::Linear,
247 "Exponential" => RecursiveKind::Exponential,
248 _ => panic!("{}", RECURSION_INVALID_KIND)
249 },
250 None => RecursiveKind::None,
251 },
252 })
253 .collect::<Vec<_>>();
254
255 match all_attrs.len() {
256 0 => EnumAtrributes {
257 recursive: RecursiveKind::None,
258 },
259 1 => all_attrs[0],
260 _ => panic!("quickcheck attribute may only be applied once to each field"),
261 }
262}
263
264fn make_enum_arbitrary(ident: &Ident, data_enum: &DataEnum) -> ArbitraryImpl {
265 let num_variants = data_enum.variants.len();
266
267 let mut initialisers = data_enum
268 .variants
269 .iter()
270 .map(|variant| {
271 (
272 &variant.ident,
273 match variant.fields.len() {
274 0 => (quote! {}, RecursiveKind::None),
275 _ => {
276 let attrs = get_enum_attrs(&variant.attrs);
277 let new_g = match attrs.recursive {
278 RecursiveKind::Exponential => quote! {&mut ::quickcheck::Gen::new(::std::cmp::max(::quickcheck::Gen::size(g) / 2, 1))},
279 RecursiveKind::Linear => quote! {&mut ::quickcheck::Gen::new(::std::cmp::max(::quickcheck::Gen::size(g) - 1, 1))},
280 RecursiveKind::None => quote! {g}
281 };
282 let field_arbitrary_generators = variant
283 .fields
284 .iter()
285 .map(|field| {
286 let ty = &field.ty;
287 quote! {<#ty as ::quickcheck::Arbitrary>::arbitrary(#new_g)}
288 })
289 .collect::<Vec<_>>();
290 (quote! {(#(#field_arbitrary_generators),*)}, attrs.recursive)
291 }
292 },
293 )
294 })
295 .map(|(ident, (initialiser_list, recursive))| {
296 (quote! {Self::#ident #initialiser_list}, recursive)
297 })
298 .enumerate()
299 .map(|(index, (constructor, recursive))| {
300 (quote! {#index => #constructor}, recursive)
301 })
302 .collect::<Vec<_>>();
303
304 initialisers.sort_by_key(|(_, recursive)| *recursive);
305 let num_recursive = initialisers
306 .iter()
307 .filter(|(_, recursive)| !matches!(recursive, RecursiveKind::None))
308 .count();
309 let initialisers = initialisers
310 .into_iter()
311 .map(|(toks, _)| toks)
312 .collect::<Vec<_>>();
313
314 let enum_name = &ident;
315 let arm_matchers = data_enum
316 .variants
317 .iter()
318 .map(|variant| {
319 let variant_ident = &variant.ident;
320 let shrinker = generate_product_shrink::<_, LitInt>(
321 &variant.fields,
322 |ty, ident, other_idents| {
323 let mut idents_all = other_idents.clone();
324 idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
325 idents_all.sort_by(|(a, _), (b, _)| {
326 a.base10_parse::<u64>()
327 .unwrap()
328 .cmp(&b.base10_parse().unwrap())
329 });
330 let initialiser_list = idents_all
331 .iter()
332 .map(|(_, stream)| stream)
333 .collect::<Vec<_>>();
334
335 let puller = make_enum_puller(
336 ident.base10_parse().unwrap(),
337 other_idents.len(),
338 &variant.ident,
339 &Ident::new("self", Span::call_site()),
340 );
341
342 quote! {
343 ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(
344 #puller
345 ),
346 move |__quickcheck_derive_moving| Self::#variant_ident(#(#initialiser_list),*))
347 }
348 },
349 |ident_str| LitInt::new(ident_str, Span::call_site()),
350 |ident, field, num_fields| {
351 make_enum_puller(
352 field.base10_parse().unwrap(),
353 num_fields - 1,
354 variant_ident,
355 &ident,
356 )
357 },
358 );
359
360 let underscores = (0..variant.fields.len())
361 .map(|_| quote! {_})
362 .collect::<Vec<_>>();
363
364 match variant.fields.is_empty() {
365 true => quote! {#enum_name::#variant_ident => ::std::boxed::Box::new(::quickcheck::empty_shrinker())},
366 false => quote! {#enum_name::#variant_ident(#(#underscores),*) => {#shrinker}} ,
367 }
368
369 })
370 .collect::<Vec<_>>();
371
372 ArbitraryImpl {
373 arbitrary: quote! {
374 match <::core::primitive::usize as ::quickcheck::Arbitrary>::arbitrary(g) % (
375 if ::quickcheck::Gen::size(g) > 1 {
376 #num_variants
377 } else {
378 #num_variants - #num_recursive
379 }) {
380 #(#initialisers),*,
381 _ => ::core::unreachable!()
382 }
383 },
384 shrink: quote! {
385 match &self {
386 #(#arm_matchers),*
387 }
388 },
389 }
390}
391
392#[proc_macro_derive(QuickCheck, attributes(quickcheck))]
409pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
410 let DeriveInput {
411 ident,
412 data,
413 generics,
414 ..
415 } = parse_macro_input!(input);
416 let ArbitraryImpl { arbitrary, shrink } = match data {
417 syn::Data::Struct(data_struct) => match data_struct.fields {
418 syn::Fields::Named(fields_named) => make_named_struct_arbitrary(&fields_named),
419 syn::Fields::Unnamed(fields_unnamed) => make_unnamed_struct_arbitrary(&fields_unnamed),
420 syn::Fields::Unit => ArbitraryImpl {
421 arbitrary: quote! {Self},
422 shrink: quote! {::quickcheck::empty_shrinker()},
423 },
424 },
425 syn::Data::Enum(data_enum) => make_enum_arbitrary(&ident, &data_enum),
426 syn::Data::Union(_) => ArbitraryImpl {
427 shrink: quote! {::quickcheck::empty_shrinker()},
428 arbitrary: {
429 syn::Error::new_spanned(&ident, "Cannot derive QuickCheck for a union yet")
430 .to_compile_error()
431 },
432 },
433 };
434
435 let generics_unconstrained = generics
436 .lifetimes()
437 .map(|lifetime| lifetime.lifetime.to_token_stream())
438 .chain(
439 generics
440 .type_params()
441 .map(|type_param| type_param.ident.to_token_stream()),
442 )
443 .collect::<Vec<_>>();
444
445 let generics_arbitrary = generics
446 .lifetimes()
447 .map(|lifetime| lifetime.to_token_stream())
448 .chain(generics.type_params().map(|type_param| {
449 let colon = match type_param.bounds.len() {
450 0 => quote! {:},
451 _ => quote! {+},
452 };
453 quote! {#type_param #colon ::quickcheck::Arbitrary}
454 }))
455 .collect::<Vec<_>>();
456
457 let generics_unconstrained_tokens = match generics_unconstrained.len() {
458 0 => quote! {},
459 _ => quote! {<#(#generics_unconstrained),*>},
460 };
461 let generics_arbitrary_tokens = match generics_arbitrary.len() {
462 0 => quote! {},
463 _ => quote! {<#(#generics_arbitrary),*>},
464 };
465
466 if !generics.lifetimes().collect::<Vec<_>>().is_empty() {
467 return syn::Error::new_spanned(
468 &ident,
469 "Cannot derive QuickCheck for a type with lifetimes yet",
470 )
471 .to_compile_error()
472 .into();
473 }
474
475 let output = quote! {
476 impl #generics_arbitrary_tokens ::quickcheck::Arbitrary for #ident #generics_unconstrained_tokens
477 where
478 #ident #generics_unconstrained_tokens : ::core::clone::Clone {
479 fn arbitrary(g: &mut ::quickcheck::Gen) -> Self {
480 #arbitrary
481 }
482
483 fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
484 #shrink
485 }
486 }
487 };
488 output.into()
489}