#![cfg_attr(feature = "nightly", warn(clippy::pedantic))]
#![recursion_limit = "128"]
extern crate proc_macro;
extern crate proc_macro2;
#[macro_use] extern crate quote;
extern crate syn;
use std::iter;
use proc_macro2::*;
use syn::*;
#[proc_macro_derive(EnumRepr, attributes(EnumReprType))]
pub fn enum_repr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let derive = syn::parse::<DeriveInput>(input)
.expect("#[derive(EnumRepr)] could not parse input");
let repr_ty = get_repr_type(&derive);
let vars = get_vars(&derive);
validate(&vars);
let ty = derive.ident;
let vis = derive.vis;
let (names, discrs): (Vec<_>, Vec<_>) = vars.iter()
.map(|x| (
x.ident.clone(),
x.discriminant.as_ref().unwrap().1.clone()
)).unzip();
let vars_len = vars.len();
let (names2, discrs2) = (names.clone(), discrs.clone());
let repr_ty2 = repr_ty.clone();
let repr_ty3 = repr_ty.clone();
let ty_repeat = iter::repeat(ty.clone()).take(vars_len);
let repr_ty_repeat = iter::repeat(repr_ty.clone()).take(vars_len);
let repr_ty_repeat2 = iter::repeat(repr_ty.clone()).take(vars_len);
let (impl_generics, ty_generics, where_clause) =
derive.generics.split_for_impl();
let gen = quote! {
impl #impl_generics #ty #ty_generics #where_clause {
#vis fn repr(&self) -> #repr_ty2 {
use #ty::*;
match self {
#( #names2 => #discrs2 as #repr_ty_repeat ),*
}
}
#vis fn from_repr(x: #repr_ty3) -> Option<#ty> {
match x {
#( x if x == #discrs as #repr_ty_repeat2 => Some(#ty_repeat :: #names),)*
_ => None,
}
}
}
};
gen.into()
}
fn get_repr_type(derive: &DeriveInput) -> Ident {
let mut found_ident = None;
for attr in &derive.attrs {
if let Some(Meta::NameValue(
MetaNameValue {
ident,
lit: Lit::Str(repr_ty),
..
})) = attr.interpret_meta() {
if found_ident.is_some() && ident == "EnumReprType" {
panic!("specify #[EnumReprType = \"...\"] exactly once \
for an enum");
}
if ident == "EnumReprType" {
found_ident = Some(Ident::new(
&repr_ty.value(),
Span::call_site())
);
}
}
}
found_ident.unwrap_or_else(|| panic!("specify #[EnumReprType = \"...\"] \
exactly once for an enum"))
}
fn get_vars(
derive: &DeriveInput
) -> punctuated::Punctuated<Variant, token::Comma> {
match derive.data {
Data::Enum(ref en) => en.variants.clone(),
_ => panic!("#[derive(EnumRepr)] is only implemented for enums")
}
}
fn validate(vars: &punctuated::Punctuated<Variant, token::Comma>) {
for i in vars {
match i.fields {
Fields::Named(_) | Fields::Unnamed(_) =>
panic!("the enum's fields must \
be in the \"ident = number literal\" form"),
Fields::Unit => ()
}
}
}