1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use proc_macro::TokenStream;

use quote::quote;
use syn::{Error, Expr};
use syn::spanned::Spanned;

#[proc_macro_attribute]
pub fn bitenum(attrs: TokenStream, body: TokenStream) -> TokenStream {
    let scalar_type = match parse_scalar(attrs) {
        Ok(ident) => ident,
        Err(err) => { return err.to_compile_error().into(); }
    };

    let mut enum_type = match syn::parse::<syn::ItemEnum>(body) {
        Ok(item) => item,
        Err(err) => { return err.to_compile_error().into(); }
    };

    let enum_name = enum_type.ident.clone();

    let must_has_discriminant = enum_type.variants.iter().nth(0).map(|v| v.discriminant.is_some()).unwrap_or(false);

    for (idx, variant) in enum_type.variants.iter_mut().enumerate() {
        if must_has_discriminant != variant.discriminant.is_some() {
            return Error::new(
                variant.span(),
                "all variants must rather have or don't have values",
            ).to_compile_error().into();
        }

        if variant.discriminant.is_none() {
            variant.discriminant = Some((
                syn::token::Eq::default(),
                syn::parse_str::<Expr>(format!("1 << {}", idx).as_str()).unwrap()
            ));
        }
    }

    TokenStream::from(quote! {
        #[repr(#scalar_type)]
        #enum_type

        impl Into<#scalar_type> for #enum_name {
            fn into(self) -> #scalar_type { self as #scalar_type }
        }

        impl bitenum::BitEnum for #enum_name {
            type Scalar = #scalar_type;
        }
    })
}


fn parse_scalar(tokens: TokenStream) -> syn::Result<syn::Ident> {
    match syn::parse::<syn::Ident>(tokens) {
        Ok(ident) => match ident.to_string().as_str() {
            "u8" | "u16" | "u32" | "u64" | "u128" => Ok(ident),
            _ => Err(Error::new(ident.span(), "must be unsigned integer type"))
        },
        Err(err) => Err(err)
    }
}