1use proc_macro::TokenStream as TokenStream1;
6
7use proc_macro2::{Ident, TokenStream};
8use quote::quote;
9use syn::{parse_macro_input, Attribute, Expr, Type};
10
11struct Variant {
15 ident: Ident,
17 discriminant: Expr,
19 documentation: Vec<Attribute>,
21}
22
23struct PodEnum {
27 vis: syn::Visibility,
31 ident: Ident,
33 repr: Type,
35 variants: Vec<Variant>,
37 attrs: Vec<Attribute>,
42}
43
44impl PodEnum {
46 fn write_impl(&self) -> TokenStream {
48 let ident = &self.ident;
49 let repr = &self.repr;
50 let vis = &self.vis;
51 let attrs = &self.attrs;
52
53 let variants = self.write_variants();
54 let debug = self.write_debug();
55 let conversions = self.write_conversions();
56 let partial_eq = self.write_partial_eq();
57
58 quote!(
59 #( #attrs )*
60 #[derive(Copy, Clone)]
61 #[repr(transparent)]
62 #vis struct #ident {
63 inner: #repr,
64 }
65
66 impl ::pod_enum::PodEnum for #ident {
67 type Repr = #repr;
68 }
69
70 unsafe impl ::pod_enum::bytemuck::Pod for #ident {}
75 unsafe impl ::pod_enum::bytemuck::Zeroable for #ident {}
80
81 #variants
82
83 #debug
84
85 #conversions
86
87 #partial_eq
88 )
89 }
90
91 fn write_variants(&self) -> TokenStream {
93 let ident = &self.ident;
94 let vis = &self.vis;
95 let variants = self.variants.iter().map(
96 |Variant {
97 ident,
98 discriminant,
99 documentation,
100 }| {
101 quote!(
102 #( #documentation )*
103 #vis const #ident: Self = Self { inner: #discriminant };
104 )
105 },
106 );
107 quote! {
108 #[allow(non_upper_case_globals)]
110 impl #ident {
111 #( #variants )*
112 }
113 }
114 }
115
116 fn write_debug(&self) -> TokenStream {
118 let ident = &self.ident;
119 let variants = self.variants.iter().map(
120 |Variant {
121 ident,
122 discriminant,
123 ..
124 }| {
125 let name = ident.to_string();
126 quote!(#discriminant => f.write_str(#name))
127 },
128 );
129 quote!(
130 impl ::core::fmt::Debug for #ident {
132 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
133 match self.inner {
134 #( #variants, )*
135 val => write!(f, "Unknown ({})", val),
136 }
137 }
138 }
139 )
140 }
141
142 fn write_conversions(&self) -> TokenStream {
144 let ident = &self.ident;
145 let repr = &self.repr;
146
147 quote!(
148 impl From<#repr> for #ident {
149 fn from(inner: #repr) -> Self {
150 Self { inner }
151 }
152 }
153
154 impl From<#ident> for #repr {
155 fn from(pod: #ident) -> Self {
156 pod.inner
157 }
158 }
159 )
160 }
161
162 fn write_partial_eq(&self) -> TokenStream {
164 let ident = &self.ident;
165 let variants = self
166 .variants
167 .iter()
168 .map(|Variant { discriminant, .. }| quote!((#discriminant, #discriminant) => true));
169
170 quote!(
171 impl PartialEq for #ident {
176 fn eq(&self, other: &Self) -> bool {
177 match (self.inner, other.inner) {
178 #( #variants, )*
179 _ => false,
180 }
181 }
182 }
183 )
184 }
185}
186
187impl TryFrom<syn::ItemEnum> for PodEnum {
191 type Error = TokenStream;
192
193 fn try_from(value: syn::ItemEnum) -> Result<Self, Self::Error> {
194 let ident = value.ident;
195 let repr = value
196 .attrs
197 .iter()
198 .find_map(|attr| {
199 if &attr.path().get_ident()?.to_string() != "repr" {
200 return None;
201 }
202 attr.parse_args::<Type>().ok()
203 })
204 .ok_or_else(|| {
205 syn::Error::new(ident.span(), "Missing `#[repr(..)]` attribute")
206 .into_compile_error()
207 })?;
208 let attrs = value
209 .attrs
210 .into_iter()
211 .filter(|attr| {
212 attr.path()
213 .get_ident()
214 .map_or(true, |name| &name.to_string() != "repr")
215 })
216 .collect();
217 let variants = value
218 .variants
219 .into_iter()
220 .map(|variant| {
221 let (docs, other_attrs) =
222 variant
223 .attrs
224 .into_iter()
225 .partition::<Vec<Attribute>, _>(|attr| {
226 attr.path()
227 .get_ident()
228 .map_or(false, |name| &name.to_string() == "doc")
229 });
230 if !other_attrs.is_empty() {
231 return Err(syn::Error::new(
232 variant.ident.span(),
233 "Unexpected non-documentation item on enum variant",
234 )
235 .into_compile_error());
236 }
237 if variant.fields != syn::Fields::Unit {
238 return Err(syn::Error::new(
239 variant.ident.span(),
240 "Unexpected non-unit enum variant",
241 )
242 .into_compile_error());
243 }
244 let discriminant = variant
245 .discriminant
246 .ok_or_else(|| {
247 syn::Error::new(
248 variant.ident.span(),
249 "Missing explicit discriminant on variant",
250 )
251 .into_compile_error()
252 })?
253 .1;
254 Ok(Variant {
255 ident: variant.ident,
256 discriminant,
257 documentation: docs,
258 })
259 })
260 .collect::<Result<Vec<Variant>, TokenStream>>()?;
261 Ok(Self {
262 vis: value.vis,
263 attrs,
264 ident,
265 repr,
266 variants,
267 })
268 }
269}
270
271#[doc = ""]
272#[proc_macro_attribute]
273pub fn pod_enum(_args: TokenStream1, input: TokenStream1) -> TokenStream1 {
274 let ast = parse_macro_input!(input as syn::ItemEnum);
275
276 let result = match PodEnum::try_from(ast) {
277 Ok(result) => result,
278 Err(e) => return e.into(),
279 };
280
281 result.write_impl().into()
282}