repr_with_fallback/
lib.rs

1//! Automatically generate [`From`] and [`Into`] impls for `enum`s with custom discriminant values
2//! and a fallback variant.
3//!
4//! # Usage
5//! ```
6//! use repr_with_fallback::repr_with_fallback;
7//!
8//! repr_with_fallback! {
9//!     /// A DNSSEC algorithm.
10//!     #[derive(Debug, PartialEq)]
11//!     pub enum Algorithm {
12//!         /// ...
13//!         RSASHA256 = 8,
14//!         RSASHA512 = 10,
15//!         ECDSAP256SHA256 = 13,
16//!         ECDSAP384SHA384 = 14,
17//!         ED25519 = 15,
18//!         Unassigned(u8),
19//!     }
20//! }
21//!
22//! assert_eq!(u8::from(Algorithm::ED25519), 15);
23//! assert_eq!(Algorithm::from(15), Algorithm::ED25519);
24//!
25//! assert_eq!(u8::from(Algorithm::Unassigned(17)), 17);
26//! assert_eq!(Algorithm::from(17), Algorithm::Unassigned(17));
27//! ```
28//!
29//! There are two restrictions imposed on the `enum`:
30//! 1. It must have only unit variants, except for exactly one variant with exactly one unnamed
31//!    field. This is used as the fallback variant, and the type of its field must match the type of
32//!    the discriminants.
33//! 2. Every variant must have a discriminant value provided.
34//!
35//! The repr type does not need to be numerical:
36//! ```
37//! repr_with_fallback! {
38//!     pub enum Strings {
39//!         Foo = "static",
40//!         Bar = "string",
41//!         Baz = "slices",
42//!         Spam = "work",
43//!         Eggs = "too",
44//!         Unknown(&'static str),
45//!     }
46//! }
47//!
48//! let s: &'static str = Strings::Foo.into();
49//! assert_eq!(s, "static");
50//! ```
51
52use proc_macro::TokenStream;
53use quote::{quote, quote_spanned};
54use syn::{Expr, Fields, FieldsUnnamed, ItemEnum, Type, Variant};
55
56#[proc_macro]
57pub fn repr_with_fallback(input: TokenStream) -> TokenStream {
58    let ast: ItemEnum = match syn::parse(input) {
59        Ok(ast) => ast,
60        Err(_) => {
61            return quote! {
62                compile_error!("This macro expects an enum definition as its input.");
63            }
64            .into();
65        }
66    };
67
68    let (fallback_variant, repr_type) = match get_repr_type(&ast) {
69        Ok(r) => r,
70        Err(e) => return e.into(),
71    };
72
73    let (unit_variants, discriminant_exprs) = match get_discriminant_exprs(&ast) {
74        Ok(d) => d,
75        Err(e) => return e.into(),
76    };
77
78    // create new enum definition, this time without the discriminants (they would cause a compile
79    // error)
80    let mut enum_without_discriminants = ast.clone();
81    enum_without_discriminants
82        .variants
83        .iter_mut()
84        .for_each(|var| var.discriminant = None);
85
86    let from_enum_impl = gen_from_enum_impl(
87        &ast,
88        repr_type,
89        fallback_variant,
90        &unit_variants,
91        &discriminant_exprs,
92    );
93    let from_repr_impl = gen_from_repr_impl(
94        &ast,
95        repr_type,
96        fallback_variant,
97        &unit_variants,
98        &discriminant_exprs,
99    );
100
101    quote! {
102        #enum_without_discriminants
103        #from_enum_impl
104        #from_repr_impl
105    }
106    .into()
107}
108
109fn get_repr_type(ast: &ItemEnum) -> Result<(&Variant, &Type), proc_macro2::TokenStream> {
110    let variants = &ast.variants;
111    let unit_variants_count = variants
112        .iter()
113        .filter(|var| var.fields == Fields::Unit)
114        .count();
115
116    let err = quote_spanned! {ast.ident.span()=>
117        compile_error!("Tthe enum must have only unit variants plus exactly one variant with exactly one unnamed field.");
118    };
119
120    if unit_variants_count != variants.len() - 1 {
121        return Err(err);
122    }
123
124    // check that we have one fallback variant
125    let (fallback_variant, fallback_variant_fields) = variants
126        .iter()
127        .filter_map(|var| match &var.fields {
128            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => Some((var, unnamed)),
129            _ => None,
130        })
131        .next()
132        .ok_or_else(|| err.clone())?;
133    // check that the fallback variant has one field
134    if fallback_variant_fields.len() != 1 {
135        return Err(err);
136    }
137
138    Ok((fallback_variant, &fallback_variant_fields[0].ty))
139}
140
141fn get_discriminant_exprs(
142    ast: &ItemEnum,
143) -> Result<(Vec<&Variant>, Vec<&Expr>), proc_macro2::TokenStream> {
144    let unit_variants: Vec<_> = ast
145        .variants
146        .iter()
147        .filter(|var| var.fields == Fields::Unit)
148        .collect();
149
150    // check that all unit variants have a discriminant.
151    // this is either `Ok(d)` with `d` being all discriminants, or `Err(i`) with `i` being the ident
152    // of the first variant without a discriminant
153    let discriminants: Result<Vec<_>, _> = unit_variants
154        .iter()
155        .map(|var| var.discriminant.as_ref().ok_or(&var.ident))
156        .collect();
157    let discriminant_exprs = match discriminants {
158        Err(ident) => {
159            return Err(quote_spanned! {ident.span()=>
160                compile_error!("All unit variants must have a discriminant.");
161            })
162        }
163        // get the expression of the discriminant (discard the " = " part)
164        Ok(d) => d.iter().map(|d| &d.1).collect(),
165    };
166    Ok((unit_variants, discriminant_exprs))
167}
168
169fn gen_from_enum_impl(
170    ast: &ItemEnum,
171    repr_type: &Type,
172    fallback_variant: &Variant,
173    unit_variants: &[&Variant],
174    discriminant_exprs: &[&Expr],
175) -> proc_macro2::TokenStream {
176    let enum_ident = &ast.ident;
177    let unit_variant_maps =
178        unit_variants
179            .iter()
180            .zip(discriminant_exprs.iter())
181            .map(|(var, expr)| {
182                let var_ident = &var.ident;
183                quote! {
184                    #enum_ident::#var_ident => #expr
185                }
186            });
187    let unit_variant_maps = quote!(#(#unit_variant_maps),*);
188
189    let fallback_ident = &fallback_variant.ident;
190    quote! {
191        impl From<#enum_ident> for #repr_type {
192            fn from(val: #enum_ident) -> Self {
193                match val {
194                    #unit_variant_maps,
195                    #enum_ident::#fallback_ident(x) => x,
196                }
197            }
198        }
199    }
200}
201
202fn gen_from_repr_impl(
203    ast: &ItemEnum,
204    repr_type: &Type,
205    fallback_variant: &Variant,
206    unit_variants: &[&Variant],
207    discriminant_exprs: &[&Expr],
208) -> proc_macro2::TokenStream {
209    let enum_ident = &ast.ident;
210    let unit_variant_map_iter =
211        unit_variants
212            .iter()
213            .zip(discriminant_exprs.iter())
214            .map(|(var, expr)| {
215                let var_ident = &var.ident;
216                quote! {
217                    #expr => #enum_ident::#var_ident
218                }
219            });
220    let unit_variant_maps = quote!(#(#unit_variant_map_iter),*);
221
222    let fallback_ident = &fallback_variant.ident;
223    quote! {
224        impl From<#repr_type> for #enum_ident {
225            fn from(val: #repr_type) -> Self {
226                match val {
227                    #unit_variant_maps,
228                    x => #enum_ident::#fallback_ident(x),
229                }
230            }
231        }
232    }
233}