1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
5
6#[proc_macro_derive(EnumUnit)]
7pub fn into_unit_enum(input: TokenStream) -> TokenStream {
8 let input = parse_macro_input!(input as DeriveInput);
9 let old_enum_name = input.ident.clone();
10 let new_enum_name = format_ident!("{}Unit", old_enum_name);
11
12 enum InputKind {
13 Struct(Vec<Ident>),
14 Enum(Vec<(Ident, Fields)>),
15 }
16
17 let kind = match input.data {
18 Data::Struct(data) => match data.fields {
19 Fields::Named(fields_named) => {
20 if fields_named.named.is_empty() {
21 return quote! {}.into();
22 }
23 let names = fields_named
24 .named
25 .into_iter()
26 .filter_map(|f| f.ident)
27 .map(|ident| format_ident!("{}", ident.to_string().to_case(Case::Pascal)))
28 .collect();
29 InputKind::Struct(names)
30 }
31 Fields::Unnamed(fields) => {
32 if fields.unnamed.is_empty() {
33 return quote! {}.into();
34 }
35 let names = (0..fields.unnamed.len())
36 .map(|i| format_ident!("F{}", i))
37 .collect();
38 InputKind::Struct(names)
39 }
40 Fields::Unit => return quote! {}.into(),
41 },
42 Data::Enum(data) => {
43 if data.variants.is_empty() {
44 return quote! {}.into();
45 }
46 let variants = data
47 .variants
48 .into_iter()
49 .map(|v| (v.ident, v.fields))
50 .collect();
51 InputKind::Enum(variants)
52 }
53 Data::Union(..) => return quote! { compile_error!("Unions are not supported.") }.into(),
54 };
55
56 let doc_comment = format!(
57 "Automatically generated unit-variants of [`{}`].",
58 old_enum_name
59 );
60
61 let derive_inner = quote! {
63 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
64 };
65
66 #[cfg(feature = "serde")]
67 let derive_inner = quote! {
68 #derive_inner, ::serde::Serialize, ::serde::Deserialize
69 };
70
71 let variant_idents: Vec<Ident> = match &kind {
73 InputKind::Struct(fields) => fields.clone(),
74 InputKind::Enum(variants) => variants.iter().map(|(ident, _)| ident.clone()).collect(),
75 };
76
77 #[cfg(feature = "bitflags")]
78 let new_enum = {
79 let size = match variant_idents.len() {
80 1..=8 => quote! { u8 },
81 9..=16 => quote! { u16 },
82 17..=32 => quote! { u32 },
83 33..=64 => quote! { u64 },
84 65..=128 => quote! { u128 },
85 _ => {
86 return quote! { compile_error!("Too many fields or variants for bitflags."); }
87 .into();
88 }
89 };
90
91 let flag_consts = variant_idents.iter().enumerate().map(|(i, ident)| {
92 quote! {
93 const #ident = 1 << #i;
94 }
95 });
96
97 quote! {
98 ::bitflags::bitflags! {
99 #[doc = #doc_comment]
100 #[derive(#derive_inner)]
101 pub struct #new_enum_name: #size {
102 #(#flag_consts)*
103 }
104 }
105 }
106 };
107
108 #[cfg(not(feature = "bitflags"))]
109 let new_enum = {
110 let variants = variant_idents.iter().map(|ident| quote! { #ident, });
111 quote! {
112 #[doc = #doc_comment]
113 #[derive(#derive_inner)]
114 pub enum #new_enum_name {
115 #(#variants)*
116 }
117 }
118 };
119
120 let new_enum_impl = match kind {
122 InputKind::Enum(ref variants) => {
123 let match_arms = variants.iter().map(|(ident, fields)| match fields {
124 Fields::Named(_) => quote! {
125 Self::#ident { .. } => #new_enum_name::#ident,
126 },
127 Fields::Unnamed(_) => quote! {
128 Self::#ident(..) => #new_enum_name::#ident,
129 },
130 Fields::Unit => quote! {
131 Self::#ident => #new_enum_name::#ident,
132 },
133 });
134
135 let doc_comment = format!("The [`{}`] of this [`{}`].", new_enum_name, old_enum_name);
136 quote! {
137 impl #old_enum_name {
138 #[doc = #doc_comment]
139 pub const fn kind(&self) -> #new_enum_name {
140 match self {
141 #(#match_arms)*
142 }
143 }
144 }
145
146 impl From<#old_enum_name> for #new_enum_name {
147 fn from(value: #old_enum_name) -> Self {
148 value.kind()
149 }
150 }
151 }
152 }
153 _ => quote! {},
154 };
155
156 quote! {
157 #new_enum
158 #new_enum_impl
159 }
160 .into()
161}