use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields};
struct VariantInfo {
ident: syn::Ident,
display_name: String,
is_default: bool,
}
pub fn derive_enum_param_impl(input: DeriveInput) -> syn::Result<TokenStream> {
let data_enum = match &input.data {
Data::Enum(e) => e,
Data::Struct(_) => {
return Err(syn::Error::new_spanned(
&input,
"#[derive(EnumParam)] only supports enums, not structs",
))
}
Data::Union(_) => {
return Err(syn::Error::new_spanned(
&input,
"#[derive(EnumParam)] only supports enums, not unions",
))
}
};
let mut variants = Vec::new();
for variant in &data_enum.variants {
match &variant.fields {
Fields::Unit => {}
Fields::Named(_) => {
return Err(syn::Error::new_spanned(
variant,
"#[derive(EnumParam)] only supports unit variants (no fields)",
))
}
Fields::Unnamed(_) => {
return Err(syn::Error::new_spanned(
variant,
"#[derive(EnumParam)] only supports unit variants (no tuple fields)",
))
}
}
let display_name = extract_name_attribute(&variant.attrs)?
.unwrap_or_else(|| variant.ident.to_string());
let is_default = has_default_attribute(&variant.attrs);
variants.push(VariantInfo {
ident: variant.ident.clone(),
display_name,
is_default,
});
}
if variants.is_empty() {
return Err(syn::Error::new_spanned(
&input,
"#[derive(EnumParam)] requires at least one variant",
));
}
let default_indices: Vec<usize> = variants
.iter()
.enumerate()
.filter(|(_, v)| v.is_default)
.map(|(i, _)| i)
.collect();
if default_indices.len() > 1 {
return Err(syn::Error::new_spanned(
&input,
"#[derive(EnumParam)] only one variant can be marked as #[default]",
));
}
let default_index = default_indices.first().copied().unwrap_or(0);
let enum_name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let count = variants.len();
let from_index_arms: Vec<TokenStream> = variants
.iter()
.enumerate()
.map(|(idx, v)| {
let ident = &v.ident;
quote! { #idx => Some(#enum_name::#ident), }
})
.collect();
let to_index_arms: Vec<TokenStream> = variants
.iter()
.enumerate()
.map(|(idx, v)| {
let ident = &v.ident;
quote! { #enum_name::#ident => #idx, }
})
.collect();
let name_arms: Vec<TokenStream> = variants
.iter()
.enumerate()
.map(|(idx, v)| {
let name = &v.display_name;
quote! { #idx => #name, }
})
.collect();
let names_array: Vec<&str> = variants.iter().map(|v| v.display_name.as_str()).collect();
let default_ident = &variants[default_index].ident;
Ok(quote! {
impl #impl_generics ::beamer::core::param_types::EnumParamValue for #enum_name #ty_generics #where_clause {
const COUNT: usize = #count;
const DEFAULT_INDEX: usize = #default_index;
fn from_index(index: usize) -> Option<Self> {
match index {
#(#from_index_arms)*
_ => None,
}
}
fn to_index(self) -> usize {
match self {
#(#to_index_arms)*
}
}
fn default_value() -> Self {
#enum_name::#default_ident
}
fn name(index: usize) -> &'static str {
match index {
#(#name_arms)*
_ => "",
}
}
fn names() -> &'static [&'static str] {
&[#(#names_array),*]
}
}
})
}
fn extract_name_attribute(attrs: &[syn::Attribute]) -> syn::Result<Option<String>> {
for attr in attrs {
if attr.path().is_ident("name") {
let name_value: syn::MetaNameValue = attr.meta.require_name_value()?.clone();
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = &name_value.value
{
return Ok(Some(lit_str.value()));
} else {
return Err(syn::Error::new_spanned(
&name_value.value,
"expected string literal for #[name = \"...\"]",
));
}
}
}
Ok(None)
}
fn has_default_attribute(attrs: &[syn::Attribute]) -> bool {
attrs.iter().any(|attr| attr.path().is_ident("default"))
}