enum_unit/
lib.rs

1use convert_case::{Case, Casing};
2use enum_unit_core::*;
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
6
7#[proc_macro_derive(EnumUnit)]
8pub fn into_unit_enum(input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10    let old_enum_name = input.ident.clone();
11    let new_enum_name = format_ident!("{}Unit", old_enum_name);
12
13    enum InputKind {
14        Struct(Vec<Ident>),
15        Enum(Vec<(Ident, Fields)>),
16    }
17
18    let kind = match input.data {
19        Data::Struct(data) => match data.fields {
20            Fields::Named(fields_named) => {
21                if fields_named.named.is_empty() {
22                    return quote! {}.into();
23                }
24                let names = fields_named
25                    .named
26                    .into_iter()
27                    .filter_map(|f| f.ident)
28                    .map(|ident| format_ident!("{}", ident.to_string().to_case(Case::Pascal)))
29                    .collect();
30                InputKind::Struct(names)
31            }
32            Fields::Unnamed(fields) => {
33                if fields.unnamed.is_empty() {
34                    return quote! {}.into();
35                }
36                let names = (0..fields.unnamed.len())
37                    .map(|i| format_ident!("{}{}", prefix(), i))
38                    .collect();
39                InputKind::Struct(names)
40            }
41            Fields::Unit => return quote! {}.into(),
42        },
43        Data::Enum(data) => {
44            if data.variants.is_empty() {
45                return quote! {}.into();
46            }
47            let variants = data
48                .variants
49                .into_iter()
50                .map(|v| (v.ident, v.fields))
51                .collect();
52            InputKind::Enum(variants)
53        }
54        Data::Union(..) => return quote! { compile_error!("Unions are not supported.") }.into(),
55    };
56
57    let doc_comment = format!(
58        "Automatically generated unit-variants of [`{}`].",
59        old_enum_name
60    );
61
62    // Trait derivation
63    let derive_inner = quote! {
64        Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
65    };
66
67    #[cfg(feature = "serde")]
68    let derive_inner = quote! {
69        #derive_inner, ::serde::Serialize, ::serde::Deserialize
70    };
71
72    // Collect variant names regardless of origin
73    let variant_idents: Vec<Ident> = match &kind {
74        InputKind::Struct(fields) => fields.clone(),
75        InputKind::Enum(variants) => variants.iter().map(|(ident, _)| ident.clone()).collect(),
76    };
77
78    #[cfg(feature = "bitflags")]
79    let new_enum = {
80        let size = match variant_idents.len() {
81            1..=8 => quote! { u8 },
82            9..=16 => quote! { u16 },
83            17..=32 => quote! { u32 },
84            33..=64 => quote! { u64 },
85            65..=128 => quote! { u128 },
86            _ => {
87                return quote! { compile_error!("Too many fields or variants for bitflags."); }
88                    .into();
89            }
90        };
91
92        let flag_consts = variant_idents.iter().enumerate().map(|(i, ident)| {
93            quote! {
94                const #ident = 1 << #i;
95            }
96        });
97
98        quote! {
99            ::bitflags::bitflags! {
100                #[doc = #doc_comment]
101                #[derive(#derive_inner)]
102                pub struct #new_enum_name: #size {
103                    #(#flag_consts)*
104                }
105            }
106        }
107    };
108
109    #[cfg(not(feature = "bitflags"))]
110    let new_enum = {
111        let variants = variant_idents.iter().map(|ident| quote! { #ident, });
112        quote! {
113            #[doc = #doc_comment]
114            #[derive(#derive_inner)]
115            pub enum #new_enum_name {
116                #(#variants)*
117            }
118        }
119    };
120
121    // Only generate kind() and From<> if the original was an enum
122    let new_enum_impl = match kind {
123        InputKind::Enum(ref variants) => {
124            let match_arms = variants.iter().map(|(ident, fields)| match fields {
125                Fields::Named(_) => quote! {
126                    Self::#ident { .. } => #new_enum_name::#ident,
127                },
128                Fields::Unnamed(_) => quote! {
129                    Self::#ident(..) => #new_enum_name::#ident,
130                },
131                Fields::Unit => quote! {
132                    Self::#ident => #new_enum_name::#ident,
133                },
134            });
135
136            let doc_comment = format!("The [`{}`] of this [`{}`].", new_enum_name, old_enum_name);
137            quote! {
138                impl #old_enum_name {
139                    #[doc = #doc_comment]
140                    pub const fn kind(&self) -> #new_enum_name {
141                        match self {
142                            #(#match_arms)*
143                        }
144                    }
145                }
146
147                impl From<#old_enum_name> for #new_enum_name {
148                    fn from(value: #old_enum_name) -> Self {
149                        value.kind()
150                    }
151                }
152            }
153        }
154        _ => quote! {},
155    };
156
157    quote! {
158        #new_enum
159        #new_enum_impl
160    }
161    .into()
162}