bolero_generator_derive/
lib.rs

1extern crate proc_macro;
2
3mod generator_attr;
4
5use generator_attr::GeneratorAttr;
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use proc_macro_crate::{crate_name, FoundCrate};
9use quote::{quote, quote_spanned, ToTokens};
10use syn::{
11    parse_macro_input, parse_quote, spanned::Spanned, Data, DataEnum, DataStruct, DataUnion,
12    DeriveInput, Error, Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Ident,
13};
14
15fn crate_ident(from: FoundCrate) -> Ident {
16    let krate = match from {
17        FoundCrate::Itself => String::from("crate"),
18        FoundCrate::Name(n) => n,
19    };
20    Ident::new(&krate, Span::call_site())
21}
22
23fn crate_path() -> TokenStream2 {
24    // prefer referring to the generator crate, if present
25    if let Ok(krate) = crate_name("bolero-generator") {
26        let krate = crate_ident(krate);
27        return quote!(#krate);
28    }
29    if let Ok(krate) = crate_name("bolero") {
30        let krate = crate_ident(krate);
31        return quote!(#krate::generator::bolero_generator);
32    }
33    // fallback to using `::bolero_generator` if for whatever reason
34    // we can't find the crate in the `Cargo.toml`
35    quote!(::bolero_generator)
36}
37
38/// Derive the an implementation of `TypeGenerator` for the given type.
39///
40/// The `#[generator(my_custom_generator())]` attribute can be used
41/// to customize how fields are generated. If no generator is specified,
42/// the `TypeGenerator` implementation will be used.
43#[proc_macro_derive(TypeGenerator, attributes(generator))]
44pub fn derive_type_generator(input: TokenStream) -> TokenStream {
45    let krate = crate_path();
46    let derive_input = parse_macro_input!(input as DeriveInput);
47    let name = derive_input.ident;
48
49    // Add `T: TypeGenerator` bounds to each generic type `T`
50    let generics = add_trait_bound(derive_input.generics, &krate);
51
52    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
53
54    // The `generate` and `mutate` methods depend on the data type
55    let (generate_method, mutate_method) = match derive_input.data {
56        Data::Struct(data) => generate_struct_type_gen(&krate, &name, data),
57        Data::Enum(data) => generate_enum_type_gen(&krate, &name, data),
58        Data::Union(data) => generate_union_type_gen(&krate, &name, data),
59    };
60
61    // Generate the implementation for the type
62    quote!(
63        #[automatically_derived]
64        impl #impl_generics #krate::TypeGenerator for #name #ty_generics #where_clause {
65            #generate_method
66
67            #mutate_method
68        }
69    )
70    .into()
71}
72
73/// Add a bound `T: TypeGenerator` to each type parameter `T`
74fn add_trait_bound(mut generics: Generics, krate: &TokenStream2) -> Generics {
75    generics.params.iter_mut().for_each(|param| {
76        if let GenericParam::Type(type_param) = param {
77            type_param.bounds.push(parse_quote!(#krate::TypeGenerator));
78        }
79    });
80    generics
81}
82
83/// Create the `generate` and `mutate` methods for the derived `TypeGenerator` impl of a struct
84fn generate_struct_type_gen(
85    krate: &TokenStream2,
86    name: &Ident,
87    data_struct: DataStruct,
88) -> (TokenStream2, TokenStream2) {
89    let value = generate_fields_type_gen(krate, name, &data_struct.fields);
90    let destructure = generate_fields_type_destructure(name, &data_struct.fields);
91    let mutate_body = generate_fields_type_mutate(krate, &data_struct.fields);
92    let driver_cache = generate_fields_type_driver_cache(krate, &data_struct.fields);
93
94    let generate_method = quote!(
95        #[inline]
96        fn generate<__BOLERO_DRIVER: #krate::driver::Driver>(__bolero_driver: &mut __BOLERO_DRIVER) -> Option<Self> {
97            __bolero_driver.enter_product::<Self, _, _>(
98                |__bolero_driver| Some(#value)
99            )
100        }
101    );
102    let mutate_method = quote!(
103        #[inline]
104        fn mutate<__BOLERO_DRIVER: #krate::driver::Driver>(&mut self, __bolero_driver: &mut __BOLERO_DRIVER) -> Option<()> {
105            __bolero_driver.enter_product::<Self, _, _>(
106                |__bolero_driver| {
107                    let #destructure = self;
108                    #mutate_body
109                    Some(())
110                }
111            )
112        }
113
114        #[inline]
115        fn driver_cache<__BOLERO_DRIVER: #krate::driver::Driver>(self, __bolero_driver: &mut __BOLERO_DRIVER) {
116            let #destructure = self;
117            #driver_cache
118        }
119    );
120    (generate_method, mutate_method)
121}
122
123/// Create the `generate` and `mutate` methods for the derived `TypeGenerator` impl of an enum
124fn generate_enum_type_gen(
125    krate: &TokenStream2,
126    name: &Ident,
127    data_enum: DataEnum,
128) -> (TokenStream2, TokenStream2) {
129    let variant_max = data_enum.variants.len();
130    let base_case: usize = 0;
131
132    let variant_names: Vec<_> = data_enum
133        .variants
134        .iter()
135        .map(|variant| {
136            let span = variant.span();
137            let name = variant.ident.to_string();
138            quote_spanned!(span=> #name,)
139        })
140        .collect();
141    let variant_names = quote_spanned!(name.span()=> &[#(#variant_names)*]);
142
143    let gen_variants: Vec<_> = data_enum
144        .variants
145        .iter()
146        .enumerate()
147        .map(|(idx, variant)| {
148            let variant_name = &variant.ident;
149            let span = variant_name.span();
150            let constructor = quote_spanned!(span=> #name::#variant_name);
151            let value = generate_fields_type_gen(krate, constructor, &variant.fields);
152
153            let idx = lower_type_index(idx, variant_max, span);
154            quote_spanned!(span=> #idx => #value,)
155        })
156        .collect();
157
158    let gen_lookup: Vec<_> = data_enum
159        .variants
160        .iter()
161        .enumerate()
162        .map(|(idx, variant)| {
163            let variant_name = &variant.ident;
164            let span = variant_name.span();
165            let constructor = quote_spanned!(span=> #name::#variant_name);
166            let wildcard = generate_fields_type_wildcard(constructor, &variant.fields);
167            let idx = lower_type_index(idx, variant_max, span);
168            quote_spanned!(span=> #wildcard => #idx,)
169        })
170        .collect();
171
172    let gen_mutate: Vec<_> = data_enum
173        .variants
174        .iter()
175        .map(|variant| {
176            let variant_name = &variant.ident;
177            let span = variant_name.span();
178            let constructor = quote_spanned!(span=> #name::#variant_name);
179            let destructure = generate_fields_type_destructure(constructor, &variant.fields);
180            let mutate = generate_fields_type_mutate(krate, &variant.fields);
181
182            quote_spanned!(span=> #destructure => {
183                #mutate
184                Some(())
185            })
186        })
187        .collect();
188
189    let gen_driver_cache: Vec<_> = data_enum
190        .variants
191        .iter()
192        .map(|variant| {
193            let variant_name = &variant.ident;
194            let span = variant_name.span();
195            let constructor = quote_spanned!(span=> #name::#variant_name);
196            let destructure = generate_fields_type_destructure(constructor, &variant.fields);
197            let driver_cache = generate_fields_type_driver_cache(krate, &variant.fields);
198
199            quote_spanned!(span=> #destructure => {
200                #driver_cache
201            })
202        })
203        .collect();
204
205    let generate_method = quote!(
206        #[inline]
207        fn generate<__BOLERO_DRIVER: #krate::driver::Driver>(__bolero_driver: &mut __BOLERO_DRIVER) -> Option<Self> {
208            __bolero_driver.enter_sum::<Self, _, _>(
209                Some(#variant_names),
210                #variant_max,
211                #base_case,
212                |__bolero_driver, __bolero_selection| {
213                    Some(match __bolero_selection {
214                        #(#gen_variants)*
215                        _ => unreachable!("Value outside of range"),
216                    })
217                }
218            )
219        }
220    );
221
222    let mutate_method = quote!(
223        #[inline]
224        fn mutate<__BOLERO_DRIVER: #krate::driver::Driver>(&mut self, __bolero_driver: &mut __BOLERO_DRIVER) -> Option<()> {
225            __bolero_driver.enter_sum::<Self, _, _>(
226                Some(#variant_names),
227                #variant_max,
228                #base_case,
229                |__bolero_driver, __bolero_new_selection| {
230                    let __bolero_prev_selection = match self {
231                        #(#gen_lookup)*
232                    };
233
234                    if __bolero_prev_selection == __bolero_new_selection {
235                        match self {
236                            #(#gen_mutate)*
237                        }
238                    } else {
239                        let next = match __bolero_new_selection {
240                            #(#gen_variants)*
241                            _ => unreachable!("Value outside of range"),
242                        };
243                        match ::core::mem::replace(self, next) {
244                            #(#gen_driver_cache)*
245                        }
246                        Some(())
247                    }
248                }
249            )
250        }
251
252        #[inline]
253        fn driver_cache<__BOLERO_DRIVER: #krate::driver::Driver>(self, __bolero_driver: &mut __BOLERO_DRIVER) {
254            match self {
255                #(#gen_driver_cache)*
256            }
257        }
258    );
259    (generate_method, mutate_method)
260}
261
262/// Create the `generate` and `mutate` methods for the derived `TypeGenerator` impl of a union
263fn generate_union_type_gen(
264    krate: &TokenStream2,
265    name: &Ident,
266    data_union: DataUnion,
267) -> (TokenStream2, TokenStream2) {
268    let span = name.span();
269    let field_max = data_union.fields.named.len();
270    let field_upper = lower_type_index(field_max, field_max, name.span());
271
272    let base_case: usize = 0;
273
274    let variant_names: Vec<_> = data_union
275        .fields
276        .named
277        .iter()
278        .enumerate()
279        .map(|(idx, variant)| {
280            let span = variant.span();
281            let name = if let Some(name) = variant.ident.as_ref() {
282                name.to_string()
283            } else {
284                format!("<UnnamedUnionVariant{idx}>")
285            };
286            quote_spanned!(span=> #name,)
287        })
288        .collect();
289    let variant_names = quote_spanned!(name.span()=> &[#(#variant_names)*]);
290
291    let fields: Vec<_> = data_union
292        .fields
293        .named
294        .iter()
295        .enumerate()
296        .map(|(idx, field)| {
297            let field_name = &field.ident;
298            let generator = GeneratorAttr::from_attrs(krate, field.attrs.iter());
299
300            let idx = lower_type_index(
301                idx,
302                field_max,
303                field_name.as_ref().map(|n| n.span()).unwrap_or(span),
304            );
305            let span = generator.span();
306            let value = generator.value_generate();
307            quote_spanned!(span=>
308                #idx => Some(#name { #field_name: #value }),
309            )
310        })
311        .collect();
312
313    let generate_method = quote!(
314        #[inline]
315        fn generate<__BOLERO_DRIVER: #krate::driver::Driver>(__bolero_driver: &mut __BOLERO_DRIVER) -> Option<Self> {
316            __bolero_driver.enter_sum::<Self, _, _>(
317                Some(#variant_names),
318                #field_upper,
319                #base_case,
320                |__bolero_driver, __bolero_selection| {
321                    match __bolero_selection {
322                        #(#fields)*
323                        _ => unreachable!("Value outside of range"),
324                    }
325                }
326            )
327        }
328    );
329
330    // The `mutate` method doesn't apply to unions
331    let mutate_method = quote!();
332
333    (generate_method, mutate_method)
334}
335
336fn lower_type_index(value: usize, max: usize, span: Span) -> TokenStream2 {
337    assert!(value <= max);
338
339    if max == 0 {
340        return Error::new(span, "Empty enums cannot be generated").to_compile_error();
341    }
342
343    quote_spanned!(span=> #value)
344}
345
346fn generate_fields_type_gen<C: ToTokens>(
347    krate: &TokenStream2,
348    constructor: C,
349    fields: &Fields,
350) -> TokenStream2 {
351    match fields {
352        Fields::Named(fields) => generate_fields_named_type_gen(krate, constructor, fields),
353        Fields::Unnamed(fields) => generate_fields_unnamed_type_gen(krate, constructor, fields),
354        Fields::Unit => quote!(#constructor),
355    }
356}
357
358fn generate_fields_type_mutate(krate: &TokenStream2, fields: &Fields) -> TokenStream2 {
359    match fields {
360        Fields::Named(fields) => generate_fields_named_type_mutate(krate, fields),
361        Fields::Unnamed(fields) => generate_fields_unnamed_type_mutate(krate, fields),
362        Fields::Unit => quote!(),
363    }
364}
365
366fn generate_fields_type_driver_cache(krate: &TokenStream2, fields: &Fields) -> TokenStream2 {
367    match fields {
368        Fields::Named(fields) => generate_fields_named_type_driver_cache(krate, fields),
369        Fields::Unnamed(fields) => generate_fields_unnamed_type_driver_cache(krate, fields),
370        Fields::Unit => quote!(),
371    }
372}
373
374fn generate_fields_type_wildcard<C: ToTokens>(constructor: C, fields: &Fields) -> TokenStream2 {
375    match fields {
376        Fields::Named(_) => quote!(#constructor { .. }),
377        Fields::Unnamed(fields) => generate_fields_unnamed_type_wildcard(constructor, fields),
378        Fields::Unit => quote!(#constructor),
379    }
380}
381
382fn generate_fields_type_destructure<C: ToTokens>(constructor: C, fields: &Fields) -> TokenStream2 {
383    match fields {
384        Fields::Named(fields) => generate_fields_named_type_destructure(constructor, fields),
385        Fields::Unnamed(fields) => generate_fields_unnamed_type_destructure(constructor, fields),
386        Fields::Unit => quote!(#constructor),
387    }
388}
389
390fn generate_fields_unnamed_type_gen<C: ToTokens>(
391    krate: &TokenStream2,
392    constructor: C,
393    fields: &FieldsUnnamed,
394) -> TokenStream2 {
395    let fields = fields.unnamed.iter().map(|field| {
396        let generator = GeneratorAttr::from_attrs(krate, field.attrs.iter());
397        let value = generator.value_generate();
398        quote!(#value)
399    });
400    quote!(#constructor ( #(#fields,)* ))
401}
402
403fn generate_fields_unnamed_type_mutate(
404    krate: &TokenStream2,
405    fields: &FieldsUnnamed,
406) -> TokenStream2 {
407    let fields = fields.unnamed.iter().enumerate().map(|(index, field)| {
408        let value = Ident::new(&format!("__bolero_unnamed_{index}"), field.span());
409        let generator = GeneratorAttr::from_attrs(krate, field.attrs.iter());
410
411        let span = generator.span();
412        quote_spanned!(span=>
413            #krate::ValueGenerator::mutate(&(#generator), __bolero_driver, #value)?
414        )
415    });
416    quote!(#(#fields;)*)
417}
418
419fn generate_fields_unnamed_type_driver_cache(
420    krate: &TokenStream2,
421    fields: &FieldsUnnamed,
422) -> TokenStream2 {
423    let fields = fields.unnamed.iter().enumerate().map(|(index, field)| {
424        let value = Ident::new(&format!("__bolero_unnamed_{index}"), field.span());
425        let generator = GeneratorAttr::from_attrs(krate, field.attrs.iter());
426
427        let span = generator.span();
428        quote_spanned!(span=>
429            #krate::ValueGenerator::driver_cache(&(#generator), __bolero_driver, #value)
430        )
431    });
432    quote!(#(#fields;)*)
433}
434
435fn generate_fields_unnamed_type_wildcard<C: ToTokens>(
436    constructor: C,
437    fields: &FieldsUnnamed,
438) -> TokenStream2 {
439    let fields = fields.unnamed.iter().map(|_| quote!(_));
440    quote!(#constructor (#(#fields),*))
441}
442
443fn generate_fields_unnamed_type_destructure<C: ToTokens>(
444    constructor: C,
445    fields: &FieldsUnnamed,
446) -> TokenStream2 {
447    let fields = fields
448        .unnamed
449        .iter()
450        .enumerate()
451        .map(|(index, field)| Ident::new(&format!("__bolero_unnamed_{index}"), field.span()));
452    quote!(#constructor (#(#fields),*))
453}
454
455fn generate_fields_named_type_gen<C: ToTokens>(
456    krate: &TokenStream2,
457    constructor: C,
458    fields: &FieldsNamed,
459) -> TokenStream2 {
460    let fields = fields.named.iter().map(|field| {
461        let name = &field.ident;
462        let generator = GeneratorAttr::from_attrs(krate, field.attrs.iter());
463        let value = generator.value_generate();
464        let span = generator.span();
465        quote_spanned!(span=>
466            #name: #value
467        )
468    });
469    quote!(#constructor { #(#fields,)* })
470}
471
472fn generate_fields_named_type_mutate(krate: &TokenStream2, fields: &FieldsNamed) -> TokenStream2 {
473    let fields = fields.named.iter().map(|field| {
474        let name = &field.ident;
475        let generator = GeneratorAttr::from_attrs(krate, field.attrs.iter());
476
477        let span = generator.span();
478        quote_spanned!(span=>
479            #krate::ValueGenerator::mutate(&(#generator), __bolero_driver, #name)?
480        )
481    });
482    quote!(#(#fields;)*)
483}
484
485fn generate_fields_named_type_driver_cache(
486    krate: &TokenStream2,
487    fields: &FieldsNamed,
488) -> TokenStream2 {
489    let fields = fields.named.iter().map(|field| {
490        let name = &field.ident;
491        let generator = GeneratorAttr::from_attrs(krate, field.attrs.iter());
492
493        let span = generator.span();
494        quote_spanned!(span=>
495            #krate::ValueGenerator::driver_cache(&(#generator), __bolero_driver, #name)
496        )
497    });
498    quote!(#(#fields;)*)
499}
500
501fn generate_fields_named_type_destructure<C: ToTokens>(
502    constructor: C,
503    fields: &FieldsNamed,
504) -> TokenStream2 {
505    let fields = fields.named.iter().map(|field| &field.ident);
506    quote!(#constructor { #(#fields,)* })
507}