1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#![doc = include_str!("../README.md")]

use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::ToTokens;
use syn::{parse_macro_input, spanned::Spanned, Error, ItemEnum};

/// auto implement the TryFrom<Literal> trait and Into<Literal> trait
/// where the `literal` must be only one type
#[proc_macro_derive(LiteralEnum, attributes(lit))]
pub fn easy(input: TokenStream) -> TokenStream {
    let item = parse_macro_input!(input as ItemEnum);

    match real_easy(item) {
        Ok(v) => v.into(),
        Err(err) => err.into_compile_error().into(),
    }
}

fn real_easy(item: ItemEnum) -> Result<TokenStream2, Error> {
    use std::collections::HashMap;

    if !item.generics.to_token_stream().is_empty() {
        return Err(Error::new(item.generics.span(), "generics is forbidden"));
    }

    let mut ty: Option<syn::Type> = None;
    let mut lit_value = Vec::new();
    let mut var_ident = Vec::new();

    let mut lit_hashmap = HashMap::<String, syn::Lit>::new(); // avoid duplicate literals

    for var in item.variants.iter() {
        if !matches!(var.fields, syn::Fields::Unit) {
            return Err(Error::new(
                var.span(),
                "every variant must be Unit kind, like `None`",
            ));
        }

        let attr = match var.attrs.iter().find(|attr| attr.path().is_ident("lit")) {
            Some(attr) => attr,
            None => {
                return Err(Error::new(
                    var.span(),
                    "every variant must provide the `lit` attribute, like `#[lit = 42]`",
                ));
            }
        };

        let syn::Meta::NameValue(ref name_value) = attr.meta else {
            return Err(Error::new(attr.meta.span(), "the format should be like: `#[lit = 42]`"));
        };

        let syn::Expr::Lit(syn::ExprLit{ref lit, ..}) = name_value.value else {
            return Err(Error::new(name_value.span(), "the value should be a literal"));
        };

        if let Some(ref t) = ty {
            if t.to_token_stream().to_string() != lit_to_ty(lit)?.to_token_stream().to_string() {
                return Err(Error::new(
                    lit.span(),
                    "All the literals must be the same type",
                ));
            }
        } else {
            ty = Some(lit_to_ty(lit)?);
        }

        var_ident.push(var.ident.clone());

        let lit_str = lit.to_token_stream().to_string();
        if let Some(it) = lit_hashmap.get(&lit_str) {
            let mut err = Error::new(lit.span(), format!("{} is declared twice", lit_str));
            err.combine(Error::new(
                it.span(),
                format!("{} is declared here first", lit_str),
            ));
            return Err(err);
        }
        lit_hashmap.insert(lit_str, lit.clone());
        lit_value.push(lit.clone());
    }

    let enum_ident = item.ident;

    match ty {
        Some(lit_ty) => Ok(derive(enum_ident, var_ident, lit_ty, lit_value)),
        None => Ok(TokenStream2::new()),
    }
}

fn derive(
    enum_ident: syn::Ident,
    var_ident: Vec<syn::Ident>,
    lit_ty: syn::Type,
    lit_value: Vec<syn::Lit>,
) -> TokenStream2 {
    quote::quote! {
        impl TryFrom<#lit_ty> for #enum_ident {
            type Error = #lit_ty;

            fn try_from(value: #lit_ty) -> Result<Self, Self::Error> {
                match value {
                    #(#lit_value => Ok(Self::#var_ident),)*
                    _ => Err(value),
                }
            }
        }

        impl Into<#lit_ty> for #enum_ident {
            fn into(self) -> #lit_ty {
                match self {
                    #(Self::#var_ident => #lit_value,)*
                }
            }
        }
    }
}

fn lit_to_ty(lit: &syn::Lit) -> Result<syn::Type, Error> {
    let ty = match lit {
        syn::Lit::Str(_) => syn::parse_str("&'static str").unwrap(),
        syn::Lit::ByteStr(_) => syn::parse_str("&'static [u8]").unwrap(),
        syn::Lit::Byte(_) => syn::parse_str("u8").unwrap(),
        syn::Lit::Char(_) => syn::parse_str("char").unwrap(),
        syn::Lit::Int(int) => {
            if int.suffix().is_empty() {
                syn::parse_str("u32").unwrap()
            } else {
                syn::parse_str(int.suffix()).unwrap()
            }
        }

        syn::Lit::Bool(_) => syn::parse_str("bool").unwrap(),
        // syn::Lit::Float(_) => syn::parse_str("f64").unwrap(), // floating-point types cannot be used in patterns
        // syn::Lit::Verbatim(_) => syn::parse_str("&'static str").unwrap(),
        _ => return Err(Error::new(lit.span(), "This type is not supported")),
    };

    Ok(ty)
}