enum_unit/
lib.rs

1use convert_case::{Case, Casing};
2use enum_unit_core::*;
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
6
7#[proc_macro_derive(EnumUnit)]
8pub fn into_unit_enum(input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10    let old_enum_name = input.ident.clone();
11    let new_enum_name = format_ident!("{}Unit", old_enum_name);
12
13    enum InputKind {
14        Struct(Vec<Ident>),
15        Enum(Vec<(Ident, Fields)>),
16    }
17
18    let kind = match input.data {
19        Data::Struct(data) => match data.fields {
20            Fields::Named(fields_named) => {
21                if fields_named.named.is_empty() {
22                    return quote! {}.into();
23                }
24                let names = fields_named
25                    .named
26                    .into_iter()
27                    .filter_map(|f| f.ident)
28                    .map(|ident| format_ident!("{}", ident.to_string().to_case(Case::Pascal)))
29                    .collect();
30                InputKind::Struct(names)
31            }
32            Fields::Unnamed(fields) => {
33                if fields.unnamed.is_empty() {
34                    return quote! {}.into();
35                }
36                let names = (0..fields.unnamed.len())
37                    .map(|i| format_ident!("{}{}", prefix(), i))
38                    .collect();
39                InputKind::Struct(names)
40            }
41            Fields::Unit => return quote! {}.into(),
42        },
43        Data::Enum(data) => {
44            if data.variants.is_empty() {
45                return quote! {}.into();
46            }
47            let variants = data
48                .variants
49                .into_iter()
50                .map(|v| (v.ident, v.fields))
51                .collect();
52            InputKind::Enum(variants)
53        }
54        Data::Union(..) => return quote! { compile_error!("Unions are not supported.") }.into(),
55    };
56
57    let doc_comment = format!("Automatically generated unit-variants of [`{old_enum_name}`].");
58
59    // Trait derivation
60    let derive_inner = quote! {
61        Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
62    };
63
64    #[cfg(feature = "serde")]
65    let derive_inner = quote! {
66        #derive_inner, ::serde::Serialize, ::serde::Deserialize
67    };
68
69    // Collect variant names regardless of origin
70    let variant_idents: Vec<Ident> = match &kind {
71        InputKind::Struct(fields) => fields.clone(),
72        InputKind::Enum(variants) => variants.iter().map(|(ident, _)| ident.clone()).collect(),
73    };
74
75    #[cfg(feature = "bitflags")]
76    let new_enum = {
77        let size = match variant_idents.len() {
78            1..=8 => quote! { u8 },
79            9..=16 => quote! { u16 },
80            17..=32 => quote! { u32 },
81            33..=64 => quote! { u64 },
82            65..=128 => quote! { u128 },
83            _ => {
84                return quote! { compile_error!("Too many fields or variants for bitflags."); }
85                    .into();
86            }
87        };
88
89        let flag_consts = variant_idents.iter().enumerate().map(|(i, ident)| {
90            quote! {
91                const #ident = 1 << #i;
92            }
93        });
94
95        quote! {
96            ::bitflags::bitflags! {
97                #[doc = #doc_comment]
98                #[derive(#derive_inner)]
99                pub struct #new_enum_name: #size {
100                    #(#flag_consts)*
101                }
102            }
103        }
104    };
105
106    #[cfg(not(feature = "bitflags"))]
107    let new_enum = {
108        let variants = variant_idents.iter().map(|ident| quote! { #ident, });
109        quote! {
110            #[doc = #doc_comment]
111            #[derive(#derive_inner)]
112            pub enum #new_enum_name {
113                #(#variants)*
114            }
115        }
116    };
117
118    // Only generate kind() and From<> if the original was an enum
119    let new_enum_impl = match kind {
120        InputKind::Enum(ref variants) => {
121            let match_arms = variants.iter().map(|(ident, fields)| match fields {
122                Fields::Named(_) => quote! {
123                    Self::#ident { .. } => #new_enum_name::#ident,
124                },
125                Fields::Unnamed(_) => quote! {
126                    Self::#ident(..) => #new_enum_name::#ident,
127                },
128                Fields::Unit => quote! {
129                    Self::#ident => #new_enum_name::#ident,
130                },
131            });
132
133            let doc_comment = format!("The [`{new_enum_name}`] of this [`{old_enum_name}`].");
134            quote! {
135                impl #old_enum_name {
136                    #[doc = #doc_comment]
137                    pub const fn kind(&self) -> #new_enum_name {
138                        match self {
139                            #(#match_arms)*
140                        }
141                    }
142                }
143
144                impl From<#old_enum_name> for #new_enum_name {
145                    fn from(value: #old_enum_name) -> Self {
146                        value.kind()
147                    }
148                }
149            }
150        }
151        _ => quote! {},
152    };
153
154    quote! {
155        #new_enum
156        #new_enum_impl
157    }
158    .into()
159}