#![recursion_limit = "128"]
extern crate proc_macro;
#[macro_use]
extern crate quote;
#[macro_use]
extern crate syn;
use std::iter;
use proc_macro::TokenStream;
use quote::ToTokens;
use quote::Tokens;
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
Data, DataEnum, DataStruct, DeriveInput, Field, Fields, GenericParam,
Generics, Ident, Variant,
};
#[derive(Debug)]
struct MatchArm {
name: Ident,
offset: usize,
data_fields: usize,
field_names: Vec<Tokens>,
type_names: Vec<Tokens>,
range_high: Tokens,
}
impl MatchArm {
fn new(
name: Ident,
offset: usize,
data_fields: usize,
field_names: Vec<Tokens>,
type_names: Vec<Tokens>,
range_high: Tokens,
) -> Self {
Self {
name,
offset,
data_fields,
field_names,
type_names,
range_high,
}
}
fn unit(name: Ident, offset: usize) -> Self {
Self::new(name, offset, 0, vec![], vec![], quote!(0usize))
}
fn from_fields(
name: Ident,
fu_named: &Punctuated<Field, Comma>,
offset: usize,
range_high: &mut Tokens,
) -> Self {
let type_names = fu_named
.iter()
.map(|u| u.ty.clone().into_tokens())
.collect();
let named = fu_named[0].ident.is_some();
let field_names = if named {
fu_named
.iter()
.map(|u| u.ident.unwrap().into_tokens())
.collect()
} else {
let mut field_names = vec![];
let elements = fu_named.len();
for x in 0..elements {
let i = syn::Index::from(x);
field_names.push(quote!(#i));
}
field_names
};
let m = Self::new(
name,
offset,
field_names.len(),
field_names,
type_names,
range_high.clone(),
);
*range_high = m.print_next_range_high();
m
}
fn print_to_discr(&self, parent: &Ident) -> Tokens {
let offset = self.offset;
let name = &self.name;
if self.data_fields == 0 {
quote!(
#parent::#name => #offset,
)
} else {
let field_names = &self.field_names;
let type_names = &self.type_names;
let range_high = &self.range_high;
let offset = "e!(#offset + #range_high);
let xfield_names: Vec<_> = field_names
.iter()
.map(|f| Ident::from(format!("x{}", f)))
.collect();
let to_discr_0 = to_discr_body(
field_names,
type_names,
offset,
Some(&xfield_names),
);
quote!(
#parent::#name {
#(
#field_names: #xfield_names,
)*
} => #to_discr_0,
)
}
}
fn print_from_discr(&self, parent: &Ident) -> Tokens {
let offset = self.offset;
let name = &self.name;
if self.data_fields == 0 {
quote!(
#offset => #parent::#name,
)
} else {
let range_high = &self.range_high;
let field_names = &self.field_names;
let type_names = &self.type_names;
let range_high_next = self.print_next_range_high();
let value = "e!(x - #offset - (#range_high));
let from_discr_0 = from_discr_body(field_names, type_names, value);
quote!(
x if x >= (#offset + #range_high)
&& x < (#offset + #range_high_next) =>
#parent::#name {
#from_discr_0
},
)
}
}
fn print_num_variants(&self) -> Tokens {
if self.data_fields == 0 {
quote!(1)
} else {
total_num_variants_product(&self.type_names)
}
}
fn print_next_range_high(&self) -> Tokens {
if self.data_fields == 0 {
panic!("You don't need to call this");
} else {
let range_high = &self.range_high;
let n = self.print_num_variants();
quote! {
#range_high + #n
}
}
}
}
fn all_variants_of(variants: &Punctuated<Variant, Comma>) -> Vec<MatchArm> {
let mut x = vec![];
for v in variants {
match v.fields {
Fields::Unit => {
let name = v.ident;
let idx = x.len() as usize;
x.push(MatchArm::unit(name, idx));
}
_ => {}
}
}
let offset = x.len() as usize;
let mut range_high = quote!(0usize);
for v in variants {
match v.fields {
Fields::Unnamed(ref fu) => {
assert!(fu.unnamed.len() > 0, "This is a unit field, wtf");
let m = MatchArm::from_fields(
v.ident,
&fu.unnamed,
offset,
&mut range_high,
);
x.push(m);
}
Fields::Named(ref fu) => {
assert!(fu.named.len() > 0, "This is a named unit field, wtf");
let m = MatchArm::from_fields(
v.ident,
&fu.named,
offset,
&mut range_high,
);
x.push(m);
}
Fields::Unit => {} }
}
x
}
fn generate_rusty_enum_code(
name: &Ident,
generics: Generics,
variants: &Punctuated<Variant, Comma>,
) -> Tokens {
let enum_variants = all_variants_of(variants);
let match_arm_from = enum_variants.iter().map(|e| e.print_from_discr(name));
let match_arm_to = enum_variants.iter().map(|e| e.print_to_discr(name));
let enum_count_s = enum_variants.iter().map(|e| e.print_num_variants());
let enum_count = quote!(
#(
#enum_count_s +
)* 0usize
);
impl_enum_like(
name,
generics,
quote! {
const NUM_VARIANTS: usize = #enum_count;
fn from_discr(value: usize) -> Self {
match value {
#(
#match_arm_from
)*
_ => unreachable!()
}
}
fn to_discr(self) -> usize {
match self {
#(
#match_arm_to
)*
}
}
},
)
}
fn generate_c_enum_code(
name: &Ident,
generics: Generics,
variants: &Punctuated<Variant, Comma>,
) -> Tokens {
let variant_a = variants.iter().map(|variant| &variant.ident);
let variant_b = variants.iter().map(|variant| &variant.ident);
let repeat_name_a = iter::repeat(name);
let repeat_name_b = iter::repeat(name);
let counter_a = 0..variants.len() as usize;
let counter_b = 0..variants.len() as usize;
let enum_count = variants.len();
impl_enum_like(
name,
generics,
quote! {
const NUM_VARIANTS: usize = #enum_count;
fn from_discr(value: usize) -> Self {
match value {
#(
#counter_a => #repeat_name_a::#variant_a,
)*
_ => unreachable!()
}
}
fn to_discr(self) -> usize {
match self {
#(
#repeat_name_b::#variant_b => #counter_b,
)*
}
}
},
)
}
fn generate_enum_code(
name: &Ident,
generics: Generics,
variants: &Punctuated<Variant, Comma>,
) -> Tokens {
let c_like = variants.iter().all(|v| v.fields == Fields::Unit);
if c_like {
generate_c_enum_code(name, generics, variants)
} else {
generate_rusty_enum_code(name, generics, variants)
}
}
fn generate_unit_struct_impl(
name: &Ident,
generics: Generics,
unit: bool,
) -> Tokens {
let hack = if unit { quote!() } else { quote!({}) };
impl_enum_like(
name,
generics,
quote! {
const NUM_VARIANTS: usize = 1usize;
fn from_discr(_value: usize) -> Self {
#name #hack
}
fn to_discr(self) -> usize {
0usize
}
},
)
}
fn generate_struct_many_elem(
name: &Ident,
generics: Generics,
field_names: &[Tokens],
type_names: &[Tokens],
) -> Tokens {
let value = "e!(value);
let from_discr_0 = from_discr_body(field_names, type_names, value);
let offset = "e!(0usize);
let to_discr_0 = to_discr_body(field_names, type_names, offset, None);
let total_num_variants = total_num_variants_product(type_names);
impl_enum_like(
name,
generics,
quote! {
const NUM_VARIANTS: usize = #total_num_variants;
fn from_discr(value: usize) -> Self {
Self {
#from_discr_0
}
}
fn to_discr(self) -> usize {
#to_discr_0
}
},
)
}
fn total_num_variants_product(type_names: &[Tokens]) -> Tokens {
let mut type_names_plus_one = type_names.to_vec();
type_names_plus_one.push(quote!{});
let product = num_variants_product(&type_names_plus_one);
product.last().unwrap().clone()
}
fn num_variants_product(type_names: &[Tokens]) -> Vec<Tokens> {
let n = type_names.len();
let mut last_p = quote!(1usize);
let mut product = vec![last_p.clone()];
if n >= 2 {
let type_n1 = &type_names[0];
last_p = quote!( <#type_n1 as ::enum_like::EnumLike>::NUM_VARIANTS );
product.push(last_p.clone());
}
for i in 2..type_names.len() {
let type_n1 = &type_names[i - 1];
let old_last_p = last_p.clone();
last_p = quote! {
#old_last_p * <#type_n1 as ::enum_like::EnumLike>::NUM_VARIANTS
};
product.push(last_p.clone());
}
product
}
fn from_discr_body(
field_names: &[Tokens],
type_names: &[Tokens],
value: &Tokens,
) -> Tokens {
let value_r = std::iter::repeat(value);
let product = num_variants_product(type_names);
let n = type_names.len();
let rem = type_names
.iter()
.take(n - 1)
.map(|tn| {
quote! {
.wrapping_rem(<#tn as ::enum_like::EnumLike>::NUM_VARIANTS)
}
})
.chain(std::iter::once(quote!{}));
debug_assert_eq!(field_names.len(), n);
debug_assert_eq!(type_names.len(), n);
debug_assert_eq!(product.len(), n);
debug_assert_eq!(rem.size_hint().0, n);
quote! {
#(
#field_names: <#type_names as ::enum_like::EnumLike>::from_discr(
( #value_r ).wrapping_div( #product ) #rem
),
)*
}
}
fn to_discr_body(
field_names: &[Tokens],
type_names: &[Tokens],
offset: &Tokens,
xfield_names: Option<&[Ident]>,
) -> Tokens {
let product = num_variants_product(type_names);
if let Some(xfield_names) = xfield_names {
quote! {
( #offset )
#(
+ #product *
<#type_names as ::enum_like::EnumLike>::to_discr(#xfield_names)
)*
}
} else {
quote! {
( #offset )
#(
+ #product *
<#type_names as ::enum_like::EnumLike>::to_discr(self.#field_names)
)*
}
}
}
fn generate_struct_with_fields(
name: &Ident,
generics: Generics,
fields: &Punctuated<Field, Comma>,
) -> Tokens {
let elements = fields.len();
match elements {
0 => generate_unit_struct_impl(name, generics, false),
_ => {
let type_names: Vec<Tokens> =
fields.iter().map(|f| f.ty.clone().into_tokens()).collect();
let mut field_names: Vec<Tokens> = vec![];
for (i, n) in fields.iter().enumerate() {
if let Some(x) = n.ident {
field_names.push(x.clone().into_tokens());
} else {
let i = syn::Index::from(i);
field_names.push(quote!(#i));
}
}
generate_struct_many_elem(name, generics, &field_names, &type_names)
}
}
}
fn generate_struct_code(
name: &Ident,
generics: Generics,
fields: &Fields,
) -> Tokens {
match *fields {
Fields::Unit => generate_unit_struct_impl(name, generics, true),
Fields::Named(ref f) => {
generate_struct_with_fields(name, generics, &f.named)
}
Fields::Unnamed(ref f) => {
generate_struct_with_fields(name, generics, &f.unnamed)
}
}
}
fn add_trait_bounds(mut generics: Generics) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(::enum_like::EnumLike));
}
}
generics
}
fn impl_enum_like(name: &Ident, generics: Generics, body: Tokens) -> Tokens {
let generics = add_trait_bounds(generics);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
quote! {
unsafe impl #impl_generics ::enum_like::EnumLike for #name #ty_generics
#where_clause {
#body
}
}
}
#[proc_macro_derive(EnumLike)]
pub fn derive_enum_like(input: TokenStream) -> TokenStream {
let input: DeriveInput = syn::parse(input).unwrap();
match input.data {
Data::Enum(DataEnum { ref variants, .. }) => {
generate_enum_code(&input.ident, input.generics, variants)
}
Data::Struct(DataStruct { ref fields, .. }) => {
generate_struct_code(&input.ident, input.generics, fields)
}
Data::Union(..) => {
panic!("#[derive(EnumLike)] is only defined for enums and structs")
}
}.into()
}