use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
Data, DataEnum, DeriveInput, Token, Type,
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
};
struct AllowedTypes {
types: Vec<Type>,
}
impl Parse for AllowedTypes {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut types = Vec::new();
while !input.is_empty() {
let ty: Type = input.parse()?;
types.push(ty);
if !input.is_empty() {
input.parse::<Token![,]>()?;
}
}
Ok(AllowedTypes { types })
}
}
#[proc_macro_attribute]
pub fn rawenum(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let name = &input.ident;
let allowed_types = parse_macro_input!(attr as AllowedTypes);
let specified_types = allowed_types.types;
if specified_types.is_empty() {
return syn::Error::new_spanned(
input,
"at least one integer type must be specified, e.g., #[rawenum(i32)]",
)
.to_compile_error()
.into();
}
let Data::Enum(DataEnum { variants, .. }) = &input.data else {
return syn::Error::new_spanned(input, "rawenum can only be applied to enums")
.to_compile_error()
.into();
};
let mut all_generated_methods = Vec::new();
const SUPPORTED_TYPES: &[&str] = &["i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64"];
for specified_type in specified_types {
let (type_ident, type_span) = match &specified_type {
Type::Path(type_path) => {
if let Some(segment) = type_path.path.segments.last() {
(segment.ident.clone(), segment.span())
} else {
return syn::Error::new_spanned(specified_type, "invalid type specified")
.to_compile_error()
.into();
}
}
_ => {
return syn::Error::new_spanned(
specified_type,
"expected an integer type identifier (e.g., i32)",
)
.to_compile_error()
.into();
}
};
let type_str = type_ident.to_string();
if !SUPPORTED_TYPES.contains(&type_str.as_str()) {
return syn::Error::new_spanned(
specified_type,
format!(
"unsupported integer type '{}'. Supported types are {}.",
type_str,
SUPPORTED_TYPES.join(", ")
),
)
.to_compile_error()
.into();
}
let fn_name = format_ident!("from_{}", type_str, span = type_span);
let mut local_generated_consts = Vec::new();
let mut local_match_arms = Vec::new();
for variant in variants {
let variant_name = &variant.ident; let variant_span = variant_name.span();
let const_name = format_ident!(
"__RAWENUM_{}_DISCRIMINANT_{}_{}",
name.to_string().to_uppercase(),
variant_name.to_string().to_uppercase(),
type_str.to_uppercase(),
span = variant_span
);
local_generated_consts.push(quote! {
const #const_name: #specified_type = #name::#variant_name as #specified_type;
});
local_match_arms.push(quote! {
#const_name => Some(Self::#variant_name),
});
}
local_match_arms.push(quote! {
_ => None,
});
let method_code = quote! {
#[allow(dead_code)] pub fn #fn_name(value: #specified_type) -> Option<Self> {
#( #local_generated_consts )*
match value {
#( #local_match_arms )*
}
}
};
all_generated_methods.push(method_code);
}
let expanded = quote! {
#input
impl #name {
#( #all_generated_methods )* }
};
expanded.into()
}