enum_unit/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
5
6#[proc_macro_derive(EnumUnit)]
7pub fn into_unit_enum(input: TokenStream) -> TokenStream {
8    let input = parse_macro_input!(input as DeriveInput);
9    let old_enum_name = input.ident.clone();
10    let new_enum_name = format_ident!("{}Unit", old_enum_name);
11
12    enum InputKind {
13        Struct(Vec<Ident>),
14        Enum(Vec<(Ident, Fields)>),
15    }
16
17    let kind = match input.data {
18        Data::Struct(data) => match data.fields {
19            Fields::Named(fields_named) => {
20                if fields_named.named.is_empty() {
21                    return quote! {}.into();
22                }
23                let names = fields_named
24                    .named
25                    .into_iter()
26                    .filter_map(|f| f.ident)
27                    .map(|ident| format_ident!("{}", ident.to_string().to_case(Case::Pascal)))
28                    .collect();
29                InputKind::Struct(names)
30            }
31            Fields::Unnamed(..) => {
32                return quote! { compile_error!("Tuple structs are not supported.") }.into();
33            }
34            Fields::Unit => return quote! {}.into(),
35        },
36        Data::Enum(data) => {
37            if data.variants.is_empty() {
38                return quote! {}.into();
39            }
40            let variants = data
41                .variants
42                .into_iter()
43                .map(|v| (v.ident, v.fields))
44                .collect();
45            InputKind::Enum(variants)
46        }
47        Data::Union(..) => return quote! { compile_error!("Unions are not supported.") }.into(),
48    };
49
50    let doc_comment = format!(
51        "Automatically generated unit-variants of [`{}`].",
52        old_enum_name
53    );
54
55    // Trait derivation
56    let derive_inner = quote! {
57        Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
58    };
59
60    #[cfg(feature = "serde")]
61    let derive_inner = quote! {
62        #derive_inner, ::serde::Serialize, ::serde::Deserialize
63    };
64
65    // Collect variant names regardless of origin
66    let variant_idents: Vec<Ident> = match &kind {
67        InputKind::Struct(fields) => fields.clone(),
68        InputKind::Enum(variants) => variants.iter().map(|(ident, _)| ident.clone()).collect(),
69    };
70
71    #[cfg(feature = "bitflags")]
72    let new_enum = {
73        let size = match variant_idents.len() {
74            1..=8 => quote! { u8 },
75            9..=16 => quote! { u16 },
76            17..=32 => quote! { u32 },
77            33..=64 => quote! { u64 },
78            65..=128 => quote! { u128 },
79            _ => {
80                return quote! { compile_error!("Too many fields or variants for bitflags."); }
81                    .into();
82            }
83        };
84
85        let flag_consts = variant_idents.iter().enumerate().map(|(i, ident)| {
86            quote! {
87                const #ident = 1 << #i;
88            }
89        });
90
91        quote! {
92            ::bitflags::bitflags! {
93                #[doc = #doc_comment]
94                #[derive(#derive_inner)]
95                pub struct #new_enum_name: #size {
96                    #(#flag_consts)*
97                }
98            }
99        }
100    };
101
102    #[cfg(not(feature = "bitflags"))]
103    let new_enum = {
104        let variants = variant_idents.iter().map(|ident| quote! { #ident, });
105        quote! {
106            #[doc = #doc_comment]
107            #[derive(#derive_inner)]
108            pub enum #new_enum_name {
109                #(#variants)*
110            }
111        }
112    };
113
114    // Only generate kind() and From<> if the original was an enum
115    let new_enum_impl = match kind {
116        InputKind::Enum(ref variants) => {
117            let match_arms = variants.iter().map(|(ident, fields)| match fields {
118                Fields::Named(_) => quote! {
119                    Self::#ident { .. } => #new_enum_name::#ident,
120                },
121                Fields::Unnamed(_) => quote! {
122                    Self::#ident(..) => #new_enum_name::#ident,
123                },
124                Fields::Unit => quote! {
125                    Self::#ident => #new_enum_name::#ident,
126                },
127            });
128
129            let doc_comment = format!("The [`{}`] of this [`{}`].", new_enum_name, old_enum_name);
130            quote! {
131                impl #old_enum_name {
132                    #[doc = #doc_comment]
133                    pub const fn kind(&self) -> #new_enum_name {
134                        match self {
135                            #(#match_arms)*
136                        }
137                    }
138                }
139
140                impl From<#old_enum_name> for #new_enum_name {
141                    fn from(value: #old_enum_name) -> Self {
142                        value.kind()
143                    }
144                }
145            }
146        }
147        _ => quote! {},
148    };
149
150    quote! {
151        #new_enum
152        #new_enum_impl
153    }
154    .into()
155}