1use proc_macro2::{Ident, Span, TokenStream};
59use quote::{quote, ToTokens as _};
60use structmeta::{NameArgs, StructMeta};
61use syn::{
62 parse::{Parse, ParseStream, Parser as _},
63 parse_macro_input,
64 punctuated::Punctuated,
65 spanned::Spanned as _,
66 token::{Brace, Colon, Comma},
67 AttrStyle, Attribute, DataEnum, DataStruct, DeriveInput, Expr, ExprStruct, FieldValue, Fields,
68 Index, Member, Path, PathSegment, Token, Variant, WhereClause, WherePredicate,
69};
70
71#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
75pub fn derive_arbitrary(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
76 let user_struct = parse_macro_input!(input as DeriveInput);
77 expand_arbitrary(user_struct)
78 .unwrap_or_else(syn::Error::into_compile_error)
79 .into()
80}
81
82fn expand_arbitrary(input: DeriveInput) -> syn::Result<TokenStream> {
83 let struct_name = input.ident.clone();
84 let generics = input.generics.clone();
85 let gen_name = "e!(g);
86 let predicates = match get_one_arg(&input.attrs, input.span())? {
87 Some(Arg::Where(preds)) => preds,
88 None => Punctuated::new(),
89 Some(Arg::Default | Arg::Gen(_) | Arg::Skip) => {
90 return Err(syn::Error::new(
91 input.span(),
92 "only `where` is valid for items",
93 ))
94 }
95 };
96 let where_clause = WhereClause {
97 where_token: Token),
98 predicates,
99 };
100
101 let ctor = match input.data {
102 syn::Data::Struct(DataStruct { fields, .. }) => expr_struct(
103 path_of_idents([struct_name.clone()]),
104 field_values(fields, gen_name)?,
105 )
106 .into_token_stream(),
107 syn::Data::Enum(DataEnum { variants, .. }) => {
108 let span = variants.span();
109 let variant_ctors = variants
110 .into_iter()
111 .filter_map(
112 |Variant {
113 attrs,
114 ident,
115 fields,
116 ..
117 }| match get_one_arg(&attrs, span) {
118 Ok(None) => match field_values(fields, gen_name) {
119 Ok(fields) => {
120 let variant_ctor = expr_struct(
121 path_of_idents([struct_name.clone(), ident]),
122 fields,
123 );
124 Some(Ok(variant_ctor))
125 }
126 Err(e) => Some(Err(e)),
127 },
128 Ok(Some(Arg::Skip)) => None,
129 Ok(Some(Arg::Gen(_) | Arg::Default | Arg::Where(_))) => {
130 Some(Err(syn::Error::new(
131 span,
132 "`gen`, `default` and `where` are not valid for enum variants", )))
134 }
135 Err(e) => Some(Err(e)),
136 },
137 )
138 .collect::<Result<Vec<_>, _>>()?;
139 quote!(
140 let options = [ #(#variant_ctors,)* ];
141 #gen_name.choose(options.as_slice()).expect("no variants to choose from").clone()
142 )
143 }
144 syn::Data::Union(_) => {
145 return Err(syn::Error::new_spanned(
146 input,
147 "#[derive(Arbitrary)] is not supported on `union`s",
148 ))
149 }
150 };
151
152 Ok(quote! {
153 impl #generics ::quickcheck::Arbitrary for #struct_name #generics
154 #where_clause
155 {
156 fn arbitrary(#gen_name: &mut ::quickcheck::Gen) -> Self {
157 #ctor
158 }
159 }
160 })
161}
162
163fn field_values(
164 fields: Fields,
165 gen_name: &TokenStream,
166) -> syn::Result<Punctuated<FieldValue, Comma>> {
167 fields
168 .into_iter()
169 .enumerate()
170 .map(|(ix, field)| {
171 let value = match get_one_arg(&field.attrs, field.span())? {
172 Some(Arg::Skip | Arg::Where(_)) => {
173 return Err(syn::Error::new_spanned(
174 field,
175 "`skip` and `where` are not valid for members",
176 ))
177 }
178 Some(Arg::Gen(custom)) => {
179 let ty = field.ty;
180 quote! {
181 (
182 ( #custom ) as ( fn(&mut ::quickcheck::Gen) -> #ty )
183 ) (&mut *#gen_name) }
186 }
187 Some(Arg::Default) => {
188 quote!(::core::default::Default::default())
189 }
190 None => quote!(::quickcheck::Arbitrary::arbitrary(#gen_name)),
191 };
192 Ok(FieldValue {
193 attrs: vec![],
194 member: match field.ident {
195 Some(name) => Member::Named(name),
196 None => Member::Unnamed(Index::from(ix)),
197 },
198 colon_token: Some(Colon::default()),
199 expr: Expr::Verbatim(value),
200 })
201 })
202 .collect()
203}
204
205fn expr_struct(path: Path, field_values: Punctuated<FieldValue, Comma>) -> ExprStruct {
206 ExprStruct {
207 attrs: vec![],
208 qself: None,
209 path,
210 brace_token: Brace::default(),
211 fields: field_values,
212 dot2_token: None,
213 rest: None,
214 }
215}
216
217fn path_of_idents(idents: impl IntoIterator<Item = Ident>) -> Path {
218 Path {
219 leading_colon: None,
220 segments: Punctuated::from_iter(idents.into_iter().map(|ident| PathSegment {
221 ident,
222 arguments: syn::PathArguments::None,
223 })),
224 }
225}
226
227#[derive(Clone)]
228enum Arg {
229 Skip,
230 Gen(TokenStream),
231 Default,
232 Where(Punctuated<WherePredicate, Comma>),
233}
234
235#[derive(StructMeta, Debug, Default)]
236struct AttrArgs {
237 gen: Option<NameArgs<TokenStream>>,
238 skip: bool,
239 default: bool,
240 r#where: Option<NameArgs<TokenStream>>,
241}
242
243impl Parse for Arg {
244 fn parse(input: ParseStream) -> syn::Result<Self> {
245 let mut hint = syn::Error::new(
246 input.span(),
247 "expected one of `gen`, `default`, `where` or `skip`",
248 );
249 match AttrArgs::parse(input) {
250 Err(e) => {
252 hint.combine(e);
253 Err(hint)
254 }
255 Ok(AttrArgs {
257 gen: None,
258 r#where: None,
259 skip: false,
260 default: false,
261 }) => Err(hint),
262 Ok(AttrArgs {
264 skip: true,
265
266 gen: None,
267 default: false,
268 r#where: None,
269 }) => Ok(Arg::Skip),
270 Ok(AttrArgs {
272 gen: Some(NameArgs { name_span: _, args }),
273
274 r#where: None,
275 skip: false,
276 default: false,
277 }) => Ok(Arg::Gen(args)),
278
279 Ok(AttrArgs {
281 r#where: Some(NameArgs { name_span: _, args }),
282
283 gen: None,
284 skip: false,
285 default: false,
286 }) => Ok(Arg::Where(Punctuated::parse_terminated.parse2(args)?)), Ok(AttrArgs {
288 default: true,
289
290 r#where: None,
291 gen: None,
292 skip: false,
293 }) => Ok(Arg::Default),
294 Ok(AttrArgs { .. }) => Err(hint),
296 }
297 }
298}
299
300fn get_one_arg(attrs: &[Attribute], parent_span: Span) -> syn::Result<Option<Arg>> {
301 let configs = attrs
302 .iter()
303 .filter(|it| it.path().is_ident("arbitrary"))
304 .map(
305 |attr @ Attribute {
306 pound_token: _,
307 style,
308 bracket_token: _,
309 meta: _,
310 }| {
311 match style {
312 AttrStyle::Outer => attr.parse_args::<Arg>(),
313 AttrStyle::Inner(_) => Err(syn::Error::new_spanned(
314 attr,
315 "only outer attributes are supported: `#[arbitrary(...)]`",
316 )),
317 }
318 },
319 )
320 .collect::<Result<Vec<_>, _>>()?;
321 match configs.as_slice() {
322 [] => Ok(None),
323 [one] => Ok(Some(one.clone())),
324 _too_many => Err(syn::Error::new(
325 parent_span,
326 "`#[arbitrary(...)]` can only be specified once",
327 )),
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 use structmeta::NameArgs;
336 use syn::parse_quote;
337
338 #[test]
339 fn readme() {
340 assert!(
341 std::process::Command::new("cargo")
342 .args(["rdme", "--check"])
343 .output()
344 .expect("couldn't run `cargo rdme`")
345 .status
346 .success(),
347 "README.md is out of date - bless the new version by running `cargo rdme`"
348 )
349 }
350
351 #[test]
352 fn attr_args() {
353 assert_eq!(
354 AttrArgs {
355 skip: true,
356 ..Default::default()
357 },
358 parse_quote!(skip),
359 );
360 assert_eq!(
361 AttrArgs {
362 default: true,
363 ..Default::default()
364 },
365 parse_quote!(default),
366 );
367 assert_eq!(
368 AttrArgs {
369 gen: Some(NameArgs {
370 name_span: Span::call_site(),
371 args: quote!(some_fn)
372 }),
373 ..Default::default()
374 },
375 parse_quote!(gen(some_fn)),
376 );
377 assert_eq!(
378 AttrArgs {
379 r#where: Some(NameArgs {
380 name_span: Span::call_site(),
381 args: quote!(foo)
382 }),
383 ..Default::default()
384 },
385 parse_quote!(where(foo)),
386 );
387 }
388
389 #[test]
390 fn trybuild() {
391 let t = trybuild::TestCases::new();
392 t.pass("trybuild/pass/**/*.rs");
393 t.compile_fail("trybuild/fail/**/*.rs")
394 }
395
396 impl PartialEq for AttrArgs {
397 fn eq(&self, other: &Self) -> bool {
398 fn norm(t: &AttrArgs) -> (Option<String>, &bool, &bool, Option<String>) {
399 let AttrArgs {
400 gen,
401 skip,
402 default,
403 r#where,
404 } = t;
405 (
406 gen.as_ref().map(|it| it.args.to_string()),
407 skip,
408 default,
409 r#where.as_ref().map(|it| it.args.to_string()),
410 )
411 }
412 norm(self) == norm(other)
413 }
414 }
415}