repr_with_fallback/
lib.rs1use proc_macro::TokenStream;
53use quote::{quote, quote_spanned};
54use syn::{Expr, Fields, FieldsUnnamed, ItemEnum, Type, Variant};
55
56#[proc_macro]
57pub fn repr_with_fallback(input: TokenStream) -> TokenStream {
58 let ast: ItemEnum = match syn::parse(input) {
59 Ok(ast) => ast,
60 Err(_) => {
61 return quote! {
62 compile_error!("This macro expects an enum definition as its input.");
63 }
64 .into();
65 }
66 };
67
68 let (fallback_variant, repr_type) = match get_repr_type(&ast) {
69 Ok(r) => r,
70 Err(e) => return e.into(),
71 };
72
73 let (unit_variants, discriminant_exprs) = match get_discriminant_exprs(&ast) {
74 Ok(d) => d,
75 Err(e) => return e.into(),
76 };
77
78 let mut enum_without_discriminants = ast.clone();
81 enum_without_discriminants
82 .variants
83 .iter_mut()
84 .for_each(|var| var.discriminant = None);
85
86 let from_enum_impl = gen_from_enum_impl(
87 &ast,
88 repr_type,
89 fallback_variant,
90 &unit_variants,
91 &discriminant_exprs,
92 );
93 let from_repr_impl = gen_from_repr_impl(
94 &ast,
95 repr_type,
96 fallback_variant,
97 &unit_variants,
98 &discriminant_exprs,
99 );
100
101 quote! {
102 #enum_without_discriminants
103 #from_enum_impl
104 #from_repr_impl
105 }
106 .into()
107}
108
109fn get_repr_type(ast: &ItemEnum) -> Result<(&Variant, &Type), proc_macro2::TokenStream> {
110 let variants = &ast.variants;
111 let unit_variants_count = variants
112 .iter()
113 .filter(|var| var.fields == Fields::Unit)
114 .count();
115
116 let err = quote_spanned! {ast.ident.span()=>
117 compile_error!("Tthe enum must have only unit variants plus exactly one variant with exactly one unnamed field.");
118 };
119
120 if unit_variants_count != variants.len() - 1 {
121 return Err(err);
122 }
123
124 let (fallback_variant, fallback_variant_fields) = variants
126 .iter()
127 .filter_map(|var| match &var.fields {
128 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => Some((var, unnamed)),
129 _ => None,
130 })
131 .next()
132 .ok_or_else(|| err.clone())?;
133 if fallback_variant_fields.len() != 1 {
135 return Err(err);
136 }
137
138 Ok((fallback_variant, &fallback_variant_fields[0].ty))
139}
140
141fn get_discriminant_exprs(
142 ast: &ItemEnum,
143) -> Result<(Vec<&Variant>, Vec<&Expr>), proc_macro2::TokenStream> {
144 let unit_variants: Vec<_> = ast
145 .variants
146 .iter()
147 .filter(|var| var.fields == Fields::Unit)
148 .collect();
149
150 let discriminants: Result<Vec<_>, _> = unit_variants
154 .iter()
155 .map(|var| var.discriminant.as_ref().ok_or(&var.ident))
156 .collect();
157 let discriminant_exprs = match discriminants {
158 Err(ident) => {
159 return Err(quote_spanned! {ident.span()=>
160 compile_error!("All unit variants must have a discriminant.");
161 })
162 }
163 Ok(d) => d.iter().map(|d| &d.1).collect(),
165 };
166 Ok((unit_variants, discriminant_exprs))
167}
168
169fn gen_from_enum_impl(
170 ast: &ItemEnum,
171 repr_type: &Type,
172 fallback_variant: &Variant,
173 unit_variants: &[&Variant],
174 discriminant_exprs: &[&Expr],
175) -> proc_macro2::TokenStream {
176 let enum_ident = &ast.ident;
177 let unit_variant_maps =
178 unit_variants
179 .iter()
180 .zip(discriminant_exprs.iter())
181 .map(|(var, expr)| {
182 let var_ident = &var.ident;
183 quote! {
184 #enum_ident::#var_ident => #expr
185 }
186 });
187 let unit_variant_maps = quote!(#(#unit_variant_maps),*);
188
189 let fallback_ident = &fallback_variant.ident;
190 quote! {
191 impl From<#enum_ident> for #repr_type {
192 fn from(val: #enum_ident) -> Self {
193 match val {
194 #unit_variant_maps,
195 #enum_ident::#fallback_ident(x) => x,
196 }
197 }
198 }
199 }
200}
201
202fn gen_from_repr_impl(
203 ast: &ItemEnum,
204 repr_type: &Type,
205 fallback_variant: &Variant,
206 unit_variants: &[&Variant],
207 discriminant_exprs: &[&Expr],
208) -> proc_macro2::TokenStream {
209 let enum_ident = &ast.ident;
210 let unit_variant_map_iter =
211 unit_variants
212 .iter()
213 .zip(discriminant_exprs.iter())
214 .map(|(var, expr)| {
215 let var_ident = &var.ident;
216 quote! {
217 #expr => #enum_ident::#var_ident
218 }
219 });
220 let unit_variant_maps = quote!(#(#unit_variant_map_iter),*);
221
222 let fallback_ident = &fallback_variant.ident;
223 quote! {
224 impl From<#repr_type> for #enum_ident {
225 fn from(val: #repr_type) -> Self {
226 match val {
227 #unit_variant_maps,
228 x => #enum_ident::#fallback_ident(x),
229 }
230 }
231 }
232 }
233}