derive_quickcheck_arbitrary/
lib.rs

1//! Derive macro for [`quickcheck::Arbitrary`](https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html).
2//!
3//! Expands to calling [`Arbitrary::arbitrary`](https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html#tymethod.arbitrary)
4//! on every field of a struct.
5//!
6//! ```
7//! use derive_quickcheck_arbitrary::Arbitrary;
8//!
9//! #[derive(Clone, Arbitrary)]
10//! struct Yakshaver {
11//!     id: usize,
12//!     name: String,
13//! }
14//! ```
15//!
16//! You can customise field generation by either:
17//! - providing a callable that accepts [`&mut quickcheck::Gen`](https://docs.rs/quickcheck/latest/quickcheck/struct.Gen.html).
18//! - always using the default value
19//! ```
20//! # use derive_quickcheck_arbitrary::Arbitrary;
21//! # mod num { pub fn clamp(input: usize, min: usize, max: usize) -> usize { todo!() } }
22//! #[derive(Clone, Arbitrary)]
23//! struct Yakshaver {
24//!     /// Must be less than 10_000
25//!     #[arbitrary(gen(|g| num::clamp(usize::arbitrary(g), 0, 10_000) ))]
26//!     id: usize,
27//!     name: String,
28//!     #[arbitrary(default)]
29//!     always_false: bool,
30//! }
31//! ```
32//!
33//! You can skip enum variants:
34//! ```
35//! # use derive_quickcheck_arbitrary::Arbitrary;
36//! #[derive(Clone, Arbitrary)]
37//! enum YakType {
38//!     Domestic {
39//!         name: String,
40//!     },
41//!     Wild,
42//!     #[arbitrary(skip)]
43//!     Alien,
44//! }
45//! ```
46//!
47//! You can add bounds for generic structs:
48//! ```
49//! # use derive_quickcheck_arbitrary::Arbitrary;
50//! # use quickcheck::Arbitrary;
51//! #[derive(Clone, Arbitrary)]
52//! #[arbitrary(where(T: Arbitrary))]
53//! struct GenericYak<T> {
54//!     name: T,
55//! }
56//! ```
57
58use proc_macro2::{Ident, Span, TokenStream};
59use quote::{quote, ToTokens as _};
60use structmeta::{NameArgs, StructMeta};
61use syn::{
62    parse::{Parse, ParseStream, Parser as _},
63    parse_macro_input,
64    punctuated::Punctuated,
65    spanned::Spanned as _,
66    token::{Brace, Colon, Comma},
67    AttrStyle, Attribute, DataEnum, DataStruct, DeriveInput, Expr, ExprStruct, FieldValue, Fields,
68    Index, Member, Path, PathSegment, Token, Variant, WhereClause, WherePredicate,
69};
70
71// TODO: https://docs.rs/proc-macro-crate/latest/proc_macro_crate/
72// TODO: https://crates.io/crates/parse-variants
73
74#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
75pub fn derive_arbitrary(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
76    let user_struct = parse_macro_input!(input as DeriveInput);
77    expand_arbitrary(user_struct)
78        .unwrap_or_else(syn::Error::into_compile_error)
79        .into()
80}
81
82fn expand_arbitrary(input: DeriveInput) -> syn::Result<TokenStream> {
83    let struct_name = input.ident.clone();
84    let generics = input.generics.clone();
85    let gen_name = &quote!(g);
86    let predicates = match get_one_arg(&input.attrs, input.span())? {
87        Some(Arg::Where(preds)) => preds,
88        None => Punctuated::new(),
89        Some(Arg::Default | Arg::Gen(_) | Arg::Skip) => {
90            return Err(syn::Error::new(
91                input.span(),
92                "only `where` is valid for items",
93            ))
94        }
95    };
96    let where_clause = WhereClause {
97        where_token: Token![where](Span::call_site()),
98        predicates,
99    };
100
101    let ctor = match input.data {
102        syn::Data::Struct(DataStruct { fields, .. }) => expr_struct(
103            path_of_idents([struct_name.clone()]),
104            field_values(fields, gen_name)?,
105        )
106        .into_token_stream(),
107        syn::Data::Enum(DataEnum { variants, .. }) => {
108            let span = variants.span();
109            let variant_ctors = variants
110                .into_iter()
111                .filter_map(
112                    |Variant {
113                         attrs,
114                         ident,
115                         fields,
116                         ..
117                     }| match get_one_arg(&attrs, span) {
118                        Ok(None) => match field_values(fields, gen_name) {
119                            Ok(fields) => {
120                                let variant_ctor = expr_struct(
121                                    path_of_idents([struct_name.clone(), ident]),
122                                    fields,
123                                );
124                                Some(Ok(variant_ctor))
125                            }
126                            Err(e) => Some(Err(e)),
127                        },
128                        Ok(Some(Arg::Skip)) => None,
129                        Ok(Some(Arg::Gen(_) | Arg::Default | Arg::Where(_))) => {
130                            Some(Err(syn::Error::new(
131                                span,
132                                "`gen`, `default` and `where` are not valid for enum variants", // TODO: probably could be
133                            )))
134                        }
135                        Err(e) => Some(Err(e)),
136                    },
137                )
138                .collect::<Result<Vec<_>, _>>()?;
139            quote!(
140                let options = [ #(#variant_ctors,)* ];
141                #gen_name.choose(options.as_slice()).expect("no variants to choose from").clone()
142            )
143        }
144        syn::Data::Union(_) => {
145            return Err(syn::Error::new_spanned(
146                input,
147                "#[derive(Arbitrary)] is not supported on `union`s",
148            ))
149        }
150    };
151
152    Ok(quote! {
153        impl #generics ::quickcheck::Arbitrary for #struct_name #generics
154            #where_clause
155        {
156            fn arbitrary(#gen_name: &mut ::quickcheck::Gen) -> Self {
157                #ctor
158            }
159        }
160    })
161}
162
163fn field_values(
164    fields: Fields,
165    gen_name: &TokenStream,
166) -> syn::Result<Punctuated<FieldValue, Comma>> {
167    fields
168        .into_iter()
169        .enumerate()
170        .map(|(ix, field)| {
171            let value = match get_one_arg(&field.attrs, field.span())? {
172                Some(Arg::Skip | Arg::Where(_)) => {
173                    return Err(syn::Error::new_spanned(
174                        field,
175                        "`skip` and `where` are not valid for members",
176                    ))
177                }
178                Some(Arg::Gen(custom)) => {
179                    let ty = field.ty;
180                    quote! {
181                        (
182                            ( #custom ) as ( fn(&mut ::quickcheck::Gen) -> #ty )
183                        ) // cast to fn pointer
184                        (&mut *#gen_name) // call it
185                    }
186                }
187                Some(Arg::Default) => {
188                    quote!(::core::default::Default::default())
189                }
190                None => quote!(::quickcheck::Arbitrary::arbitrary(#gen_name)),
191            };
192            Ok(FieldValue {
193                attrs: vec![],
194                member: match field.ident {
195                    Some(name) => Member::Named(name),
196                    None => Member::Unnamed(Index::from(ix)),
197                },
198                colon_token: Some(Colon::default()),
199                expr: Expr::Verbatim(value),
200            })
201        })
202        .collect()
203}
204
205fn expr_struct(path: Path, field_values: Punctuated<FieldValue, Comma>) -> ExprStruct {
206    ExprStruct {
207        attrs: vec![],
208        qself: None,
209        path,
210        brace_token: Brace::default(),
211        fields: field_values,
212        dot2_token: None,
213        rest: None,
214    }
215}
216
217fn path_of_idents(idents: impl IntoIterator<Item = Ident>) -> Path {
218    Path {
219        leading_colon: None,
220        segments: Punctuated::from_iter(idents.into_iter().map(|ident| PathSegment {
221            ident,
222            arguments: syn::PathArguments::None,
223        })),
224    }
225}
226
227#[derive(Clone)]
228enum Arg {
229    Skip,
230    Gen(TokenStream),
231    Default,
232    Where(Punctuated<WherePredicate, Comma>),
233}
234
235#[derive(StructMeta, Debug, Default)]
236struct AttrArgs {
237    gen: Option<NameArgs<TokenStream>>,
238    skip: bool,
239    default: bool,
240    r#where: Option<NameArgs<TokenStream>>,
241}
242
243impl Parse for Arg {
244    fn parse(input: ParseStream) -> syn::Result<Self> {
245        let mut hint = syn::Error::new(
246            input.span(),
247            "expected one of  `gen`, `default`, `where` or `skip`",
248        );
249        match AttrArgs::parse(input) {
250            // inner error
251            Err(e) => {
252                hint.combine(e);
253                Err(hint)
254            }
255            // nothing
256            Ok(AttrArgs {
257                gen: None,
258                r#where: None,
259                skip: false,
260                default: false,
261            }) => Err(hint),
262            // just `skip`
263            Ok(AttrArgs {
264                skip: true,
265
266                gen: None,
267                default: false,
268                r#where: None,
269            }) => Ok(Arg::Skip),
270            // just `gen`
271            Ok(AttrArgs {
272                gen: Some(NameArgs { name_span: _, args }),
273
274                r#where: None,
275                skip: false,
276                default: false,
277            }) => Ok(Arg::Gen(args)),
278
279            // just `where`
280            Ok(AttrArgs {
281                r#where: Some(NameArgs { name_span: _, args }),
282
283                gen: None,
284                skip: false,
285                default: false,
286            }) => Ok(Arg::Where(Punctuated::parse_terminated.parse2(args)?)), // just `default`
287            Ok(AttrArgs {
288                default: true,
289
290                r#where: None,
291                gen: None,
292                skip: false,
293            }) => Ok(Arg::Default),
294            // some combination of arguments
295            Ok(AttrArgs { .. }) => Err(hint),
296        }
297    }
298}
299
300fn get_one_arg(attrs: &[Attribute], parent_span: Span) -> syn::Result<Option<Arg>> {
301    let configs = attrs
302        .iter()
303        .filter(|it| it.path().is_ident("arbitrary"))
304        .map(
305            |attr @ Attribute {
306                 pound_token: _,
307                 style,
308                 bracket_token: _,
309                 meta: _,
310             }| {
311                match style {
312                    AttrStyle::Outer => attr.parse_args::<Arg>(),
313                    AttrStyle::Inner(_) => Err(syn::Error::new_spanned(
314                        attr,
315                        "only outer attributes are supported: `#[arbitrary(...)]`",
316                    )),
317                }
318            },
319        )
320        .collect::<Result<Vec<_>, _>>()?;
321    match configs.as_slice() {
322        [] => Ok(None),
323        [one] => Ok(Some(one.clone())),
324        _too_many => Err(syn::Error::new(
325            parent_span,
326            "`#[arbitrary(...)]` can only be specified once",
327        )),
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    use structmeta::NameArgs;
336    use syn::parse_quote;
337
338    #[test]
339    fn readme() {
340        assert!(
341            std::process::Command::new("cargo")
342                .args(["rdme", "--check"])
343                .output()
344                .expect("couldn't run `cargo rdme`")
345                .status
346                .success(),
347            "README.md is out of date - bless the new version by running `cargo rdme`"
348        )
349    }
350
351    #[test]
352    fn attr_args() {
353        assert_eq!(
354            AttrArgs {
355                skip: true,
356                ..Default::default()
357            },
358            parse_quote!(skip),
359        );
360        assert_eq!(
361            AttrArgs {
362                default: true,
363                ..Default::default()
364            },
365            parse_quote!(default),
366        );
367        assert_eq!(
368            AttrArgs {
369                gen: Some(NameArgs {
370                    name_span: Span::call_site(),
371                    args: quote!(some_fn)
372                }),
373                ..Default::default()
374            },
375            parse_quote!(gen(some_fn)),
376        );
377        assert_eq!(
378            AttrArgs {
379                r#where: Some(NameArgs {
380                    name_span: Span::call_site(),
381                    args: quote!(foo)
382                }),
383                ..Default::default()
384            },
385            parse_quote!(where(foo)),
386        );
387    }
388
389    #[test]
390    fn trybuild() {
391        let t = trybuild::TestCases::new();
392        t.pass("trybuild/pass/**/*.rs");
393        t.compile_fail("trybuild/fail/**/*.rs")
394    }
395
396    impl PartialEq for AttrArgs {
397        fn eq(&self, other: &Self) -> bool {
398            fn norm(t: &AttrArgs) -> (Option<String>, &bool, &bool, Option<String>) {
399                let AttrArgs {
400                    gen,
401                    skip,
402                    default,
403                    r#where,
404                } = t;
405                (
406                    gen.as_ref().map(|it| it.args.to_string()),
407                    skip,
408                    default,
409                    r#where.as_ref().map(|it| it.args.to_string()),
410                )
411            }
412            norm(self) == norm(other)
413        }
414    }
415}