extern crate proc_macro;
mod config;
mod discriminant;
mod repr;
use config::Config;
use discriminant::Discriminant;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use repr::Repr;
use std::collections::HashSet;
use syn::Attribute;
use syn::{
parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Visibility,
};
fn set_token_stream_span(tokens: TokenStream, span: Span) -> TokenStream {
tokens
.into_iter()
.map(|mut tt| {
tt.set_span(span);
tt
})
.collect()
}
fn check_no_alias<'a>(
enum_: &ItemEnum,
variants: impl Iterator<Item = (&'a Ident, &'a Discriminant, Span)> + Clone,
) -> syn::Result<TokenStream> {
let mut values: HashSet<i128> = HashSet::new();
for (_, variant, span) in variants {
if let &Discriminant::Literal(value) = variant {
if !values.insert(value) {
return Err(Error::new(
span,
format!("discriminant value `{value}` assigned more than once"),
));
}
} else {
let mut checking_enum = syn::ItemEnum {
ident: format_ident!("_Check{}", enum_.ident),
vis: Visibility::Inherited,
..enum_.clone()
};
checking_enum.attrs.retain(|attr| {
matches!(
attr.path().to_token_stream().to_string().as_str(),
"repr" | "allow" | "warn" | "deny" | "forbid"
)
});
return Ok(quote!(
#[allow(dead_code)]
#checking_enum
));
}
}
Ok(TokenStream::default())
}
fn emit_debug_impl<'a>(
ident: &Ident,
variants: impl Iterator<Item = &'a Ident> + Clone,
attrs: impl Iterator<Item = &'a Vec<Attribute>> + Clone,
) -> TokenStream {
let attrs = attrs.map(|attrs| {
let iter = attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
quote!(#(#iter)*)
});
quote!(impl ::core::fmt::Debug for #ident {
fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
#![allow(unreachable_patterns)]
let s = match *self {
#( #attrs Self::#variants => stringify!(#variants), )*
_ => {
return fmt.debug_tuple(stringify!(#ident)).field(&self.0).finish();
}
};
fmt.pad(s)
}
})
}
fn path_matches_prelude_derive(
got_path: &syn::Path,
expected_path_after_std: &[&'static str],
) -> bool {
let &[a, b] = expected_path_after_std else {
unimplemented!("checking against stdlib paths with != 2 parts");
};
let segments: Vec<&syn::PathSegment> = got_path.segments.iter().collect();
if segments
.iter()
.any(|segment| !matches!(segment.arguments, syn::PathArguments::None))
{
return false;
}
match &segments[..] {
[maybe_core_or_std, maybe_a, maybe_b] => {
(maybe_core_or_std.ident == "core" || maybe_core_or_std.ident == "std")
&& maybe_a.ident == a
&& maybe_b.ident == b
}
[maybe_a, maybe_b] => {
maybe_a.ident == a && maybe_b.ident == b && got_path.leading_colon.is_none()
}
[maybe_b] => maybe_b.ident == b && got_path.leading_colon.is_none(),
_ => false,
}
}
fn open_enum_impl(
enum_: ItemEnum,
Config {
allow_alias,
repr_visibility,
}: Config,
) -> Result<TokenStream, Error> {
let mut struct_attrs: Vec<TokenStream> = Vec::with_capacity(enum_.attrs.len() + 5);
struct_attrs.push(quote!(#[allow(clippy::exhaustive_structs)]));
if !enum_.generics.params.is_empty() {
return Err(Error::new(enum_.generics.span(), "enum cannot be generic"));
}
let mut variants = Vec::with_capacity(enum_.variants.len());
let mut last_field = Discriminant::Literal(-1);
for variant in &enum_.variants {
if !matches!(variant.fields, syn::Fields::Unit) {
return Err(Error::new(variant.span(), "enum cannot contain fields"));
}
let (value, value_span) = if let Some((_, discriminant)) = &variant.discriminant {
let span = discriminant.span();
(Discriminant::new(discriminant.clone())?, span)
} else {
last_field = last_field
.next_value()
.ok_or_else(|| Error::new(variant.span(), "enum discriminant overflowed"))?;
(last_field.clone(), variant.ident.span())
};
last_field = value.clone();
variants.push((&variant.ident, value, value_span, &variant.attrs))
}
let mut impl_attrs: Vec<TokenStream> = vec![quote!(#[allow(non_upper_case_globals)])];
let mut explicit_repr: Option<Repr> = None;
let mut extra_derives = vec![quote!(::core::cmp::PartialEq), quote!(::core::cmp::Eq)];
let mut make_custom_debug_impl = false;
for attr in &enum_.attrs {
let mut include_in_struct = true;
match attr.path().to_token_stream().to_string().as_str() {
"derive" => {
if let Ok(derive_paths) =
attr.parse_args_with(Punctuated::<syn::Path, syn::Token![,]>::parse_terminated)
{
for derive in &derive_paths {
const PARTIAL_EQ_PATH: &[&str] = &["cmp", "PartialEq"];
const EQ_PATH: &[&str] = &["cmp", "Eq"];
const DEBUG_PATH: &[&str] = &["fmt", "Debug"];
if path_matches_prelude_derive(derive, PARTIAL_EQ_PATH)
|| path_matches_prelude_derive(derive, EQ_PATH)
{
continue;
}
if path_matches_prelude_derive(derive, DEBUG_PATH) && !allow_alias {
make_custom_debug_impl = true;
continue;
}
extra_derives.push(derive.to_token_stream());
}
include_in_struct = false;
}
}
"allow" | "warn" | "deny" | "forbid" => impl_attrs.push(attr.to_token_stream()),
"repr" => {
assert!(explicit_repr.is_none(), "duplicate explicit repr");
explicit_repr = Some(attr.parse_args()?);
include_in_struct = false;
}
"non_exhaustive" => {
return Err(Error::new(attr.path().span(), "`non_exhaustive` cannot be applied to an open enum; it is already non-exhaustive"));
}
_ => {}
}
if include_in_struct {
struct_attrs.push(attr.to_token_stream());
}
}
let typecheck_repr: Repr = explicit_repr.unwrap_or(Repr::Isize);
let inner_repr = match explicit_repr {
Some(explicit_repr) => {
struct_attrs.push(quote!(#[repr(transparent)]));
explicit_repr
}
None => {
repr::autodetect_inner_repr(variants.iter().map(|v| &v.1))
}
};
if !extra_derives.is_empty() {
struct_attrs.push(quote!(#[derive(#(#extra_derives),*)]));
}
let alias_check = if allow_alias {
TokenStream::default()
} else {
check_no_alias(&enum_, variants.iter().map(|(i, v, s, _)| (*i, v, *s)))?
};
let syn::ItemEnum { ident, vis, .. } = enum_;
let debug_impl = if make_custom_debug_impl {
emit_debug_impl(
&ident,
variants.iter().map(|(i, _, _, _)| *i),
variants.iter().map(|(_, _, _, a)| *a),
)
} else {
TokenStream::default()
};
let fields = variants
.into_iter()
.map(|(name, value, value_span, attrs)| {
let mut value = value.into_token_stream();
value = set_token_stream_span(value, value_span);
let inner = if typecheck_repr == inner_repr {
value
} else {
quote!(::core::convert::identity::<#typecheck_repr>(#value) as #inner_repr)
};
quote!(
#(#attrs)*
pub const #name: #ident = #ident(#inner);
)
});
Ok(quote! {
#(#struct_attrs)*
#vis struct #ident(#repr_visibility #inner_repr);
#(#impl_attrs)*
impl #ident {
#(
#fields
)*
}
#debug_impl
#alias_check
})
}
#[proc_macro_attribute]
pub fn open_enum(
attrs: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let enum_ = parse_macro_input!(input as syn::ItemEnum);
let config = parse_macro_input!(attrs as Config);
open_enum_impl(enum_, config)
.unwrap_or_else(Error::into_compile_error)
.into()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_matches_stdlib_derive() {
const DEBUG_PATH: &[&str] = &["fmt", "Debug"];
for success_case in [
"::core::fmt::Debug",
"::std::fmt::Debug",
"core::fmt::Debug",
"std::fmt::Debug",
"fmt::Debug",
"Debug",
] {
assert!(
path_matches_prelude_derive(&syn::parse_str(success_case).unwrap(), DEBUG_PATH),
"{success_case}"
);
}
for fail_case in [
"::fmt::Debug",
"::Debug",
"zerocopy::AsBytes",
"::zerocopy::AsBytes",
"PartialEq",
"core::cmp::Eq",
] {
assert!(
!path_matches_prelude_derive(&syn::parse_str(fail_case).unwrap(), DEBUG_PATH),
"{fail_case}"
);
}
}
}