enum_unit/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
5
6#[proc_macro_derive(EnumUnit)]
7pub fn into_unit_enum(input: TokenStream) -> TokenStream {
8    let input = parse_macro_input!(input as DeriveInput);
9    let old_enum_name = input.ident.clone();
10    let new_enum_name = format_ident!("{}Unit", old_enum_name);
11
12    enum InputKind {
13        Struct(Vec<Ident>),
14        Enum(Vec<(Ident, Fields)>),
15    }
16
17    let kind = match input.data {
18        Data::Struct(data) => match data.fields {
19            Fields::Named(fields_named) => {
20                if fields_named.named.is_empty() {
21                    return quote! {}.into();
22                }
23                let names = fields_named
24                    .named
25                    .into_iter()
26                    .filter_map(|f| f.ident)
27                    .map(|ident| format_ident!("{}", ident.to_string().to_case(Case::Pascal)))
28                    .collect();
29                InputKind::Struct(names)
30            }
31            Fields::Unnamed(fields) => {
32                if fields.unnamed.is_empty() {
33                    return quote! {}.into();
34                }
35                let names = (0..fields.unnamed.len())
36                    .map(|i| format_ident!("F{}", i))
37                    .collect();
38                InputKind::Struct(names)
39            }
40            Fields::Unit => return quote! {}.into(),
41        },
42        Data::Enum(data) => {
43            if data.variants.is_empty() {
44                return quote! {}.into();
45            }
46            let variants = data
47                .variants
48                .into_iter()
49                .map(|v| (v.ident, v.fields))
50                .collect();
51            InputKind::Enum(variants)
52        }
53        Data::Union(..) => return quote! { compile_error!("Unions are not supported.") }.into(),
54    };
55
56    let doc_comment = format!(
57        "Automatically generated unit-variants of [`{}`].",
58        old_enum_name
59    );
60
61    // Trait derivation
62    let derive_inner = quote! {
63        Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
64    };
65
66    #[cfg(feature = "serde")]
67    let derive_inner = quote! {
68        #derive_inner, ::serde::Serialize, ::serde::Deserialize
69    };
70
71    // Collect variant names regardless of origin
72    let variant_idents: Vec<Ident> = match &kind {
73        InputKind::Struct(fields) => fields.clone(),
74        InputKind::Enum(variants) => variants.iter().map(|(ident, _)| ident.clone()).collect(),
75    };
76
77    #[cfg(feature = "bitflags")]
78    let new_enum = {
79        let size = match variant_idents.len() {
80            1..=8 => quote! { u8 },
81            9..=16 => quote! { u16 },
82            17..=32 => quote! { u32 },
83            33..=64 => quote! { u64 },
84            65..=128 => quote! { u128 },
85            _ => {
86                return quote! { compile_error!("Too many fields or variants for bitflags."); }
87                    .into();
88            }
89        };
90
91        let flag_consts = variant_idents.iter().enumerate().map(|(i, ident)| {
92            quote! {
93                const #ident = 1 << #i;
94            }
95        });
96
97        quote! {
98            ::bitflags::bitflags! {
99                #[doc = #doc_comment]
100                #[derive(#derive_inner)]
101                pub struct #new_enum_name: #size {
102                    #(#flag_consts)*
103                }
104            }
105        }
106    };
107
108    #[cfg(not(feature = "bitflags"))]
109    let new_enum = {
110        let variants = variant_idents.iter().map(|ident| quote! { #ident, });
111        quote! {
112            #[doc = #doc_comment]
113            #[derive(#derive_inner)]
114            pub enum #new_enum_name {
115                #(#variants)*
116            }
117        }
118    };
119
120    // Only generate kind() and From<> if the original was an enum
121    let new_enum_impl = match kind {
122        InputKind::Enum(ref variants) => {
123            let match_arms = variants.iter().map(|(ident, fields)| match fields {
124                Fields::Named(_) => quote! {
125                    Self::#ident { .. } => #new_enum_name::#ident,
126                },
127                Fields::Unnamed(_) => quote! {
128                    Self::#ident(..) => #new_enum_name::#ident,
129                },
130                Fields::Unit => quote! {
131                    Self::#ident => #new_enum_name::#ident,
132                },
133            });
134
135            let doc_comment = format!("The [`{}`] of this [`{}`].", new_enum_name, old_enum_name);
136            quote! {
137                impl #old_enum_name {
138                    #[doc = #doc_comment]
139                    pub const fn kind(&self) -> #new_enum_name {
140                        match self {
141                            #(#match_arms)*
142                        }
143                    }
144                }
145
146                impl From<#old_enum_name> for #new_enum_name {
147                    fn from(value: #old_enum_name) -> Self {
148                        value.kind()
149                    }
150                }
151            }
152        }
153        _ => quote! {},
154    };
155
156    quote! {
157        #new_enum
158        #new_enum_impl
159    }
160    .into()
161}