cenum-derive 1.0.3

A rust derive macro for C-style primitive enums.
Documentation
extern crate proc_macro;
use crate::proc_macro::TokenStream;
use quote::quote;
use syn::*;

#[proc_macro_attribute]
pub fn cenum(_metadata: TokenStream, input: TokenStream) -> TokenStream {
    let ast = parse_macro_input!(input as DeriveInput);
    impl_cenum(&ast)
}

fn impl_cenum(ast: &DeriveInput) -> TokenStream {
    let name = &ast.ident;
    let variants = match &ast.data {
        Data::Enum(DataEnum { variants, .. }) => variants.into_iter().collect::<Vec<&Variant>>(),
        _ => panic!("not deriving cenum on an enum"),
    };
    if variants
        .iter()
        .any(|variant| variant.fields != Fields::Unit)
    {
        panic!("cannot have cenum trait on enums with fields")
    }
    let mut pairs: Vec<(String, i64)> = vec![];
    let mut current_discriminant = 0;
    let mut first_variant = true;
    for variant in &variants {
        let is_first_variant = first_variant;
        first_variant = false;
        let discriminant = match &variant.discriminant {
            Some((
                _,
                Expr::Lit(ExprLit {
                    lit: Lit::Int(lit_int),
                    ..
                }),
            )) => {
                let discriminant = lit_int.base10_parse::<i64>().unwrap();
                if !is_first_variant && discriminant < current_discriminant {
                    panic!("attempted to reuse discriminant");
                }
                current_discriminant = discriminant + 1;
                discriminant
            }
            Some((
                _,
                Expr::Unary(ExprUnary {
                    op: UnOp::Neg(_),
                    expr,
                    ..
                })
            )) => {
                match &**expr {
                    Expr::Lit(ExprLit {
                        lit: Lit::Int(lit_int),
                        ..
                    }) => {
                        let discriminant = -lit_int.base10_parse::<i64>().unwrap();
                        if !is_first_variant && discriminant < current_discriminant {
                            panic!("attempted to reuse discriminant");
                        }
                        current_discriminant = discriminant + 1;
                        discriminant
                    },
                    _ => panic!("expected integer literal as discriminant")
                }
            },
            Some(_) => panic!("expected integer literal as discriminant"),
            None => {
                if is_first_variant {
                    current_discriminant = 0;
                }
                let discriminant = current_discriminant;
                current_discriminant += 1;
                discriminant
            }
        };
        pairs.push((variant.ident.to_string(), discriminant));
    }

    let pairs_formatted = format!(
        "[{}]",
        pairs
            .iter()
            .map({ |(key, value)| format!("({}::{}, {})", name.to_string(), key, value) })
            .collect::<Vec<String>>()
            .join(", ")
    );
    let pairs_parsed: ExprArray = parse_str(&pairs_formatted).unwrap();

    let data_name = Ident::new(
        &format!("__{}_data", name.to_string()).to_uppercase(),
        name.span(),
    );
    let cache_name = Ident::new(
        &format!("__{}_cache", name.to_string()).to_uppercase(),
        name.span(),
    );
    let icache_name = Ident::new(
        &format!("__{}_icache", name.to_string()).to_uppercase(),
        name.span(),
    );
    let get_cache_name = Ident::new(&format!("__{}_get_cache", name.to_string()), name.span());
    let get_icache_name = Ident::new(&format!("__{}_get_icache", name.to_string()), name.span());

    let gen = quote! {

        #[derive(PartialEq, Eq, Hash, Clone, Debug)]
        #ast

        static #data_name: &'static [(#name, i64)] = &#pairs_parsed;
        static #cache_name: ::std::sync::atomic::AtomicPtr<::std::collections::HashMap<#name, i64>> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut());
        static #icache_name: ::std::sync::atomic::AtomicPtr<::std::collections::HashMap<i64, #name>> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut());

        #[allow(non_snake_case)]
        fn #get_cache_name() -> &'static ::std::collections::HashMap<#name, i64> {
            unsafe {
                if #cache_name.load(::std::sync::atomic::Ordering::Relaxed).is_null() {
                    let mut map_built = Box::new(::std::collections::HashMap::new());
                    for (key, value) in #data_name {
                        map_built.insert(key.clone(), *value);
                    }
                    let map_built = Box::into_raw(map_built);
                    if !#cache_name.compare_and_swap(::std::ptr::null_mut(), map_built, ::std::sync::atomic::Ordering::Relaxed).is_null() {
                        drop(Box::from_raw(map_built));
                    }
                }
                return #cache_name.load(::std::sync::atomic::Ordering::Relaxed).as_ref().unwrap();
            }
        }

        #[allow(non_snake_case)]
        fn #get_icache_name() -> &'static ::std::collections::HashMap<i64, #name> {
            unsafe {
                if #icache_name.load(::std::sync::atomic::Ordering::Relaxed).is_null() {
                    let mut map_built = Box::new(::std::collections::HashMap::new());
                    for (key, value) in #data_name {
                        map_built.insert(*value, key.clone());
                    }
                    let map_built = Box::into_raw(map_built);
                    if !#icache_name.compare_and_swap(::std::ptr::null_mut(), map_built, ::std::sync::atomic::Ordering::Relaxed).is_null() {
                        drop(Box::from_raw(map_built));
                    }
                }
                return #icache_name.load(::std::sync::atomic::Ordering::Relaxed).as_ref().unwrap();
            }
        }

        impl Cenum for #name {
            fn to_primitive(&self) -> i64 {
                return *#get_cache_name().get(self).unwrap();
            }

            fn from_primitive(value: i64) -> #name {
                return #get_icache_name().get(&value).unwrap().clone();
            }

            fn is_discriminant(value: i64) -> bool {
                return #get_icache_name().get(&value).is_some();
            }
        }

        impl ::cenum::num::ToPrimitive for #name {
            fn to_i64(&self) -> Option<i64> {
                Some(self.to_primitive() as i64)
            }

            fn to_u64(&self) -> Option<u64> {
                Some(self.to_primitive() as u64)
            }
        }


    };
    gen.into()
}