1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
5
6#[proc_macro_derive(EnumUnit)]
7pub fn into_unit_enum(input: TokenStream) -> TokenStream {
8 let input = parse_macro_input!(input as DeriveInput);
9 let old_enum_name = input.ident.clone();
10 let new_enum_name = format_ident!("{}Unit", old_enum_name);
11
12 enum InputKind {
13 Struct(Vec<Ident>),
14 Enum(Vec<(Ident, Fields)>),
15 }
16
17 let kind = match input.data {
18 Data::Struct(data) => match data.fields {
19 Fields::Named(fields_named) => {
20 if fields_named.named.is_empty() {
21 return quote! {}.into();
22 }
23 let names = fields_named
24 .named
25 .into_iter()
26 .filter_map(|f| f.ident)
27 .map(|ident| format_ident!("{}", ident.to_string().to_case(Case::Pascal)))
28 .collect();
29 InputKind::Struct(names)
30 }
31 Fields::Unnamed(..) => {
32 return quote! { compile_error!("Tuple structs are not supported.") }.into();
33 }
34 Fields::Unit => return quote! {}.into(),
35 },
36 Data::Enum(data) => {
37 if data.variants.is_empty() {
38 return quote! {}.into();
39 }
40 let variants = data
41 .variants
42 .into_iter()
43 .map(|v| (v.ident, v.fields))
44 .collect();
45 InputKind::Enum(variants)
46 }
47 Data::Union(..) => return quote! { compile_error!("Unions are not supported.") }.into(),
48 };
49
50 let doc_comment = format!(
51 "Automatically generated unit-variants of [`{}`].",
52 old_enum_name
53 );
54
55 let derive_inner = quote! {
57 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
58 };
59
60 #[cfg(feature = "serde")]
61 let derive_inner = quote! {
62 #derive_inner, ::serde::Serialize, ::serde::Deserialize
63 };
64
65 let variant_idents: Vec<Ident> = match &kind {
67 InputKind::Struct(fields) => fields.clone(),
68 InputKind::Enum(variants) => variants.iter().map(|(ident, _)| ident.clone()).collect(),
69 };
70
71 #[cfg(feature = "bitflags")]
72 let new_enum = {
73 let size = match variant_idents.len() {
74 1..=8 => quote! { u8 },
75 9..=16 => quote! { u16 },
76 17..=32 => quote! { u32 },
77 33..=64 => quote! { u64 },
78 65..=128 => quote! { u128 },
79 _ => {
80 return quote! { compile_error!("Too many fields or variants for bitflags."); }
81 .into();
82 }
83 };
84
85 let flag_consts = variant_idents.iter().enumerate().map(|(i, ident)| {
86 quote! {
87 const #ident = 1 << #i;
88 }
89 });
90
91 quote! {
92 ::bitflags::bitflags! {
93 #[doc = #doc_comment]
94 #[derive(#derive_inner)]
95 pub struct #new_enum_name: #size {
96 #(#flag_consts)*
97 }
98 }
99 }
100 };
101
102 #[cfg(not(feature = "bitflags"))]
103 let new_enum = {
104 let variants = variant_idents.iter().map(|ident| quote! { #ident, });
105 quote! {
106 #[doc = #doc_comment]
107 #[derive(#derive_inner)]
108 pub enum #new_enum_name {
109 #(#variants)*
110 }
111 }
112 };
113
114 let new_enum_impl = match kind {
116 InputKind::Enum(ref variants) => {
117 let match_arms = variants.iter().map(|(ident, fields)| match fields {
118 Fields::Named(_) => quote! {
119 Self::#ident { .. } => #new_enum_name::#ident,
120 },
121 Fields::Unnamed(_) => quote! {
122 Self::#ident(..) => #new_enum_name::#ident,
123 },
124 Fields::Unit => quote! {
125 Self::#ident => #new_enum_name::#ident,
126 },
127 });
128
129 let doc_comment = format!("The [`{}`] of this [`{}`].", new_enum_name, old_enum_name);
130 quote! {
131 impl #old_enum_name {
132 #[doc = #doc_comment]
133 pub const fn kind(&self) -> #new_enum_name {
134 match self {
135 #(#match_arms)*
136 }
137 }
138 }
139
140 impl From<#old_enum_name> for #new_enum_name {
141 fn from(value: #old_enum_name) -> Self {
142 value.kind()
143 }
144 }
145 }
146 }
147 _ => quote! {},
148 };
149
150 quote! {
151 #new_enum
152 #new_enum_impl
153 }
154 .into()
155}