num_enum_derive/
lib.rs

1// Not supported by MSRV
2#![allow(clippy::uninlined_format_args)]
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::quote;
9use syn::{parse_macro_input, Expr, Ident};
10
11mod enum_attributes;
12mod parsing;
13use parsing::{get_crate_path, EnumInfo};
14mod utils;
15mod variant_attributes;
16
17/// Implements `Into<Primitive>` for a `#[repr(Primitive)] enum`.
18///
19/// (It actually implements `From<Enum> for Primitive`)
20///
21/// ## Allows turning an enum into a primitive.
22///
23/// ```rust
24/// use num_enum::IntoPrimitive;
25///
26/// #[derive(IntoPrimitive)]
27/// #[repr(u8)]
28/// enum Number {
29///     Zero,
30///     One,
31/// }
32///
33/// let zero: u8 = Number::Zero.into();
34/// assert_eq!(zero, 0u8);
35/// ```
36#[proc_macro_derive(IntoPrimitive, attributes(num_enum, catch_all))]
37pub fn derive_into_primitive(input: TokenStream) -> TokenStream {
38    let enum_info = parse_macro_input!(input as EnumInfo);
39    let catch_all = enum_info.catch_all();
40    let name = &enum_info.name;
41    let repr = &enum_info.repr;
42
43    let body = if let Some(catch_all_ident) = catch_all {
44        quote! {
45            match enum_value {
46                #name::#catch_all_ident(raw) => raw,
47                rest => unsafe { *(&rest as *const #name as *const Self) }
48            }
49        }
50    } else {
51        quote! { enum_value as Self }
52    };
53
54    TokenStream::from(quote! {
55        impl From<#name> for #repr {
56            #[inline]
57            fn from (enum_value: #name) -> Self
58            {
59                #body
60            }
61        }
62    })
63}
64
65/// Implements `From<Primitive>` for a `#[repr(Primitive)] enum`.
66///
67/// Turning a primitive into an enum with `from`.
68/// ----------------------------------------------
69///
70/// ```rust
71/// use num_enum::FromPrimitive;
72///
73/// #[derive(Debug, Eq, PartialEq, FromPrimitive)]
74/// #[repr(u8)]
75/// enum Number {
76///     Zero,
77///     #[num_enum(default)]
78///     NonZero,
79/// }
80///
81/// let zero = Number::from(0u8);
82/// assert_eq!(zero, Number::Zero);
83///
84/// let one = Number::from(1u8);
85/// assert_eq!(one, Number::NonZero);
86///
87/// let two = Number::from(2u8);
88/// assert_eq!(two, Number::NonZero);
89/// ```
90#[proc_macro_derive(FromPrimitive, attributes(num_enum, default, catch_all))]
91pub fn derive_from_primitive(input: TokenStream) -> TokenStream {
92    let enum_info: EnumInfo = parse_macro_input!(input);
93    let krate = get_crate_path(enum_info.crate_path.clone());
94
95    let is_naturally_exhaustive = enum_info.is_naturally_exhaustive();
96    let catch_all_body = match is_naturally_exhaustive {
97        Ok(is_naturally_exhaustive) => {
98            if is_naturally_exhaustive {
99                quote! { unreachable!("exhaustive enum") }
100            } else if let Some(default_ident) = enum_info.default() {
101                quote! { Self::#default_ident }
102            } else if let Some(catch_all_ident) = enum_info.catch_all() {
103                quote! { Self::#catch_all_ident(number) }
104            } else {
105                let span = Span::call_site();
106                let message =
107                    "#[derive(num_enum::FromPrimitive)] requires enum to be exhaustive, or a variant marked with `#[default]`, `#[num_enum(default)]`, or `#[num_enum(catch_all)`";
108                return syn::Error::new(span, message).to_compile_error().into();
109            }
110        }
111        Err(err) => {
112            return err.to_compile_error().into();
113        }
114    };
115
116    let EnumInfo {
117        ref name, ref repr, ..
118    } = enum_info;
119
120    let variant_idents: Vec<Ident> = enum_info.variant_idents();
121    let expression_idents: Vec<Vec<Ident>> = enum_info.expression_idents();
122    let variant_expressions: Vec<Vec<Expr>> = enum_info.variant_expressions();
123
124    debug_assert_eq!(variant_idents.len(), variant_expressions.len());
125
126    TokenStream::from(quote! {
127        impl #krate::FromPrimitive for #name {
128            type Primitive = #repr;
129
130            fn from_primitive(number: Self::Primitive) -> Self {
131                // Use intermediate const(s) so that enums defined like
132                // `Two = ONE + 1u8` work properly.
133                #![allow(non_upper_case_globals)]
134                #(
135                    #(
136                        const #expression_idents: #repr = #variant_expressions;
137                    )*
138                )*
139                #[deny(unreachable_patterns)]
140                match number {
141                    #(
142                        #( #expression_idents )|*
143                        => Self::#variant_idents,
144                    )*
145                    #[allow(unreachable_patterns)]
146                    _ => #catch_all_body,
147                }
148            }
149        }
150
151        impl ::core::convert::From<#repr> for #name {
152            #[inline]
153            fn from (
154                number: #repr,
155            ) -> Self {
156                #krate::FromPrimitive::from_primitive(number)
157            }
158        }
159
160        #[doc(hidden)]
161        impl #krate::CannotDeriveBothFromPrimitiveAndTryFromPrimitive for #name {}
162    })
163}
164
165/// Implements `TryFrom<Primitive>` for a `#[repr(Primitive)] enum`.
166///
167/// Attempting to turn a primitive into an enum with `try_from`.
168/// ----------------------------------------------
169///
170/// ```rust
171/// use num_enum::TryFromPrimitive;
172/// use std::convert::TryFrom;
173///
174/// #[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
175/// #[repr(u8)]
176/// enum Number {
177///     Zero,
178///     One,
179/// }
180///
181/// let zero = Number::try_from(0u8);
182/// assert_eq!(zero, Ok(Number::Zero));
183///
184/// let three = Number::try_from(3u8);
185/// assert_eq!(
186///     three.unwrap_err().to_string(),
187///     "No discriminant in enum `Number` matches the value `3`",
188/// );
189/// ```
190#[proc_macro_derive(TryFromPrimitive, attributes(num_enum))]
191pub fn derive_try_from_primitive(input: TokenStream) -> TokenStream {
192    let enum_info: EnumInfo = parse_macro_input!(input);
193    let krate = get_crate_path(enum_info.crate_path.clone());
194    let EnumInfo {
195        ref name,
196        ref repr,
197        ref error_type_info,
198        ..
199    } = enum_info;
200
201    let variant_idents: Vec<Ident> = enum_info.variant_idents();
202    let expression_idents: Vec<Vec<Ident>> = enum_info.expression_idents();
203    let variant_expressions: Vec<Vec<Expr>> = enum_info.variant_expressions();
204
205    debug_assert_eq!(variant_idents.len(), variant_expressions.len());
206
207    let error_type = &error_type_info.name;
208    let error_constructor = &error_type_info.constructor;
209
210    TokenStream::from(quote! {
211        impl #krate::TryFromPrimitive for #name {
212            type Primitive = #repr;
213            type Error = #error_type;
214
215            const NAME: &'static str = stringify!(#name);
216
217            fn try_from_primitive (
218                number: Self::Primitive,
219            ) -> ::core::result::Result<
220                Self,
221                #error_type
222            > {
223                // Use intermediate const(s) so that enums defined like
224                // `Two = ONE + 1u8` work properly.
225                #![allow(non_upper_case_globals)]
226                #(
227                    #(
228                        const #expression_idents: #repr = #variant_expressions;
229                    )*
230                )*
231                #[deny(unreachable_patterns)]
232                match number {
233                    #(
234                        #( #expression_idents )|*
235                        => ::core::result::Result::Ok(Self::#variant_idents),
236                    )*
237                    #[allow(unreachable_patterns)]
238                    _ => ::core::result::Result::Err(
239                        #error_constructor ( number )
240                    ),
241                }
242            }
243        }
244
245        impl ::core::convert::TryFrom<#repr> for #name {
246            type Error = #error_type;
247
248            #[inline]
249            fn try_from (
250                number: #repr,
251            ) -> ::core::result::Result<Self, #error_type>
252            {
253                #krate::TryFromPrimitive::try_from_primitive(number)
254            }
255        }
256
257        #[doc(hidden)]
258        impl #krate::CannotDeriveBothFromPrimitiveAndTryFromPrimitive for #name {}
259    })
260}
261
262/// Generates a `unsafe fn unchecked_transmute_from(number: Primitive) -> Self`
263/// associated function.
264///
265/// Allows unsafely turning a primitive into an enum with unchecked_transmute_from
266/// ------------------------------------------------------------------------------
267///
268/// If you're really certain a conversion will succeed, and want to avoid a small amount of overhead, you can use unsafe
269/// code to do this conversion. Unless you have data showing that the match statement generated in the `try_from` above is a
270/// bottleneck for you, you should avoid doing this, as the unsafe code has potential to cause serious memory issues in
271/// your program.
272///
273/// Note that this derive ignores any `default`, `catch_all`, and `alternatives` attributes on the enum.
274/// If you need support for conversions from these values, you should use `TryFromPrimitive` or `FromPrimitive`.
275///
276/// ```rust
277/// use num_enum::UnsafeFromPrimitive;
278///
279/// #[derive(Debug, Eq, PartialEq, UnsafeFromPrimitive)]
280/// #[repr(u8)]
281/// enum Number {
282///     Zero,
283///     One,
284/// }
285///
286/// fn main() {
287///     assert_eq!(
288///         Number::Zero,
289///         unsafe { Number::unchecked_transmute_from(0_u8) },
290///     );
291///     assert_eq!(
292///         Number::One,
293///         unsafe { Number::unchecked_transmute_from(1_u8) },
294///     );
295/// }
296///
297/// unsafe fn undefined_behavior() {
298///     let _ = Number::unchecked_transmute_from(2); // 2 is not a valid discriminant!
299/// }
300/// ```
301#[proc_macro_derive(UnsafeFromPrimitive, attributes(num_enum))]
302pub fn derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream {
303    let enum_info = parse_macro_input!(stream as EnumInfo);
304    let krate = get_crate_path(enum_info.crate_path);
305
306    let EnumInfo {
307        ref name, ref repr, ..
308    } = enum_info;
309
310    TokenStream::from(quote! {
311        impl #krate::UnsafeFromPrimitive for #name {
312            type Primitive = #repr;
313
314            unsafe fn unchecked_transmute_from(number: Self::Primitive) -> Self {
315                ::core::mem::transmute(number)
316            }
317        }
318    })
319}
320
321/// Implements `core::default::Default` for a `#[repr(Primitive)] enum`.
322///
323/// Whichever variant has the `#[default]` or `#[num_enum(default)]` attribute will be returned.
324/// ----------------------------------------------
325///
326/// ```rust
327/// #[derive(Debug, Eq, PartialEq, num_enum::Default)]
328/// #[repr(u8)]
329/// enum Number {
330///     Zero,
331///     #[default]
332///     One,
333/// }
334///
335/// assert_eq!(Number::One, Number::default());
336/// assert_eq!(Number::One, <Number as ::core::default::Default>::default());
337/// ```
338#[proc_macro_derive(Default, attributes(num_enum, default))]
339pub fn derive_default(stream: TokenStream) -> TokenStream {
340    let enum_info = parse_macro_input!(stream as EnumInfo);
341
342    let default_ident = match enum_info.default() {
343        Some(ident) => ident,
344        None => {
345            let span = Span::call_site();
346            let message =
347                "#[derive(num_enum::Default)] requires enum to be exhaustive, or a variant marked with `#[default]` or `#[num_enum(default)]`";
348            return syn::Error::new(span, message).to_compile_error().into();
349        }
350    };
351
352    let EnumInfo { ref name, .. } = enum_info;
353
354    TokenStream::from(quote! {
355        impl ::core::default::Default for #name {
356            #[inline]
357            fn default() -> Self {
358                Self::#default_ident
359            }
360        }
361    })
362}