#![deny(missing_debug_implementations, missing_docs)]
#![no_std]
#![recursion_limit = "128"]
extern crate alloc;
extern crate proc_macro;
use alloc::vec::Vec;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use proc_macro_roids::{namespace_parameters, FieldsExt};
use quote::quote;
use syn::{
parse_macro_input, parse_quote, Attribute, Data, DataEnum, DeriveInput, Field, Fields, LitStr,
Meta, Path,
};
const ATTRIBUTES_TO_COPY: &[&str] = &["doc", "cfg", "allow", "deny"];
#[cfg(not(tarpaulin_include))]
#[proc_macro_derive(EnumVariantType, attributes(evt))]
pub fn enum_variant_type(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
enum_variant_type_impl(ast).into()
}
#[inline]
fn enum_variant_type_impl(ast: DeriveInput) -> proc_macro2::TokenStream {
let enum_name = &ast.ident;
let vis = &ast.vis;
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
let data_enum = data_enum(&ast);
let variants = &data_enum.variants;
let mut wrap_in_module = None::<Ident>;
let mut derive_for_all_variants = None::<Attribute>;
let mut marker_trait_paths = Vec::<Path>::new();
let mut repr_c = false;
for attr in ast.attrs.iter() {
if attr.path().is_ident("repr") {
if let Meta::List(list) = &attr.meta {
list.parse_nested_meta(|parse_nested_meta| {
if parse_nested_meta.path.is_ident("C") {
repr_c = true;
}
Ok(())
})
.unwrap_or_else(|e| panic!("Failed to parse repr attribute. Error: {}", e));
}
} else if attr.path().is_ident("evt") {
attr.parse_nested_meta(|nested_meta| {
if nested_meta.path.is_ident("module") {
let module_name: LitStr = nested_meta
.value()
.and_then(|value| value.parse())
.unwrap_or_else(|e| {
panic!(
"Expected `evt` attribute argument in the form: \
`#[evt(module = \"some_module_name\")]`. Error: {}",
e
)
});
wrap_in_module = Some(Ident::new(&module_name.value(), Span::call_site()));
return Ok(());
}
if nested_meta.path.is_ident("derive") {
let mut items = Vec::new();
nested_meta.parse_nested_meta(|parse_nested_meta| {
items.push(parse_nested_meta.path);
Ok(())
})?;
derive_for_all_variants = Some(parse_quote! {
#[derive( #(#items),* )]
});
return Ok(());
}
if nested_meta.path.is_ident("implement_marker_traits") {
nested_meta.parse_nested_meta(|parse_nested_meta| {
marker_trait_paths.push(parse_nested_meta.path);
Ok(())
})?;
return Ok(());
}
panic!(
"Unexpected usage of `evt` attribute, please see examples at:\n\
<https://docs.rs/enum_variant_type/>"
)
})
.unwrap_or_else(|e| {
panic!("Failed to process evt attribute. Error: {}", e);
});
}
}
let mut struct_declarations = proc_macro2::TokenStream::new();
let ns: Path = parse_quote!(evt);
let skip: Path = parse_quote!(skip);
let struct_declarations_iter = variants.iter()
.filter(|variant| !proc_macro_roids::contains_tag(&variant.attrs, &ns, &skip))
.map(|variant| {
let variant_name = &variant.ident;
let attrs_to_copy = variant
.attrs
.iter()
.filter(|attribute| {
ATTRIBUTES_TO_COPY
.iter()
.any(|attr_to_copy| attribute.path().is_ident(attr_to_copy))
})
.collect::<Vec<&Attribute>>();
let evt_meta_lists = namespace_parameters(&variant.attrs, &ns);
let mut variant_struct_attrs = evt_meta_lists
.into_iter()
.fold(
proc_macro2::TokenStream::new(),
|mut attrs_tokens, variant_struct_attr| {
attrs_tokens.extend(quote!(#[#variant_struct_attr]));
attrs_tokens
},
);
if repr_c {
variant_struct_attrs.extend(quote! {
#[repr(C)]
})
}
let variant_fields = &variant.fields;
let fields_with_vis = variant_fields
.iter()
.cloned()
.map(|mut field| {
field.vis = vis.clone();
field
})
.collect::<Vec<Field>>();
let data_struct = match variant_fields {
Fields::Unit => quote! {
struct #variant_name;
},
Fields::Unnamed(..) => {
quote! {
struct #variant_name #ty_generics (#(#fields_with_vis,)*) #where_clause;
}
}
Fields::Named(..) => quote! {
struct #variant_name #ty_generics #where_clause {
#(#fields_with_vis,)*
}
},
};
let construction_form = variant_fields.construction_form();
let deconstruct_variant_struct = if variant_fields.is_unit() {
proc_macro2::TokenStream::new()
} else {
quote! {
let #variant_name #construction_form = variant_struct;
}
};
let impl_from_variant_for_enum = quote! {
impl #impl_generics core::convert::From<#variant_name #ty_generics>
for #enum_name #ty_generics
#where_clause {
fn from(variant_struct: #variant_name #ty_generics) -> Self {
#deconstruct_variant_struct
#enum_name::#variant_name #construction_form
}
}
};
let impl_try_from_enum_for_variant = quote! {
impl #impl_generics core::convert::TryFrom<#enum_name #ty_generics>
for #variant_name #ty_generics
#where_clause {
type Error = #enum_name #ty_generics;
fn try_from(enum_variant: #enum_name #ty_generics) -> Result<Self, Self::Error> {
if let #enum_name::#variant_name #construction_form = enum_variant {
core::result::Result::Ok(#variant_name #construction_form)
} else {
core::result::Result::Err(enum_variant)
}
}
}
};
quote! {
#(#attrs_to_copy)*
#derive_for_all_variants
#variant_struct_attrs
#vis #data_struct
#impl_from_variant_for_enum
#impl_try_from_enum_for_variant
#(impl #ty_generics #marker_trait_paths for #variant_name #ty_generics {})*
}
});
struct_declarations.extend(struct_declarations_iter);
if let Some(module_to_wrap_in) = wrap_in_module {
quote! {
#vis mod #module_to_wrap_in {
use super::*;
#struct_declarations
}
}
} else {
struct_declarations
}
}
fn data_enum(ast: &DeriveInput) -> &DataEnum {
if let Data::Enum(data_enum) = &ast.data {
data_enum
} else {
panic!("`EnumVariantType` derive can only be used on an enum.");
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
use alloc::string::ToString;
use pretty_assertions::assert_eq;
use quote::quote;
use syn::{parse_quote, DeriveInput};
use super::enum_variant_type_impl;
#[test]
fn generates_correct_tokens_for_basic_enum() {
let ast: DeriveInput = parse_quote! {
pub enum MyEnum {
#[evt(derive(Clone, Copy, Debug, PartialEq))]
Unit,
#[evt(derive(Debug))]
Tuple(u32, u64),
Struct {
field_0: u32,
field_1: u64,
},
}
};
let actual_tokens = enum_variant_type_impl(ast);
let expected_tokens = quote! {
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Unit;
impl core::convert::From<Unit> for MyEnum {
fn from(variant_struct: Unit) -> Self {
MyEnum::Unit
}
}
impl core::convert::TryFrom<MyEnum> for Unit {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::Unit = enum_variant {
core::result::Result::Ok(Unit)
} else {
core::result::Result::Err(enum_variant)
}
}
}
#[derive(Debug)]
pub struct Tuple(pub u32, pub u64,);
impl core::convert::From<Tuple> for MyEnum {
fn from(variant_struct: Tuple) -> Self {
let Tuple(_0, _1,) = variant_struct;
MyEnum::Tuple(_0, _1,)
}
}
impl core::convert::TryFrom<MyEnum> for Tuple {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::Tuple(_0, _1,) = enum_variant {
core::result::Result::Ok(Tuple(_0, _1,))
} else {
core::result::Result::Err(enum_variant)
}
}
}
pub struct Struct {
pub field_0: u32,
pub field_1: u64,
}
impl core::convert::From<Struct> for MyEnum {
fn from(variant_struct: Struct) -> Self {
let Struct { field_0, field_1, } = variant_struct;
MyEnum::Struct { field_0, field_1, }
}
}
impl core::convert::TryFrom<MyEnum> for Struct {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::Struct { field_0, field_1, } = enum_variant {
core::result::Result::Ok(Struct { field_0, field_1, })
} else {
core::result::Result::Err(enum_variant)
}
}
}
};
assert_eq!(expected_tokens.to_string(), actual_tokens.to_string());
}
#[test]
fn skips_variants_marked_with_evt_skip() {
let ast: DeriveInput = parse_quote! {
pub enum MyEnum {
#[evt(derive(Clone, Copy, Debug, PartialEq))]
Unit,
#[evt(skip)]
UnitSkipped,
}
};
let actual_tokens = enum_variant_type_impl(ast);
let expected_tokens = quote! {
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Unit;
impl core::convert::From<Unit> for MyEnum {
fn from(variant_struct: Unit) -> Self {
MyEnum::Unit
}
}
impl core::convert::TryFrom<MyEnum> for Unit {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::Unit = enum_variant {
core::result::Result::Ok(Unit)
} else {
core::result::Result::Err(enum_variant)
}
}
}
};
assert_eq!(expected_tokens.to_string(), actual_tokens.to_string());
}
#[test]
fn put_variants_in_module() {
let ast: DeriveInput = parse_quote! {
#[evt(module = "example")]
pub enum MyEnum {
A,
B
}
};
let actual_tokens = enum_variant_type_impl(ast);
let expected_tokens = quote! {
pub mod example {
use super::*;
pub struct A;
impl core::convert::From<A> for MyEnum {
fn from(variant_struct: A) -> Self {
MyEnum::A
}
}
impl core::convert::TryFrom<MyEnum> for A {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::A = enum_variant {
core::result::Result::Ok(A)
} else {
core::result::Result::Err(enum_variant)
}
}
}
pub struct B;
impl core::convert::From<B> for MyEnum {
fn from(variant_struct: B) -> Self {
MyEnum::B
}
}
impl core::convert::TryFrom<MyEnum> for B {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::B = enum_variant {
core::result::Result::Ok(B)
} else {
core::result::Result::Err(enum_variant)
}
}
}
}
};
assert_eq!(expected_tokens.to_string(), actual_tokens.to_string());
}
#[test]
fn derive_traits_for_all_variants() {
let ast: DeriveInput = parse_quote! {
#[evt(derive(Debug))]
pub enum MyEnum {
A,
#[evt(derive(Clone))]
B
}
};
let actual_tokens = enum_variant_type_impl(ast);
let expected_tokens = quote! {
#[derive(Debug)]
pub struct A;
impl core::convert::From<A> for MyEnum {
fn from(variant_struct: A) -> Self {
MyEnum::A
}
}
impl core::convert::TryFrom<MyEnum> for A {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::A = enum_variant {
core::result::Result::Ok(A)
} else {
core::result::Result::Err(enum_variant)
}
}
}
#[derive(Debug)]
#[derive(Clone)]
pub struct B;
impl core::convert::From<B> for MyEnum {
fn from(variant_struct: B) -> Self {
MyEnum::B
}
}
impl core::convert::TryFrom<MyEnum> for B {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::B = enum_variant {
core::result::Result::Ok(B)
} else {
core::result::Result::Err(enum_variant)
}
}
}
};
assert_eq!(expected_tokens.to_string(), actual_tokens.to_string());
}
#[test]
fn derive_marker_trait() {
let ast: DeriveInput = parse_quote! {
#[evt(implement_marker_traits(MarkerTrait1))]
pub enum MyEnum {
A,
B
}
};
let actual_tokens = enum_variant_type_impl(ast);
let expected_tokens = quote! {
pub struct A;
impl core::convert::From<A> for MyEnum {
fn from(variant_struct: A) -> Self {
MyEnum::A
}
}
impl core::convert::TryFrom<MyEnum> for A {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::A = enum_variant {
core::result::Result::Ok(A)
} else {
core::result::Result::Err(enum_variant)
}
}
}
impl MarkerTrait1 for A {}
pub struct B;
impl core::convert::From<B> for MyEnum {
fn from(variant_struct: B) -> Self {
MyEnum::B
}
}
impl core::convert::TryFrom<MyEnum> for B {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::B = enum_variant {
core::result::Result::Ok(B)
} else {
core::result::Result::Err(enum_variant)
}
}
}
impl MarkerTrait1 for B {}
};
assert_eq!(expected_tokens.to_string(), actual_tokens.to_string());
}
#[test]
fn derive_marker_repr() {
let ast: DeriveInput = parse_quote! {
#[derive(Debug)]
#[repr(C)]
pub enum MyEnum {
A { i: i64 },
B { i: i64 },
}
};
let actual_tokens = enum_variant_type_impl(ast);
let expected_tokens = quote! {
#[repr(C)]
pub struct A { pub i: i64, }
impl core::convert::From<A> for MyEnum {
fn from(variant_struct: A) -> Self {
let A { i, } = variant_struct;
MyEnum::A { i, }
}
}
impl core::convert::TryFrom<MyEnum> for A {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::A { i, } = enum_variant {
core::result::Result::Ok(A { i, })
} else {
core::result::Result::Err(enum_variant)
}
}
}
#[repr(C)]
pub struct B { pub i: i64, }
impl core::convert::From<B> for MyEnum {
fn from(variant_struct: B) -> Self {
let B { i, } = variant_struct;
MyEnum::B { i, }
}
}
impl core::convert::TryFrom<MyEnum> for B {
type Error = MyEnum;
fn try_from(enum_variant: MyEnum) -> Result<Self, Self::Error> {
if let MyEnum::B { i, } = enum_variant {
core::result::Result::Ok(B { i, })
} else {
core::result::Result::Err(enum_variant)
}
}
}
};
assert_eq!(expected_tokens.to_string(), actual_tokens.to_string());
}
}