open-enum-derive 0.5.2

An attribute for generating "open" C-like enums, those that accept any integer value, by using a newtype struct and associated constants
Documentation
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

extern crate proc_macro;

mod config;
mod discriminant;
mod repr;

use config::Config;

use discriminant::Discriminant;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use repr::Repr;
use std::collections::HashSet;
use syn::Attribute;
use syn::{
    parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Visibility,
};

/// Sets the span for every token tree in the token stream
fn set_token_stream_span(tokens: TokenStream, span: Span) -> TokenStream {
    tokens
        .into_iter()
        .map(|mut tt| {
            tt.set_span(span);
            tt
        })
        .collect()
}

/// Checks that there are no duplicate discriminant values. If all variants are literals, return an `Err` so we can have
/// more clear error messages. Otherwise, emit a static check that ensures no duplicates.
fn check_no_alias<'a>(
    enum_: &ItemEnum,
    variants: impl Iterator<Item = (&'a Ident, &'a Discriminant, Span)> + Clone,
) -> syn::Result<TokenStream> {
    // If they're all literals, we can give better error messages by checking at proc macro time.
    let mut values: HashSet<i128> = HashSet::new();
    for (_, variant, span) in variants {
        if let &Discriminant::Literal(value) = variant {
            if !values.insert(value) {
                return Err(Error::new(
                    span,
                    format!("discriminant value `{value}` assigned more than once"),
                ));
            }
        } else {
            let mut checking_enum = syn::ItemEnum {
                ident: format_ident!("_Check{}", enum_.ident),
                vis: Visibility::Inherited,
                ..enum_.clone()
            };
            checking_enum.attrs.retain(|attr| {
                matches!(
                    attr.path().to_token_stream().to_string().as_str(),
                    "repr" | "allow" | "warn" | "deny" | "forbid"
                )
            });
            return Ok(quote!(
                #[allow(dead_code)]
                #checking_enum
            ));
        }
    }
    Ok(TokenStream::default())
}

fn emit_debug_impl<'a>(
    ident: &Ident,
    variants: impl Iterator<Item = &'a Ident> + Clone,
    attrs: impl Iterator<Item = &'a Vec<Attribute>> + Clone,
) -> TokenStream {
    let attrs = attrs.map(|attrs| {
        // Only allow "#[cfg(...)]" attributes
        let iter = attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
        quote!(#(#iter)*)
    });
    quote!(impl ::core::fmt::Debug for #ident {
        fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
            #![allow(unreachable_patterns)]
            let s = match *self {
                #( #attrs Self::#variants => stringify!(#variants), )*
                _ => {
                    return fmt.debug_tuple(stringify!(#ident)).field(&self.0).finish();
                }
            };
            fmt.pad(s)
        }
    })
}

fn path_matches_prelude_derive(
    got_path: &syn::Path,
    expected_path_after_std: &[&'static str],
) -> bool {
    let &[a, b] = expected_path_after_std else {
        unimplemented!("checking against stdlib paths with != 2 parts");
    };
    let segments: Vec<&syn::PathSegment> = got_path.segments.iter().collect();
    if segments
        .iter()
        .any(|segment| !matches!(segment.arguments, syn::PathArguments::None))
    {
        return false;
    }
    match &segments[..] {
        // `core::fmt::Debug` or `some_crate::module::Name`
        [maybe_core_or_std, maybe_a, maybe_b] => {
            (maybe_core_or_std.ident == "core" || maybe_core_or_std.ident == "std")
                && maybe_a.ident == a
                && maybe_b.ident == b
        }
        // `fmt::Debug` or `module::Name`
        [maybe_a, maybe_b] => {
            maybe_a.ident == a && maybe_b.ident == b && got_path.leading_colon.is_none()
        }
        // `Debug` or `Name``
        [maybe_b] => maybe_b.ident == b && got_path.leading_colon.is_none(),
        _ => false,
    }
}

fn open_enum_impl(
    enum_: ItemEnum,
    Config {
        allow_alias,
        repr_visibility,
    }: Config,
) -> Result<TokenStream, Error> {
    // Does the enum define a `#[repr()]`?
    let mut struct_attrs: Vec<TokenStream> = Vec::with_capacity(enum_.attrs.len() + 5);
    struct_attrs.push(quote!(#[allow(clippy::exhaustive_structs)]));

    if !enum_.generics.params.is_empty() {
        return Err(Error::new(enum_.generics.span(), "enum cannot be generic"));
    }
    let mut variants = Vec::with_capacity(enum_.variants.len());
    let mut last_field = Discriminant::Literal(-1);
    for variant in &enum_.variants {
        if !matches!(variant.fields, syn::Fields::Unit) {
            return Err(Error::new(variant.span(), "enum cannot contain fields"));
        }

        let (value, value_span) = if let Some((_, discriminant)) = &variant.discriminant {
            let span = discriminant.span();
            (Discriminant::new(discriminant.clone())?, span)
        } else {
            last_field = last_field
                .next_value()
                .ok_or_else(|| Error::new(variant.span(), "enum discriminant overflowed"))?;
            (last_field.clone(), variant.ident.span())
        };
        last_field = value.clone();
        variants.push((&variant.ident, value, value_span, &variant.attrs))
    }

    let mut impl_attrs: Vec<TokenStream> = vec![quote!(#[allow(non_upper_case_globals)])];
    let mut explicit_repr: Option<Repr> = None;

    // To make `match` seamless, derive(PartialEq, Eq) if they aren't already.
    let mut extra_derives = vec![quote!(::core::cmp::PartialEq), quote!(::core::cmp::Eq)];

    let mut make_custom_debug_impl = false;
    for attr in &enum_.attrs {
        let mut include_in_struct = true;
        // Turns out `is_ident` does a `to_string` every time
        match attr.path().to_token_stream().to_string().as_str() {
            "derive" => {
                if let Ok(derive_paths) =
                    attr.parse_args_with(Punctuated::<syn::Path, syn::Token![,]>::parse_terminated)
                {
                    for derive in &derive_paths {
                        // These derives are treated specially
                        const PARTIAL_EQ_PATH: &[&str] = &["cmp", "PartialEq"];
                        const EQ_PATH: &[&str] = &["cmp", "Eq"];
                        const DEBUG_PATH: &[&str] = &["fmt", "Debug"];

                        if path_matches_prelude_derive(derive, PARTIAL_EQ_PATH)
                            || path_matches_prelude_derive(derive, EQ_PATH)
                        {
                            // This derive is always included, exclude it.
                            continue;
                        }
                        if path_matches_prelude_derive(derive, DEBUG_PATH) && !allow_alias {
                            make_custom_debug_impl = true;
                            // Don't include this derive since we're generating a special one.
                            continue;
                        }
                        extra_derives.push(derive.to_token_stream());
                    }
                    include_in_struct = false;
                }
            }
            // Copy linting attribute to the impl.
            "allow" | "warn" | "deny" | "forbid" => impl_attrs.push(attr.to_token_stream()),
            "repr" => {
                assert!(explicit_repr.is_none(), "duplicate explicit repr");
                explicit_repr = Some(attr.parse_args()?);
                include_in_struct = false;
            }
            "non_exhaustive" => {
                // technically it's exhaustive if the enum covers the full integer range
                return Err(Error::new(attr.path().span(), "`non_exhaustive` cannot be applied to an open enum; it is already non-exhaustive"));
            }
            _ => {}
        }
        if include_in_struct {
            struct_attrs.push(attr.to_token_stream());
        }
    }

    // The proper repr to type-check against
    let typecheck_repr: Repr = explicit_repr.unwrap_or(Repr::Isize);

    // The actual representation of the value.
    let inner_repr = match explicit_repr {
        Some(explicit_repr) => {
            // If there is an explicit repr, emit #[repr(transparent)].
            struct_attrs.push(quote!(#[repr(transparent)]));
            explicit_repr
        }
        None => {
            // If there isn't an explicit repr, determine an appropriate sized integer that will fit.
            // Interpret all discriminant expressions as isize.
            repr::autodetect_inner_repr(variants.iter().map(|v| &v.1))
        }
    };

    if !extra_derives.is_empty() {
        struct_attrs.push(quote!(#[derive(#(#extra_derives),*)]));
    }

    let alias_check = if allow_alias {
        TokenStream::default()
    } else {
        check_no_alias(&enum_, variants.iter().map(|(i, v, s, _)| (*i, v, *s)))?
    };

    let syn::ItemEnum { ident, vis, .. } = enum_;

    let debug_impl = if make_custom_debug_impl {
        emit_debug_impl(
            &ident,
            variants.iter().map(|(i, _, _, _)| *i),
            variants.iter().map(|(_, _, _, a)| *a),
        )
    } else {
        TokenStream::default()
    };

    let fields = variants
        .into_iter()
        .map(|(name, value, value_span, attrs)| {
            let mut value = value.into_token_stream();
            value = set_token_stream_span(value, value_span);
            let inner = if typecheck_repr == inner_repr {
                value
            } else {
                quote!(::core::convert::identity::<#typecheck_repr>(#value) as #inner_repr)
            };
            quote!(
                #(#attrs)*
                pub const #name: #ident = #ident(#inner);
            )
        });

    Ok(quote! {
        #(#struct_attrs)*
        #vis struct #ident(#repr_visibility #inner_repr);

        #(#impl_attrs)*
        impl #ident {
            #(
                #fields
            )*
        }
        #debug_impl
        #alias_check
    })
}

#[proc_macro_attribute]
pub fn open_enum(
    attrs: proc_macro::TokenStream,
    input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let enum_ = parse_macro_input!(input as syn::ItemEnum);
    let config = parse_macro_input!(attrs as Config);
    open_enum_impl(enum_, config)
        .unwrap_or_else(Error::into_compile_error)
        .into()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_path_matches_stdlib_derive() {
        const DEBUG_PATH: &[&str] = &["fmt", "Debug"];

        for success_case in [
            "::core::fmt::Debug",
            "::std::fmt::Debug",
            "core::fmt::Debug",
            "std::fmt::Debug",
            "fmt::Debug",
            "Debug",
        ] {
            assert!(
                path_matches_prelude_derive(&syn::parse_str(success_case).unwrap(), DEBUG_PATH),
                "{success_case}"
            );
        }

        for fail_case in [
            "::fmt::Debug",
            "::Debug",
            "zerocopy::AsBytes",
            "::zerocopy::AsBytes",
            "PartialEq",
            "core::cmp::Eq",
        ] {
            assert!(
                !path_matches_prelude_derive(&syn::parse_str(fail_case).unwrap(), DEBUG_PATH),
                "{fail_case}"
            );
        }
    }
}