1use convert_case::{Case, Casing};
2use enum_unit_core::*;
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
6
7#[proc_macro_derive(EnumUnit)]
8pub fn into_unit_enum(input: TokenStream) -> TokenStream {
9 let input = parse_macro_input!(input as DeriveInput);
10 let old_enum_name = input.ident.clone();
11 let new_enum_name = format_ident!("{}Unit", old_enum_name);
12
13 enum InputKind {
14 Struct(Vec<Ident>),
15 Enum(Vec<(Ident, Fields)>),
16 }
17
18 let kind = match input.data {
19 Data::Struct(data) => match data.fields {
20 Fields::Named(fields_named) => {
21 if fields_named.named.is_empty() {
22 return quote! {}.into();
23 }
24 let names = fields_named
25 .named
26 .into_iter()
27 .filter_map(|f| f.ident)
28 .map(|ident| format_ident!("{}", ident.to_string().to_case(Case::Pascal)))
29 .collect();
30 InputKind::Struct(names)
31 }
32 Fields::Unnamed(fields) => {
33 if fields.unnamed.is_empty() {
34 return quote! {}.into();
35 }
36 let names = (0..fields.unnamed.len())
37 .map(|i| format_ident!("{}{}", prefix(), i))
38 .collect();
39 InputKind::Struct(names)
40 }
41 Fields::Unit => return quote! {}.into(),
42 },
43 Data::Enum(data) => {
44 if data.variants.is_empty() {
45 return quote! {}.into();
46 }
47 let variants = data
48 .variants
49 .into_iter()
50 .map(|v| (v.ident, v.fields))
51 .collect();
52 InputKind::Enum(variants)
53 }
54 Data::Union(..) => return quote! { compile_error!("Unions are not supported.") }.into(),
55 };
56
57 let doc_comment = format!("Automatically generated unit-variants of [`{old_enum_name}`].");
58
59 let derive_inner = quote! {
61 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
62 };
63
64 #[cfg(feature = "serde")]
65 let derive_inner = quote! {
66 #derive_inner, ::serde::Serialize, ::serde::Deserialize
67 };
68
69 let variant_idents: Vec<Ident> = match &kind {
71 InputKind::Struct(fields) => fields.clone(),
72 InputKind::Enum(variants) => variants.iter().map(|(ident, _)| ident.clone()).collect(),
73 };
74
75 #[cfg(feature = "bitflags")]
76 let new_enum = {
77 let size = match variant_idents.len() {
78 1..=8 => quote! { u8 },
79 9..=16 => quote! { u16 },
80 17..=32 => quote! { u32 },
81 33..=64 => quote! { u64 },
82 65..=128 => quote! { u128 },
83 _ => {
84 return quote! { compile_error!("Too many fields or variants for bitflags."); }
85 .into();
86 }
87 };
88
89 let flag_consts = variant_idents.iter().enumerate().map(|(i, ident)| {
90 quote! {
91 const #ident = 1 << #i;
92 }
93 });
94
95 quote! {
96 ::bitflags::bitflags! {
97 #[doc = #doc_comment]
98 #[derive(#derive_inner)]
99 pub struct #new_enum_name: #size {
100 #(#flag_consts)*
101 }
102 }
103 }
104 };
105
106 #[cfg(not(feature = "bitflags"))]
107 let new_enum = {
108 let variants = variant_idents.iter().map(|ident| quote! { #ident, });
109 quote! {
110 #[doc = #doc_comment]
111 #[derive(#derive_inner)]
112 pub enum #new_enum_name {
113 #(#variants)*
114 }
115 }
116 };
117
118 let new_enum_impl = match kind {
120 InputKind::Enum(ref variants) => {
121 let match_arms = variants.iter().map(|(ident, fields)| match fields {
122 Fields::Named(_) => quote! {
123 Self::#ident { .. } => #new_enum_name::#ident,
124 },
125 Fields::Unnamed(_) => quote! {
126 Self::#ident(..) => #new_enum_name::#ident,
127 },
128 Fields::Unit => quote! {
129 Self::#ident => #new_enum_name::#ident,
130 },
131 });
132
133 let doc_comment = format!("The [`{new_enum_name}`] of this [`{old_enum_name}`].");
134 quote! {
135 impl #old_enum_name {
136 #[doc = #doc_comment]
137 pub const fn kind(&self) -> #new_enum_name {
138 match self {
139 #(#match_arms)*
140 }
141 }
142 }
143
144 impl From<#old_enum_name> for #new_enum_name {
145 fn from(value: #old_enum_name) -> Self {
146 value.kind()
147 }
148 }
149 }
150 }
151 _ => quote! {},
152 };
153
154 quote! {
155 #new_enum
156 #new_enum_impl
157 }
158 .into()
159}