mutatis_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{spanned::Spanned, *};
6
7mod container_attributes;
8mod field_attributes;
9use container_attributes::ContainerAttributes;
10use field_attributes::FieldBehavior;
11
12static MUTATIS_ATTRIBUTE_NAME: &str = "mutatis";
13
14#[proc_macro_derive(Mutate, attributes(mutatis))]
15pub fn derive_mutator(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
16    let input = syn::parse_macro_input!(tokens as DeriveInput);
17    expand_derive_mutator(input)
18        .unwrap_or_else(syn::Error::into_compile_error)
19        .into()
20}
21
22fn expand_derive_mutator(input: DeriveInput) -> Result<TokenStream> {
23    let container_attrs = ContainerAttributes::from_derive_input(&input)?;
24    let mutator_ty = MutatorType::new(&input, &container_attrs)?;
25
26    let mutator_type_def = gen_mutator_type_def(&input, &mutator_ty, &container_attrs)?;
27    let mutator_type_default_impl = gen_mutator_type_default_impl(&mutator_ty)?;
28    let mutator_ctor = gen_mutator_ctor(&mutator_ty)?;
29    let mutator_impl = gen_mutator_impl(&input, &mutator_ty)?;
30    let default_mutator_impl = gen_default_mutator_impl(&mutator_ty, &container_attrs)?;
31
32    Ok(quote! {
33        #mutator_type_def
34        #mutator_type_default_impl
35        #mutator_ctor
36        #mutator_impl
37        #default_mutator_impl
38    })
39}
40
41struct MutatorType {
42    ty_name: Ident,
43
44    mutator_name: Ident,
45
46    mutator_fields: Vec<MutatorField>,
47
48    /// A vec of quoted generic parameters, without any bounds but with `const`
49    /// defs, e.g. `'a`, `const N: usize`, or `T`.
50    ty_impl_generics: Vec<TokenStream>,
51
52    /// A vec of quoted generic parameters, without any bounds and without any
53    /// `const` defs, e.t. `'a`, `N`, or `T.
54    ty_name_generics: Vec<TokenStream>,
55
56    /// A vec of quoted bounds for the generics above, e.g. `A: Iterator<Item =
57    /// B>,`.
58    ty_generics_bounds: Vec<TokenStream>,
59}
60
61impl MutatorType {
62    fn new(input: &DeriveInput, container_attrs: &ContainerAttributes) -> Result<Self> {
63        let ty_name = input.ident.clone();
64
65        let mutator_name = container_attrs
66            .mutator_name
67            .clone()
68            .unwrap_or_else(|| Ident::new(&format!("{}Mutator", input.ident), input.ident.span()));
69
70        let mutator_fields = get_mutator_fields(&input)?;
71
72        let mut ty_impl_generics = vec![];
73        let mut ty_name_generics = vec![];
74        let mut ty_generics_bounds = vec![];
75
76        for gen in &input.generics.params {
77            match gen {
78                GenericParam::Lifetime(l) => {
79                    if !l.bounds.is_empty() {
80                        ty_generics_bounds.push(quote! { #l });
81                    }
82
83                    let l = &l.lifetime;
84                    ty_impl_generics.push(quote! { #l });
85                    ty_name_generics.push(quote! { #l });
86                }
87                GenericParam::Const(c) => {
88                    ty_impl_generics.push(quote! { #c });
89                    let c = &c.ident;
90                    ty_name_generics.push(quote! { #c });
91                }
92                GenericParam::Type(t) => {
93                    if !t.bounds.is_empty() {
94                        ty_generics_bounds.push(quote! { #t });
95                    }
96                    let t = &t.ident;
97                    ty_impl_generics.push(quote! { #t });
98                    ty_name_generics.push(quote! { #t });
99                }
100            }
101        }
102
103        if let Some(wc) = &input.generics.where_clause {
104            for bound in wc.predicates.iter() {
105                ty_generics_bounds.push(quote! { #bound });
106            }
107        }
108
109        Ok(Self {
110            ty_name,
111            mutator_name,
112            mutator_fields,
113            ty_impl_generics,
114            ty_name_generics,
115            ty_generics_bounds,
116        })
117    }
118
119    fn mutator_impl_generics_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
120        self.ty_impl_generics.iter().cloned().chain(
121            self.mutator_fields
122                .iter()
123                .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
124        )
125    }
126
127    /// All the `impl` generic parameters for this mutator, including those
128    /// inherited from the type that it is a mutator for.
129    fn mutator_impl_generics(&self) -> TokenStream {
130        let impl_generics = self
131            .ty_impl_generics
132            .iter()
133            .cloned()
134            .chain(
135                self.mutator_fields
136                    .iter()
137                    .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
138            )
139            .collect::<Vec<_>>();
140        if impl_generics.is_empty() {
141            quote! {}
142        } else {
143            quote! { < #( #impl_generics ),* > }
144        }
145    }
146
147    /// All the named (i.e. just the "N" and excluding "const", ":", and "usize"
148    /// in `const N: usize` generics) generic parameters for this mutator,
149    /// including those inherited from the type that it is a mutator for.
150    fn mutator_name_generics_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
151        self.ty_name_generics.iter().cloned().chain(
152            self.mutator_fields
153                .iter()
154                .filter_map(|f| f.generic.as_ref().map(|g| quote! { #g })),
155        )
156    }
157
158    fn mutator_impl_generics_with_defaults_iter(&self) -> impl Iterator<Item = TokenStream> + '_ {
159        self.ty_impl_generics
160            .iter()
161            .cloned()
162            .chain(self.mutator_fields.iter().filter_map(move |f| {
163                f.generic.as_ref().map(|g| {
164                    let for_ty = &f.for_ty;
165                    quote! { #g = <#for_ty as mutatis::DefaultMutate>::DefaultMutate }
166                })
167            }))
168    }
169
170    fn ty_name_with_generics(&self) -> TokenStream {
171        let ty_name = &self.ty_name;
172        if self.ty_name_generics.is_empty() {
173            quote! { #ty_name }
174        } else {
175            let ty_generics = self.ty_name_generics.iter();
176            quote! { #ty_name < #( #ty_generics ),* > }
177        }
178    }
179
180    fn mutator_name_with_generics(&self, kind: MutatorNameGenericsKind) -> TokenStream {
181        let mutator_name = &self.mutator_name;
182
183        let generics = match kind {
184            MutatorNameGenericsKind::Generics => {
185                self.mutator_name_generics_iter().collect::<Vec<_>>()
186            }
187            MutatorNameGenericsKind::Impl {
188                impl_default: false,
189            } => self.mutator_impl_generics_iter().collect::<Vec<_>>(),
190            MutatorNameGenericsKind::Impl { impl_default: true } => self
191                .mutator_impl_generics_with_defaults_iter()
192                .collect::<Vec<_>>(),
193            MutatorNameGenericsKind::JustTyGenerics => self.ty_name_generics.clone(),
194        };
195
196        if generics.is_empty() {
197            quote! { #mutator_name }
198        } else {
199            quote! { #mutator_name < #( #generics ),* > }
200        }
201    }
202
203    fn where_clause(&self, kind: WhereClauseKind) -> TokenStream {
204        let mut bounds = self.ty_generics_bounds.clone();
205
206        match kind {
207            WhereClauseKind::NoMutateBounds => {}
208            WhereClauseKind::MutateBounds => {
209                for f in &self.mutator_fields {
210                    let for_ty = &f.for_ty;
211                    if let Some(g) = f.generic.as_ref() {
212                        bounds.push(quote! { #g: mutatis::Mutate<#for_ty> });
213                    } else {
214                        debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
215                        bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
216                    }
217                }
218            }
219            WhereClauseKind::DefaultBounds => {
220                for f in &self.mutator_fields {
221                    if let Some(g) = f.generic.as_ref() {
222                        bounds.push(quote! { #g: Default });
223                    } else {
224                        let for_ty = &f.for_ty;
225                        debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
226                        bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
227                    }
228                }
229            }
230            WhereClauseKind::DefaultMutateBounds => {
231                for f in &self.mutator_fields {
232                    let for_ty = &f.for_ty;
233                    bounds.push(quote! { #for_ty: mutatis::DefaultMutate });
234                }
235            }
236        }
237
238        if bounds.is_empty() {
239            quote! {}
240        } else {
241            quote! { where #( #bounds ),* }
242        }
243    }
244
245    fn phantom_fields_defs<'a>(
246        &self,
247        input: &'a DeriveInput,
248    ) -> impl Iterator<Item = TokenStream> + 'a {
249        let make_phantom_field = |i, ty| {
250            let ident = Ident::new(&format!("_phantom{i}"), Span::call_site());
251            quote! { #ident : core::marker::PhantomData<#ty> , }
252        };
253
254        input
255            .generics
256            .params
257            .iter()
258            .enumerate()
259            .map(move |(i, g)| match g {
260                GenericParam::Lifetime(l) => {
261                    let l = &l.lifetime;
262                    make_phantom_field(i, quote! { & #l () })
263                }
264                GenericParam::Const(c) => {
265                    let c = &c.ident;
266                    make_phantom_field(i, quote! { [(); #c] })
267                }
268                GenericParam::Type(t) => {
269                    let t = &t.ident;
270                    make_phantom_field(i, quote! { #t })
271                }
272            })
273    }
274
275    fn phantom_fields_literals(&self) -> impl Iterator<Item = TokenStream> + '_ {
276        (0..self.ty_name_generics.len()).map(|i| {
277            let ident = Ident::new(&format!("_phantom{i}"), Span::call_site());
278            quote! { #ident : core::marker::PhantomData, }
279        })
280    }
281}
282
283#[derive(Clone, Copy)]
284enum WhereClauseKind {
285    NoMutateBounds,
286    MutateBounds,
287    DefaultBounds,
288    DefaultMutateBounds,
289}
290
291#[derive(Clone, Copy)]
292enum MutatorNameGenericsKind {
293    Generics,
294    Impl { impl_default: bool },
295    JustTyGenerics,
296}
297
298struct MutatorField {
299    /// The identifier for this field inside the mutator struct.
300    ident: Ident,
301    /// The generic type parameter for this field, if any.
302    generic: Option<Ident>,
303    /// The behavior for this field.
304    behavior: FieldBehavior,
305    /// The type that this field is a mutator for.
306    for_ty: Type,
307}
308
309fn get_mutator_fields(input: &DeriveInput) -> Result<Vec<MutatorField>> {
310    let mut i = 0;
311    let mut generic = |b: &FieldBehavior| -> Option<Ident> {
312        if b.needs_generic() {
313            let g = Ident::new(&format!("MutatorT{}", i), Span::call_site());
314            i += 1;
315            Some(g)
316        } else {
317            None
318        }
319    };
320
321    match &input.data {
322        Data::Struct(data) => match &data.fields {
323            Fields::Named(fields) => fields
324                .named
325                .iter()
326                .filter_map(|f| {
327                    FieldBehavior::for_field(f)
328                        .map(|b| {
329                            b.map(|b| MutatorField {
330                                ident: f.ident.clone().unwrap(),
331                                generic: generic(&b),
332                                behavior: b,
333                                for_ty: f.ty.clone(),
334                            })
335                        })
336                        .transpose()
337                })
338                .collect(),
339            Fields::Unnamed(fields) => fields
340                .unnamed
341                .iter()
342                .enumerate()
343                .filter_map(|(i, f)| {
344                    FieldBehavior::for_field(f)
345                        .map(|b| {
346                            b.map(|b| MutatorField {
347                                ident: Ident::new(&format!("field{}", i), f.span()),
348                                generic: generic(&b),
349                                behavior: b,
350                                for_ty: f.ty.clone(),
351                            })
352                        })
353                        .transpose()
354                })
355                .collect(),
356            Fields::Unit => Ok(vec![]),
357        },
358        Data::Enum(data) => Ok(data
359            .variants
360            .iter()
361            .map(|v| {
362                let prefix = v.ident.to_string().to_lowercase();
363                match v.fields {
364                    Fields::Named(ref fields) => fields
365                        .named
366                        .iter()
367                        .filter_map(|f| {
368                            FieldBehavior::for_field(f)
369                                .map(|b| {
370                                    b.map(|b| MutatorField {
371                                        ident: Ident::new(
372                                            &format!("{prefix}_{}", f.ident.clone().unwrap()),
373                                            f.span(),
374                                        ),
375                                        generic: generic(&b),
376                                        behavior: b,
377                                        for_ty: f.ty.clone(),
378                                    })
379                                })
380                                .transpose()
381                        })
382                        .collect::<Result<Vec<_>>>(),
383                    Fields::Unnamed(ref fields) => fields
384                        .unnamed
385                        .iter()
386                        .enumerate()
387                        .filter_map(|(i, f)| {
388                            FieldBehavior::for_field(f)
389                                .map(|b| {
390                                    b.map(|b| MutatorField {
391                                        ident: Ident::new(&format!("{prefix}{i}"), f.span()),
392                                        generic: generic(&b),
393                                        behavior: b,
394                                        for_ty: f.ty.clone(),
395                                    })
396                                })
397                                .transpose()
398                        })
399                        .collect::<Result<Vec<_>>>(),
400                    Fields::Unit => Ok(vec![]),
401                }
402            })
403            .collect::<Result<Vec<_>>>()?
404            .into_iter()
405            .flat_map(|fs| fs)
406            .collect()),
407        Data::Union(_) => Err(Error::new_spanned(
408            input,
409            "cannot `derive(Mutate)` on a union",
410        )),
411    }
412}
413
414fn gen_mutator_type_def(
415    input: &DeriveInput,
416    mutator_ty: &MutatorType,
417    container_attrs: &ContainerAttributes,
418) -> Result<TokenStream> {
419    let vis = &input.vis;
420    let name = &input.ident;
421
422    let impl_default = container_attrs.default_mutate.unwrap_or(true);
423    let mutator_name =
424        mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Impl { impl_default });
425
426    let mut temp: Option<LitStr> = None;
427    let doc = container_attrs.mutator_doc.as_deref().unwrap_or_else(|| {
428        temp = Some(LitStr::new(
429            &format!(" A mutator for the `{name}` type."),
430            input.ident.span(),
431        ));
432        std::slice::from_ref(temp.as_ref().unwrap())
433    });
434
435    let where_clause = mutator_ty.where_clause(WhereClauseKind::NoMutateBounds);
436
437    let fields = mutator_ty
438        .mutator_fields
439        .iter()
440        .map(|f| {
441            let ident = &f.ident;
442            if let Some(g) = f.generic.as_ref() {
443                quote! { #ident: #g , }
444            } else {
445                let for_ty = &f.for_ty;
446                debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
447                quote! { #ident: <#for_ty as mutatis::DefaultMutate>::DefaultMutate, }
448            }
449        })
450        .collect::<Vec<_>>();
451
452    let phantoms = mutator_ty.phantom_fields_defs(input);
453
454    Ok(quote! {
455        #( #[doc = #doc] )*
456        // #[derive(Clone, Debug)]
457        #vis struct #mutator_name #where_clause {
458            #( #fields )*
459            #( #phantoms )*
460            _private: (),
461        }
462    })
463}
464
465fn gen_mutator_type_default_impl(mutator_ty: &MutatorType) -> Result<TokenStream> {
466    let impl_generics = mutator_ty.mutator_impl_generics();
467    let mutator_name = mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
468    let where_clause = mutator_ty.where_clause(WhereClauseKind::DefaultBounds);
469
470    let fields = mutator_ty
471        .mutator_fields
472        .iter()
473        .map(|f| {
474            let ident = &f.ident;
475            quote! { #ident: Default::default(), }
476        })
477        .collect::<Vec<_>>();
478
479    let phantoms = mutator_ty.phantom_fields_literals();
480
481    Ok(quote! {
482        #[automatically_derived]
483        impl #impl_generics Default for #mutator_name #where_clause {
484            fn default() -> Self {
485                Self {
486                    #( #fields )*
487                    #( #phantoms )*
488                    _private: (),
489                }
490            }
491        }
492    })
493}
494
495fn gen_mutator_ctor(mutator_ty: &MutatorType) -> Result<TokenStream> {
496    let impl_generics = mutator_ty.mutator_impl_generics();
497
498    let params = mutator_ty
499        .mutator_fields
500        .iter()
501        .filter_map(|f| {
502            f.generic.as_ref().map(|g| {
503                let ident = &f.ident;
504                quote! { #ident: #g , }
505            })
506        })
507        .collect::<Vec<_>>();
508
509    let fields = mutator_ty
510        .mutator_fields
511        .iter()
512        .map(|f| {
513            let ident = &f.ident;
514            if f.generic.is_some() {
515                quote! { #ident , }
516            } else {
517                let for_ty = &f.for_ty;
518                debug_assert_eq!(f.behavior, FieldBehavior::DefaultMutate);
519                quote! { #ident: mutatis::mutators::default::<#for_ty>() , }
520            }
521        })
522        .collect::<Vec<_>>();
523
524    let name = &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
525    let doc = format!("Construct a new `{name}` instance.");
526    let where_clause = mutator_ty.where_clause(WhereClauseKind::NoMutateBounds);
527    let phantoms = mutator_ty.phantom_fields_literals();
528
529    Ok(quote! {
530        impl #impl_generics #name #where_clause {
531            #[doc = #doc]
532            #[inline]
533            pub fn new( #( #params )* ) -> Self {
534                Self {
535                    #( #fields )*
536                    #( #phantoms )*
537                    _private: (),
538                }
539            }
540        }
541    })
542}
543
544fn gen_mutator_impl(input: &DeriveInput, mutator_ty: &MutatorType) -> Result<TokenStream> {
545    // TODO: make a list of all the individual mutations we *could* make, and
546    // then choose only one of them to actually perform.
547
548    let impl_generics = mutator_ty.mutator_impl_generics();
549
550    let ty_name = mutator_ty.ty_name_with_generics();
551    let where_clause = mutator_ty.where_clause(WhereClauseKind::MutateBounds);
552
553    let mut fields_iter = mutator_ty.mutator_fields.iter();
554    let mut make_mutation = |value| {
555        let ident = &fields_iter.next().unwrap().ident;
556        quote! { self.#ident.mutate(mutations, #value)?; }
557    };
558
559    let mutation_body = match &input.data {
560        Data::Struct(data) => match &data.fields {
561            Fields::Named(fields) => {
562                let mutations = fields
563                    .named
564                    .iter()
565                    .filter(|f| FieldBehavior::for_field(f).unwrap().is_some())
566                    .map(|f| {
567                        let ident = &f.ident;
568                        make_mutation(quote! { &mut value.#ident })
569                    });
570                quote! {
571                    #( #mutations )*
572                }
573            }
574            Fields::Unnamed(fields) => {
575                let mutations = fields
576                    .unnamed
577                    .iter()
578                    .enumerate()
579                    .filter(|(_i, f)| FieldBehavior::for_field(f).unwrap().is_some())
580                    .map(|(i, f)| {
581                        let index = Index {
582                            index: u32::try_from(i).unwrap(),
583                            span: f.span(),
584                        };
585                        make_mutation(quote! { &mut value.#index })
586                    });
587                quote! {
588                    #( #mutations )*
589                }
590            }
591            Fields::Unit => quote! {},
592        },
593
594        Data::Enum(data) => {
595            // TODO: add support for changing from one enum variant to another.
596
597            let mut variants = vec![];
598            for v in data.variants.iter() {
599                let variant_ident = &v.ident;
600                match &v.fields {
601                    Fields::Named(fields) => {
602                        let mut patterns = vec![];
603                        let mutates = fields
604                            .named
605                            .iter()
606                            .filter_map(|f| {
607                                let ident = &f.ident;
608                                if FieldBehavior::for_field(f).unwrap().is_some() {
609                                    patterns.push(quote! { #ident , });
610                                    Some(make_mutation(quote! { #ident }))
611                                } else {
612                                    patterns.push(quote! { #ident: _ , });
613                                    None
614                                }
615                            })
616                            .collect::<Vec<_>>();
617                        variants.push(quote! {
618                            #ty_name::#variant_ident { #( #patterns )* } => {
619                                #( #mutates )*
620                            }
621                        });
622                    }
623
624                    Fields::Unnamed(fields) => {
625                        let mut patterns = vec![];
626                        let mutates = fields
627                            .unnamed
628                            .iter()
629                            .enumerate()
630                            .filter_map(|(i, f)| {
631                                if FieldBehavior::for_field(f).unwrap().is_some() {
632                                    let binding = Ident::new(&format!("field{}", i), f.span());
633                                    patterns.push(quote! { #binding , });
634                                    Some(make_mutation(quote! { #binding }))
635                                } else {
636                                    patterns.push(quote! { _ , });
637                                    None
638                                }
639                            })
640                            .collect::<Vec<_>>();
641                        variants.push(quote! {
642                            #ty_name::#variant_ident( #( #patterns )* ) => {
643                                #( #mutates )*
644                            }
645                        });
646                    }
647
648                    Fields::Unit => {
649                        variants.push(quote! {
650                            #ty_name::#variant_ident => {}
651                        });
652                    }
653                }
654            }
655
656            quote! {
657                match value {
658                    #( #variants )*
659                }
660            }
661        }
662
663        Data::Union(_) => {
664            return Err(Error::new_spanned(
665                input,
666                "cannot `derive(Mutate)` on a union",
667            ))
668        }
669    };
670
671    let mutate_method = quote! {
672        fn mutate(
673            &mut self,
674            mutations: &mut mutatis::Candidates,
675            value: &mut #ty_name,
676        ) -> mutatis::Result<()> {
677            #mutation_body
678
679            // Silence unused-variable warnings if every field was marked `ignore`.
680            let _ = (mutations, value);
681
682            Ok(())
683        }
684    };
685
686    let mutator_name = &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::Generics);
687
688    Ok(quote! {
689        #[automatically_derived]
690        impl #impl_generics mutatis::Mutate<#ty_name> for #mutator_name
691            #where_clause
692        {
693            #mutate_method
694        }
695    })
696}
697
698fn gen_default_mutator_impl(
699    mutator_ty: &MutatorType,
700    container_attrs: &ContainerAttributes,
701) -> Result<TokenStream> {
702    let impl_default = container_attrs.default_mutate.unwrap_or(true);
703    if !impl_default {
704        return Ok(quote! {});
705    }
706
707    let ty_generics = if mutator_ty.ty_impl_generics.is_empty() {
708        quote! {}
709    } else {
710        let gens = &mutator_ty.ty_impl_generics;
711        quote! { < #( #gens ),* > }
712    };
713
714    let ty_name = mutator_ty.ty_name_with_generics();
715    let where_clause = mutator_ty.where_clause(WhereClauseKind::DefaultMutateBounds);
716    let mutator_name =
717        &mutator_ty.mutator_name_with_generics(MutatorNameGenericsKind::JustTyGenerics);
718
719    Ok(quote! {
720        #[automatically_derived]
721        impl #ty_generics mutatis::DefaultMutate for #ty_name
722            #where_clause
723        {
724            type DefaultMutate = #mutator_name;
725        }
726    })
727}