discriminant_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{Data, DeriveInput, Ident, Meta, Token, punctuated::Punctuated};
5
6/// Derive macro for `Discriminant<T>` trait
7///
8/// This macro will check whether if it is a enum and has `#[repr(...)]` attribute
9#[proc_macro_derive(Discriminant)]
10pub fn discriminant_derive(t: TokenStream) -> TokenStream {
11    let ty = TokenStream2::from(t);
12    let ast = syn::parse(ty).unwrap();
13    let repr_type = find_repr_type(&ast).unwrap();
14
15    ensure_enum_valid(&ast);
16
17    impl_discriminant_macro(&ast, &repr_type)
18}
19
20fn impl_discriminant_macro(ast: &DeriveInput, repr_type: &Ident) -> TokenStream {
21    let name = &ast.ident;
22    let imp = quote! {
23        impl Discriminant<#repr_type> for #name {
24            fn discriminant(&self) -> #repr_type {
25                // Should be safe here
26                unsafe { *<*const #name>::from(self).cast::<#repr_type>() }
27            }
28        }
29    };
30    imp.into()
31}
32
33fn ensure_enum_valid(ast: &DeriveInput) {
34    if let Data::Enum(data) = &ast.data {
35        if data.variants.is_empty() == false {
36            return;
37        }
38
39        panic!("Can't derive PrimitiveRepr on a zero variant enum");
40    }
41
42    panic!("Discriminant can only be derived for enums");
43}
44
45fn find_repr_type(ast: &DeriveInput) -> Option<Ident> {
46    for meta in ast
47        .attrs
48        .iter()
49        .filter(|attr| attr.path().is_ident("repr"))
50        .filter_map(|attr| {
51            attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
52                .ok()
53        })
54        .flatten()
55    {
56        if let Meta::Path(path) = meta {
57            if let Some(ident) = path.get_ident() {
58                return Some(ident.clone());
59            }
60        }
61    }
62
63    None
64}