use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Expr, Lit, PatLit, parse_macro_input};
#[proc_macro_derive(Enum)]
pub fn derive_enum(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;
let enum_array_name = syn::Ident::new(&format!("{}Array", name), name.span());
let data = match input.data {
Data::Enum(e) => e,
_ => panic!("Enum derive only works on enums"),
};
let mut has_repr = false;
for attr in input.attrs {
if attr.path().is_ident("repr") {
has_repr = true;
}
}
if !has_repr {
panic!("Enum must have repr(u*) or repr(i*) or repr(C)");
}
let mut variants = Vec::new();
let mut expected = 0usize;
for v in data.variants.iter() {
let ident = &v.ident;
let value = match &v.discriminant {
Some((_, expr)) => match expr {
Expr::Lit(PatLit {
lit: Lit::Int(int), ..
}) => int.base10_parse::<usize>().unwrap(),
_ => panic!("Enum values must be integer literals"),
},
None => expected,
};
if value != expected {
panic!("Enum variants must be contiguous starting at 0");
}
variants.push(ident.clone());
expected += 1;
}
let count = variants.len();
let expanded = quote! {
impl #name {
pub const VALUES: [#name; #count] = [
#( #name::#variants ),*
];
const COUNT: usize = #count;
fn to_index(self) -> usize {
self as usize
}
fn as_str(self) -> &'static str {
match self {
#( #name::#variants => stringify!(#variants), )*
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct #enum_array_name<T> {
pub data: [T; #count],
}
impl<T> #enum_array_name<T> {
pub fn new(init: T) -> Self
where
T: Copy,
{
Self {
data: [init; #count],
}
}
pub fn from_array(data: [T; #count]) -> Self {
Self { data }
}
pub fn init_with<F: FnMut(#name) -> T>(mut f: F) -> Self {
let mut data: [::core::mem::MaybeUninit<T>; #count] =
unsafe { ::core::mem::MaybeUninit::<[::core::mem::MaybeUninit<T>; #count]>::uninit().assume_init() };
for (i, v) in #name::VALUES.iter().enumerate() {
data[i].write(f(*v));
}
let data = unsafe {
::core::ptr::read(&data as *const _ as *const [T; #count])
};
Self { data }
}
pub fn try_init_with<F: FnMut(#name) -> Result<T, E>, E>(mut f: F) -> Result<Self, E> {
let mut data: [::core::mem::MaybeUninit<Option<T>>; #count] =
unsafe { ::core::mem::MaybeUninit::<[::core::mem::MaybeUninit<Option<T>>; #count]>::uninit().assume_init() };
for i in 0..#count {
data[i].write(None);
}
let mut data: [Option<T>; #count] = unsafe {
::core::ptr::read(&data as *const _ as *const [Option<T>; #count])
};
for (i, v) in #name::VALUES.iter().enumerate() {
data[i] = Some(f(*v)?);
}
let mut unwrapped_data: [::core::mem::MaybeUninit<T>; #count] =
unsafe { ::core::mem::MaybeUninit::<[::core::mem::MaybeUninit<T>; #count]>::uninit().assume_init() };
for i in 0..#count {
unwrapped_data[i].write(data[i].take().unwrap());
}
Ok(Self {
data: unsafe { ::core::ptr::read(&unwrapped_data as *const _ as *const [T; #count])}
})
}
pub fn get(&self, key: #name) -> &T {
&self.data[key.to_index()]
}
pub fn get_mut(&mut self, key: #name) -> &mut T {
&mut self.data[key.to_index()]
}
}
impl <T> ::core::ops::Index<#name> for #enum_array_name<T> {
type Output = T;
fn index(&self, key: #name) -> &Self::Output {
&self.data[key.to_index()]
}
}
impl <T> ::core::ops::IndexMut<#name> for #enum_array_name<T> {
fn index_mut(&mut self, key: #name) -> &mut Self::Output {
&mut self.data[key.to_index()]
}
}
impl<T: ::core::fmt::Debug> ::core::fmt::Debug for #enum_array_name<T> {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
let mut list = f.debug_list();
#(
list.entry(&format_args!(
"{}: {:?}",
stringify!(#variants),
&self.data[#name::#variants as usize]
));
)*
list.finish()
}
}
impl<T: Default> Default for #enum_array_name<T> {
fn default() -> Self {
Self::init_with(|_| T::default())
}
}
};
TokenStream::from(expanded)
}