cenum_derive/
lib.rs

1extern crate proc_macro;
2use crate::proc_macro::TokenStream;
3use quote::quote;
4use syn::*;
5
6#[proc_macro_attribute]
7pub fn cenum(_metadata: TokenStream, input: TokenStream) -> TokenStream {
8    let ast = parse_macro_input!(input as DeriveInput);
9    impl_cenum(&ast)
10}
11
12fn impl_cenum(ast: &DeriveInput) -> TokenStream {
13    let name = &ast.ident;
14    let variants = match &ast.data {
15        Data::Enum(DataEnum { variants, .. }) => variants.into_iter().collect::<Vec<&Variant>>(),
16        _ => panic!("not deriving cenum on an enum"),
17    };
18    if variants
19        .iter()
20        .any(|variant| variant.fields != Fields::Unit)
21    {
22        panic!("cannot have cenum trait on enums with fields")
23    }
24    let mut pairs: Vec<(String, i64)> = vec![];
25    let mut current_discriminant = 0;
26    let mut first_variant = true;
27    for variant in &variants {
28        let is_first_variant = first_variant;
29        first_variant = false;
30        let discriminant = match &variant.discriminant {
31            Some((
32                _,
33                Expr::Lit(ExprLit {
34                    lit: Lit::Int(lit_int),
35                    ..
36                }),
37            )) => {
38                let discriminant = lit_int.base10_parse::<i64>().unwrap();
39                if !is_first_variant && discriminant < current_discriminant {
40                    panic!("attempted to reuse discriminant");
41                }
42                current_discriminant = discriminant + 1;
43                discriminant
44            }
45            Some((
46                _,
47                Expr::Unary(ExprUnary {
48                    op: UnOp::Neg(_),
49                    expr,
50                    ..
51                })
52            )) => {
53                match &**expr {
54                    Expr::Lit(ExprLit {
55                        lit: Lit::Int(lit_int),
56                        ..
57                    }) => {
58                        let discriminant = -lit_int.base10_parse::<i64>().unwrap();
59                        if !is_first_variant && discriminant < current_discriminant {
60                            panic!("attempted to reuse discriminant");
61                        }
62                        current_discriminant = discriminant + 1;
63                        discriminant
64                    },
65                    _ => panic!("expected integer literal as discriminant")
66                }
67            },
68            Some(_) => panic!("expected integer literal as discriminant"),
69            None => {
70                if is_first_variant {
71                    current_discriminant = 0;
72                }
73                let discriminant = current_discriminant;
74                current_discriminant += 1;
75                discriminant
76            }
77        };
78        pairs.push((variant.ident.to_string(), discriminant));
79    }
80
81    let pairs_formatted = format!(
82        "[{}]",
83        pairs
84            .iter()
85            .map({ |(key, value)| format!("({}::{}, {})", name.to_string(), key, value) })
86            .collect::<Vec<String>>()
87            .join(", ")
88    );
89    let pairs_parsed: ExprArray = parse_str(&pairs_formatted).unwrap();
90
91    let data_name = Ident::new(
92        &format!("__{}_data", name.to_string()).to_uppercase(),
93        name.span(),
94    );
95    let cache_name = Ident::new(
96        &format!("__{}_cache", name.to_string()).to_uppercase(),
97        name.span(),
98    );
99    let icache_name = Ident::new(
100        &format!("__{}_icache", name.to_string()).to_uppercase(),
101        name.span(),
102    );
103    let get_cache_name = Ident::new(&format!("__{}_get_cache", name.to_string()), name.span());
104    let get_icache_name = Ident::new(&format!("__{}_get_icache", name.to_string()), name.span());
105
106    let gen = quote! {
107
108        #[derive(PartialEq, Eq, Hash, Clone, Debug)]
109        #ast
110
111        static #data_name: &'static [(#name, i64)] = &#pairs_parsed;
112        static #cache_name: ::std::sync::atomic::AtomicPtr<::std::collections::HashMap<#name, i64>> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut());
113        static #icache_name: ::std::sync::atomic::AtomicPtr<::std::collections::HashMap<i64, #name>> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut());
114
115        #[allow(non_snake_case)]
116        fn #get_cache_name() -> &'static ::std::collections::HashMap<#name, i64> {
117            unsafe {
118                if #cache_name.load(::std::sync::atomic::Ordering::Relaxed).is_null() {
119                    let mut map_built = Box::new(::std::collections::HashMap::new());
120                    for (key, value) in #data_name {
121                        map_built.insert(key.clone(), *value);
122                    }
123                    let map_built = Box::into_raw(map_built);
124                    if !#cache_name.compare_and_swap(::std::ptr::null_mut(), map_built, ::std::sync::atomic::Ordering::Relaxed).is_null() {
125                        drop(Box::from_raw(map_built));
126                    }
127                }
128                return #cache_name.load(::std::sync::atomic::Ordering::Relaxed).as_ref().unwrap();
129            }
130        }
131
132        #[allow(non_snake_case)]
133        fn #get_icache_name() -> &'static ::std::collections::HashMap<i64, #name> {
134            unsafe {
135                if #icache_name.load(::std::sync::atomic::Ordering::Relaxed).is_null() {
136                    let mut map_built = Box::new(::std::collections::HashMap::new());
137                    for (key, value) in #data_name {
138                        map_built.insert(*value, key.clone());
139                    }
140                    let map_built = Box::into_raw(map_built);
141                    if !#icache_name.compare_and_swap(::std::ptr::null_mut(), map_built, ::std::sync::atomic::Ordering::Relaxed).is_null() {
142                        drop(Box::from_raw(map_built));
143                    }
144                }
145                return #icache_name.load(::std::sync::atomic::Ordering::Relaxed).as_ref().unwrap();
146            }
147        }
148
149        impl Cenum for #name {
150            fn to_primitive(&self) -> i64 {
151                return *#get_cache_name().get(self).unwrap();
152            }
153
154            fn from_primitive(value: i64) -> #name {
155                return #get_icache_name().get(&value).unwrap().clone();
156            }
157
158            fn is_discriminant(value: i64) -> bool {
159                return #get_icache_name().get(&value).is_some();
160            }
161        }
162
163        impl ::cenum::num::ToPrimitive for #name {
164            fn to_i64(&self) -> Option<i64> {
165                Some(self.to_primitive() as i64)
166            }
167
168            fn to_u64(&self) -> Option<u64> {
169                Some(self.to_primitive() as u64)
170            }
171        }
172
173
174    };
175    gen.into()
176}