enum_arr 0.1.2

Crate for Enum-Indexed arrays inspired by the Odin Programming Language
Documentation
use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Expr, Lit, PatLit, parse_macro_input};

#[proc_macro_derive(Enum)]
pub fn derive_enum(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = input.ident;
    let enum_array_name = syn::Ident::new(&format!("{}Array", name), name.span());

    let data = match input.data {
        Data::Enum(e) => e,
        _ => panic!("Enum derive only works on enums"),
    };

    let mut has_repr = false;
    for attr in input.attrs {
        if attr.path().is_ident("repr") {
            has_repr = true;
        }
    }
    if !has_repr {
        panic!("Enum must have repr(u*) or repr(i*) or repr(C)");
    }

    let mut variants = Vec::new();
    let mut expected = 0usize;

    for v in data.variants.iter() {
        let ident = &v.ident;

        let value = match &v.discriminant {
            Some((_, expr)) => match expr {
                Expr::Lit(PatLit {
                    lit: Lit::Int(int), ..
                }) => int.base10_parse::<usize>().unwrap(),
                _ => panic!("Enum values must be integer literals"),
            },
            None => expected,
        };

        if value != expected {
            panic!("Enum variants must be contiguous starting at 0");
        }

        variants.push(ident.clone());
        expected += 1;
    }

    let count = variants.len();

    let expanded = quote! {
        impl #name {
            pub const VALUES: [#name; #count] = [
                #( #name::#variants ),*
            ];
            const COUNT: usize = #count;

            fn to_index(self) -> usize {
                self as usize
            }

            fn as_str(self) -> &'static str {
                match self {
                    #( #name::#variants => stringify!(#variants), )*
                }
            }
        }

        #[derive(Clone, Copy, PartialEq, Eq, Hash)]
        pub struct #enum_array_name<T> {
            pub data: [T; #count],
        }

        impl<T> #enum_array_name<T> {
            pub fn new(init: T) -> Self
            where
                T: Copy,
            {
                Self {
                    data: [init; #count],
                }
            }

            pub fn from_array(data: [T; #count]) -> Self {
                Self { data }
            }

            pub fn init_with<F: FnMut(#name) -> T>(mut f: F) -> Self {
                let mut data: [::core::mem::MaybeUninit<T>; #count] =
                unsafe { ::core::mem::MaybeUninit::<[::core::mem::MaybeUninit<T>; #count]>::uninit().assume_init() };

                for (i, v) in #name::VALUES.iter().enumerate() {
                    data[i].write(f(*v));
                }

                let data = unsafe {
                    ::core::ptr::read(&data as *const _ as *const [T; #count])
                };

                Self { data }
            }

            pub fn try_init_with<F: FnMut(#name) -> Result<T, E>, E>(mut f: F) -> Result<Self, E> {
                let mut data: [::core::mem::MaybeUninit<Option<T>>; #count] =
                unsafe { ::core::mem::MaybeUninit::<[::core::mem::MaybeUninit<Option<T>>; #count]>::uninit().assume_init() };

                // set all data to None first:
                for i in 0..#count {
                    data[i].write(None);
                }
                let mut data: [Option<T>; #count] = unsafe {
                    ::core::ptr::read(&data as *const _ as *const [Option<T>; #count])
                };

                // try initialize all values:
                for (i, v) in #name::VALUES.iter().enumerate() {
                    data[i] = Some(f(*v)?);
                }

                // if all successfully initialized, unwrap all values:
                let mut unwrapped_data: [::core::mem::MaybeUninit<T>; #count] =
                unsafe { ::core::mem::MaybeUninit::<[::core::mem::MaybeUninit<T>; #count]>::uninit().assume_init() };
                for i in 0..#count {
                    unwrapped_data[i].write(data[i].take().unwrap());
                }

                Ok(Self {
                    data: unsafe { ::core::ptr::read(&unwrapped_data as *const _ as *const [T; #count])}
                })
            }

            pub fn get(&self, key: #name) -> &T {
                &self.data[key.to_index()]
            }

            pub fn get_mut(&mut self, key: #name) -> &mut T {
                &mut self.data[key.to_index()]
            }
        }

        impl <T> ::core::ops::Index<#name> for #enum_array_name<T> {
            type Output = T;

            fn index(&self, key: #name) -> &Self::Output {
                &self.data[key.to_index()]
            }
        }

        impl <T> ::core::ops::IndexMut<#name> for #enum_array_name<T> {
            fn index_mut(&mut self, key: #name) -> &mut Self::Output {
                &mut self.data[key.to_index()]
            }
        }

        impl<T: ::core::fmt::Debug> ::core::fmt::Debug for #enum_array_name<T> {
            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
                let mut list = f.debug_list();
                #(
                    list.entry(&format_args!(
                        "{}: {:?}",
                        stringify!(#variants),
                        &self.data[#name::#variants as usize]
                    ));
                )*

                list.finish()
            }
        }

        impl<T: Default> Default for #enum_array_name<T> {
            fn default() -> Self {
                Self::init_with(|_| T::default())
            }
        }
    };

    TokenStream::from(expanded)
}