discriminant_derive/
lib.rs1use proc_macro::TokenStream;
2use proc_macro::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{Data, DeriveInput, Ident, Meta, Token, punctuated::Punctuated};
5
6#[proc_macro_derive(Discriminant)]
10pub fn discriminant_derive(t: TokenStream) -> TokenStream {
11 let ty = TokenStream2::from(t);
12 let ast = syn::parse(ty).unwrap();
13 let repr_type = find_repr_type(&ast).unwrap();
14
15 ensure_enum_valid(&ast);
16
17 impl_discriminant_macro(&ast, &repr_type)
18}
19
20fn impl_discriminant_macro(ast: &DeriveInput, repr_type: &Ident) -> TokenStream {
21 let name = &ast.ident;
22 let imp = quote! {
23 impl Discriminant<#repr_type> for #name {
24 fn discriminant(&self) -> #repr_type {
25 unsafe { *<*const #name>::from(self).cast::<#repr_type>() }
27 }
28 }
29 };
30 imp.into()
31}
32
33fn ensure_enum_valid(ast: &DeriveInput) {
34 if let Data::Enum(data) = &ast.data {
35 if data.variants.is_empty() == false {
36 return;
37 }
38
39 panic!("Can't derive PrimitiveRepr on a zero variant enum");
40 }
41
42 panic!("Discriminant can only be derived for enums");
43}
44
45fn find_repr_type(ast: &DeriveInput) -> Option<Ident> {
46 for meta in ast
47 .attrs
48 .iter()
49 .filter(|attr| attr.path().is_ident("repr"))
50 .filter_map(|attr| {
51 attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
52 .ok()
53 })
54 .flatten()
55 {
56 if let Meta::Path(path) = meta {
57 if let Some(ident) = path.get_ident() {
58 return Some(ident.clone());
59 }
60 }
61 }
62
63 None
64}