Skip to main content

fnum_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse_macro_input, ItemEnum};
5
6#[proc_macro_derive(Fnum)]
7pub fn derive_fnum(input: TokenStream) -> TokenStream {
8    let item = parse_macro_input!(input as ItemEnum);
9    let enum_name = item.ident;
10    let variants = item.variants;
11
12    let variant_idx_arms = variants.iter().enumerate().map(|(i, variant)| {
13        let ident = &variant.ident;
14        match &variant.fields {
15            syn::Fields::Named(_) => {
16                quote! {
17                    #enum_name::#ident{..} => {#i}
18                }
19            }
20            syn::Fields::Unnamed(fields) => {
21                let fs = fields.unnamed.iter().map(|_| {
22                    quote! {_}
23                }).collect::<Vec<_>>();
24                quote! {
25                    #enum_name::#ident(#(#fs),*) => {#i}
26                }
27            }
28            syn::Fields::Unit => {
29                quote! {
30                    #enum_name::#ident => {#i}
31                }
32            }
33        }
34    }).collect::<Vec<_>>();
35
36    let uninit_variant_arms = variants.iter().enumerate().map(|(i, variant)| {
37        let ident = &variant.ident;
38        match &variant.fields {
39            syn::Fields::Named(fields) => {
40                let inits = fields.named.iter().map(|field| {
41                    let name = field.ident.clone().unwrap(); // unwrap daijoubu?
42                    quote! {
43                        #name: ::std::mem::MaybeUninit::uninit().assume_init()
44                    }
45                }).collect::<Vec<_>>();
46                quote! {
47                    #i => #enum_name::#ident{#(#inits),*}
48                }
49            }
50            syn::Fields::Unnamed(fields) => {
51                let inits = fields.unnamed.iter().map(|_| quote! { ::std::mem::MaybeUninit::uninit().assume_init() }).collect::<Vec<_>>();
52                quote! {
53                    #i => #enum_name::#ident(#(#inits),*)
54                }
55            }
56            syn::Fields::Unit => {
57                quote! {
58                    #i => #enum_name::#ident
59                }
60            }
61        }
62    }).collect::<Vec<_>>();
63
64    let make_table = variants.iter().enumerate().map(|(i, variant)| {
65        let ident = &variant.ident;
66        let arm = match &variant.fields {
67            syn::Fields::Named(fields) => {
68                let field_idents = fields.named.iter().map(|field| field.ident.clone().unwrap()).collect::<Vec<_>>();
69                let pointers = field_idents.iter().map(|i| quote! {right_pointer(#i)}).collect::<Vec<_>>();
70                quote! {
71                    #enum_name::#ident{#(#field_idents),*} => {[#(#pointers),*].iter().max().unwrap() - pointer(&e)}
72                }
73            }
74            syn::Fields::Unnamed(fields) => {
75                let field_idents = fields.unnamed.iter().enumerate().map(|(i, _)| {
76                    quote::format_ident!("field{}", i)
77                }).collect::<Vec<_>>();
78                let pointers = field_idents.iter().map(|i| quote! {right_pointer(#i)}).collect::<Vec<_>>();
79                quote! {
80                    #enum_name::#ident(#(#field_idents),*) => {[#(#pointers),*].iter().max().unwrap() - pointer(&e)}
81                }
82            }
83            syn::Fields::Unit => {
84                quote! {
85                    #enum_name::#ident => {2} // dame kamo
86                }
87            }
88        };
89        quote! {{
90            let e = unsafe { #enum_name::uninit_variant(#i) };
91            let size = match &e {
92                #arm,
93                _ => unreachable!()
94            };
95            ::std::mem::forget(e);
96            size
97        }}
98    }).collect::<Vec<_>>();
99
100    let variant_num = variants.len();
101    let gen = quote! {
102        impl ::fnum::Fnum for #enum_name {
103            fn variant_count() -> usize {
104                #variant_num
105            }
106            fn variant_index(&self) -> usize {
107                match self {
108                    #(#variant_idx_arms),*
109                }
110            }
111            unsafe fn uninit_variant(idx: usize) -> Self {
112                assert!(idx < Self::variant_count());
113                match idx {
114                    #(#uninit_variant_arms,)*
115                    _ => unreachable!(),
116                }
117            }
118            fn size_of_variant(idx: usize) -> usize {
119                fn pointer<T>(t: &T) -> usize {
120                    t as *const _ as usize
121                }
122                fn right_pointer<T>(t: &T) -> usize {
123                    unsafe {(t as *const T).offset(1) as usize}
124                }
125                static TABLE: ::fnum::__Lazy<[usize; #variant_num]> = ::fnum::__Lazy::new(|| [#(#make_table),*]);
126                (*TABLE)[idx]
127            }
128        }
129    };
130    gen.into()
131}