1use core::panic;
2use std::collections::HashMap;
3
4use proc_macro::{self};
5use proc_macro2::{Span, TokenStream};
6use quote::{ToTokens, format_ident, quote};
7use syn::{
8 Attribute, DataEnum, DeriveInput, Field, FieldsNamed, FieldsUnnamed, Ident, LitInt,
9 MetaNameValue, Type, parse_macro_input, punctuated::Punctuated, token::Comma,
10};
11
12fn generate_product_shrink<
13 Iter: IntoIterator<Item = Field> + Clone,
14 IdentKind: Clone + ToTokens + ToString,
15>(
16 fields: &Iter,
17 constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
18 make_ident: impl Fn(&str) -> IdentKind,
19 self_helper: impl Fn(Ident, &IdentKind, usize) -> TokenStream,
20) -> TokenStream {
21 let self_copies = fields
22 .clone()
23 .into_iter()
24 .enumerate()
25 .map(|(idx, field)| {
26 let ident = field
27 .ident
28 .clone()
29 .map(|ident| ident.to_string())
30 .unwrap_or(idx.to_string());
31 let unique_self = format_ident!("self_{}", ident);
32 quote! {
33 let #unique_self = <Self as ::std::clone::Clone>::clone(&self);
34 }
35 })
36 .collect::<Vec<_>>();
37
38 let cloning_iterator_madness: TokenStream = fields
39 .clone()
40 .into_iter()
41 .enumerate()
42 .map(|(idx, field)| {
43 let ident = make_ident(
44 &field
45 .ident
46 .clone()
47 .map(|ident| ident.to_string())
48 .unwrap_or(idx.to_string()),
49 );
50 let other_idents = fields
51 .clone()
52 .into_iter()
53 .enumerate()
54 .map(|(idx, field)| {
55 make_ident(
56 &field
57 .ident
58 .clone()
59 .map(|ident| ident.to_string())
60 .unwrap_or(idx.to_string()),
61 )
62 })
63 .filter(|e| e.to_string() != ident.to_string())
64 .map(|field_ident| {
65 let unique_self_toks = self_helper(
66 format_ident!("self_{}", ident.to_string()),
67 &field_ident,
68 fields.clone().into_iter().collect::<Vec<_>>().len(),
69 );
70 (
71 field_ident.clone(),
72 quote! {::core::clone::Clone::clone(#unique_self_toks)},
73 )
74 })
75 .collect::<Vec<_>>();
76 constructor(&field.ty, &ident, &other_idents)
77 })
78 .collect::<Vec<_>>()
79 .iter()
80 .rev()
81 .cloned()
82 .reduce(|a, b| quote! {::std::iter::Iterator::chain(#a, #b)})
83 .unwrap_or(quote! {});
84
85 quote! {
86 #(#self_copies)*
87 ::std::boxed::Box::new(#cloning_iterator_madness)
88 }
89}
90fn generate_product_shrink_simple<
91 Iter: IntoIterator<Item = Field> + Clone,
92 IdentKind: Clone + ToTokens + ToString,
93>(
94 fields: &Iter,
95 constructor: impl Fn(&Type, &IdentKind, &Vec<(IdentKind, TokenStream)>) -> TokenStream,
96 make_ident: impl Fn(&str) -> IdentKind,
97) -> TokenStream {
98 generate_product_shrink(
99 fields,
100 constructor,
101 make_ident,
102 |unique_self, field_ident, _| quote! {&#unique_self.#field_ident},
103 )
104}
105
106fn make_enum_puller(pull: usize, others: usize, variant: &Ident, source: &Ident) -> TokenStream {
107 let v_puller = [quote! {__quickcheck_derive_match_puller}];
108 let pull_pattern = (0..(pull))
109 .map(|_| quote! {_})
110 .chain(v_puller.iter().cloned())
111 .chain((pull..others).map(|_| quote! {_}));
112
113 quote! {if let Self::#variant(#(#pull_pattern),*) = &#source {
114 __quickcheck_derive_match_puller
115 } else {
116 ::core::unreachable!()
117 }}
118}
119
120struct ArbitraryImpl {
121 arbitrary: TokenStream,
122 shrink: TokenStream,
123}
124
125fn make_named_struct_arbitrary(fields_named: &FieldsNamed) -> ArbitraryImpl {
126 let field_arbitrary_generators = fields_named
127 .named
128 .iter()
129 .map(|field| {
130 let identifier = &field.ident;
131 let ty = &field.ty;
132 quote! {
133 #identifier: <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
134 }
135 })
136 .collect::<Vec<_>>();
137 ArbitraryImpl {
138 shrink: generate_product_shrink_simple(
139 &fields_named.named,
140 |ty, ident, other_idents| {
141 let other_idents_initialisers = other_idents
142 .iter()
143 .map(|(ident, toks)| {
144 quote! {#ident: #toks}
145 })
146 .collect::<Vec<_>>();
147 quote! {
148 ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
149 move |__quickcheck_derive_moving| Self {#ident: __quickcheck_derive_moving, #(#other_idents_initialisers),*})
150 }
151 },
152 |ident_str| Ident::new(ident_str, Span::call_site()),
153 ),
154 arbitrary: quote! {
155 Self {
156 #(#field_arbitrary_generators),*
157 }
158 },
159 }
160}
161
162fn make_unnamed_struct_arbitrary(fields_unnamed: &FieldsUnnamed) -> ArbitraryImpl {
163 let field_arbitrary_generators = fields_unnamed
164 .unnamed
165 .iter()
166 .map(|field| {
167 let ty = &field.ty;
168 quote! {
169 <#ty as ::quickcheck::Arbitrary>::arbitrary(g)
170 }
171 })
172 .collect::<Vec<_>>();
173 ArbitraryImpl {
174 arbitrary: quote! {
175 Self(#(#field_arbitrary_generators),*)
176 },
177 shrink: generate_product_shrink_simple::<_, LitInt>(
178 &fields_unnamed.unnamed,
179 |ty, ident, other_idents| {
180 let mut idents_all = other_idents.clone();
181 idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
182 idents_all.sort_by(|(a, _), (b, _)| {
183 a.base10_parse::<u64>()
184 .unwrap()
185 .cmp(&b.base10_parse().unwrap())
186 });
187 let initialiser_list = idents_all
188 .iter()
189 .map(|(_, stream)| stream)
190 .collect::<Vec<_>>();
191
192 quote! {
193 ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(&self.#ident),
194 move |__quickcheck_derive_moving| Self(#(#initialiser_list),*))
195 }
196 },
197 |ident_str| LitInt::new(ident_str, Span::call_site()),
198 ),
199 }
200}
201
202#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
203enum RecursiveKind {
204 None = 0,
205 Linear = 1,
206 Exponential = 2,
207}
208
209#[derive(Clone, Copy, Debug)]
210struct EnumAtrributes {
211 recursive: RecursiveKind,
212}
213
214fn get_enum_attrs(attrs: &Vec<Attribute>) -> EnumAtrributes {
215 const RECURSION_INVALID_KIND: &str =
216 "quickcheck recursive strategies must be one of None, Linear, Exponential";
217
218 dbg!(attrs);
219 let all_attrs = attrs
220 .iter()
221 .filter(|attr| attr.meta.path().is_ident("quickcheck"))
222 .map(|attr| {
223 attr.parse_args_with(Punctuated::<MetaNameValue, Comma>::parse_terminated)
224 .expect("quickcheck attribute must have comma seperated arguments")
225 .iter()
226 .map(|arg| {
227 (
228 arg.path
229 .get_ident()
230 .expect("quickcheck arguments must be of the form `ident = value`")
231 .to_string(),
232 match &arg.value {
233 syn::Expr::Path(v) => v.path.require_ident().expect("quickcheck recursive strategies must be one of None, Linear, Exponential").to_string(),
234 _ => panic!("quickcheck values must be literals"),
235 },
236 )
237 })
238 .collect::<HashMap<_, _>>()
239 })
240 .map(|key_values| EnumAtrributes {
241 recursive: match key_values
242 .get("recursive")
243 .cloned()
244 {
245 Some(v) => match v.as_str() {
246 "None" => RecursiveKind::None,
247 "Linear" => RecursiveKind::Linear,
248 "Exponential" => RecursiveKind::Exponential,
249 _ => panic!("{}", RECURSION_INVALID_KIND)
250 },
251 None => RecursiveKind::None,
252 },
253 })
254 .collect::<Vec<_>>();
255
256 match all_attrs.len() {
257 0 => EnumAtrributes {
258 recursive: RecursiveKind::None,
259 },
260 1 => all_attrs[0],
261 _ => panic!("quickcheck attribute may only be applied once to each field"),
262 }
263}
264
265fn make_enum_arbitrary(ident: &Ident, data_enum: &DataEnum) -> ArbitraryImpl {
266 let num_variants = data_enum.variants.len();
267
268 let mut initialisers = data_enum
269 .variants
270 .iter()
271 .map(|variant| {
272 (
273 &variant.ident,
274 match variant.fields.len() {
275 0 => (quote! {}, RecursiveKind::None),
276 _ => {
277 let attrs = get_enum_attrs(&variant.attrs);
278 let new_g = match attrs.recursive {
279 RecursiveKind::Exponential => quote! {&mut ::quickcheck::Gen::new(::std::cmp::max(::quickcheck::Gen::size(g) / 2, 0))},
280 RecursiveKind::Linear => quote! {&mut ::quickcheck::Gen::new(::std::cmp::max(::quickcheck::Gen::size(g) - 1, 0))},
281 RecursiveKind::None => quote! {g}
282 };
283 let field_arbitrary_generators = variant
284 .fields
285 .iter()
286 .map(|field| {
287 let ty = &field.ty;
288 quote! {<#ty as ::quickcheck::Arbitrary>::arbitrary(#new_g)}
289 })
290 .collect::<Vec<_>>();
291 (quote! {(#(#field_arbitrary_generators),*)}, attrs.recursive)
292 }
293 },
294 )
295 })
296 .map(|(ident, (initialiser_list, recursive))| {
297 (quote! {Self::#ident #initialiser_list}, recursive)
298 })
299 .enumerate()
300 .map(|(index, (constructor, recursive))| {
301 (quote! {#index => #constructor}, recursive)
302 })
303 .collect::<Vec<_>>();
304
305 initialisers.sort_by_key(|(_, recursive)| *recursive);
306 let num_recursive = initialisers
307 .iter()
308 .filter(|(_, recursive)| !matches!(recursive, RecursiveKind::None))
309 .count();
310 let initialisers = initialisers
311 .into_iter()
312 .map(|(toks, _)| toks)
313 .collect::<Vec<_>>();
314
315 let enum_name = &ident;
316 let arm_matchers = data_enum
317 .variants
318 .iter()
319 .map(|variant| {
320 let variant_ident = &variant.ident;
321 let shrinker = generate_product_shrink::<_, LitInt>(
322 &variant.fields,
323 |ty, ident, other_idents| {
324 let mut idents_all = other_idents.clone();
325 idents_all.push((ident.clone(), quote! {__quickcheck_derive_moving}));
326 idents_all.sort_by(|(a, _), (b, _)| {
327 a.base10_parse::<u64>()
328 .unwrap()
329 .cmp(&b.base10_parse().unwrap())
330 });
331 let initialiser_list = idents_all
332 .iter()
333 .map(|(_, stream)| stream)
334 .collect::<Vec<_>>();
335
336 let puller = make_enum_puller(
337 ident.base10_parse().unwrap(),
338 other_idents.len(),
339 &variant.ident,
340 &Ident::new("self", Span::call_site()),
341 );
342
343 quote! {
344 ::std::iter::Iterator::map(<#ty as ::quickcheck::Arbitrary>::shrink(
345 #puller
346 ),
347 move |__quickcheck_derive_moving| Self::#variant_ident(#(#initialiser_list),*))
348 }
349 },
350 |ident_str| LitInt::new(ident_str, Span::call_site()),
351 |ident, field, num_fields| {
352 make_enum_puller(
353 field.base10_parse().unwrap(),
354 num_fields - 1,
355 variant_ident,
356 &ident,
357 )
358 },
359 );
360
361 let underscores = (0..variant.fields.len())
362 .map(|_| quote! {_})
363 .collect::<Vec<_>>();
364
365 match variant.fields.is_empty() {
366 true => quote! {#enum_name::#variant_ident => ::std::boxed::Box::new(::quickcheck::empty_shrinker())},
367 false => quote! {#enum_name::#variant_ident(#(#underscores),*) => {#shrinker}} ,
368 }
369
370 })
371 .collect::<Vec<_>>();
372
373 ArbitraryImpl {
374 arbitrary: quote! {
375 match <::core::primitive::usize as ::quickcheck::Arbitrary>::arbitrary(g) % (
376 if ::quickcheck::Gen::size(g) > 0 {
377 #num_variants
378 } else {
379 #num_variants - #num_recursive
380 }) {
381 #(#initialisers),*,
382 _ => ::core::unreachable!()
383 }
384 },
385 shrink: quote! {
386 match &self {
387 #(#arm_matchers),*
388 }
389 },
390 }
391}
392
393#[proc_macro_derive(QuickCheck, attributes(quickcheck))]
406pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
407 let DeriveInput {
408 ident,
409 data,
410 generics,
411 ..
412 } = parse_macro_input!(input);
413 let ArbitraryImpl { arbitrary, shrink } = match data {
414 syn::Data::Struct(data_struct) => match data_struct.fields {
415 syn::Fields::Named(fields_named) => make_named_struct_arbitrary(&fields_named),
416 syn::Fields::Unnamed(fields_unnamed) => make_unnamed_struct_arbitrary(&fields_unnamed),
417 syn::Fields::Unit => ArbitraryImpl {
418 arbitrary: quote! {Self},
419 shrink: quote! {::quickcheck::empty_shrinker()},
420 },
421 },
422 syn::Data::Enum(data_enum) => make_enum_arbitrary(&ident, &data_enum),
423 syn::Data::Union(_) => ArbitraryImpl {
424 shrink: quote! {::quickcheck::empty_shrinker()},
425 arbitrary: {
426 syn::Error::new_spanned(&ident, "Cannot derive QuickCheck for a union yet")
427 .to_compile_error()
428 },
429 },
430 };
431
432 let generics_unconstrained = generics
433 .lifetimes()
434 .map(|lifetime| lifetime.lifetime.to_token_stream())
435 .chain(
436 generics
437 .type_params()
438 .map(|type_param| type_param.ident.to_token_stream()),
439 )
440 .collect::<Vec<_>>();
441
442 let generics_arbitrary = generics
443 .lifetimes()
444 .map(|lifetime| lifetime.to_token_stream())
445 .chain(generics.type_params().map(|type_param| {
446 let colon = match type_param.bounds.len() {
447 0 => quote! {:},
448 _ => quote! {+},
449 };
450 quote! {#type_param #colon ::quickcheck::Arbitrary}
451 }))
452 .collect::<Vec<_>>();
453
454 let generics_unconstrained_tokens = match generics_unconstrained.len() {
455 0 => quote! {},
456 _ => quote! {<#(#generics_unconstrained),*>},
457 };
458 let generics_arbitrary_tokens = match generics_arbitrary.len() {
459 0 => quote! {},
460 _ => quote! {<#(#generics_arbitrary),*>},
461 };
462
463 if !generics.lifetimes().collect::<Vec<_>>().is_empty() {
464 return syn::Error::new_spanned(
465 &ident,
466 "Cannot derive QuickCheck for a type with lifetimes yet",
467 )
468 .to_compile_error()
469 .into();
470 }
471
472 let output = quote! {
473 impl #generics_arbitrary_tokens ::quickcheck::Arbitrary for #ident #generics_unconstrained_tokens
474 where
475 #ident #generics_unconstrained_tokens : ::core::clone::Clone {
476 fn arbitrary(g: &mut ::quickcheck::Gen) -> Self {
477 #arbitrary
478 }
479
480 fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
481 #shrink
482 }
483 }
484 };
485 output.into()
486}