use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, parse_quote, AttrStyle};
macro_rules! compile_error_unless_ok {
($result:expr) => {
match $result {
Ok(value) => value,
Err(error) => return error.to_compile_error().into(),
}
};
}
#[proc_macro_attribute]
pub fn discriminant(arguments: TokenStream, item: TokenStream) -> TokenStream {
let enum_item = parse_macro_input!(item as syn::ItemEnum);
let enum_name = &enum_item.ident;
let arguments: TokenStream2 = arguments.into();
let repr_type = compile_error_unless_ok!(get_repr_type(arguments.clone()));
let from_discriminant_code = generate_from_discriminant_function(&repr_type, &enum_item);
let discriminant_code = generate_discriminant_function(&repr_type);
quote! {
#[repr(#arguments)]
#enum_item
impl #enum_name {
#from_discriminant_code
#discriminant_code
}
}
.into()
}
#[proc_macro_derive(IntoDiscriminant)]
pub fn derive_into_discriminant(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as syn::DeriveInput);
let enum_name = &input.ident;
let repr_args = compile_error_unless_ok!(get_repr_args("IntoDiscriminant", &input));
let repr_type = compile_error_unless_ok!(get_repr_type(repr_args));
let discriminant_code = generate_discriminant_function(&repr_type);
quote! {
impl IntoDiscriminant for #enum_name {
type DiscriminantType = #repr_type;
#discriminant_code
}
}
.into()
}
#[proc_macro_derive(FromDiscriminant)]
pub fn derive_from_discriminant(item: TokenStream) -> TokenStream {
let cloned_item = item.clone();
let input = parse_macro_input!(item as syn::DeriveInput);
let enum_item = parse_macro_input!(cloned_item as syn::ItemEnum);
let enum_name = &enum_item.ident;
let repr_args = compile_error_unless_ok!(get_repr_args("FromDiscriminant", &input));
let repr_type = compile_error_unless_ok!(get_repr_type(repr_args));
let from_discriminant_code = generate_from_discriminant_function(&repr_type, &enum_item);
quote! {
impl FromDiscriminant for #enum_name {
type DiscriminantType = #repr_type;
#from_discriminant_code
}
}
.into()
}
fn get_repr_type(arguments: TokenStream2) -> Result<syn::Path, syn::Error> {
let allowed_types = [
"u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", "isize",
];
arguments
.clone()
.into_iter()
.filter_map(|token_tree| {
if let proc_macro2::TokenTree::Ident(ident) = token_tree {
let ident_str = ident.to_string();
if allowed_types.contains(&ident_str.as_str()) {
return Some(syn::parse_str::<syn::Path>(&ident_str).unwrap());
}
}
None
})
.next()
.ok_or_else(|| {
syn::Error::new_spanned(
arguments,
"Valid enum representation type expected as argument to the discriminant \
macro, e.g., #[discriminant(u8)]",
)
})
}
fn get_repr_args(macro_name: &str, input: &syn::DeriveInput) -> Result<TokenStream2, syn::Error> {
let x = input
.attrs
.iter()
.filter(|attr| matches!(attr.style, AttrStyle::Outer))
.filter(|attr| {
let path = attr.path();
path.is_ident("repr") || path.is_ident("discriminant")
})
.filter_map(|attr| attr.meta.require_list().ok())
.next()
.ok_or_else(|| {
syn::Error::new_spanned(
input,
format!(
"When deriving {} on an enum, you also need to specify \
representation type with #[repr()] or #[discriminant()]",
macro_name
),
)
})?;
Ok(x.tokens.clone())
}
fn enum_unit_variants(enum_item: &syn::ItemEnum) -> (Vec<proc_macro2::Ident>, Vec<syn::Expr>) {
let mut previous_expr: Option<syn::Expr> = None;
enum_item
.variants
.iter()
.filter(|variant| matches!(variant.fields, syn::Fields::Unit))
.map(|variant| {
let expr = if let Some(discriminant) = &variant.discriminant {
discriminant.1.clone()
} else if let Some(ref old_expr) = previous_expr {
parse_quote!( 1 + #old_expr )
} else {
parse_quote!(0)
};
previous_expr = Some(expr.clone());
(variant.ident.clone(), expr)
})
.unzip()
}
fn generate_from_discriminant_function(
repr_type: &syn::Path,
enum_item: &syn::ItemEnum,
) -> TokenStream2 {
let (variant_names, discriminants) = enum_unit_variants(enum_item);
let enum_name = &enum_item.ident;
quote! {
fn from_discriminant(discriminant: #repr_type) -> Option<Self> {
match discriminant {
#( discriminant if discriminant == #discriminants =>
Some(#enum_name::#variant_names), )*
_ => None,
}
}
}
}
fn generate_discriminant_function(repr_type: &syn::Path) -> TokenStream2 {
quote! {
fn discriminant(&self) -> #repr_type {
unsafe {
*<*const _>::from(self).cast::<#repr_type>()
}
}
}
}