pod_enum_macros/
lib.rs

1//! A proc-macro for the `pod-enum` crate
2//!
3//! Consider importing that crate instead
4
5use proc_macro::TokenStream as TokenStream1;
6
7use proc_macro2::{Ident, TokenStream};
8use quote::quote;
9use syn::{parse_macro_input, Attribute, Expr, Type};
10
11/// A variant of a [`PodEnum`]
12///
13/// This has already been parsed to ensure that no errors are possible during code generation.
14struct Variant {
15    /// The name of this variant
16    ident: Ident,
17    /// The discriminant given for this variant
18    discriminant: Expr,
19    /// The documentation attributes on this variant
20    documentation: Vec<Attribute>,
21}
22
23/// An enum given to this macro
24///
25/// This has already been parsed to ensure that no errors are possible during code generation.
26struct PodEnum {
27    /// The visibility of this enum (e.g. `pub` or `pub(crate)`)
28    ///
29    /// We forward this visibility to the new type we generate and all of the variants.
30    vis: syn::Visibility,
31    /// The name of this type
32    ident: Ident,
33    /// The type used for the representation
34    repr: Type,
35    /// The variants of this enum, along with their discriminants and documentation
36    variants: Vec<Variant>,
37    /// The attributes to apply to the output
38    ///
39    /// This is all attributes on the input except the `#[repr(..)]` attribute, which is filtered
40    /// out to be handled separately.
41    attrs: Vec<Attribute>,
42}
43
44/// Code generation
45impl PodEnum {
46    /// Write all methods associated with this type
47    fn write_impl(&self) -> TokenStream {
48        let ident = &self.ident;
49        let repr = &self.repr;
50        let vis = &self.vis;
51        let attrs = &self.attrs;
52
53        let variants = self.write_variants();
54        let debug = self.write_debug();
55        let conversions = self.write_conversions();
56        let partial_eq = self.write_partial_eq();
57
58        quote!(
59            #( #attrs )*
60            #[derive(Copy, Clone)]
61            #[repr(transparent)]
62            #vis struct #ident {
63                inner: #repr,
64            }
65
66            impl ::pod_enum::PodEnum for #ident {
67                type Repr = #repr;
68            }
69
70            // SAFETY:
71            // The `PodEnum` trait (implemented above) checks that our internal type is
72            // `Pod`, and since we're #[`repr(transparent)]` with one field, we can also
73            // implement `Pod`.
74            unsafe impl ::pod_enum::bytemuck::Pod for #ident {}
75            // SAFETY:
76            // The `PodEnum` trait (implemented above) checks that our internal type is
77            // `Pod` (which implies `Zeroable`), and since we're #[`repr(transparent)]`
78            // with one field, we can also implement `Zeroable`.
79            unsafe impl ::pod_enum::bytemuck::Zeroable for #ident {}
80
81            #variants
82
83            #debug
84
85            #conversions
86
87            #partial_eq
88        )
89    }
90
91    /// Write out all variants of this enum as constants
92    fn write_variants(&self) -> TokenStream {
93        let ident = &self.ident;
94        let vis = &self.vis;
95        let variants = self.variants.iter().map(
96            |Variant {
97                 ident,
98                 discriminant,
99                 documentation,
100             }| {
101                quote!(
102                    #( #documentation )*
103                    #vis const #ident: Self = Self { inner: #discriminant };
104                )
105            },
106        );
107        quote! {
108            /// The variants of this enum
109            #[allow(non_upper_case_globals)]
110            impl #ident {
111                #( #variants )*
112            }
113        }
114    }
115
116    /// Write the debug impl
117    fn write_debug(&self) -> TokenStream {
118        let ident = &self.ident;
119        let variants = self.variants.iter().map(
120            |Variant {
121                 ident,
122                 discriminant,
123                 ..
124             }| {
125                let name = ident.to_string();
126                quote!(#discriminant => f.write_str(#name))
127            },
128        );
129        quote!(
130            /// Display which variant this is, or call it unknown and show the discriminant
131            impl ::core::fmt::Debug for #ident {
132                fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
133                    match self.inner {
134                        #( #variants, )*
135                        val => write!(f, "Unknown ({})", val),
136                    }
137                }
138            }
139        )
140    }
141
142    /// Write conversions to and from the underlying base type
143    fn write_conversions(&self) -> TokenStream {
144        let ident = &self.ident;
145        let repr = &self.repr;
146
147        quote!(
148            impl From<#repr> for #ident {
149                fn from(inner: #repr) -> Self {
150                    Self { inner }
151                }
152            }
153
154            impl From<#ident> for #repr {
155                fn from(pod: #ident) -> Self {
156                    pod.inner
157                }
158            }
159        )
160    }
161
162    /// Write [`PartialEq`] implementation
163    fn write_partial_eq(&self) -> TokenStream {
164        let ident = &self.ident;
165        let variants = self
166            .variants
167            .iter()
168            .map(|Variant { discriminant, .. }| quote!((#discriminant, #discriminant) => true));
169
170        quote!(
171            /// Listed variants compare as equaling itself and unequal to anything else
172            ///
173            /// Any variant not listed compares as unequal to everything, including itself.
174            /// Thus, we only implement `PartialEq` and not [`Eq`].
175            impl PartialEq for #ident {
176                fn eq(&self, other: &Self) -> bool {
177                    match (self.inner, other.inner) {
178                        #( #variants, )*
179                        _ => false,
180                    }
181                }
182            }
183        )
184    }
185}
186
187/// Attempt to parse the input from a Rust enum definition
188///
189/// In the event of an error, we try to return as many compile errors as we can.
190impl TryFrom<syn::ItemEnum> for PodEnum {
191    type Error = TokenStream;
192
193    fn try_from(value: syn::ItemEnum) -> Result<Self, Self::Error> {
194        let ident = value.ident;
195        let repr = value
196            .attrs
197            .iter()
198            .find_map(|attr| {
199                if &attr.path().get_ident()?.to_string() != "repr" {
200                    return None;
201                }
202                attr.parse_args::<Type>().ok()
203            })
204            .ok_or_else(|| {
205                syn::Error::new(ident.span(), "Missing `#[repr(..)]` attribute")
206                    .into_compile_error()
207            })?;
208        let attrs = value
209            .attrs
210            .into_iter()
211            .filter(|attr| {
212                attr.path()
213                    .get_ident()
214                    .map_or(true, |name| &name.to_string() != "repr")
215            })
216            .collect();
217        let variants = value
218            .variants
219            .into_iter()
220            .map(|variant| {
221                let (docs, other_attrs) =
222                    variant
223                        .attrs
224                        .into_iter()
225                        .partition::<Vec<Attribute>, _>(|attr| {
226                            attr.path()
227                                .get_ident()
228                                .map_or(false, |name| &name.to_string() == "doc")
229                        });
230                if !other_attrs.is_empty() {
231                    return Err(syn::Error::new(
232                        variant.ident.span(),
233                        "Unexpected non-documentation item on enum variant",
234                    )
235                    .into_compile_error());
236                }
237                if variant.fields != syn::Fields::Unit {
238                    return Err(syn::Error::new(
239                        variant.ident.span(),
240                        "Unexpected non-unit enum variant",
241                    )
242                    .into_compile_error());
243                }
244                let discriminant = variant
245                    .discriminant
246                    .ok_or_else(|| {
247                        syn::Error::new(
248                            variant.ident.span(),
249                            "Missing explicit discriminant on variant",
250                        )
251                        .into_compile_error()
252                    })?
253                    .1;
254                Ok(Variant {
255                    ident: variant.ident,
256                    discriminant,
257                    documentation: docs,
258                })
259            })
260            .collect::<Result<Vec<Variant>, TokenStream>>()?;
261        Ok(Self {
262            vis: value.vis,
263            attrs,
264            ident,
265            repr,
266            variants,
267        })
268    }
269}
270
271#[doc = ""]
272#[proc_macro_attribute]
273pub fn pod_enum(_args: TokenStream1, input: TokenStream1) -> TokenStream1 {
274    let ast = parse_macro_input!(input as syn::ItemEnum);
275
276    let result = match PodEnum::try_from(ast) {
277        Ok(result) => result,
278        Err(e) => return e.into(),
279    };
280
281    result.write_impl().into()
282}