repr_discriminant_derive/
lib.rs

1//! Attribute macro to implement a discriminant method for enums with a specific representation type.
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, quote};
5use syn::{
6    Data, DeriveInput, Ident, ImplGenerics, Type, TypeGenerics, WhereClause, parse_macro_input,
7};
8
9const SUPPORTED_TYPES: &[&str] = &[
10    "i8", "i16", "i32", "i64", "i128", "isize", "u8", "u16", "u32", "u64", "u128", "usize",
11];
12
13/// Attribute macro to implement a discriminant method for enums with a specific representation type.
14///
15/// # Panics
16///
17/// This macro will panic if the input type is not an enum with a valid `#[repr(T)]`.
18#[proc_macro_derive(ReprDiscriminant)]
19pub fn repr_discriminant(input: TokenStream) -> TokenStream {
20    let input = parse_macro_input!(input as DeriveInput);
21
22    let Data::Enum(_) = input.data else {
23        unimplemented!("`ReprDiscriminant` can only be derived for enums")
24    };
25
26    let repr_type: Type = input
27        .attrs
28        .iter()
29        .filter(|attr| attr.path().is_ident("repr"))
30        .find_map(|attr| attr.parse_args().ok())
31        .expect("`#[repr(T)]` is required");
32
33    assert!(
34        SUPPORTED_TYPES.contains(&repr_type.to_token_stream().to_string().as_str()),
35        "`ReprDiscriminant` can only be used with the following types: {SUPPORTED_TYPES:?}"
36    );
37
38    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
39    impl_all(
40        &impl_generics,
41        &input.ident,
42        &ty_generics,
43        where_clause,
44        &repr_type,
45    )
46    .into()
47}
48
49fn impl_all(
50    impl_generics: &ImplGenerics<'_>,
51    name: &Ident,
52    ty_generics: &TypeGenerics<'_>,
53    where_clause: Option<&WhereClause>,
54    repr_type: &Type,
55) -> proc_macro2::TokenStream {
56    let const_impl = impl_const(impl_generics, name, ty_generics, where_clause, repr_type);
57    let trait_impl = impl_trait(impl_generics, name, ty_generics, where_clause, repr_type);
58
59    quote! {
60        #const_impl
61
62        #trait_impl
63    }
64}
65
66fn impl_const(
67    impl_generics: &ImplGenerics<'_>,
68    name: &Ident,
69    ty_generics: &TypeGenerics<'_>,
70    where_clause: Option<&WhereClause>,
71    repr_type: &Type,
72) -> proc_macro2::TokenStream {
73    quote! {
74        impl #impl_generics #name #ty_generics #where_clause {
75            /// Returns the discriminant value of the enum.
76            ///
77            /// # Safety
78            ///
79            /// This method is safe, because the macro guarantees that the enum is repr(T).
80            pub const fn discriminant(&self) -> #repr_type {
81                #[allow(unsafe_code)]
82                unsafe {
83                    *::core::ptr::from_ref(self)
84                        .cast::<#repr_type>()
85                }
86            }
87        }
88    }
89}
90
91fn impl_trait(
92    impl_generics: &ImplGenerics<'_>,
93    name: &Ident,
94    ty_generics: &TypeGenerics<'_>,
95    where_clause: Option<&WhereClause>,
96    repr_type: &Type,
97) -> proc_macro2::TokenStream {
98    quote! {
99        #[expect(unsafe_code)]
100        unsafe impl #impl_generics ::repr_discriminant::ReprDiscriminant for #name #ty_generics #where_clause {
101            type Repr = #repr_type;
102
103            fn repr_discriminant(&self) -> Self::Repr {
104                self.discriminant()
105            }
106        }
107    }
108}