1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5#[proc_macro_derive(EnumUnit)]
6pub fn into_unit_enum(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
9
10 let variants = if let Data::Enum(data_enum) = input.data {
12 data_enum.variants
13 } else {
14 return quote! { compile_error!("Unsupported structure (enum's only)") }.into();
15 };
16
17 if variants.is_empty() {
19 return quote! {}.into();
20 }
21
22 let old_enum_name = input.ident;
24 let new_enum_name = quote::format_ident!("{}Unit", old_enum_name);
25
26 let match_arms = variants.iter().map(|variant| {
28 let ident = &variant.ident;
29
30 match &variant.fields {
32 Fields::Unit => {
33 quote! {
34 #old_enum_name::#ident => #new_enum_name::#ident,
35 }
36 }
37 Fields::Unnamed(_) => {
38 quote! {
39 #old_enum_name::#ident(..) => #new_enum_name::#ident,
40 }
41 }
42 Fields::Named(_) => {
43 quote! {
44 #old_enum_name::#ident { .. } => #new_enum_name::#ident,
45 }
46 }
47 }
48 });
49
50 let derive_inner = quote! {
52 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
53 };
54
55 #[cfg(feature = "serde")]
57 let derive_inner = quote! {
58 #derive_inner, ::serde::Serialize, ::serde::Deserialize
59 };
60
61 let doc_comment = format!(
62 "Automatically generated unit-variants of [`{}`].",
63 old_enum_name
64 );
65
66 #[cfg(not(feature = "bitflags"))]
68 let new_enum = {
69 let flag_arms = variants.iter().map(|variant| {
70 let ident = &variant.ident;
71 quote! { #ident, }
72 });
73
74 quote! {
75 #[doc = #doc_comment]
76 #[derive(#derive_inner)]
77 pub enum #new_enum_name {
78 #(#flag_arms)*
79 }
80 }
81 };
82
83 #[cfg(feature = "bitflags")]
84 let new_enum = {
85 let size = match variants.len() {
87 1..=8 => quote! { u8 },
88 9..=16 => quote! { u16 },
89 17..=32 => quote! { u32 },
90 33..=64 => quote! { u64 },
91 65..=128 => quote! { u128 },
92 _ => return quote! { compile_error!("Enum has too many variants."); }.into(),
93 };
94
95 let flag_arms = variants.iter().enumerate().map(|(i, variant)| {
96 let ident = &variant.ident;
97 quote! {
98 const #ident = 1 << #i;
99 }
100 });
101
102 quote! {
103 ::bitflags::bitflags! {
104 #[doc = #doc_comment]
105 #[derive(#derive_inner)]
106 pub struct #new_enum_name: #size {
107 #(#flag_arms)*
108 }
109 }
110 }
111 };
112
113 let doc_comment = format!("The [`{}`] of this [`{}`].", new_enum_name, old_enum_name);
114
115 let new_enum_impl = quote! {
117 impl #old_enum_name {
118 #[doc = #doc_comment]
119 pub const fn kind(&self) -> #new_enum_name {
120 match self {
121 #(#match_arms)*
122 }
123 }
124 }
125
126 impl From<#old_enum_name> for #new_enum_name {
127 fn from(value: #old_enum_name) -> Self {
128 value.kind()
129 }
130 }
131 };
132
133 quote! {
135 #new_enum
136 #new_enum_impl
137 }
138 .into()
139}