1use crate::utils::AttributeExt;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote_spanned;
4use quote::{format_ident, quote};
5use syn::parse_quote;
6use syn::spanned::Spanned;
7use syn::DeriveInput;
8use syn::Result;
9
10pub fn enum_ref(input: DeriveInput) -> TokenStream2 {
12 match enum_ref_impl(input) {
13 Ok(result) => result,
14 Err(error) => error.to_compile_error(),
15 }
16}
17
18fn enum_ref_impl(input: DeriveInput) -> Result<TokenStream2> {
20 let data = extract_enum(&input)?;
21 let ident = &input.ident;
22 let repr = repr_attr(&input);
23 let ref_ident = format_ident!("{}Ref", ident);
24 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
25 let (ref_generics, ref_lifetime) = make_ref_generics(&input.generics);
26 let (impl_ref_generics, type_ref_generics, _) = ref_generics.split_for_impl();
27 let variants = data.variants.iter().map(make_ref);
28 let arms = data.variants.iter().map(make_arm);
29 Ok(quote! {
30 const _: () = {
31 #[derive(::core::fmt::Debug)]
32 #repr
33 pub enum #ref_ident #impl_ref_generics {
34 #( #variants ),*
35 }
36
37 impl #impl_generics ::enum_ref::EnumRef for #ident #type_generics #where_clause {
38 type Ref<#ref_lifetime> where Self: #ref_lifetime = #ref_ident #type_ref_generics
39 where
40 Self: #ref_lifetime;
41
42 fn as_ref(&self) -> <Self as ::enum_ref::EnumRef>::Ref<'_> {
43 type __enum_ref_EnumRef_Ref #impl_ref_generics =
53 <#ident #type_generics as ::enum_ref::EnumRef>::Ref<#ref_lifetime>;
54 match self {
55 #(
56 Self::#arms => __enum_ref_EnumRef_Ref::#arms,
57 )*
58 }
59 }
60 }
61 };
62 })
63}
64
65pub fn enum_mut(input: DeriveInput) -> TokenStream2 {
67 match enum_mut_impl(input) {
68 Ok(result) => result,
69 Err(error) => error.to_compile_error(),
70 }
71}
72
73fn enum_mut_impl(input: DeriveInput) -> Result<TokenStream2> {
75 let data = extract_enum(&input)?;
76 let ident = &input.ident;
77 let repr = repr_attr(&input);
78 let mut_ident = format_ident!("{}Mut", ident);
79 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
80 let (ref_generics, ref_lifetime) = make_ref_generics(&input.generics);
81 let (impl_ref_generics, type_ref_generics, _) = ref_generics.split_for_impl();
82 let variants = data.variants.iter().map(make_mut);
83 let arms = data.variants.iter().map(make_arm);
84 Ok(quote! {
85 const _: () = {
86 #[derive(::core::fmt::Debug)]
87 #repr
88 pub enum #mut_ident #impl_ref_generics {
89 #( #variants ),*
90 }
91
92 impl #impl_generics ::enum_ref::EnumMut for #ident #type_generics #where_clause {
93 type Mut<#ref_lifetime> where Self: #ref_lifetime = #mut_ident #type_ref_generics
94 where
95 Self: #ref_lifetime;
96
97 fn as_mut(&mut self) -> <Self as ::enum_ref::EnumMut>::Mut<'_> {
98 type __enum_ref_EnumMut_Mut #impl_ref_generics =
108 <#ident #type_generics as ::enum_ref::EnumMut>::Mut<#ref_lifetime>;
109 match self {
110 #(
111 Self::#arms => __enum_ref_EnumMut_Mut::#arms,
112 )*
113 }
114 }
115 }
116 };
117 })
118}
119
120fn extract_enum(input: &DeriveInput) -> Result<&syn::DataEnum> {
122 let data = match &input.data {
123 syn::Data::Enum(data) => data,
124 syn::Data::Struct(_) => bail_spanned!(
125 input,
126 "derive(EnumRef) only works on `enum` types but found struct"
127 ),
128 syn::Data::Union(_) => bail_spanned!(
129 input,
130 "derive(EnumRef) only works on `enum` types but found union"
131 ),
132 };
133 Ok(data)
134}
135
136fn make_ref(variant: &syn::Variant) -> syn::Variant {
137 make_ref_variant(variant, Mutability::Ref)
138}
139
140fn make_mut(variant: &syn::Variant) -> syn::Variant {
141 make_ref_variant(variant, Mutability::Mut)
142}
143
144#[derive(Debug, Clone, Copy)]
145enum Mutability {
146 Ref,
147 Mut,
148}
149
150fn make_ref_generics(generics: &syn::Generics) -> (syn::Generics, syn::Lifetime) {
154 let mut generics = generics.clone();
155 let lifetime = make_ref_lifetime();
156 generics.params.insert(0, parse_quote!(#lifetime));
157 (generics, lifetime)
158}
159
160fn make_ref_lifetime() -> syn::Lifetime {
161 parse_quote!('__enum_ref_lt)
162}
163
164fn make_ref_variant(variant: &syn::Variant, mutable: Mutability) -> syn::Variant {
165 let lt = make_ref_lifetime();
166 let mutability = matches!(mutable, Mutability::Mut).then_some(quote!(mut));
167 let mut fields = variant.fields.clone();
168 for field in &mut fields {
169 let ty = &field.ty;
170 let ref_ty: syn::TypeReference = parse_quote!(&#lt #mutability #ty);
171 field.ty = syn::Type::Reference(ref_ty);
172 }
173 syn::Variant {
174 fields,
175 ..variant.clone()
176 }
177}
178
179fn repr_attr(input: &DeriveInput) -> Option<syn::Attribute> {
180 input
181 .attrs
182 .iter()
183 .cloned()
184 .find(AttributeExt::is_repr_attribute)
185}
186
187fn make_arm(variant: &syn::Variant) -> TokenStream2 {
188 let span = variant.span();
189 let ident = &variant.ident;
190 match &variant.fields {
191 syn::Fields::Named(fields) => {
192 let names = fields.named.iter().map(|f| {
193 f.ident
194 .as_ref()
195 .expect("named fields must have identifiers")
196 });
197 quote_spanned!(span=> #ident { #(#names),* })
198 }
199 syn::Fields::Unnamed(fields) => {
200 let underscores = fields
201 .unnamed
202 .iter()
203 .enumerate()
204 .map(|(n, _field)| format_ident!("_{n}"));
205 quote_spanned!(span=> #ident (#(#underscores),*))
206 }
207 syn::Fields::Unit => {
208 quote_spanned!(span=> #ident)
209 }
210 }
211}