use proc_macro::TokenStream as TokenStream1;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_macro_input, Attribute, Expr, Type};
struct Variant {
ident: Ident,
discriminant: Expr,
documentation: Vec<Attribute>,
}
struct PodEnum {
vis: syn::Visibility,
ident: Ident,
repr: Type,
variants: Vec<Variant>,
attrs: Vec<Attribute>,
}
impl PodEnum {
fn write_impl(&self) -> TokenStream {
let ident = &self.ident;
let repr = &self.repr;
let vis = &self.vis;
let attrs = &self.attrs;
let variants = self.write_variants();
let debug = self.write_debug();
let conversions = self.write_conversions();
let partial_eq = self.write_partial_eq();
quote!(
#( #attrs )*
#[derive(Copy, Clone)]
#[repr(transparent)]
#vis struct #ident {
inner: #repr,
}
impl ::pod_enum::PodEnum for #ident {
type Repr = #repr;
}
unsafe impl ::pod_enum::bytemuck::Pod for #ident {}
unsafe impl ::pod_enum::bytemuck::Zeroable for #ident {}
#variants
#debug
#conversions
#partial_eq
)
}
fn write_variants(&self) -> TokenStream {
let ident = &self.ident;
let vis = &self.vis;
let variants = self.variants.iter().map(
|Variant {
ident,
discriminant,
documentation,
}| {
quote!(
#( #documentation )*
#vis const #ident: Self = Self { inner: #discriminant };
)
},
);
quote! {
#[allow(non_upper_case_globals)]
impl #ident {
#( #variants )*
}
}
}
fn write_debug(&self) -> TokenStream {
let ident = &self.ident;
let variants = self.variants.iter().map(
|Variant {
ident,
discriminant,
..
}| {
let name = ident.to_string();
quote!(#discriminant => f.write_str(#name))
},
);
quote!(
impl ::core::fmt::Debug for #ident {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
match self.inner {
#( #variants, )*
val => write!(f, "Unknown ({})", val),
}
}
}
)
}
fn write_conversions(&self) -> TokenStream {
let ident = &self.ident;
let repr = &self.repr;
quote!(
impl From<#repr> for #ident {
fn from(inner: #repr) -> Self {
Self { inner }
}
}
impl From<#ident> for #repr {
fn from(pod: #ident) -> Self {
pod.inner
}
}
)
}
fn write_partial_eq(&self) -> TokenStream {
let ident = &self.ident;
let variants = self
.variants
.iter()
.map(|Variant { discriminant, .. }| quote!((#discriminant, #discriminant) => true));
quote!(
impl PartialEq for #ident {
fn eq(&self, other: &Self) -> bool {
match (self.inner, other.inner) {
#( #variants, )*
_ => false,
}
}
}
)
}
}
impl TryFrom<syn::ItemEnum> for PodEnum {
type Error = TokenStream;
fn try_from(value: syn::ItemEnum) -> Result<Self, Self::Error> {
let ident = value.ident;
let repr = value
.attrs
.iter()
.find_map(|attr| {
if &attr.path().get_ident()?.to_string() != "repr" {
return None;
}
attr.parse_args::<Type>().ok()
})
.ok_or_else(|| {
syn::Error::new(ident.span(), "Missing `#[repr(..)]` attribute")
.into_compile_error()
})?;
let attrs = value
.attrs
.into_iter()
.filter(|attr| {
attr.path()
.get_ident()
.map_or(true, |name| &name.to_string() != "repr")
})
.collect();
let variants = value
.variants
.into_iter()
.map(|variant| {
let (docs, other_attrs) =
variant
.attrs
.into_iter()
.partition::<Vec<Attribute>, _>(|attr| {
attr.path()
.get_ident()
.map_or(false, |name| &name.to_string() == "doc")
});
if !other_attrs.is_empty() {
return Err(syn::Error::new(
variant.ident.span(),
"Unexpected non-documentation item on enum variant",
)
.into_compile_error());
}
if variant.fields != syn::Fields::Unit {
return Err(syn::Error::new(
variant.ident.span(),
"Unexpected non-unit enum variant",
)
.into_compile_error());
}
let discriminant = variant
.discriminant
.ok_or_else(|| {
syn::Error::new(
variant.ident.span(),
"Missing explicit discriminant on variant",
)
.into_compile_error()
})?
.1;
Ok(Variant {
ident: variant.ident,
discriminant,
documentation: docs,
})
})
.collect::<Result<Vec<Variant>, TokenStream>>()?;
Ok(Self {
vis: value.vis,
attrs,
ident,
repr,
variants,
})
}
}
#[doc = ""]
#[proc_macro_attribute]
pub fn pod_enum(_args: TokenStream1, input: TokenStream1) -> TokenStream1 {
let ast = parse_macro_input!(input as syn::ItemEnum);
let result = match PodEnum::try_from(ast) {
Ok(result) => result,
Err(e) => return e.into(),
};
result.write_impl().into()
}