Skip to main content

enum_discriminant_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, parse_quote, AttrStyle};
5
6macro_rules! compile_error_unless_ok {
7    ($result:expr) => {
8        match $result {
9            Ok(value) => value,
10            Err(error) => return error.to_compile_error().into(),
11        }
12    };
13}
14
15/// Adds a `discriminant()` function to get the numeric value of enum variants. Also adds
16/// a `from_discriminant()` function to create unit type enum variants from discriminants.
17///
18/// The `discriminant()` function relies on casting, as described in the
19/// [Rust language documentation](https://doc.rust-lang.org/std/mem/fn.discriminant.html#accessing-the-numeric-value-of-the-discriminant).
20/// The `from_discriminant()` function, on the other hand, is essentially a `match`
21/// statement with all the unit type variants.
22#[proc_macro_attribute]
23pub fn discriminant(arguments: TokenStream, item: TokenStream) -> TokenStream {
24    let enum_item = parse_macro_input!(item as syn::ItemEnum);
25    let enum_name = &enum_item.ident;
26
27    let arguments: TokenStream2 = arguments.into();
28
29    let repr_type = compile_error_unless_ok!(get_repr_type(arguments.clone()));
30
31    let from_discriminant_code = generate_from_discriminant_function(&repr_type, &enum_item);
32    let discriminant_code = generate_discriminant_function(&repr_type);
33
34    quote! {
35        #[repr(#arguments)]
36        #enum_item
37
38        impl #enum_name {
39            #from_discriminant_code
40
41            #discriminant_code
42        }
43    }
44    .into()
45}
46
47/// Derive macro generating an impl for the `IntoDiscriminant` trait for enums. The trait
48/// adds a `discriminant()` function to get the numeric value of enum variants.
49///
50/// The `discriminant()` function relies on casting, as described in the
51/// [Rust language documentation](https://doc.rust-lang.org/std/mem/fn.discriminant.html#accessing-the-numeric-value-of-the-discriminant).
52#[proc_macro_derive(IntoDiscriminant)]
53pub fn derive_into_discriminant(item: TokenStream) -> TokenStream {
54    let input = parse_macro_input!(item as syn::DeriveInput);
55    let enum_name = &input.ident;
56
57    let repr_args = compile_error_unless_ok!(get_repr_args("IntoDiscriminant", &input));
58    let repr_type = compile_error_unless_ok!(get_repr_type(repr_args));
59
60    let discriminant_code = generate_discriminant_function(&repr_type);
61
62    quote! {
63        impl IntoDiscriminant for #enum_name {
64            type DiscriminantType = #repr_type;
65
66            #discriminant_code
67        }
68    }
69    .into()
70}
71
72/// Derive macro generating an impl for the `FromDiscriminant` trait for enums. The trait
73/// adds a `from_discriminant()` function to create unit type enum variants from
74/// discriminants.
75///
76/// The `from_discriminant()` function is essentially a `match` statement with all the
77/// unit type variants.
78#[proc_macro_derive(FromDiscriminant)]
79pub fn derive_from_discriminant(item: TokenStream) -> TokenStream {
80    let cloned_item = item.clone();
81    let input = parse_macro_input!(item as syn::DeriveInput);
82    let enum_item = parse_macro_input!(cloned_item as syn::ItemEnum);
83    let enum_name = &enum_item.ident;
84
85    let repr_args = compile_error_unless_ok!(get_repr_args("FromDiscriminant", &input));
86    let repr_type = compile_error_unless_ok!(get_repr_type(repr_args));
87
88    let from_discriminant_code = generate_from_discriminant_function(&repr_type, &enum_item);
89
90    quote! {
91        impl FromDiscriminant for #enum_name {
92            type DiscriminantType = #repr_type;
93
94            #from_discriminant_code
95        }
96    }
97    .into()
98}
99
100/// Returns the first valid representation type found in the arguments or a compile error
101/// if none is found.
102fn get_repr_type(arguments: TokenStream2) -> Result<syn::Path, syn::Error> {
103    let allowed_types = [
104        "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", "isize",
105    ];
106
107    arguments
108        .clone()
109        .into_iter()
110        // Filter arguments that are allowed types and convert them to syn::Path
111        .filter_map(|token_tree| {
112            if let proc_macro2::TokenTree::Ident(ident) = token_tree {
113                let ident_str = ident.to_string();
114                if allowed_types.contains(&ident_str.as_str()) {
115                    return Some(syn::parse_str::<syn::Path>(&ident_str).unwrap());
116                }
117            }
118            None
119        })
120        .next()
121        // On error, return a compile error as a TokenStream
122        .ok_or_else(|| {
123            syn::Error::new_spanned(
124                arguments,
125                "Valid enum representation type expected as argument to the discriminant \
126                 macro, e.g., #[discriminant(u8)]",
127            )
128        })
129}
130
131// Finds the first `repr` or `discriminant` attribute in the input and returns its
132// arguments. This is used to determine the representation type of the enum.
133fn get_repr_args(macro_name: &str, input: &syn::DeriveInput) -> Result<TokenStream2, syn::Error> {
134    let x = input
135        .attrs
136        .iter()
137        .filter(|attr| matches!(attr.style, AttrStyle::Outer))
138        .filter(|attr| {
139            let path = attr.path();
140            path.is_ident("repr") || path.is_ident("discriminant")
141        })
142        .filter_map(|attr| attr.meta.require_list().ok())
143        .next()
144        .ok_or_else(|| {
145            syn::Error::new_spanned(
146                input,
147                format!(
148                    "When deriving {} on an enum, you also need to specify \
149                     representation type with #[repr()] or #[discriminant()]",
150                    macro_name
151                ),
152            )
153        })?;
154    Ok(x.tokens.clone())
155}
156
157// Returns a tuple of the names and discriminants of the unit variants of an enum. The
158// discriminants are returned as expressions, since explicit input discriminants can be
159// constant expressions.
160fn enum_unit_variants(enum_item: &syn::ItemEnum) -> (Vec<proc_macro2::Ident>, Vec<syn::Expr>) {
161    let mut previous_expr: Option<syn::Expr> = None;
162    enum_item
163        .variants
164        .iter()
165        .filter(|variant| matches!(variant.fields, syn::Fields::Unit))
166        .map(|variant| {
167            let expr = if let Some(discriminant) = &variant.discriminant {
168                discriminant.1.clone()
169            } else if let Some(ref old_expr) = previous_expr {
170                parse_quote!( 1 + #old_expr )
171            } else {
172                parse_quote!(0)
173            };
174            previous_expr = Some(expr.clone());
175            (variant.ident.clone(), expr)
176        })
177        .unzip()
178}
179
180fn generate_from_discriminant_function(
181    repr_type: &syn::Path,
182    enum_item: &syn::ItemEnum,
183) -> TokenStream2 {
184    let (variant_names, discriminants) = enum_unit_variants(enum_item);
185    let enum_name = &enum_item.ident;
186
187    quote! {
188        /// Creates an enum variant from a discriminant numeric value if there is a unit
189        /// type variant with that value.
190        fn from_discriminant(discriminant: #repr_type) -> Option<Self> {
191            match discriminant {
192                // Match arm guard needed in case discriminant is not a literal but
193                // constant other expression
194                #( discriminant if discriminant == #discriminants =>
195                    Some(#enum_name::#variant_names), )*
196                _ => None,
197            }
198        }
199    }
200}
201
202fn generate_discriminant_function(repr_type: &syn::Path) -> TokenStream2 {
203    quote! {
204         /// Returns the discriminant numeric value of an enum variant.
205         fn discriminant(&self) -> #repr_type {
206            // See https://doc.rust-lang.org/core/mem/fn.discriminant.html
207            unsafe {
208                *<*const _>::from(self).cast::<#repr_type>()
209            }
210        }
211    }
212}