1use convert_case::{Case, Casing};
2use enum_unit_core::*;
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
6
7#[proc_macro_derive(EnumUnit)]
8pub fn into_unit_enum(input: TokenStream) -> TokenStream {
9 let input = parse_macro_input!(input as DeriveInput);
10 let old_enum_name = input.ident.clone();
11 let new_enum_name = format_ident!("{}Unit", old_enum_name);
12
13 enum InputKind {
14 Struct(Vec<Ident>),
15 Enum(Vec<(Ident, Fields)>),
16 }
17
18 let kind = match input.data {
19 Data::Struct(data) => match data.fields {
20 Fields::Named(fields_named) => {
21 if fields_named.named.is_empty() {
22 return quote! {}.into();
23 }
24 let names = fields_named
25 .named
26 .into_iter()
27 .filter_map(|f| f.ident)
28 .map(|ident| format_ident!("{}", ident.to_string().to_case(Case::Pascal)))
29 .collect();
30 InputKind::Struct(names)
31 }
32 Fields::Unnamed(fields) => {
33 if fields.unnamed.is_empty() {
34 return quote! {}.into();
35 }
36 let names = (0..fields.unnamed.len())
37 .map(|i| format_ident!("{}{}", prefix(), i))
38 .collect();
39 InputKind::Struct(names)
40 }
41 Fields::Unit => return quote! {}.into(),
42 },
43 Data::Enum(data) => {
44 if data.variants.is_empty() {
45 return quote! {}.into();
46 }
47 let variants = data
48 .variants
49 .into_iter()
50 .map(|v| (v.ident, v.fields))
51 .collect();
52 InputKind::Enum(variants)
53 }
54 Data::Union(..) => return quote! { compile_error!("Unions are not supported.") }.into(),
55 };
56
57 let doc_comment = format!(
58 "Automatically generated unit-variants of [`{}`].",
59 old_enum_name
60 );
61
62 let derive_inner = quote! {
64 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord
65 };
66
67 #[cfg(feature = "serde")]
68 let derive_inner = quote! {
69 #derive_inner, ::serde::Serialize, ::serde::Deserialize
70 };
71
72 let variant_idents: Vec<Ident> = match &kind {
74 InputKind::Struct(fields) => fields.clone(),
75 InputKind::Enum(variants) => variants.iter().map(|(ident, _)| ident.clone()).collect(),
76 };
77
78 #[cfg(feature = "bitflags")]
79 let new_enum = {
80 let size = match variant_idents.len() {
81 1..=8 => quote! { u8 },
82 9..=16 => quote! { u16 },
83 17..=32 => quote! { u32 },
84 33..=64 => quote! { u64 },
85 65..=128 => quote! { u128 },
86 _ => {
87 return quote! { compile_error!("Too many fields or variants for bitflags."); }
88 .into();
89 }
90 };
91
92 let flag_consts = variant_idents.iter().enumerate().map(|(i, ident)| {
93 quote! {
94 const #ident = 1 << #i;
95 }
96 });
97
98 quote! {
99 ::bitflags::bitflags! {
100 #[doc = #doc_comment]
101 #[derive(#derive_inner)]
102 pub struct #new_enum_name: #size {
103 #(#flag_consts)*
104 }
105 }
106 }
107 };
108
109 #[cfg(not(feature = "bitflags"))]
110 let new_enum = {
111 let variants = variant_idents.iter().map(|ident| quote! { #ident, });
112 quote! {
113 #[doc = #doc_comment]
114 #[derive(#derive_inner)]
115 pub enum #new_enum_name {
116 #(#variants)*
117 }
118 }
119 };
120
121 let new_enum_impl = match kind {
123 InputKind::Enum(ref variants) => {
124 let match_arms = variants.iter().map(|(ident, fields)| match fields {
125 Fields::Named(_) => quote! {
126 Self::#ident { .. } => #new_enum_name::#ident,
127 },
128 Fields::Unnamed(_) => quote! {
129 Self::#ident(..) => #new_enum_name::#ident,
130 },
131 Fields::Unit => quote! {
132 Self::#ident => #new_enum_name::#ident,
133 },
134 });
135
136 let doc_comment = format!("The [`{}`] of this [`{}`].", new_enum_name, old_enum_name);
137 quote! {
138 impl #old_enum_name {
139 #[doc = #doc_comment]
140 pub const fn kind(&self) -> #new_enum_name {
141 match self {
142 #(#match_arms)*
143 }
144 }
145 }
146
147 impl From<#old_enum_name> for #new_enum_name {
148 fn from(value: #old_enum_name) -> Self {
149 value.kind()
150 }
151 }
152 }
153 }
154 _ => quote! {},
155 };
156
157 quote! {
158 #new_enum
159 #new_enum_impl
160 }
161 .into()
162}