Skip to main content

mutatis_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{spanned::Spanned, *};
6
7mod container_attributes;
8mod field_attributes;
9use container_attributes::ContainerAttributes;
10use field_attributes::FieldBehavior;
11
12static MUTATIS_ATTRIBUTE_NAME: &str = "mutatis";
13
14#[proc_macro_derive(Mutate, attributes(mutatis))]
15pub fn derive_mutator(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
16    let input = syn::parse_macro_input!(tokens as DeriveInput);
17    expand_derive_mutator(input)
18        .unwrap_or_else(syn::Error::into_compile_error)
19        .into()
20}
21
22fn expand_derive_mutator(input: DeriveInput) -> Result<TokenStream> {
23    let container_attrs = ContainerAttributes::from_derive_input(&input)?;
24    let mutator_ty = MutatorType::new(&input, &container_attrs)?;
25
26    let mutator_type_def = gen_mutator_type_def(&input, &mutator_ty, &container_attrs)?;
27    let mutator_type_default_impl = gen_mutator_type_default_impl(&mutator_ty)?;
28    let mutator_ctor = gen_mutator_ctor(&mutator_ty)?;
29    let mutator_impl = gen_mutator_impl(&input, &mutator_ty)?;
30    let default_mutator_impl = gen_default_mutator_impl(&mutator_ty, &container_attrs)?;
31    let generate_impl = gen_generate_impl(&input, &mutator_ty, &container_attrs)?;
32
33    Ok(quote! {
34        #mutator_type_def
35        #mutator_type_default_impl
36        #mutator_ctor
37        #mutator_impl
38        #default_mutator_impl
39        #generate_impl
40    })
41}
42
43struct MutatorType {
44    ty_name: Ident,
45
46    mutator_name: Ident,
47
48    mutator_fields: Vec<MutatorField>,
49
50    /// A vec of quoted generic parameters, without any bounds but with `const`
51    /// defs, e.g. `'a`, `const N: usize`, or `T`.
52    ty_impl_generics: Vec<TokenStream>,
53
54    /// A vec of quoted generic parameters, without any bounds and without any
55    /// `const` defs, e.t. `'a`, `N`, or `T.
56    ty_name_generics: Vec<TokenStream>,
57
58    /// A vec of quoted bounds for the generics above, e.g. `A: Iterator<Item =
59    /// B>,`.
60    ty_generics_bounds: Vec<TokenStream>,
61}
62
63impl MutatorType {
64    fn new(input: &DeriveInput, container_attrs: &ContainerAttributes) -> Result<Self> {
65        let ty_name = input.ident.clone();
66
67        let mutator_name = container_attrs
68            .mutator_name
69            .clone()
70            .unwrap_or_else(|| Ident::new(&format!("{}Mutator", input.ident), input.ident.span()));
71
72        let mutator_fields = get_mutator_fields(&input)?;
73
74        let mut ty_impl_generics = vec![];
75        let mut ty_name_generics = vec![];
76        let mut ty_generics_bounds = vec![];
77
78        for gen in &input.generics.params {
79            match gen {
80                GenericParam::Lifetime(l) => {
81                    if !l.bounds.is_empty() {
82                        ty_generics_bounds.push(quote! { #l });
83                    }
84
85                    let l = &l.lifetime;
86                    ty_impl_generics.push(quote! { #l });
87                    ty_name_generics.push(quote! { #l });
88                }
89                GenericParam::Const(c) => {
90                    ty_impl_generics.push(quote! { #c });
91                    let c = &c.ident;
92                    ty_name_generics.push(quote! { #c });
93                }
94                GenericParam::Type(t) => {
95                    if !t.bounds.is_empty() {
96                        ty_generics_bounds.push(quote! { #t });
97                    }
98                    let t = &t.ident;
99                    ty_impl_generics.push(quote! { #t });
100                    ty_name_generics.push(quote! { #t });
101                }
102            }
103        }
104
105        if let Some(wc) = &input.generics.where_clause {
106            for bound in wc.predicates.iter() {
107                ty_generics_bounds.push(quote! { #bound });
108            }
109        }
110
111        Ok(Self {
112            ty_name,
113            mutator_name,
114            mutator_fields,
115            ty_impl_generics,
116            ty_name_generics,
117            ty_generics_bounds,
118        })
119    }
120
121    fn mutator_impl_generics_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
122        self.ty_impl_generics.iter().cloned().chain(
123            self.mutator_fields
124                .iter()
125                .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
126        )
127    }
128
129    /// All the `impl` generic parameters for this mutator, including those
130    /// inherited from the type that it is a mutator for.
131    fn mutator_impl_generics(&self) -> TokenStream {
132        let impl_generics = self
133            .ty_impl_generics
134            .iter()
135            .cloned()
136            .chain(
137                self.mutator_fields
138                    .iter()
139                    .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
140            )
141            .collect::<Vec<_>>();
142        if impl_generics.is_empty() {
143            quote! {}
144        } else {
145            quote! { < #( #impl_generics ),* > }
146        }
147    }
148
149    /// All the named (i.e. just the "N" and excluding "const", ":", and "usize"
150    /// in `const N: usize` generics) generic parameters for this mutator,
151    /// including those inherited from the type that it is a mutator for.
152    fn mutator_name_generics_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
153        self.ty_name_generics.iter().cloned().chain(
154            self.mutator_fields
155                .iter()
156                .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
157        )
158    }
159
160    fn mutator_impl_generics_with_defaults_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
161        self.ty_impl_generics
162            .iter()
163            .cloned()
164            .chain(self.mutator_fields.iter().filter_map(move |f| {
165                f.generic.as_ref().map(|g| {
166                    let for_ty = &f.for_ty;
167                    quote! { #g = <#for_ty as mutatis::DefaultMutate>::DefaultMutate }
168                })
169            }))
170    }
171
172    fn ty_name_with_generics(&self) -> TokenStream {
173        let ty_name = &self.ty_name;
174        if self.ty_name_generics.is_empty() {
175            quote! { #ty_name }
176        } else {
177            let ty_generics = self.ty_name_generics.iter();
178            quote! { #ty_name < #( #ty_generics ),* > }
179        }
180    }
181
182    fn mutator_name_with_generics(&self, kind: MutatorNameGenericsKind) -> TokenStream {
183        let mutator_name = &self.mutator_name;
184
185        let generics = match kind {
186            MutatorNameGenericsKind::Generics => {
187                self.mutator_name_generics_iter().collect::<Vec<_>>()
188            }
189            MutatorNameGenericsKind::Impl {
190                impl_default: false,
191            } => self.mutator_impl_generics_iter().collect::<Vec<_>>(),
192            MutatorNameGenericsKind::Impl { impl_default: true } => self
193                .mutator_impl_generics_with_defaults_iter()
194                .collect::<Vec<_>>(),
195            MutatorNameGenericsKind::JustTyGenerics => self.ty_name_generics.clone(),
196        };
197
198        if generics.is_empty() {
199            quote! { #mutator_name }
200        } else {
201            quote! { #mutator_name < #( #generics ),* > }
202        }
203    }
204
205    fn where_clause(&self, kind: WhereClauseKind) -> TokenStream {
206        let mut bounds = self.ty_generics_bounds.clone();
207
208        match kind {
209            WhereClauseKind::NoMutateBounds => {}
210            WhereClauseKind::MutateBounds => {
211                for f in &self.mutator_fields {
212                    let for_ty = &f.for_ty;
213                    if let Some(g) = f.generic.as_ref() {
214                        bounds.push(quote! { #g: mutatis::Mutate<#for_ty> });
215                    } else {
216                        debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
217                        bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
218                    }
219                }
220            }
221            WhereClauseKind::MutateAndGenerateBounds => {
222                for f in &self.mutator_fields {
223                    let for_ty = &f.for_ty;
224                    if let Some(g) = f.generic.as_ref() {
225                        bounds.push(
226                            quote! { #g: mutatis::Mutate<#for_ty> + mutatis::Generate<#for_ty> },
227                        );
228                    } else {
229                        debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
230                        bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
231                        bounds.push(quote! { <#for_ty as mutatis::DefaultMutate>::DefaultMutate: mutatis::Generate<#for_ty> });
232                    }
233                }
234            }
235            WhereClauseKind::DefaultBounds => {
236                for f in &self.mutator_fields {
237                    if let Some(g) = f.generic.as_ref() {
238                        bounds.push(quote! { #g: Default });
239                    } else {
240                        let for_ty = &f.for_ty;
241                        debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
242                        bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
243                    }
244                }
245            }
246            WhereClauseKind::DefaultMutateBounds => {
247                for f in &self.mutator_fields {
248                    let for_ty = &f.for_ty;
249                    bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
250                }
251            }
252        }
253
254        if bounds.is_empty() {
255            quote! {}
256        } else {
257            quote! { where #( #bounds ),* }
258        }
259    }
260
261    fn phantom_fields_defs<'a>(
262        &self,
263        input: &'a DeriveInput,
264    ) -> impl Iterator<Item = TokenStream> + 'a {
265        let make_phantom_field = |i, ty| {
266            let ident = Ident::new(&format!("_phantom{i}"), Span::call_site());
267            quote! { #ident : core::marker::PhantomData<#ty> , }
268        };
269
270        input
271            .generics
272            .params
273            .iter()
274            .enumerate()
275            .map(move |(i, g)| match g {
276                GenericParam::Lifetime(l) => {
277                    let l = &l.lifetime;
278                    make_phantom_field(i, quote! { & #l () })
279                }
280                GenericParam::Const(c) => {
281                    let c = &c.ident;
282                    make_phantom_field(i, quote! { [(); #c] })
283                }
284                GenericParam::Type(t) => {
285                    let t = &t.ident;
286                    make_phantom_field(i, quote! { #t })
287                }
288            })
289    }
290
291    fn phantom_fields_literals(&self) -> impl Iterator<Item = TokenStream> + '_ {
292        (0..self.ty_name_generics.len()).map(|i| {
293            let ident = Ident::new(&format!("_phantom{i}"), Span::call_site());
294            quote! { #ident : core::marker::PhantomData, }
295        })
296    }
297}
298
299#[derive(Clone, Copy)]
300enum WhereClauseKind {
301    NoMutateBounds,
302    MutateBounds,
303    MutateAndGenerateBounds,
304    DefaultBounds,
305    DefaultMutateBounds,
306}
307
308#[derive(Clone, Copy)]
309enum MutatorNameGenericsKind {
310    Generics,
311    Impl { impl_default: bool },
312    JustTyGenerics,
313}
314
315struct MutatorField {
316    /// The identifier for this field inside the mutator struct.
317    ident: Ident,
318    /// The generic type parameter for this field, if any.
319    generic: Option<Ident>,
320    /// The behavior for this field.
321    behavior: FieldBehavior,
322    /// The type that this field is a mutator for.
323    for_ty: Type,
324}
325
326fn get_mutator_fields(input: &DeriveInput) -> Result<Vec<MutatorField>> {
327    let mut i = 0;
328    let mut generic = |b: &FieldBehavior| -> Option<Ident> {
329        if b.needs_generic() {
330            let g = Ident::new(&format!("MutatorT{}", i), Span::call_site());
331            i += 1;
332            Some(g)
333        } else {
334            None
335        }
336    };
337
338    match &input.data {
339        Data::Struct(data) => match &data.fields {
340            Fields::Named(fields) => fields
341                .named
342                .iter()
343                .filter_map(|f| {
344                    FieldBehavior::for_field(f)
345                        .map(|b| {
346                            b.map(|b| MutatorField {
347                                ident: f.ident.clone().unwrap(),
348                                generic: generic(&b),
349                                behavior: b,
350                                for_ty: f.ty.clone(),
351                            })
352                        })
353                        .transpose()
354                })
355                .collect(),
356            Fields::Unnamed(fields) => fields
357                .unnamed
358                .iter()
359                .enumerate()
360                .filter_map(|(i, f)| {
361                    FieldBehavior::for_field(f)
362                        .map(|b| {
363                            b.map(|b| MutatorField {
364                                ident: Ident::new(&format!("field{}", i), f.span()),
365                                generic: generic(&b),
366                                behavior: b,
367                                for_ty: f.ty.clone(),
368                            })
369                        })
370                        .transpose()
371                })
372                .collect(),
373            Fields::Unit => Ok(vec![]),
374        },
375        Data::Enum(data) => Ok(data
376            .variants
377            .iter()
378            .map(|v| {
379                let prefix = v.ident.to_string().to_lowercase();
380                match v.fields {
381                    Fields::Named(ref fields) => fields
382                        .named
383                        .iter()
384                        .filter_map(|f| {
385                            FieldBehavior::for_field(f)
386                                .map(|b| {
387                                    b.map(|b| MutatorField {
388                                        ident: Ident::new(
389                                            &format!("{prefix}_{}", f.ident.clone().unwrap()),
390                                            f.span(),
391                                        ),
392                                        generic: generic(&b),
393                                        behavior: b,
394                                        for_ty: f.ty.clone(),
395                                    })
396                                })
397                                .transpose()
398                        })
399                        .collect::<Result<Vec<_>>>(),
400                    Fields::Unnamed(ref fields) => fields
401                        .unnamed
402                        .iter()
403                        .enumerate()
404                        .filter_map(|(i, f)| {
405                            FieldBehavior::for_field(f)
406                                .map(|b| {
407                                    b.map(|b| MutatorField {
408                                        ident: Ident::new(&format!("{prefix}{i}"), f.span()),
409                                        generic: generic(&b),
410                                        behavior: b,
411                                        for_ty: f.ty.clone(),
412                                    })
413                                })
414                                .transpose()
415                        })
416                        .collect::<Result<Vec<_>>>(),
417                    Fields::Unit => Ok(vec![]),
418                }
419            })
420            .collect::<Result<Vec<_>>>()?
421            .into_iter()
422            .flat_map(|fs| fs)
423            .collect()),
424        Data::Union(_) => Err(Error::new_spanned(
425            input,
426            "cannot `derive(Mutate)` on a union",
427        )),
428    }
429}
430
431fn gen_mutator_type_def(
432    input: &DeriveInput,
433    mutator_ty: &MutatorType,
434    container_attrs: &ContainerAttributes,
435) -> Result<TokenStream> {
436    let vis = &input.vis;
437    let name = &input.ident;
438
439    let impl_default = container_attrs.default_mutate.unwrap_or(true);
440    let mutator_name =
441        mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Impl { impl_default });
442
443    let mut temp: Option<LitStr> = None;
444    let doc = container_attrs.mutator_doc.as_deref().unwrap_or_else(|| {
445        temp = Some(LitStr::new(
446            &format!(" A mutator for the `{name}` type."),
447            input.ident.span(),
448        ));
449        std::slice::from_ref(temp.as_ref().unwrap())
450    });
451
452    let where_clause = mutator_ty.where_clause(WhereClauseKind::NoMutateBounds);
453
454    let fields = mutator_ty
455        .mutator_fields
456        .iter()
457        .map(|f| {
458            let ident = &f.ident;
459            if let Some(g) = f.generic.as_ref() {
460                quote! { #ident: #g , }
461            } else {
462                let for_ty = &f.for_ty;
463                debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
464                quote! { #ident: <#for_ty as mutatis::DefaultMutate>::DefaultMutate, }
465            }
466        })
467        .collect::<Vec<_>>();
468
469    let phantoms = mutator_ty.phantom_fields_defs(input);
470
471    Ok(quote! {
472        #( #[doc = #doc] )*
473        // #[derive(Clone, Debug)]
474        #vis struct #mutator_name #where_clause {
475            #( #fields )*
476            #( #phantoms )*
477            _private: (),
478        }
479    })
480}
481
482fn gen_mutator_type_default_impl(mutator_ty: &MutatorType) -> Result<TokenStream> {
483    let impl_generics = mutator_ty.mutator_impl_generics();
484    let mutator_name = mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
485    let where_clause = mutator_ty.where_clause(WhereClauseKind::DefaultBounds);
486
487    let fields = mutator_ty
488        .mutator_fields
489        .iter()
490        .map(|f| {
491            let ident = &f.ident;
492            quote! { #ident: Default::default(), }
493        })
494        .collect::<Vec<_>>();
495
496    let phantoms = mutator_ty.phantom_fields_literals();
497
498    Ok(quote! {
499        #[automatically_derived]
500        impl #impl_generics Default for #mutator_name #where_clause {
501            fn default() -> Self {
502                Self {
503                    #( #fields )*
504                    #( #phantoms )*
505                    _private: (),
506                }
507            }
508        }
509    })
510}
511
512fn gen_mutator_ctor(mutator_ty: &MutatorType) -> Result<TokenStream> {
513    let impl_generics = mutator_ty.mutator_impl_generics();
514
515    let params = mutator_ty
516        .mutator_fields
517        .iter()
518        .filter_map(|f| {
519            f.generic.as_ref().map(|g| {
520                let ident = &f.ident;
521                quote! { #ident: #g , }
522            })
523        })
524        .collect::<Vec<_>>();
525
526    let fields = mutator_ty
527        .mutator_fields
528        .iter()
529        .map(|f| {
530            let ident = &f.ident;
531            if f.generic.is_some() {
532                quote! { #ident , }
533            } else {
534                let for_ty = &f.for_ty;
535                debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
536                quote! { #ident: mutatis::mutators::default::<#for_ty>() , }
537            }
538        })
539        .collect::<Vec<_>>();
540
541    let name = &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
542    let doc = format!("Construct a new `{name}` instance.");
543    let where_clause = mutator_ty.where_clause(WhereClauseKind::NoMutateBounds);
544    let phantoms = mutator_ty.phantom_fields_literals();
545
546    Ok(quote! {
547        impl #impl_generics #name #where_clause {
548            #[doc = #doc]
549            #[inline]
550            pub fn new( #( #params )* ) -> Self {
551                Self {
552                    #( #fields )*
553                    #( #phantoms )*
554                    _private: (),
555                }
556            }
557        }
558    })
559}
560
561fn gen_mutator_impl(input: &DeriveInput, mutator_ty: &MutatorType) -> Result<TokenStream> {
562    let impl_generics = mutator_ty.mutator_impl_generics();
563
564    let ty_name = mutator_ty.ty_name_with_generics();
565
566    let is_multi_variant_enum = matches!(&input.data, Data::Enum(data) if data.variants.len() > 1);
567    let where_clause = if is_multi_variant_enum {
568        mutator_ty.where_clause(WhereClauseKind::MutateAndGenerateBounds)
569    } else {
570        mutator_ty.where_clause(WhereClauseKind::MutateBounds)
571    };
572
573    let mut fields_iter = mutator_ty.mutator_fields.iter();
574    let mut make_mutation = |value| {
575        let ident = &fields_iter.next().unwrap().ident;
576        quote! { self.#ident.mutate(mutations, #value)?; }
577    };
578
579    let mutation_body = match &input.data {
580        Data::Struct(data) => match &data.fields {
581            Fields::Named(fields) => {
582                let mutations = fields
583                    .named
584                    .iter()
585                    .filter(|f| FieldBehavior::for_field(f).unwrap().is_some())
586                    .map(|f| {
587                        let ident = &f.ident;
588                        make_mutation(quote! { &mut value.#ident })
589                    });
590                quote! {
591                    #( #mutations )*
592                }
593            }
594            Fields::Unnamed(fields) => {
595                let mutations = fields
596                    .unnamed
597                    .iter()
598                    .enumerate()
599                    .filter(|(_i, f)| FieldBehavior::for_field(f).unwrap().is_some())
600                    .map(|(i, f)| {
601                        let index = Index {
602                            index: u32::try_from(i).unwrap(),
603                            span: f.span(),
604                        };
605                        make_mutation(quote! { &mut value.#index })
606                    });
607                quote! {
608                    #( #mutations )*
609                }
610            }
611            Fields::Unit => quote! {},
612        },
613
614        Data::Enum(data) => {
615            // Build the existing field-mutation match arms.
616            let mut field_mutation_arms = vec![];
617            for v in data.variants.iter() {
618                let variant_ident = &v.ident;
619                match &v.fields {
620                    Fields::Named(fields) => {
621                        let mut patterns = vec![];
622                        let mutates = fields
623                            .named
624                            .iter()
625                            .filter_map(|f| {
626                                let ident = &f.ident;
627                                if FieldBehavior::for_field(f).unwrap().is_some() {
628                                    patterns.push(quote! { ref mut #ident , });
629                                    Some(make_mutation(quote! { #ident }))
630                                } else {
631                                    patterns.push(quote! { #ident: _ , });
632                                    None
633                                }
634                            })
635                            .collect::<Vec<_>>();
636                        field_mutation_arms.push(quote! {
637                            #ty_name::#variant_ident { #( #patterns )* } => {
638                                #( #mutates )*
639                            }
640                        });
641                    }
642
643                    Fields::Unnamed(fields) => {
644                        let mut patterns = vec![];
645                        let mutates = fields
646                            .unnamed
647                            .iter()
648                            .enumerate()
649                            .filter_map(|(i, f)| {
650                                if FieldBehavior::for_field(f).unwrap().is_some() {
651                                    let binding = Ident::new(&format!("field{}", i), f.span());
652                                    patterns.push(quote! { ref mut #binding , });
653                                    Some(make_mutation(quote! { #binding }))
654                                } else {
655                                    patterns.push(quote! { _ , });
656                                    None
657                                }
658                            })
659                            .collect::<Vec<_>>();
660                        field_mutation_arms.push(quote! {
661                            #ty_name::#variant_ident( #( #patterns )* ) => {
662                                #( #mutates )*
663                            }
664                        });
665                    }
666
667                    Fields::Unit => {
668                        field_mutation_arms.push(quote! {
669                            #ty_name::#variant_ident => {}
670                        });
671                    }
672                }
673            }
674
675            // Build variant-switching mutations for enums with multiple
676            // variants.
677            let variant_switching = if data.variants.len() > 1 {
678                // Build a match to determine the current variant index.
679                let index_arms: Vec<_> = data
680                    .variants
681                    .iter()
682                    .enumerate()
683                    .map(|(v_idx, v)| {
684                        let variant_ident = &v.ident;
685                        match &v.fields {
686                            Fields::Named(_) => {
687                                quote! { #ty_name::#variant_ident { .. } => #v_idx, }
688                            }
689                            Fields::Unnamed(_) => {
690                                quote! { #ty_name::#variant_ident(..) => #v_idx, }
691                            }
692                            Fields::Unit => quote! { #ty_name::#variant_ident => #v_idx, },
693                        }
694                    })
695                    .collect();
696
697                // For each variant, build an expression that constructs a new
698                // instance of that variant. For non-ignored fields, use the
699                // mutator's `generate` method. For ignored fields, use
700                // `Default::default()`.
701                let mut variant_mutations = vec![];
702                let mut mutator_field_offset = 0usize;
703                for (v_idx, v) in data.variants.iter().enumerate() {
704                    let variant_ident = &v.ident;
705
706                    let construction = match &v.fields {
707                        Fields::Named(fields) => {
708                            let field_exprs: Vec<_> = fields
709                                .named
710                                .iter()
711                                .map(|f| {
712                                    let ident = &f.ident;
713                                    if let Some(_behavior) = FieldBehavior::for_field(f).unwrap() {
714                                        let mutator_ident =
715                                            &mutator_ty.mutator_fields[mutator_field_offset].ident;
716                                        mutator_field_offset += 1;
717                                        quote! { #ident: self.#mutator_ident.generate(ctx)? }
718                                    } else {
719                                        quote! { #ident: Default::default() }
720                                    }
721                                })
722                                .collect();
723                            quote! {
724                                *value = #ty_name::#variant_ident { #( #field_exprs ),* };
725                            }
726                        }
727                        Fields::Unnamed(fields) => {
728                            let field_exprs: Vec<_> = fields
729                                .unnamed
730                                .iter()
731                                .map(|f| {
732                                    if let Some(_behavior) = FieldBehavior::for_field(f).unwrap() {
733                                        let mutator_ident =
734                                            &mutator_ty.mutator_fields[mutator_field_offset].ident;
735                                        mutator_field_offset += 1;
736                                        quote! { self.#mutator_ident.generate(ctx)? }
737                                    } else {
738                                        quote! { Default::default() }
739                                    }
740                                })
741                                .collect();
742                            quote! {
743                                *value = #ty_name::#variant_ident( #( #field_exprs ),* );
744                            }
745                        }
746                        Fields::Unit => {
747                            quote! {
748                                *value = #ty_name::#variant_ident;
749                            }
750                        }
751                    };
752
753                    variant_mutations.push((v_idx, construction));
754                }
755
756                let num_variants = data.variants.len();
757                let group_count = num_variants - 1;
758                let group_arms: Vec<_> = variant_mutations
759                    .iter()
760                    .map(|(v_idx, construction)| {
761                        quote! {
762                            #v_idx => {
763                                #construction
764                                Ok(())
765                            }
766                        }
767                    })
768                    .collect();
769
770                quote! {
771                    let _variant_index: usize = match value {
772                        #( #index_arms )*
773                    };
774                    mutations.mutation_group(#group_count as u32, |ctx, _which| {
775                        let _target = if (_which as usize) >= _variant_index {
776                            _which as usize + 1
777                        } else {
778                            _which as usize
779                        };
780                        match _target {
781                            #( #group_arms )*
782                            _ => unreachable!(),
783                        }
784                    })?;
785                }
786            } else {
787                quote! {}
788            };
789
790            quote! {
791                #variant_switching
792                match *value {
793                    #( #field_mutation_arms )*
794                }
795            }
796        }
797
798        Data::Union(_) => {
799            return Err(Error::new_spanned(
800                input,
801                "cannot `derive(Mutate)` on a union",
802            ))
803        }
804    };
805
806    let mutate_method = quote! {
807        fn mutate(
808            &mut self,
809            mutations: &mut mutatis::Candidates,
810            value: &mut #ty_name,
811        ) -> mutatis::Result<()> {
812            #mutation_body
813
814            // Silence unused-variable warnings if every field was marked
815            // `ignore`. Allow unreachable for empty enums.
816            #[allow(unreachable_code)]
817            let _ = (mutations, value);
818
819            Ok(())
820        }
821    };
822
823    let mutation_count_body = match &input.data {
824        Data::Struct(data) => {
825            let mut field_counts = vec![];
826            let mut mutator_field_idx = 0usize;
827            match &data.fields {
828                Fields::Named(fields) => {
829                    for f in fields.named.iter() {
830                        if FieldBehavior::for_field(f).unwrap().is_some() {
831                            let field_ident = &f.ident;
832                            let mutator_ident = &mutator_ty.mutator_fields[mutator_field_idx].ident;
833                            mutator_field_idx += 1;
834                            field_counts.push(quote! {
835                                _count += self.#mutator_ident.mutation_count(
836                                    &value.#field_ident,
837                                    shrink,
838                                )?;
839                            });
840                        }
841                    }
842                }
843                Fields::Unnamed(fields) => {
844                    for (i, f) in fields.unnamed.iter().enumerate() {
845                        if FieldBehavior::for_field(f).unwrap().is_some() {
846                            let index = Index {
847                                index: u32::try_from(i).unwrap(),
848                                span: f.span(),
849                            };
850                            let mutator_ident = &mutator_ty.mutator_fields[mutator_field_idx].ident;
851                            mutator_field_idx += 1;
852                            field_counts.push(quote! {
853                                _count += self.#mutator_ident.mutation_count(
854                                    &value.#index,
855                                    shrink,
856                                )?;
857                            });
858                        }
859                    }
860                }
861                Fields::Unit => {}
862            }
863            quote! {
864                let mut _count = 0u32;
865                #(#field_counts)*
866                Some(_count)
867            }
868        }
869
870        Data::Enum(data) => {
871            let variant_switch_count = if data.variants.len() > 1 {
872                data.variants.len() - 1
873            } else {
874                0
875            };
876
877            let mut count_arms = vec![];
878            let mut mutator_field_idx = 0usize;
879            for v in data.variants.iter() {
880                let variant_ident = &v.ident;
881                match &v.fields {
882                    Fields::Named(fields) => {
883                        let mut patterns = vec![];
884                        let mut fld_counts = vec![];
885                        for f in fields.named.iter() {
886                            let ident = &f.ident;
887                            if FieldBehavior::for_field(f).unwrap().is_some() {
888                                patterns.push(quote! { ref #ident, });
889                                let mutator_ident =
890                                    &mutator_ty.mutator_fields[mutator_field_idx].ident;
891                                mutator_field_idx += 1;
892                                fld_counts.push(quote! {
893                                    _count += self.#mutator_ident.mutation_count(
894                                        #ident,
895                                        shrink,
896                                    )?;
897                                });
898                            } else {
899                                patterns.push(quote! { #ident: _, });
900                            }
901                        }
902                        count_arms.push(quote! {
903                            #ty_name::#variant_ident { #(#patterns)* } => {
904                                #(#fld_counts)*
905                            }
906                        });
907                    }
908                    Fields::Unnamed(fields) => {
909                        let mut patterns = vec![];
910                        let mut fld_counts = vec![];
911                        for (i, f) in fields.unnamed.iter().enumerate() {
912                            if FieldBehavior::for_field(f).unwrap().is_some() {
913                                let binding = Ident::new(&format!("field{}", i), f.span());
914                                patterns.push(quote! { ref #binding, });
915                                let mutator_ident =
916                                    &mutator_ty.mutator_fields[mutator_field_idx].ident;
917                                mutator_field_idx += 1;
918                                fld_counts.push(quote! {
919                                    _count += self.#mutator_ident.mutation_count(
920                                        #binding,
921                                        shrink,
922                                    )?;
923                                });
924                            } else {
925                                patterns.push(quote! { _, });
926                            }
927                        }
928                        count_arms.push(quote! {
929                            #ty_name::#variant_ident(#(#patterns)*) => {
930                                #(#fld_counts)*
931                            }
932                        });
933                    }
934                    Fields::Unit => {
935                        count_arms.push(quote! {
936                            #ty_name::#variant_ident => {}
937                        });
938                    }
939                }
940            }
941
942            if count_arms.is_empty() {
943                quote! { Some(0u32) }
944            } else {
945                quote! {
946                    let mut _count = #variant_switch_count as u32;
947                    match *value {
948                        #(#count_arms)*
949                    }
950                    Some(_count)
951                }
952            }
953        }
954
955        Data::Union(_) => quote! { None },
956    };
957
958    let mutation_count_method = quote! {
959        #[inline]
960        fn mutation_count(&self, value: &#ty_name, shrink: bool) -> core::option::Option<u32> {
961            #mutation_count_body
962        }
963    };
964
965    let mutator_name = &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
966
967    Ok(quote! {
968        #[automatically_derived]
969        impl #impl_generics mutatis::Mutate<#ty_name> for #mutator_name
970            #where_clause
971        {
972            #mutate_method
973            #mutation_count_method
974        }
975    })
976}
977
978fn gen_default_mutator_impl(
979    mutator_ty: &MutatorType,
980    container_attrs: &ContainerAttributes,
981) -> Result<TokenStream> {
982    let impl_default = container_attrs.default_mutate.unwrap_or(true);
983    if !impl_default {
984        return Ok(quote! {});
985    }
986
987    let ty_generics = if mutator_ty.ty_impl_generics.is_empty() {
988        quote! {}
989    } else {
990        let gens = &mutator_ty.ty_impl_generics;
991        quote! { < #( #gens ),* > }
992    };
993
994    let ty_name = mutator_ty.ty_name_with_generics();
995    let where_clause = mutator_ty.where_clause(WhereClauseKind::DefaultMutateBounds);
996    let mutator_name =
997        &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::JustTyGenerics);
998
999    Ok(quote! {
1000        #[automatically_derived]
1001        impl #ty_generics mutatis::DefaultMutate for #ty_name
1002            #where_clause
1003        {
1004            type DefaultMutate = #mutator_name;
1005        }
1006    })
1007}
1008
1009fn gen_generate_impl(
1010    input: &DeriveInput,
1011    mutator_ty: &MutatorType,
1012    container_attrs: &ContainerAttributes,
1013) -> Result<TokenStream> {
1014    let impl_generate = container_attrs.generate.unwrap_or(true);
1015    if !impl_generate {
1016        return Ok(quote! {});
1017    }
1018
1019    let impl_generics = mutator_ty.mutator_impl_generics();
1020    let ty_name = mutator_ty.ty_name_with_generics();
1021    let mutator_name = &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
1022    let where_clause = mutator_ty.where_clause(WhereClauseKind::MutateAndGenerateBounds);
1023
1024    let mut fields_iter = mutator_ty.mutator_fields.iter();
1025    let mut next_field_generate = || -> TokenStream {
1026        let mf = fields_iter.next().unwrap();
1027        let ident = &mf.ident;
1028        quote! { self.#ident.generate(cx)? }
1029    };
1030
1031    // For struct literals, we can't include generic args (e.g. `Foo<T> { .. }`
1032    // is invalid). Use the bare name and let Rust infer the type params.
1033    let bare_ty_name = &mutator_ty.ty_name;
1034
1035    let generate_body = match &input.data {
1036        Data::Struct(data) => match &data.fields {
1037            Fields::Named(fields) => {
1038                let field_exprs: Vec<_> = fields
1039                    .named
1040                    .iter()
1041                    .map(|f| {
1042                        let ident = &f.ident;
1043                        if FieldBehavior::for_field(f).unwrap().is_some() {
1044                            let expr = next_field_generate();
1045                            quote! { #ident: #expr }
1046                        } else {
1047                            quote! { #ident: Default::default() }
1048                        }
1049                    })
1050                    .collect();
1051                quote! { Ok(#bare_ty_name { #( #field_exprs ),* }) }
1052            }
1053            Fields::Unnamed(fields) => {
1054                let field_exprs: Vec<_> = fields
1055                    .unnamed
1056                    .iter()
1057                    .map(|f| {
1058                        if FieldBehavior::for_field(f).unwrap().is_some() {
1059                            next_field_generate()
1060                        } else {
1061                            quote! { Default::default() }
1062                        }
1063                    })
1064                    .collect();
1065                quote! { Ok(#bare_ty_name( #( #field_exprs ),* )) }
1066            }
1067            Fields::Unit => {
1068                quote! { Ok(#bare_ty_name) }
1069            }
1070        },
1071
1072        Data::Enum(data) => {
1073            if data.variants.is_empty() {
1074                quote! { unreachable!() }
1075            } else {
1076                let num_variants = data.variants.len();
1077                let mut mutator_field_offset = 0usize;
1078                let match_arms: Vec<_> = data
1079                    .variants
1080                    .iter()
1081                    .enumerate()
1082                    .map(|(v_idx, v)| {
1083                        let variant_ident = &v.ident;
1084                        let construction = match &v.fields {
1085                            Fields::Named(fields) => {
1086                                let field_exprs: Vec<_> = fields
1087                                    .named
1088                                    .iter()
1089                                    .map(|f| {
1090                                        let ident = &f.ident;
1091                                        if FieldBehavior::for_field(f).unwrap().is_some() {
1092                                            let mf =
1093                                                &mutator_ty.mutator_fields[mutator_field_offset];
1094                                            let mutator_ident = &mf.ident;
1095                                            mutator_field_offset += 1;
1096                                            quote! { #ident: self.#mutator_ident.generate(cx)? }
1097                                        } else {
1098                                            quote! { #ident: Default::default() }
1099                                        }
1100                                    })
1101                                    .collect();
1102                                quote! { #ty_name::#variant_ident { #( #field_exprs ),* } }
1103                            }
1104                            Fields::Unnamed(fields) => {
1105                                let field_exprs: Vec<_> = fields
1106                                    .unnamed
1107                                    .iter()
1108                                    .map(|f| {
1109                                        if FieldBehavior::for_field(f).unwrap().is_some() {
1110                                            let mf =
1111                                                &mutator_ty.mutator_fields[mutator_field_offset];
1112                                            let mutator_ident = &mf.ident;
1113                                            mutator_field_offset += 1;
1114                                            quote! { self.#mutator_ident.generate(cx)? }
1115                                        } else {
1116                                            quote! { Default::default() }
1117                                        }
1118                                    })
1119                                    .collect();
1120                                quote! { #ty_name::#variant_ident( #( #field_exprs ),* ) }
1121                            }
1122                            Fields::Unit => {
1123                                quote! { #ty_name::#variant_ident }
1124                            }
1125                        };
1126                        quote! { Some(#v_idx) => Ok(#construction), }
1127                    })
1128                    .collect();
1129
1130                quote! {
1131                    match cx.rng().gen_index(#num_variants) {
1132                        #( #match_arms )*
1133                        _ => unreachable!(),
1134                    }
1135                }
1136            }
1137        }
1138
1139        Data::Union(_) => {
1140            return Err(Error::new_spanned(
1141                input,
1142                "cannot `derive(Mutate)` on a union",
1143            ))
1144        }
1145    };
1146
1147    Ok(quote! {
1148        #[automatically_derived]
1149        impl #impl_generics mutatis::Generate<#ty_name> for #mutator_name
1150            #where_clause
1151        {
1152            fn generate(&mut self, cx: &mut mutatis::Context) -> mutatis::Result<#ty_name> {
1153                #generate_body
1154            }
1155        }
1156    })
1157}