enum_ref_macro/
derive.rs

1use crate::utils::AttributeExt;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote_spanned;
4use quote::{format_ident, quote};
5use syn::parse_quote;
6use syn::spanned::Spanned;
7use syn::DeriveInput;
8use syn::Result;
9
10/// Wrapper around [`enum_ref_impl`] for error conversions.
11pub fn enum_ref(input: DeriveInput) -> TokenStream2 {
12    match enum_ref_impl(input) {
13        Ok(result) => result,
14        Err(error) => error.to_compile_error(),
15    }
16}
17
18/// Implements the `#[derive(EnumRef)]` functionality for the given `input`.
19fn enum_ref_impl(input: DeriveInput) -> Result<TokenStream2> {
20    let data = extract_enum(&input)?;
21    let ident = &input.ident;
22    let repr = repr_attr(&input);
23    let ref_ident = format_ident!("{}Ref", ident);
24    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
25    let (ref_generics, ref_lifetime) = make_ref_generics(&input.generics);
26    let (impl_ref_generics, type_ref_generics, _) = ref_generics.split_for_impl();
27    let variants = data.variants.iter().map(make_ref);
28    let arms = data.variants.iter().map(make_arm);
29    Ok(quote! {
30        const _: () = {
31            #[derive(::core::fmt::Debug)]
32            #repr
33            pub enum #ref_ident #impl_ref_generics {
34                #( #variants ),*
35            }
36
37            impl #impl_generics ::enum_ref::EnumRef for #ident #type_generics #where_clause {
38                type Ref<#ref_lifetime> where Self: #ref_lifetime = #ref_ident #type_ref_generics
39                where
40                    Self: #ref_lifetime;
41
42                fn as_ref(&self) -> <Self as ::enum_ref::EnumRef>::Ref<'_> {
43                    // This type alias is a workaround for a Rust compiler bug.
44                    //
45                    // # Note
46                    //
47                    // This is required for a workaround for this issue:
48                    // https://github.com/rust-lang/rust/issues/86935#issuecomment-1484160404
49                    //
50                    // The problem is that we cannot use associated type paths to
51                    // disambiguate enum variants with named fields.
52                    type __enum_ref_EnumRef_Ref #impl_ref_generics =
53                        <#ident #type_generics as ::enum_ref::EnumRef>::Ref<#ref_lifetime>;
54                    match self {
55                        #(
56                            Self::#arms => __enum_ref_EnumRef_Ref::#arms,
57                        )*
58                    }
59                }
60            }
61        };
62    })
63}
64
65/// Wrapper around [`enum_mut_impl`] for error conversions.
66pub fn enum_mut(input: DeriveInput) -> TokenStream2 {
67    match enum_mut_impl(input) {
68        Ok(result) => result,
69        Err(error) => error.to_compile_error(),
70    }
71}
72
73/// Implements the `#[derive(EnumMut)]` functionality for the given `input`.
74fn enum_mut_impl(input: DeriveInput) -> Result<TokenStream2> {
75    let data = extract_enum(&input)?;
76    let ident = &input.ident;
77    let repr = repr_attr(&input);
78    let mut_ident = format_ident!("{}Mut", ident);
79    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
80    let (ref_generics, ref_lifetime) = make_ref_generics(&input.generics);
81    let (impl_ref_generics, type_ref_generics, _) = ref_generics.split_for_impl();
82    let variants = data.variants.iter().map(make_mut);
83    let arms = data.variants.iter().map(make_arm);
84    Ok(quote! {
85        const _: () = {
86            #[derive(::core::fmt::Debug)]
87            #repr
88            pub enum #mut_ident #impl_ref_generics {
89                #( #variants ),*
90            }
91
92            impl #impl_generics ::enum_ref::EnumMut for #ident #type_generics #where_clause {
93                type Mut<#ref_lifetime> where Self: #ref_lifetime = #mut_ident #type_ref_generics
94                where
95                    Self: #ref_lifetime;
96
97                fn as_mut(&mut self) -> <Self as ::enum_ref::EnumMut>::Mut<'_> {
98                    // This type alias is a workaround for a Rust compiler bug.
99                    //
100                    // # Note
101                    //
102                    // This is required for a workaround for this issue:
103                    // https://github.com/rust-lang/rust/issues/86935#issuecomment-1484160404
104                    //
105                    // The problem is that we cannot use associated type paths to
106                    // disambiguate enum variants with named fields.
107                    type __enum_ref_EnumMut_Mut #impl_ref_generics =
108                        <#ident #type_generics as ::enum_ref::EnumMut>::Mut<#ref_lifetime>;
109                    match self {
110                        #(
111                            Self::#arms => __enum_ref_EnumMut_Mut::#arms,
112                        )*
113                    }
114                }
115            }
116        };
117    })
118}
119
120/// Sanitizes the input to the `EnumRef` and `EnumMut` derive macros.
121fn extract_enum(input: &DeriveInput) -> Result<&syn::DataEnum> {
122    let data = match &input.data {
123        syn::Data::Enum(data) => data,
124        syn::Data::Struct(_) => bail_spanned!(
125            input,
126            "derive(EnumRef) only works on `enum` types but found struct"
127        ),
128        syn::Data::Union(_) => bail_spanned!(
129            input,
130            "derive(EnumRef) only works on `enum` types but found union"
131        ),
132    };
133    Ok(data)
134}
135
136fn make_ref(variant: &syn::Variant) -> syn::Variant {
137    make_ref_variant(variant, Mutability::Ref)
138}
139
140fn make_mut(variant: &syn::Variant) -> syn::Variant {
141    make_ref_variant(variant, Mutability::Mut)
142}
143
144#[derive(Debug, Clone, Copy)]
145enum Mutability {
146    Ref,
147    Mut,
148}
149
150/// Adds a special `'__enum_ref_lt` lifetime parameter to the start of the given `generics`.
151///
152/// Returns the adjusted generics as well as the added lifetime.
153fn make_ref_generics(generics: &syn::Generics) -> (syn::Generics, syn::Lifetime) {
154    let mut generics = generics.clone();
155    let lifetime = make_ref_lifetime();
156    generics.params.insert(0, parse_quote!(#lifetime));
157    (generics, lifetime)
158}
159
160fn make_ref_lifetime() -> syn::Lifetime {
161    parse_quote!('__enum_ref_lt)
162}
163
164fn make_ref_variant(variant: &syn::Variant, mutable: Mutability) -> syn::Variant {
165    let lt = make_ref_lifetime();
166    let mutability = matches!(mutable, Mutability::Mut).then_some(quote!(mut));
167    let mut fields = variant.fields.clone();
168    for field in &mut fields {
169        let ty = &field.ty;
170        let ref_ty: syn::TypeReference = parse_quote!(&#lt #mutability #ty);
171        field.ty = syn::Type::Reference(ref_ty);
172    }
173    syn::Variant {
174        fields,
175        ..variant.clone()
176    }
177}
178
179fn repr_attr(input: &DeriveInput) -> Option<syn::Attribute> {
180    input
181        .attrs
182        .iter()
183        .cloned()
184        .find(AttributeExt::is_repr_attribute)
185}
186
187fn make_arm(variant: &syn::Variant) -> TokenStream2 {
188    let span = variant.span();
189    let ident = &variant.ident;
190    match &variant.fields {
191        syn::Fields::Named(fields) => {
192            let names = fields.named.iter().map(|f| {
193                f.ident
194                    .as_ref()
195                    .expect("named fields must have identifiers")
196            });
197            quote_spanned!(span=> #ident { #(#names),* })
198        }
199        syn::Fields::Unnamed(fields) => {
200            let underscores = fields
201                .unnamed
202                .iter()
203                .enumerate()
204                .map(|(n, _field)| format_ident!("_{n}"));
205            quote_spanned!(span=> #ident (#(#underscores),*))
206        }
207        syn::Fields::Unit => {
208            quote_spanned!(span=> #ident)
209        }
210    }
211}