Skip to main content

mutatis_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{spanned::Spanned, *};
6
7mod container_attributes;
8mod field_attributes;
9use container_attributes::ContainerAttributes;
10use field_attributes::FieldBehavior;
11
12static MUTATIS_ATTRIBUTE_NAME: &str = "mutatis";
13
14#[proc_macro_derive(Mutate, attributes(mutatis))]
15pub fn derive_mutator(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
16    let input = syn::parse_macro_input!(tokens as DeriveInput);
17    expand_derive_mutator(input)
18        .unwrap_or_else(syn::Error::into_compile_error)
19        .into()
20}
21
22fn expand_derive_mutator(input: DeriveInput) -> Result<TokenStream> {
23    let container_attrs = ContainerAttributes::from_derive_input(&input)?;
24    let mutator_ty = MutatorType::new(&input, &container_attrs)?;
25
26    let mutator_type_def = gen_mutator_type_def(&input, &mutator_ty, &container_attrs)?;
27    let mutator_type_default_impl = gen_mutator_type_default_impl(&mutator_ty)?;
28    let mutator_ctor = gen_mutator_ctor(&mutator_ty)?;
29    let mutator_impl = gen_mutator_impl(&input, &mutator_ty)?;
30    let default_mutator_impl = gen_default_mutator_impl(&mutator_ty, &container_attrs)?;
31    let generate_impl = gen_generate_impl(&input, &mutator_ty, &container_attrs)?;
32
33    Ok(quote! {
34        #mutator_type_def
35        #mutator_type_default_impl
36        #mutator_ctor
37        #mutator_impl
38        #default_mutator_impl
39        #generate_impl
40    })
41}
42
43struct MutatorType {
44    ty_name: Ident,
45
46    mutator_name: Ident,
47
48    mutator_fields: Vec<MutatorField>,
49
50    /// A vec of quoted generic parameters, without any bounds but with `const`
51    /// defs, e.g. `'a`, `const N: usize`, or `T`.
52    ty_impl_generics: Vec<TokenStream>,
53
54    /// A vec of quoted generic parameters, without any bounds and without any
55    /// `const` defs, e.t. `'a`, `N`, or `T.
56    ty_name_generics: Vec<TokenStream>,
57
58    /// A vec of quoted bounds for the generics above, e.g. `A: Iterator<Item =
59    /// B>,`.
60    ty_generics_bounds: Vec<TokenStream>,
61}
62
63impl MutatorType {
64    fn new(input: &DeriveInput, container_attrs: &ContainerAttributes) -> Result<Self> {
65        let ty_name = input.ident.clone();
66
67        let mutator_name = container_attrs
68            .mutator_name
69            .clone()
70            .unwrap_or_else(|| Ident::new(&format!("{}Mutator", input.ident), input.ident.span()));
71
72        let mutator_fields = get_mutator_fields(&input)?;
73
74        let mut ty_impl_generics = vec![];
75        let mut ty_name_generics = vec![];
76        let mut ty_generics_bounds = vec![];
77
78        for gen in &input.generics.params {
79            match gen {
80                GenericParam::Lifetime(l) => {
81                    if !l.bounds.is_empty() {
82                        ty_generics_bounds.push(quote! { #l });
83                    }
84
85                    let l = &l.lifetime;
86                    ty_impl_generics.push(quote! { #l });
87                    ty_name_generics.push(quote! { #l });
88                }
89                GenericParam::Const(c) => {
90                    ty_impl_generics.push(quote! { #c });
91                    let c = &c.ident;
92                    ty_name_generics.push(quote! { #c });
93                }
94                GenericParam::Type(t) => {
95                    if !t.bounds.is_empty() {
96                        ty_generics_bounds.push(quote! { #t });
97                    }
98                    let t = &t.ident;
99                    ty_impl_generics.push(quote! { #t });
100                    ty_name_generics.push(quote! { #t });
101                }
102            }
103        }
104
105        if let Some(wc) = &input.generics.where_clause {
106            for bound in wc.predicates.iter() {
107                ty_generics_bounds.push(quote! { #bound });
108            }
109        }
110
111        Ok(Self {
112            ty_name,
113            mutator_name,
114            mutator_fields,
115            ty_impl_generics,
116            ty_name_generics,
117            ty_generics_bounds,
118        })
119    }
120
121    fn mutator_impl_generics_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
122        self.ty_impl_generics.iter().cloned().chain(
123            self.mutator_fields
124                .iter()
125                .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
126        )
127    }
128
129    /// All the `impl` generic parameters for this mutator, including those
130    /// inherited from the type that it is a mutator for.
131    fn mutator_impl_generics(&self) -> TokenStream {
132        let impl_generics = self
133            .ty_impl_generics
134            .iter()
135            .cloned()
136            .chain(
137                self.mutator_fields
138                    .iter()
139                    .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
140            )
141            .collect::<Vec<_>>();
142        if impl_generics.is_empty() {
143            quote! {}
144        } else {
145            quote! { < #( #impl_generics ),* > }
146        }
147    }
148
149    /// All the named (i.e. just the "N" and excluding "const", ":", and "usize"
150    /// in `const N: usize` generics) generic parameters for this mutator,
151    /// including those inherited from the type that it is a mutator for.
152    fn mutator_name_generics_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
153        self.ty_name_generics.iter().cloned().chain(
154            self.mutator_fields
155                .iter()
156                .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
157        )
158    }
159
160    fn mutator_impl_generics_with_defaults_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
161        self.ty_impl_generics
162            .iter()
163            .cloned()
164            .chain(self.mutator_fields.iter().filter_map(move |f| {
165                f.generic.as_ref().map(|g| {
166                    let for_ty = &f.for_ty;
167                    quote! { #g = <#for_ty as mutatis::DefaultMutate>::DefaultMutate }
168                })
169            }))
170    }
171
172    fn ty_name_with_generics(&self) -> TokenStream {
173        let ty_name = &self.ty_name;
174        if self.ty_name_generics.is_empty() {
175            quote! { #ty_name }
176        } else {
177            let ty_generics = self.ty_name_generics.iter();
178            quote! { #ty_name < #( #ty_generics ),* > }
179        }
180    }
181
182    fn mutator_name_with_generics(&self, kind: MutatorNameGenericsKind) -> TokenStream {
183        let mutator_name = &self.mutator_name;
184
185        let generics = match kind {
186            MutatorNameGenericsKind::Generics => {
187                self.mutator_name_generics_iter().collect::<Vec<_>>()
188            }
189            MutatorNameGenericsKind::Impl {
190                impl_default: false,
191            } => self.mutator_impl_generics_iter().collect::<Vec<_>>(),
192            MutatorNameGenericsKind::Impl { impl_default: true } => self
193                .mutator_impl_generics_with_defaults_iter()
194                .collect::<Vec<_>>(),
195            MutatorNameGenericsKind::JustTyGenerics => self.ty_name_generics.clone(),
196        };
197
198        if generics.is_empty() {
199            quote! { #mutator_name }
200        } else {
201            quote! { #mutator_name < #( #generics ),* > }
202        }
203    }
204
205    fn where_clause(&self, kind: WhereClauseKind) -> TokenStream {
206        let mut bounds = self.ty_generics_bounds.clone();
207
208        match kind {
209            WhereClauseKind::NoMutateBounds => {}
210            WhereClauseKind::MutateBounds => {
211                for f in &self.mutator_fields {
212                    let for_ty = &f.for_ty;
213                    if let Some(g) = f.generic.as_ref() {
214                        bounds.push(quote! { #g: mutatis::Mutate<#for_ty> });
215                    } else {
216                        debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
217                        bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
218                    }
219                }
220            }
221            WhereClauseKind::MutateAndGenerateBounds => {
222                for f in &self.mutator_fields {
223                    let for_ty = &f.for_ty;
224                    if let Some(g) = f.generic.as_ref() {
225                        bounds.push(
226                            quote! { #g: mutatis::Mutate<#for_ty> + mutatis::Generate<#for_ty> },
227                        );
228                    } else {
229                        debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
230                        bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
231                        bounds.push(quote! { <#for_ty as mutatis::DefaultMutate>::DefaultMutate: mutatis::Generate<#for_ty> });
232                    }
233                }
234            }
235            WhereClauseKind::DefaultBounds => {
236                for f in &self.mutator_fields {
237                    if let Some(g) = f.generic.as_ref() {
238                        bounds.push(quote! { #g: Default });
239                    } else {
240                        let for_ty = &f.for_ty;
241                        debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
242                        bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
243                    }
244                }
245            }
246            WhereClauseKind::DefaultMutateBounds => {
247                for f in &self.mutator_fields {
248                    let for_ty = &f.for_ty;
249                    bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
250                }
251            }
252        }
253
254        if bounds.is_empty() {
255            quote! {}
256        } else {
257            quote! { where #( #bounds ),* }
258        }
259    }
260
261    fn phantom_fields_defs<'a>(
262        &self,
263        input: &'a DeriveInput,
264    ) -> impl Iterator<Item = TokenStream> + 'a {
265        let make_phantom_field = |i, ty| {
266            let ident = Ident::new(&format!("_phantom{i}"), Span::call_site());
267            quote! { #ident : core::marker::PhantomData<#ty> , }
268        };
269
270        input
271            .generics
272            .params
273            .iter()
274            .enumerate()
275            .map(move |(i, g)| match g {
276                GenericParam::Lifetime(l) => {
277                    let l = &l.lifetime;
278                    make_phantom_field(i, quote! { & #l () })
279                }
280                GenericParam::Const(c) => {
281                    let c = &c.ident;
282                    make_phantom_field(i, quote! { [(); #c] })
283                }
284                GenericParam::Type(t) => {
285                    let t = &t.ident;
286                    make_phantom_field(i, quote! { #t })
287                }
288            })
289    }
290
291    fn phantom_fields_literals(&self) -> impl Iterator<Item = TokenStream> + '_ {
292        (0..self.ty_name_generics.len()).map(|i| {
293            let ident = Ident::new(&format!("_phantom{i}"), Span::call_site());
294            quote! { #ident : core::marker::PhantomData, }
295        })
296    }
297}
298
299#[derive(Clone, Copy)]
300enum WhereClauseKind {
301    NoMutateBounds,
302    MutateBounds,
303    MutateAndGenerateBounds,
304    DefaultBounds,
305    DefaultMutateBounds,
306}
307
308#[derive(Clone, Copy)]
309enum MutatorNameGenericsKind {
310    Generics,
311    Impl { impl_default: bool },
312    JustTyGenerics,
313}
314
315struct MutatorField {
316    /// The identifier for this field inside the mutator struct.
317    ident: Ident,
318    /// The generic type parameter for this field, if any.
319    generic: Option<Ident>,
320    /// The behavior for this field.
321    behavior: FieldBehavior,
322    /// The type that this field is a mutator for.
323    for_ty: Type,
324}
325
326fn get_mutator_fields(input: &DeriveInput) -> Result<Vec<MutatorField>> {
327    let mut i = 0;
328    let mut generic = |b: &FieldBehavior| -> Option<Ident> {
329        if b.needs_generic() {
330            let g = Ident::new(&format!("MutatorT{}", i), Span::call_site());
331            i += 1;
332            Some(g)
333        } else {
334            None
335        }
336    };
337
338    match &input.data {
339        Data::Struct(data) => match &data.fields {
340            Fields::Named(fields) => fields
341                .named
342                .iter()
343                .filter_map(|f| {
344                    FieldBehavior::for_field(f)
345                        .map(|b| {
346                            b.map(|b| MutatorField {
347                                ident: f.ident.clone().unwrap(),
348                                generic: generic(&b),
349                                behavior: b,
350                                for_ty: f.ty.clone(),
351                            })
352                        })
353                        .transpose()
354                })
355                .collect(),
356            Fields::Unnamed(fields) => fields
357                .unnamed
358                .iter()
359                .enumerate()
360                .filter_map(|(i, f)| {
361                    FieldBehavior::for_field(f)
362                        .map(|b| {
363                            b.map(|b| MutatorField {
364                                ident: Ident::new(&format!("field{}", i), f.span()),
365                                generic: generic(&b),
366                                behavior: b,
367                                for_ty: f.ty.clone(),
368                            })
369                        })
370                        .transpose()
371                })
372                .collect(),
373            Fields::Unit => Ok(vec![]),
374        },
375        Data::Enum(data) => Ok(data
376            .variants
377            .iter()
378            .map(|v| {
379                let prefix = v.ident.to_string().to_lowercase();
380                match v.fields {
381                    Fields::Named(ref fields) => fields
382                        .named
383                        .iter()
384                        .filter_map(|f| {
385                            FieldBehavior::for_field(f)
386                                .map(|b| {
387                                    b.map(|b| MutatorField {
388                                        ident: Ident::new(
389                                            &format!("{prefix}_{}", f.ident.clone().unwrap()),
390                                            f.span(),
391                                        ),
392                                        generic: generic(&b),
393                                        behavior: b,
394                                        for_ty: f.ty.clone(),
395                                    })
396                                })
397                                .transpose()
398                        })
399                        .collect::<Result<Vec<_>>>(),
400                    Fields::Unnamed(ref fields) => fields
401                        .unnamed
402                        .iter()
403                        .enumerate()
404                        .filter_map(|(i, f)| {
405                            FieldBehavior::for_field(f)
406                                .map(|b| {
407                                    b.map(|b| MutatorField {
408                                        ident: Ident::new(&format!("{prefix}{i}"), f.span()),
409                                        generic: generic(&b),
410                                        behavior: b,
411                                        for_ty: f.ty.clone(),
412                                    })
413                                })
414                                .transpose()
415                        })
416                        .collect::<Result<Vec<_>>>(),
417                    Fields::Unit => Ok(vec![]),
418                }
419            })
420            .collect::<Result<Vec<_>>>()?
421            .into_iter()
422            .flat_map(|fs| fs)
423            .collect()),
424        Data::Union(_) => Err(Error::new_spanned(
425            input,
426            "cannot `derive(Mutate)` on a union",
427        )),
428    }
429}
430
431fn gen_mutator_type_def(
432    input: &DeriveInput,
433    mutator_ty: &MutatorType,
434    container_attrs: &ContainerAttributes,
435) -> Result<TokenStream> {
436    let vis = &input.vis;
437    let name = &input.ident;
438
439    let impl_default = container_attrs.default_mutate.unwrap_or(true);
440    let mutator_name =
441        mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Impl { impl_default });
442
443    let mut temp: Option<LitStr> = None;
444    let doc = container_attrs.mutator_doc.as_deref().unwrap_or_else(|| {
445        temp = Some(LitStr::new(
446            &format!(" A mutator for the `{name}` type."),
447            input.ident.span(),
448        ));
449        std::slice::from_ref(temp.as_ref().unwrap())
450    });
451
452    let where_clause = mutator_ty.where_clause(WhereClauseKind::NoMutateBounds);
453
454    let fields = mutator_ty
455        .mutator_fields
456        .iter()
457        .map(|f| {
458            let ident = &f.ident;
459            if let Some(g) = f.generic.as_ref() {
460                quote! { #ident: #g , }
461            } else {
462                let for_ty = &f.for_ty;
463                debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
464                quote! { #ident: <#for_ty as mutatis::DefaultMutate>::DefaultMutate, }
465            }
466        })
467        .collect::<Vec<_>>();
468
469    let phantoms = mutator_ty.phantom_fields_defs(input);
470
471    Ok(quote! {
472        #( #[doc = #doc] )*
473        // #[derive(Clone, Debug)]
474        #vis struct #mutator_name #where_clause {
475            #( #fields )*
476            #( #phantoms )*
477            _private: (),
478        }
479    })
480}
481
482fn gen_mutator_type_default_impl(mutator_ty: &MutatorType) -> Result<TokenStream> {
483    let impl_generics = mutator_ty.mutator_impl_generics();
484    let mutator_name = mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
485    let where_clause = mutator_ty.where_clause(WhereClauseKind::DefaultBounds);
486
487    let fields = mutator_ty
488        .mutator_fields
489        .iter()
490        .map(|f| {
491            let ident = &f.ident;
492            quote! { #ident: Default::default(), }
493        })
494        .collect::<Vec<_>>();
495
496    let phantoms = mutator_ty.phantom_fields_literals();
497
498    Ok(quote! {
499        #[automatically_derived]
500        impl #impl_generics Default for #mutator_name #where_clause {
501            fn default() -> Self {
502                Self {
503                    #( #fields )*
504                    #( #phantoms )*
505                    _private: (),
506                }
507            }
508        }
509    })
510}
511
512fn gen_mutator_ctor(mutator_ty: &MutatorType) -> Result<TokenStream> {
513    let impl_generics = mutator_ty.mutator_impl_generics();
514
515    let params = mutator_ty
516        .mutator_fields
517        .iter()
518        .filter_map(|f| {
519            f.generic.as_ref().map(|g| {
520                let ident = &f.ident;
521                quote! { #ident: #g , }
522            })
523        })
524        .collect::<Vec<_>>();
525
526    let fields = mutator_ty
527        .mutator_fields
528        .iter()
529        .map(|f| {
530            let ident = &f.ident;
531            if f.generic.is_some() {
532                quote! { #ident , }
533            } else {
534                let for_ty = &f.for_ty;
535                debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
536                quote! { #ident: mutatis::mutators::default::<#for_ty>() , }
537            }
538        })
539        .collect::<Vec<_>>();
540
541    let name = &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
542    let doc = format!("Construct a new `{name}` instance.");
543    let where_clause = mutator_ty.where_clause(WhereClauseKind::NoMutateBounds);
544    let phantoms = mutator_ty.phantom_fields_literals();
545
546    Ok(quote! {
547        impl #impl_generics #name #where_clause {
548            #[doc = #doc]
549            #[inline]
550            pub fn new( #( #params )* ) -> Self {
551                Self {
552                    #( #fields )*
553                    #( #phantoms )*
554                    _private: (),
555                }
556            }
557        }
558    })
559}
560
561fn gen_mutator_impl(input: &DeriveInput, mutator_ty: &MutatorType) -> Result<TokenStream> {
562    let impl_generics = mutator_ty.mutator_impl_generics();
563
564    let ty_name = mutator_ty.ty_name_with_generics();
565
566    let is_multi_variant_enum = matches!(&input.data, Data::Enum(data) if data.variants.len() > 1);
567    let where_clause = if is_multi_variant_enum {
568        mutator_ty.where_clause(WhereClauseKind::MutateAndGenerateBounds)
569    } else {
570        mutator_ty.where_clause(WhereClauseKind::MutateBounds)
571    };
572
573    let mut fields_iter = mutator_ty.mutator_fields.iter();
574    let mut make_mutation = |value| {
575        let ident = &fields_iter.next().unwrap().ident;
576        quote! { self.#ident.mutate(mutations, #value)?; }
577    };
578
579    let mutation_body = match &input.data {
580        Data::Struct(data) => match &data.fields {
581            Fields::Named(fields) => {
582                let mutations = fields
583                    .named
584                    .iter()
585                    .filter(|f| FieldBehavior::for_field(f).unwrap().is_some())
586                    .map(|f| {
587                        let ident = &f.ident;
588                        make_mutation(quote! { &mut value.#ident })
589                    });
590                quote! {
591                    #( #mutations )*
592                }
593            }
594            Fields::Unnamed(fields) => {
595                let mutations = fields
596                    .unnamed
597                    .iter()
598                    .enumerate()
599                    .filter(|(_i, f)| FieldBehavior::for_field(f).unwrap().is_some())
600                    .map(|(i, f)| {
601                        let index = Index {
602                            index: u32::try_from(i).unwrap(),
603                            span: f.span(),
604                        };
605                        make_mutation(quote! { &mut value.#index })
606                    });
607                quote! {
608                    #( #mutations )*
609                }
610            }
611            Fields::Unit => quote! {},
612        },
613
614        Data::Enum(data) => {
615            // Build the existing field-mutation match arms.
616            let mut field_mutation_arms = vec![];
617            for v in data.variants.iter() {
618                let variant_ident = &v.ident;
619                match &v.fields {
620                    Fields::Named(fields) => {
621                        let mut patterns = vec![];
622                        let mutates = fields
623                            .named
624                            .iter()
625                            .filter_map(|f| {
626                                let ident = &f.ident;
627                                if FieldBehavior::for_field(f).unwrap().is_some() {
628                                    patterns.push(quote! { ref mut #ident , });
629                                    Some(make_mutation(quote! { #ident }))
630                                } else {
631                                    patterns.push(quote! { #ident: _ , });
632                                    None
633                                }
634                            })
635                            .collect::<Vec<_>>();
636                        field_mutation_arms.push(quote! {
637                            #ty_name::#variant_ident { #( #patterns )* } => {
638                                #( #mutates )*
639                            }
640                        });
641                    }
642
643                    Fields::Unnamed(fields) => {
644                        let mut patterns = vec![];
645                        let mutates = fields
646                            .unnamed
647                            .iter()
648                            .enumerate()
649                            .filter_map(|(i, f)| {
650                                if FieldBehavior::for_field(f).unwrap().is_some() {
651                                    let binding = Ident::new(&format!("field{}", i), f.span());
652                                    patterns.push(quote! { ref mut #binding , });
653                                    Some(make_mutation(quote! { #binding }))
654                                } else {
655                                    patterns.push(quote! { _ , });
656                                    None
657                                }
658                            })
659                            .collect::<Vec<_>>();
660                        field_mutation_arms.push(quote! {
661                            #ty_name::#variant_ident( #( #patterns )* ) => {
662                                #( #mutates )*
663                            }
664                        });
665                    }
666
667                    Fields::Unit => {
668                        field_mutation_arms.push(quote! {
669                            #ty_name::#variant_ident => {}
670                        });
671                    }
672                }
673            }
674
675            // Build variant-switching mutations for enums with multiple
676            // variants.
677            let variant_switching = if data.variants.len() > 1 {
678                // Build a match to determine the current variant index.
679                let index_arms: Vec<_> = data
680                    .variants
681                    .iter()
682                    .enumerate()
683                    .map(|(v_idx, v)| {
684                        let variant_ident = &v.ident;
685                        match &v.fields {
686                            Fields::Named(_) => {
687                                quote! { #ty_name::#variant_ident { .. } => #v_idx, }
688                            }
689                            Fields::Unnamed(_) => {
690                                quote! { #ty_name::#variant_ident(..) => #v_idx, }
691                            }
692                            Fields::Unit => quote! { #ty_name::#variant_ident => #v_idx, },
693                        }
694                    })
695                    .collect();
696
697                // For each variant, build an expression that constructs a new
698                // instance of that variant. For non-ignored fields, use the
699                // mutator's `generate` method. For ignored fields, use
700                // `Default::default()`.
701                let mut variant_mutations = vec![];
702                let mut mutator_field_offset = 0usize;
703                for (v_idx, v) in data.variants.iter().enumerate() {
704                    let variant_ident = &v.ident;
705
706                    let construction = match &v.fields {
707                        Fields::Named(fields) => {
708                            let field_exprs: Vec<_> = fields
709                                .named
710                                .iter()
711                                .map(|f| {
712                                    let ident = &f.ident;
713                                    if let Some(_behavior) = FieldBehavior::for_field(f).unwrap() {
714                                        let mutator_ident =
715                                            &mutator_ty.mutator_fields[mutator_field_offset].ident;
716                                        mutator_field_offset += 1;
717                                        quote! { #ident: self.#mutator_ident.generate(ctx)? }
718                                    } else {
719                                        quote! { #ident: Default::default() }
720                                    }
721                                })
722                                .collect();
723                            quote! {
724                                *value = #ty_name::#variant_ident { #( #field_exprs ),* };
725                            }
726                        }
727                        Fields::Unnamed(fields) => {
728                            let field_exprs: Vec<_> = fields
729                                .unnamed
730                                .iter()
731                                .map(|f| {
732                                    if let Some(_behavior) = FieldBehavior::for_field(f).unwrap() {
733                                        let mutator_ident =
734                                            &mutator_ty.mutator_fields[mutator_field_offset].ident;
735                                        mutator_field_offset += 1;
736                                        quote! { self.#mutator_ident.generate(ctx)? }
737                                    } else {
738                                        quote! { Default::default() }
739                                    }
740                                })
741                                .collect();
742                            quote! {
743                                *value = #ty_name::#variant_ident( #( #field_exprs ),* );
744                            }
745                        }
746                        Fields::Unit => {
747                            quote! {
748                                *value = #ty_name::#variant_ident;
749                            }
750                        }
751                    };
752
753                    variant_mutations.push((v_idx, construction));
754                }
755
756                let num_variants = data.variants.len();
757                let switch_stmts: Vec<_> = variant_mutations
758                    .iter()
759                    .map(|(v_idx, construction)| {
760                        quote! {
761                            if _variant_index != #v_idx {
762                                mutations.mutation(|ctx| {
763                                    #construction
764                                    Ok(())
765                                })?;
766                            }
767                        }
768                    })
769                    .collect();
770
771                quote! {
772                    let _variant_index: usize = match value {
773                        #( #index_arms )*
774                    };
775                    let _ = #num_variants;
776                    #( #switch_stmts )*
777                }
778            } else {
779                quote! {}
780            };
781
782            quote! {
783                #variant_switching
784                match *value {
785                    #( #field_mutation_arms )*
786                }
787            }
788        }
789
790        Data::Union(_) => {
791            return Err(Error::new_spanned(
792                input,
793                "cannot `derive(Mutate)` on a union",
794            ))
795        }
796    };
797
798    let mutate_method = quote! {
799        fn mutate(
800            &mut self,
801            mutations: &mut mutatis::Candidates,
802            value: &mut #ty_name,
803        ) -> mutatis::Result<()> {
804            #mutation_body
805
806            // Silence unused-variable warnings if every field was marked
807            // `ignore`. Allow unreachable for empty enums.
808            #[allow(unreachable_code)]
809            let _ = (mutations, value);
810
811            Ok(())
812        }
813    };
814
815    let mutator_name = &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
816
817    Ok(quote! {
818        #[automatically_derived]
819        impl #impl_generics mutatis::Mutate<#ty_name> for #mutator_name
820            #where_clause
821        {
822            #mutate_method
823        }
824    })
825}
826
827fn gen_default_mutator_impl(
828    mutator_ty: &MutatorType,
829    container_attrs: &ContainerAttributes,
830) -> Result<TokenStream> {
831    let impl_default = container_attrs.default_mutate.unwrap_or(true);
832    if !impl_default {
833        return Ok(quote! {});
834    }
835
836    let ty_generics = if mutator_ty.ty_impl_generics.is_empty() {
837        quote! {}
838    } else {
839        let gens = &mutator_ty.ty_impl_generics;
840        quote! { < #( #gens ),* > }
841    };
842
843    let ty_name = mutator_ty.ty_name_with_generics();
844    let where_clause = mutator_ty.where_clause(WhereClauseKind::DefaultMutateBounds);
845    let mutator_name =
846        &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::JustTyGenerics);
847
848    Ok(quote! {
849        #[automatically_derived]
850        impl #ty_generics mutatis::DefaultMutate for #ty_name
851            #where_clause
852        {
853            type DefaultMutate = #mutator_name;
854        }
855    })
856}
857
858fn gen_generate_impl(
859    input: &DeriveInput,
860    mutator_ty: &MutatorType,
861    container_attrs: &ContainerAttributes,
862) -> Result<TokenStream> {
863    let impl_generate = container_attrs.generate.unwrap_or(true);
864    if !impl_generate {
865        return Ok(quote! {});
866    }
867
868    let impl_generics = mutator_ty.mutator_impl_generics();
869    let ty_name = mutator_ty.ty_name_with_generics();
870    let mutator_name = &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
871    let where_clause = mutator_ty.where_clause(WhereClauseKind::MutateAndGenerateBounds);
872
873    let mut fields_iter = mutator_ty.mutator_fields.iter();
874    let mut next_field_generate = || -> TokenStream {
875        let mf = fields_iter.next().unwrap();
876        let ident = &mf.ident;
877        quote! { self.#ident.generate(cx)? }
878    };
879
880    // For struct literals, we can't include generic args (e.g. `Foo<T> { .. }`
881    // is invalid). Use the bare name and let Rust infer the type params.
882    let bare_ty_name = &mutator_ty.ty_name;
883
884    let generate_body = match &input.data {
885        Data::Struct(data) => match &data.fields {
886            Fields::Named(fields) => {
887                let field_exprs: Vec<_> = fields
888                    .named
889                    .iter()
890                    .map(|f| {
891                        let ident = &f.ident;
892                        if FieldBehavior::for_field(f).unwrap().is_some() {
893                            let expr = next_field_generate();
894                            quote! { #ident: #expr }
895                        } else {
896                            quote! { #ident: Default::default() }
897                        }
898                    })
899                    .collect();
900                quote! { Ok(#bare_ty_name { #( #field_exprs ),* }) }
901            }
902            Fields::Unnamed(fields) => {
903                let field_exprs: Vec<_> = fields
904                    .unnamed
905                    .iter()
906                    .map(|f| {
907                        if FieldBehavior::for_field(f).unwrap().is_some() {
908                            next_field_generate()
909                        } else {
910                            quote! { Default::default() }
911                        }
912                    })
913                    .collect();
914                quote! { Ok(#bare_ty_name( #( #field_exprs ),* )) }
915            }
916            Fields::Unit => {
917                quote! { Ok(#bare_ty_name) }
918            }
919        },
920
921        Data::Enum(data) => {
922            if data.variants.is_empty() {
923                quote! { unreachable!() }
924            } else {
925                let num_variants = data.variants.len();
926                let mut mutator_field_offset = 0usize;
927                let match_arms: Vec<_> = data
928                    .variants
929                    .iter()
930                    .enumerate()
931                    .map(|(v_idx, v)| {
932                        let variant_ident = &v.ident;
933                        let construction = match &v.fields {
934                            Fields::Named(fields) => {
935                                let field_exprs: Vec<_> = fields
936                                    .named
937                                    .iter()
938                                    .map(|f| {
939                                        let ident = &f.ident;
940                                        if FieldBehavior::for_field(f).unwrap().is_some() {
941                                            let mf =
942                                                &mutator_ty.mutator_fields[mutator_field_offset];
943                                            let mutator_ident = &mf.ident;
944                                            mutator_field_offset += 1;
945                                            quote! { #ident: self.#mutator_ident.generate(cx)? }
946                                        } else {
947                                            quote! { #ident: Default::default() }
948                                        }
949                                    })
950                                    .collect();
951                                quote! { #ty_name::#variant_ident { #( #field_exprs ),* } }
952                            }
953                            Fields::Unnamed(fields) => {
954                                let field_exprs: Vec<_> = fields
955                                    .unnamed
956                                    .iter()
957                                    .map(|f| {
958                                        if FieldBehavior::for_field(f).unwrap().is_some() {
959                                            let mf =
960                                                &mutator_ty.mutator_fields[mutator_field_offset];
961                                            let mutator_ident = &mf.ident;
962                                            mutator_field_offset += 1;
963                                            quote! { self.#mutator_ident.generate(cx)? }
964                                        } else {
965                                            quote! { Default::default() }
966                                        }
967                                    })
968                                    .collect();
969                                quote! { #ty_name::#variant_ident( #( #field_exprs ),* ) }
970                            }
971                            Fields::Unit => {
972                                quote! { #ty_name::#variant_ident }
973                            }
974                        };
975                        quote! { Some(#v_idx) => Ok(#construction), }
976                    })
977                    .collect();
978
979                quote! {
980                    match cx.rng().gen_index(#num_variants) {
981                        #( #match_arms )*
982                        _ => unreachable!(),
983                    }
984                }
985            }
986        }
987
988        Data::Union(_) => {
989            return Err(Error::new_spanned(
990                input,
991                "cannot `derive(Mutate)` on a union",
992            ))
993        }
994    };
995
996    Ok(quote! {
997        #[automatically_derived]
998        impl #impl_generics mutatis::Generate<#ty_name> for #mutator_name
999            #where_clause
1000        {
1001            fn generate(&mut self, cx: &mut mutatis::Context) -> mutatis::Result<#ty_name> {
1002                #generate_body
1003            }
1004        }
1005    })
1006}