use std::{collections::HashSet};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use rustc_hash::{FxBuildHasher};
use syn::{parse_macro_input, DeriveInput, Ident, Type};
#[proc_macro_derive(ConstEnum)]
pub fn derive_const_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
impl_const_enum(parse_macro_input!(input as DeriveInput)).unwrap_or_else(|error| error.into_compile_error()).into()
}
thread_local! {
static PRIMITIVE_TYPES: HashSet<Ident, FxBuildHasher> = {
let type_tokens = [
quote! { u8 },
quote! { u16 },
quote! { u32 },
quote! { u64 },
quote! { u128 },
quote! { usize },
quote! { i8 },
quote! { i16 },
quote! { i32 },
quote! { i64 },
quote! { i128 },
quote! { isize },
];
HashSet::from_iter(type_tokens.into_iter().map(|tokens| syn::parse2(tokens).unwrap()))
};
}
fn impl_const_enum(item: DeriveInput) -> Result<TokenStream, syn::Error> {
let syn::Data::Enum(enum_data) = item.data else { return Err(syn::Error::new_spanned(item, "ConstEnum: expected an enum")); };
let representation: Option<Type> = item.attrs.iter().find_map(|attribute| {
attribute.meta
.path()
.get_ident()
.and_then(|ident| {
(ident == "repr")
.then(|| attribute.parse_args().ok())
.flatten()
})
});
let enum_name: Ident = item.ident;
let variant_count: usize = enum_data.variants.len();
let variant_idents = enum_data.variants.iter().map(|variant| &variant.ident);
let variant_names = enum_data.variants.iter().map(|variant| variant.ident.to_string());
let enum_discriminant: Option<Type> = representation.and_then(|ty| {
let Type::Path(type_path) = &ty else { return None; };
if !type_path.qself.is_none() { return None; };
let ident: &Ident = type_path.path.get_ident()?;
PRIMITIVE_TYPES.with(|primitives| primitives.contains(ident)).then_some(ty)
});
let crate_name: Ident = Ident::new("cenum_utils", Span::call_site());
let count_impl: TokenStream = quote! {
impl ::#crate_name::EnumCount for #enum_name {
const COUNT: usize = #variant_count;
}
};
let names_impl: TokenStream = quote! {
impl ::#crate_name::EnumNames for #enum_name {
const NAMES: &[&str] = &[#(#variant_names),*];
}
};
let discriminants_impl: Option<TokenStream> = enum_discriminant.map(|discriminant| {
quote! {
impl ::#crate_name::EnumDiscriminants for #enum_name {
type Discriminant = #discriminant;
const DISCRIMINANTS: &[Self::Discriminant] = &[#(Self::#variant_idents as #discriminant),*];
}
}
});
Ok(quote! { #count_impl #names_impl #discriminants_impl })
}