#![recursion_limit = "512"]
extern crate proc_macro;
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
bracketed,
parse::{Parse, ParseBuffer},
parse2, parse_macro_input, parse_quote,
punctuated::Punctuated,
spanned::Spanned,
Error, ItemEnum, Path, Token,
};
#[derive(Debug)]
struct Args {
derives: Vec<Path>,
traits: Vec<Path>,
}
impl Parse for Args {
fn parse(input: &ParseBuffer) -> Result<Self, Error> {
let kvps = Punctuated::<_, Token![,]>::parse_terminated_with(input, |input| {
let key: Path = input.parse()?;
let key = key.segments.first().unwrap().value().ident.to_string();
input.parse::<Token![=]>()?;
let value;
bracketed!(value in input);
let value = Punctuated::<Path, Token![,]>::parse_terminated(&value)?;
let value = value.into_iter().collect();
Ok((key, value))
})?;
let mut args = Args {
derives: Vec::new(),
traits: Vec::new(),
};
for (key, value) in kvps.into_iter() {
match &key[..] {
"derive" => args.derives = value,
"trait_obj" => args.traits = value,
_ => unimplemented!("error"),
}
}
Ok(args)
}
}
#[doc(hidden)]
#[proc_macro_attribute]
pub fn tyenum(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let Args {
derives,
traits,
} = parse_macro_input!(attr as Args);
let item: TokenStream = item.into();
let mut tyenum: ItemEnum = parse2(item).unwrap();
let name = &tyenum.ident;
let mut impls = Vec::new();
let mut tyenum_trait_object_impl = Vec::new();
let mut tyenum_derive_variants = Vec::new();
for v in tyenum.variants.iter_mut() {
let ty: &dyn ToTokens;
let ident;
{
let mut iter = v.fields.iter();
ty = if let Some(f) = iter.next() {
if let Some(f) = iter.next() {
return Error::new(f.span(), "maximum one field in variants allowed").to_compile_error().into();
}
&f.ty
} else {
&v.ident
};
ident = &v.ident;
impls.push(quote! {
impl From<#ty> for #name {
fn from(variant: #ty) -> Self {
#name::#ident(variant)
}
}
impl std::convert::TryFrom<#name> for #ty {
type Error = tyenum::TryFromTyenumError;
fn try_from(e: #name) -> Result<Self, tyenum::TryFromTyenumError> {
if let #name::#ident(variant) = e {
Ok(variant)
} else {
Err(tyenum::TryFromTyenumError)
}
}
}
impl tyenum::IsTypeOf<#name> for #ty {
fn is_type_of(e: &#name) -> bool {
if let #name::#ident(_) = e {
true
} else {
false
}
}
}
});
tyenum_trait_object_impl.push(quote! {#name::#ident(ref mut v) => v.trait_obj()});
tyenum_derive_variants.push(quote! {#name::#ident(v)});
}
*v = parse_quote!(#ident(#ty));
}
let tyenum_trait_object_impl_ref = &tyenum_trait_object_impl;
let trait_name = Ident::new(&format!("{}ToTraitObject", name), Span::call_site());
for trt in traits {
impls.push(quote! {
impl<'a> #name {
fn trait_obj(&'a mut self) -> Option<&'a mut dyn #trt> {
match self {
#(#tyenum_trait_object_impl_ref),*
}
}
}
trait #trait_name<'a, T> {
fn trait_obj(&'a mut self) -> Option<T>;
}
impl<'a, I> #trait_name<'a, &'a mut dyn #trt> for I {
default fn trait_obj(&'a mut self) -> Option<&'a mut dyn #trt> {
None
}
}
impl<'a, I: #trt> #trait_name<'a, &'a mut dyn #trt> for I {
fn trait_obj(&'a mut self) -> Option<&'a mut dyn #trt> {
Some(self)
}
}
})
}
let tyenum_derive_variants_ref = &tyenum_derive_variants;
for drv in derives {
impls.push(quote! {
impl std::ops::Deref for #name {
type Target = dyn #drv;
fn deref(&self) -> &Self::Target {
match self {
#(#tyenum_derive_variants_ref => v as & Self::Target),*
}
}
}
impl std::ops::DerefMut for #name {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
#(#tyenum_derive_variants_ref => v as &mut Self::Target),*
}
}
}
});
}
quote!(
impl #name {
fn is<T: tyenum::IsTypeOf<#name>>(&self) -> bool {
T::is_type_of(self)
}
}
#tyenum
#(#impls)*
)
.into()
}