#![deny(missing_docs)]
extern crate proc_macro;
use proc_macro::TokenStream as NativeTokenStream;
use proc_macro2::{Delimiter, Span, TokenStream, TokenTree};
use quote::quote;
struct EnumVariant {
name: syn::Ident,
value: syn::Expr,
}
#[proc_macro_derive(ConstEnum)]
pub fn const_enum(input: NativeTokenStream) -> NativeTokenStream {
let input = syn::parse_macro_input!(input as syn::DeriveInput);
let enum_name = input.ident;
let enum_variants = get_enum_variants(&input.data);
let enum_type = get_enum_repr_type(&input.attrs);
let match_impl = build_from_match(&enum_name, &enum_variants);
let expanded = quote! {
impl const core::convert::From<#enum_name> for #enum_type {
fn from(value: #enum_name) -> Self {
value as Self
}
}
impl const core::convert::From<#enum_type> for #enum_name {
fn from(value: #enum_type) -> Self {
#match_impl
}
}
};
NativeTokenStream::from(expanded)
}
fn get_enum_repr_type(attrs: &Vec<syn::Attribute>) -> syn::Ident {
let repr = syn::Ident::new("repr", Span::call_site());
let repr_attr = attrs.iter().find(|attr| match attr.style {
syn::AttrStyle::Outer => attr.path.is_ident(&repr),
_ => false,
}).unwrap_or_else(|| panic!("repr attribute not found on enum"));
let repr_tokens = repr_attr.tokens.clone();
let mut repr_tokens_iter = repr_tokens.into_iter();
let first_token = repr_tokens_iter.next();
if first_token.is_none() || repr_tokens_iter.next().is_some() {
panic!("malformed repr attribute, expected repr(TYPE)");
}
match first_token.unwrap().clone() {
TokenTree::Group(repr_items) => {
if repr_items.delimiter() != Delimiter::Parenthesis {
panic!("malformed repr attribute, expected repr(TYPE)");
}
let mut repr_types_iter = repr_items.stream().into_iter();
let first_repr_item = repr_types_iter.next().unwrap();
if let Some(_) = repr_types_iter.next() {
panic!("malformed repr attribute, expected single type");
}
match first_repr_item.clone() {
TokenTree::Ident(repr_type) => repr_type,
_ => panic!("malformed repr attribute, unexpected type"),
}
},
_ => panic!("malformed repr attribute, unexpected token"),
}
}
fn get_enum_variants(data: &syn::Data) -> Vec<EnumVariant> {
match *data {
syn::Data::Enum(ref data) => {
data.variants.iter().map(|variant| {
let pair = variant.discriminant.as_ref().unwrap();
let name = variant.ident.clone();
let value = pair.1.clone();
EnumVariant { name, value }
}).collect()
}
syn::Data::Struct(_) => panic!("unexpected struct, const-enum only supports enums"),
syn::Data::Union(_) => panic!("unexpected union, const-enum only supports enums"),
}
}
fn build_from_match(enum_name: &syn::Ident, variants: &Vec<EnumVariant>) -> TokenStream {
let mut match_arms = TokenStream::new();
variants.iter().for_each(|variant| {
let (name, value) = (&variant.name, &variant.value);
match_arms.extend(quote! {
#value => #enum_name::#name,
});
});
match_arms.extend(quote! {
_ => panic!("invalid value provided"),
});
return quote! {
match value {
#match_arms
}
};
}