use std::collections::BTreeSet;
use proc_macro::TokenStream;
use proc_macro2::{Literal, TokenStream as TokenStream2};
use quote::{ToTokens, quote};
use syn::{
Data, DeriveInput, Fields, Ident, ItemEnum, ItemStruct, Token, Type, Variant, braced,
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
};
fn extract_enum_variants(input: &DeriveInput) -> syn::Result<Vec<(&syn::Ident, &syn::Type)>> {
let mut distinct_types = BTreeSet::new();
let Data::Enum(data_enum) = &input.data else {
return Err(syn::Error::new_spanned(
input,
"EnumConversions can only be used with enums",
));
};
data_enum.variants.iter().map(|variant: &Variant| {
let variant_name = &variant.ident;
match &variant.fields {
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
let field_type = &fields.unnamed.first().unwrap().ty;
if !distinct_types.insert(field_type.to_token_stream().to_string()) {
return Err(syn::Error::new_spanned(
field_type,
"EnumConversions only works with enums that have unnamed single fields of distinct types"
));
}
Ok((variant_name, field_type))
},
_ => Err(syn::Error::new_spanned(
variant,
"EnumConversions only works with enums that have unnamed single fields"
))
}
}).collect()
}
fn generate_enum_self_conversions(enum_name: &Ident, variants: &[(&Ident, &Type)]) -> TokenStream2 {
let mut conversions = TokenStream2::new();
for (variant_name, field_type) in variants {
let from_impl = quote! {
impl From<#field_type> for #enum_name {
fn from(value: #field_type) -> Self {
#enum_name::#variant_name(value)
}
}
};
conversions.extend(from_impl);
let try_from_impl = quote! {
impl TryFrom<#enum_name> for #field_type {
type Error = #enum_name;
fn try_from(value: #enum_name) -> Result<Self, Self::Error> {
match value {
#enum_name::#variant_name(inner) => Ok(inner),
x => Err(x),
}
}
}
};
conversions.extend(try_from_impl);
let try_from_ref_impl = quote! {
impl<'a> TryFrom<&'a #enum_name> for &'a #field_type {
type Error = &'a #enum_name;
fn try_from(value: &'a #enum_name) -> Result<Self, Self::Error> {
match value {
#enum_name::#variant_name(inner) => Ok(inner),
_ => Err(value),
}
}
}
};
conversions.extend(try_from_ref_impl);
}
conversions
}
fn generate_enum_target_conversions(
enum_name: &Ident,
target_type: &Type,
variants: &[(&Ident, &Type)],
) -> TokenStream2 {
let mut conversions = TokenStream2::new();
for (variant_name, field_type) in variants {
let from_impl = quote! {
impl From<#field_type> for #target_type {
fn from(value: #field_type) -> Self {
let enum_value = #enum_name::#variant_name(value);
Self::from(enum_value)
}
}
};
conversions.extend(from_impl);
let try_from_impl = quote! {
impl TryFrom<#target_type> for #field_type {
type Error = #target_type;
fn try_from(value: #target_type) -> Result<Self, Self::Error> {
match #enum_name::try_from(value) {
Ok(#enum_name::#variant_name(inner)) => Ok(inner),
Ok(x) => Err(#target_type::from(x)),
Err(x) => Err(x),
}
}
}
};
conversions.extend(try_from_impl);
let try_from_ref_impl = quote! {
impl<'a> TryFrom<&'a #target_type> for &'a #field_type {
type Error = &'a #target_type;
fn try_from(value: &'a #target_type) -> Result<Self, Self::Error> {
match <&'a #enum_name>::try_from(value) {
Ok(#enum_name::#variant_name(inner)) => Ok(inner),
Ok(_) => Err(value),
Err(_) => Err(value),
}
}
}
};
conversions.extend(try_from_ref_impl);
}
conversions
}
struct EnumConversionsArgs {
target_types: Punctuated<Type, Token![,]>,
}
impl Parse for EnumConversionsArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(EnumConversionsArgs {
target_types: input.parse_terminated(Type::parse, Token![,])?,
})
}
}
#[proc_macro_attribute]
pub fn enum_conversions(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as EnumConversionsArgs);
let input = parse_macro_input!(item as DeriveInput);
let enum_name = &input.ident;
let variants = match extract_enum_variants(&input) {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
let mut all_conversions = TokenStream2::new();
all_conversions.extend(generate_enum_self_conversions(enum_name, &variants));
for target_type in args.target_types {
all_conversions.extend(generate_enum_target_conversions(
enum_name,
&target_type,
&variants,
));
}
let expanded = quote! {
#input
#all_conversions
};
TokenStream::from(expanded)
}
struct CommonCode {
content: TokenStream2,
}
impl Parse for CommonCode {
fn parse(input: ParseStream) -> syn::Result<Self> {
let content;
braced!(content in input);
let content = content.parse()?;
Ok(CommonCode { content })
}
}
#[proc_macro_attribute]
pub fn common_fields(attr: TokenStream, item: TokenStream) -> TokenStream {
let common_code = parse_macro_input!(attr as CommonCode);
let common_fields_tokens = common_code.content;
let mut input_enum = parse_macro_input!(item as ItemEnum);
let temp_struct_tokens = quote! {
struct TempStruct {
#common_fields_tokens
}
};
let temp_struct: Result<ItemStruct, syn::Error> = syn::parse2(temp_struct_tokens);
if let Err(err) = temp_struct {
let error_string = err.to_string();
let error_lit = Literal::string(&error_string);
return TokenStream::from(quote! {
compile_error!(#error_lit);
});
}
let temp_struct = temp_struct.unwrap();
let common_fields = match temp_struct.fields {
Fields::Named(named) => named.named,
_ => {
let error_lit = Literal::string("Expected named fields in common code block");
return TokenStream::from(quote! {
compile_error!(#error_lit);
});
}
};
for variant in &mut input_enum.variants {
if let Fields::Named(ref mut fields) = variant.fields {
for field in common_fields.iter() {
fields.named.push(field.clone());
}
} else {
let error_lit = Literal::string("Expected named variants in enum");
return TokenStream::from(quote! {
compile_error!(#error_lit);
});
}
}
quote! {
#input_enum
}
.into()
}