discriminant-macro 0.1.0

Better `Discriminant` and enum operations
Documentation
use proc_macro::TokenStream as TokenStream1;
use proc_macro2::{Span, TokenStream};
use proc_macro_error::{abort, proc_macro_error};
use syn::punctuated::Punctuated;
use syn::*;
use template_quote::quote;

fn random() -> u64 {
    use std::hash::{BuildHasher, Hasher};
    std::collections::hash_map::RandomState::new()
        .build_hasher()
        .finish()
}

fn internal(input: ItemEnum) -> TokenStream {
    let krate: Path = input
        .attrs
        .iter()
        .filter_map(|a| match &a.meta {
            Meta::List(MetaList { path, tokens, .. }) => {
                if let (true, krate) = (path.is_ident("discriminant"), parse_quote!(#tokens)) {
                    Some(krate)
                } else {
                    None
                }
            }
            _ => None,
        })
        .next()
        .unwrap_or(parse_quote!(::discriminant));
    let discriminant_attrs = input
        .attrs
        .iter()
        .filter_map(|a| match &a.meta {
            Meta::NameValue(MetaNameValue { path, value, .. })
                if path.is_ident("discriminant_attr") =>
            {
                let s: LitStr = parse2(quote! {#value}).unwrap();
                Some(s.value())
            }
            _ => None,
        })
        .collect::<Vec<_>>();
    let discriminant_attrs = core::convert::identity::<ItemStruct>(
        parse_str(&format!("{} struct S {{}}", discriminant_attrs.join(""))).unwrap(),
    )
    .attrs;
    let specified_repr = discriminant_attrs
        .iter()
        .chain(&input.attrs)
        .filter_map(|a| match &a.meta {
            Meta::List(MetaList { path, tokens, .. }) if path.is_ident("repr") => {
                if let Ok(reprs) = parse::Parser::parse2(
                    Punctuated::<Meta, Token![,]>::parse_terminated,
                    tokens.clone(),
                ) {
                    reprs
                        .iter()
                        .filter_map(|r| Some(r.path().get_ident()?.to_string()))
                        .filter_map(|r| match r.as_str() {
                            "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32"
                            | "i64" | "isize" => Some(Ident::new(&r, Span::call_site())),
                            _ => None,
                        })
                        .next()
                } else {
                    None
                }
            }
            _ => None,
        })
        .next();
    let repr = specified_repr
        .clone()
        .unwrap_or(Ident::new("isize", Span::call_site()));
    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
    let discriminant_enum_ident = Ident::new(
        &format!("__Discriminant_{}_{}", &input.ident, random() % 1000),
        Span::call_site(),
    );
    let disc_indices = input
        .variants
        .iter()
        .scan(parse_quote!(0), |acc, variant| {
            if let Some((_, expr)) = &variant.discriminant {
                *acc = expr.clone();
            }
            let ret = acc.clone();
            *acc = parse_quote!(#ret + 1);
            Some(ret)
        })
        .collect::<Vec<Expr>>();
    quote! {
        #[repr(#repr)]
        #(#discriminant_attrs)*
        #[derive(
            ::core::marker::Copy,
            ::core::clone::Clone,
            ::core::fmt::Debug,
            ::core::hash::Hash,
            ::core::cmp::PartialEq,
            ::core::cmp::Eq,
        )]
        #{&input.vis} enum #discriminant_enum_ident {
            #(for variant in &input.variants) {
                #{
                    variant.attrs.iter().filter_map(|a| match &a.meta {
                        Meta::NameValue(MetaNameValue{path, value, ..}) if path.is_ident("discriminant_attr") => {
                            let s: LitStr = parse2(quote! {#value}).unwrap();
                            let discriminant_attrs = core::convert::identity::<ItemStruct>(
                                parse_str(&format!("{} struct S {{}}", s.value())).unwrap()
                            ).attrs;
                            Some(quote!{#(#discriminant_attrs)*})
                        },
                        _ => None,
                    }).next()
                }
                #{&variant.ident}
                #(if let Some((eq_token, expr)) = &variant.discriminant) {
                    #eq_token #expr
                },
            }
        }

        impl ::core::fmt::Display for #discriminant_enum_ident {
            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
                <Self as ::core::fmt::Debug>::fmt(self, f)
            }
        }

        impl ::core::cmp::PartialOrd for #discriminant_enum_ident {
            fn partial_cmp(&self, other: &Self) -> ::core::option::Option<::core::cmp::Ordering> {
                (*self as #repr).partial_cmp(&(*other as #repr))
            }
        }

        impl ::core::cmp::Ord for #discriminant_enum_ident {
            fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
                (*self as #repr).cmp(&(*other as #repr))
            }
        }

        #[automatically_derived]
        unsafe impl #impl_generics #krate::Enum for #{&input.ident}
        #ty_generics #where_clause
        {
            type Discriminant = #discriminant_enum_ident;

            fn discriminant(&self) -> Self::Discriminant {
                match self {
                    #(for Variant{ident, fields, ..} in &input.variants) {
                        Self::#ident
                        #(if let Fields::Unnamed(_) = fields) { (..) }
                        #(if let Fields::Named(_) = fields) { {..} }
                        => #discriminant_enum_ident::#ident,
                    }
                }
            }
        }

        impl ::core::convert::TryFrom<#repr> for #discriminant_enum_ident {
            type Error = ();
            fn try_from(value: #repr) -> ::core::result::Result<Self, Self::Error> {
                #(for (variant, disc) in input.variants.iter().zip(&disc_indices)) {
                    if value == #disc { ::core::result::Result::Ok(Self::#{&variant.ident}) } else
                }
                { ::core::result::Result::Err(()) }
            }
        }

        impl ::core::convert::Into<#repr> for #discriminant_enum_ident {
            fn into(self) -> #repr {
                self as #repr
            }
        }

        unsafe impl #krate::Discriminant for #discriminant_enum_ident {
            type Repr = #repr;
            fn all() -> impl ::core::iter::Iterator<Item = Self> {
                struct Iter(::core::option::Option<#discriminant_enum_ident>);
                impl ::core::iter::Iterator for Iter {
                    type Item = #discriminant_enum_ident;
                    fn next(&mut self) -> Option<Self::Item> {
                        match self.0 {
                            #(for (curr, next) in input.variants.iter().zip(
                                    input.variants.iter().skip(1).map(Some).chain(core::iter::once(None))
                            )) {
                                ::core::option::Option::Some(#discriminant_enum_ident::#{&curr.ident}) => {
                                    let ret = self.0;
                                    self.0 = #(if let Some(next) = next) {
                                        Some(#discriminant_enum_ident::#{&next.ident})
                                    } #(else) { None };
                                    ret
                                }
                            }
                            ::core::option::Option::None => ::core::option::Option::None,
                        }
                    }
                    fn size_hint(&self) -> (
                        ::core::primitive::usize,
                        ::core::option::Option<::core::primitive::usize>
                    ) {
                        let n = Self(self.0).count();
                        (n, ::core::option::Option::Some(n))
                    }
                    fn count(self) -> usize {
                        match self.0 {
                            #(for (n, variant) in input.variants.iter().enumerate()) {
                                ::core::option::Option::Some(#discriminant_enum_ident::#{&variant.ident}) => #{disc_indices.len() - n},
                            }
                            ::core::option::Option::None => 0,
                        }
                    }
                    fn last(self) -> Option<Self::Item> {
                        #(if let Some(last) = &input.variants.iter().last()) {
                            self.0.map(|_| #discriminant_enum_ident::#{&last.ident})
                        } #(else) {
                            ::core::option::Option::None
                        }
                    }
                }
                #(if let Some(item) = input.variants.iter().next()) {
                    Iter(::core::option::Option::Some(#discriminant_enum_ident::#{&item.ident}))
                } #(else) {
                    Iter(::core::option::Option::None)
                }
            }
        }
    }
}

#[proc_macro_derive(Enum, attributes(discriminant, discriminant_attr))]
#[proc_macro_error]
pub fn derive_enum(input: TokenStream1) -> TokenStream1 {
    internal(parse(input).unwrap_or_else(|_| {
        abort!(
            Span::call_site(),
            "#[derive(Enum)] is only applicative on enums."
        )
    }))
    .into()
}