#![doc = include_str!("../README.md")]
use quote::{
quote,
ToTokens,
};
use syn::{
Meta,
Data,
Type,
DataEnum,
Attribute,
DeriveInput,
MetaNameValue,
parse_macro_input,
};
use thiserror::Error;
use proc_macro::TokenStream;
#[derive(Error, Debug)]
enum Error {
#[error("`{0}` can only be derived for enums")]
DeriveForNonEnum(String),
#[error("Missing #[armtype = ...] attribute {0}, required for `{1}`-derived enum")]
MissingArmType(String, String),
#[error("Missing #[value = ...] attribute, expected for `{0}`-derived enum")]
MissingValue(String),
#[error("Attemping to parse non-literal attribute for `value`: not yet supported")]
NonLiteralValue,
}
#[proc_macro_derive(Const, attributes(value, armtype))]
pub fn thisenum_const(input: TokenStream) -> TokenStream {
let name = "Const";
let input = parse_macro_input!(input as DeriveInput);
let enum_name = &input.ident;
let variants = match input.data {
Data::Enum(DataEnum { variants, .. }) => variants,
_ => panic!("{}", Error::DeriveForNonEnum(name.into())),
};
let (type_name, deref) = match get_deref_type(&input.attrs) {
Some((type_name, deref)) => (type_name, deref),
None => panic!("{}", Error::MissingArmType("applied to enum".into(), name.into())),
};
let type_name_raw = match get_type(&input.attrs) {
Some(type_name_raw) => type_name_raw,
None => panic!("{}", Error::MissingArmType("applied to enum".into(), name.into())),
};
let (variant_match_arms, variant_inv_match_arms): (Vec<_>, Vec<_>) = variants
.iter()
.map(|variant| {
let variant_name = &variant.ident;
let num_args = match variant.fields {
syn::Fields::Named(syn::FieldsNamed { ref named, .. }) => named.len(),
syn::Fields::Unnamed(syn::FieldsUnnamed { ref unnamed, .. }) => unnamed.len(),
syn::Fields::Unit => 0,
};
let value = match get_val(name.into(), &variant.attrs) {
Ok(value) => value,
Err(e) => panic!("{}", e),
};
let args_tokens = match num_args {
0 => quote! {},
_ => {
let args = (0..num_args).map(|_| quote! { _ });
quote! { ( #(#args),* ) }
},
};
let vma = match deref {
true => quote! { #enum_name::#variant_name #args_tokens => #value, },
false => quote! { #enum_name::#variant_name #args_tokens => &#value, },
};
match num_args == 0 {
true => (vma, Some(quote! { #value => Ok(#enum_name::#variant_name), })),
false => (vma, None),
}
})
.unzip();
let any_args = variant_inv_match_arms
.iter()
.any(|vima| vima.is_none());
let variant_par_eq_lhs = match deref {
true => quote! { &self.value() == other },
false => quote! { self.value() == other },
};
let variant_par_eq_rhs = match deref {
true => quote! { &other.value() == self },
false => quote! { other.value() == self },
};
let into_impl = match deref {
false => quote! {
#[automatically_derived]
#[doc = concat!(" [`Into`] implementation for [`", stringify!(#enum_name), "`]")]
impl ::std::convert::Into<#type_name_raw> for #enum_name {
#[inline]
fn into(self) -> #type_name_raw {
*self.value()
}
}
},
true => quote! { },
};
let mut expanded = quote! {
#[automatically_derived]
impl #enum_name {
#[inline]
#[doc = concat!(" * [`&'static ", stringify!(#type_name), "`]")]
pub fn value(&self) -> &'static #type_name {
match self {
#( #variant_match_arms )*
}
}
}
#[automatically_derived]
#[doc = concat!(" [`PartialEq<", stringify!(#type_name_raw) ,">`] implementation for [`", stringify!(#enum_name), "`]")]
#[doc = concat!(" This is the LHS of the [`PartialEq`] implementation between [`", stringify!(#enum_name), "`] and [`", stringify!(#type_name_raw), "`]")]
impl ::std::cmp::PartialEq<#type_name_raw> for #enum_name {
#[inline]
fn eq(&self, other: &#type_name_raw) -> bool {
#variant_par_eq_lhs
}
}
#[automatically_derived]
#[doc = concat!(" [`PartialEq<", stringify!(#enum_name) ,">`] implementation for [`", stringify!(#type_name_raw), "`]")]
#[doc = concat!(" This is the RHS of the [`PartialEq`] implementation between [`", stringify!(#enum_name), "`] and [`", stringify!(#type_name_raw), "`]")]
impl ::std::cmp::PartialEq<#enum_name> for #type_name_raw {
#[inline]
fn eq(&self, other: &#enum_name) -> bool {
#variant_par_eq_rhs
}
}
#into_impl
};
if !any_args {
let variant_inv_match_arms = variant_inv_match_arms.into_iter().map(|vima| vima.unwrap());
expanded = quote! {
#expanded
#[automatically_derived]
#[doc = concat!(" [`TryFrom`] implementation for [`", stringify!(#enum_name), "`]")]
impl ::std::convert::TryFrom<#type_name_raw> for #enum_name {
type Error = ();
#[inline]
fn try_from(value: #type_name_raw) -> Result<Self, Self::Error> {
match value {
#( #variant_inv_match_arms )*
_ => Err(()),
}
}
}
};
}
TokenStream::from(expanded)
}
#[proc_macro_derive(ConstEach, attributes(value, armtype))]
pub fn thisenum_const_each(input: TokenStream) -> TokenStream {
let name = "ConstEach";
let input = parse_macro_input!(input as DeriveInput);
let enum_name = &input.ident;
let variants = match input.data {
Data::Enum(DataEnum { variants, .. }) => variants,
_ => panic!("{}", Error::DeriveForNonEnum(name.into())),
};
let variant_code = variants.iter().map(|variant| {
let variant_name = &variant.ident;
match (get_type(&variant.attrs), get_val(name.into(), &variant.attrs)) {
(Some(typ), Ok(value)) => quote! {
#enum_name::#variant_name => {
let val: &dyn ::std::any::Any = &(#value as #typ);
val.downcast_ref::<T>()
},
},
(None, Ok(value)) => quote! {
#enum_name::#variant_name => {
let val: &dyn ::std::any::Any = &#value;
val.downcast_ref::<T>()
},
},
(_, Err(_)) => quote! { #enum_name::#variant_name => None, },
}
});
let expanded = quote! {
#[automatically_derived]
#[doc = concat!(" [`ConstEach`] implementation for [`", stringify!(#enum_name), "`]")]
impl #enum_name {
pub fn value<T: 'static>(&self) -> Option<&'static T> {
match self {
#( #variant_code )*
_ => None,
}
}
}
};
TokenStream::from(expanded)
}
fn get_val(name: String, attrs: &[Attribute]) -> Result<proc_macro2::TokenStream, Error> {
for attr in attrs {
if !attr.path.is_ident("value") { continue; }
match attr.parse_meta() {
Ok(meta) => match meta {
Meta::NameValue(MetaNameValue { lit, .. }) => return Ok(lit.into_token_stream()),
Meta::List(list) => {
let tokens = list.nested.iter().map(|nested_meta| {
match nested_meta {
syn::NestedMeta::Lit(lit) => lit.to_token_stream(),
syn::NestedMeta::Meta(meta) => meta.to_token_stream(),
}
});
return Ok(quote! { #( #tokens )* });
}
Meta::Path(_) => return Ok(meta.into_token_stream())
},
Err(_) => {
return Err(Error::NonLiteralValue);
},
}
}
Err(Error::MissingValue(name))
}
fn get_deref_type(attrs: &[Attribute]) -> Option<(Type, bool)> {
for attr in attrs {
if !attr.path.is_ident("armtype") { continue; }
let tokens = match attr.parse_args::<proc_macro2::TokenStream>() {
Ok(tokens) => tokens,
Err(_) => return None,
};
let deref = tokens
.to_string()
.trim()
.starts_with('&');
let tokens = match deref {
true => {
let mut tokens = tokens.into_iter();
let _ = tokens.next();
tokens.collect::<proc_macro2::TokenStream>()
}
false => tokens,
};
return match syn::parse2::<Type>(tokens).ok() {
Some(type_name) => Some((type_name, deref)),
None => None
}
}
None
}
fn get_type(attrs: &[Attribute]) -> Option<Type> {
for attr in attrs {
if !attr.path.is_ident("armtype") { continue; }
let tokens = match attr.parse_args::<proc_macro2::TokenStream>() {
Ok(tokens) => tokens,
Err(_) => return None,
};
return syn::parse2::<Type>(
tokens
.into_iter()
.collect::<proc_macro2::TokenStream>()
).ok()
}
None
}