extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::parse::Error;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{parse_macro_input, Expr, Ident, ItemEnum, ItemStruct, Variant};
#[proc_macro_attribute]
pub fn const_table(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut errors = proc_macro2::TokenStream::new();
let input_item = parse_macro_input!(item as syn::Item);
let input_item = if let syn::Item::Enum(e) = input_item {
e
} else {
let span = input_item.span();
let message = "the const_table attribute may only be applied to enums";
return Error::new(span, message).to_compile_error().into();
};
if !input_item.generics.params.is_empty() {
let span = input_item.generics.params.span();
let message = "a const_table enum cannot be generic";
errors.extend(Error::new(span, message).to_compile_error());
}
let (enum_attrs, repr_type) = {
let mut attrs = Vec::with_capacity(input_item.attrs.len());
let mut repr = None;
for attr in input_item.attrs {
if attr.path.is_ident("derive") {
let mut conflict_found = false;
if let Ok(syn::Meta::List(derive_attr)) = attr.parse_meta() {
for arg in &derive_attr.nested {
if let syn::NestedMeta::Meta(syn::Meta::Path(p)) = arg {
if p.is_ident("Copy") || p.is_ident("Clone") ||
p.is_ident("Debug") || p.is_ident("Hash") ||
p.is_ident("PartialEq") || p.is_ident("Eq")
{
let span = p.span();
let message = format!("the {} trait is already implemented by the const_table macro", p.get_ident().unwrap());
errors.extend(Error::new(span, message).to_compile_error());
conflict_found = true;
}
}
}
}
if conflict_found {
continue;
}
}
if attr.path.is_ident("repr") {
let ident: Ident = attr.parse_args().unwrap();
if ident != "u8" && ident != "u16" && ident != "u32" && ident != "u64" {
let span = attr.tokens.span();
let message = "unsupported repr hint for a const_table enum: expected one of u8, u16, u32 or u64 (default is u32)";
errors.extend(Error::new(span, message).to_compile_error());
continue;
}
repr = Some(ident);
} else {
attrs.push(attr);
}
}
(attrs, repr.unwrap_or_else(|| Ident::new("u32", Span::call_site())))
};
let mut input_variants = input_item.variants.iter();
let first_variant = input_variants.next();
let (variants, value_exprs): (Punctuated<Variant, syn::token::Comma>, Vec<Expr>) = input_variants.map(|variant| {
if !variant.fields.is_empty() {
let span = variant.fields.span();
let message = "in a const_table enum, only the first variant should have fields";
errors.extend(Error::new(span, message).to_compile_error());
}
if let Some((_, expr)) = &variant.discriminant {
let v = Variant {
discriminant: None,
fields: syn::Fields::Unit,
..(*variant).clone()
};
(v, expr.clone())
} else {
let span = variant.span();
let message = "in a const_table enum, all but the first variant should have a discriminant expression";
errors.extend(Error::new(span, message).to_compile_error());
let empty_expr = Expr::Tuple(syn::ExprTuple {
attrs: Vec::new(), paren_token: syn::token::Paren { span: variant.ident.span() }, elems: Punctuated::new()
});
(variant.clone(), empty_expr)
}
}).unzip();
if variants.is_empty() {
let span = input_item.brace_token.span;
let message = "a const_table enum needs at least one variant with a discriminant expression";
errors.extend(Error::new(span, message).to_compile_error());
return errors.into();
}
let struct_decl = if let Some(v) = first_variant {
use syn::Fields::Named;
if let Named(fields) = &v.fields {
ItemStruct {
attrs: v.attrs.clone(),
vis: input_item.vis.clone(),
struct_token: syn::token::Struct {
span: Span::call_site(),
},
ident: v.ident.clone(),
generics: Default::default(),
fields: Named((*fields).clone()),
semi_token: None,
}
} else {
let span = v.span();
let message = "the first variant of a const_table enum should have named fields to specify the table layout";
errors.extend(Error::new(span, message).to_compile_error());
return errors.into();
}
} else {
let span = input_item.brace_token.span;
let message = "a const_table enum needs at least one variant with named fields to specify the table layout";
errors.extend(Error::new(span, message).to_compile_error());
return errors.into();
};
let struct_name = &struct_decl.ident;
let table_size = variants.len();
let enum_decl = ItemEnum {
attrs: enum_attrs,
variants,
..input_item
};
let enum_name = &enum_decl.ident;
let expanded = quote! {
#errors
#[repr(#repr_type)]
#[derive(core::marker::Copy, core::clone::Clone, core::fmt::Debug, core::hash::Hash, core::cmp::PartialEq, core::cmp::Eq)]
#enum_decl
#struct_decl
impl #enum_name {
pub const COUNT: usize = #table_size;
pub fn iter() -> impl core::iter::DoubleEndedIterator<Item = Self> {
(0..Self::COUNT).map(|i| unsafe { core::mem::transmute(i as #repr_type) })
}
}
impl core::ops::Deref for #enum_name {
type Target = #struct_name;
fn deref(&self) -> &Self::Target {
use #enum_name::*;
const TABLE: [#struct_name; #table_size] = [ #(#value_exprs),* ];
&TABLE[*self as usize]
}
}
impl core::convert::TryFrom<#repr_type> for #enum_name {
type Error = #repr_type;
fn try_from(i: #repr_type) -> core::result::Result<Self, #repr_type> {
if (i as usize) < Self::COUNT {
core::result::Result::Ok(unsafe { core::mem::transmute(i) })
} else {
core::result::Result::Err(i)
}
}
}
};
expanded.into()
}