use check_keyword::CheckKeyword;
use heck::ToSnekCase;
use proc_macro::TokenStream;
use proc_macro_error2::{emit_error, proc_macro_error};
use quote::{format_ident, quote};
use syn::{Fields, Ident, ItemEnum, parse_macro_input};
struct VariantInfo {
normal: Ident,
snake: Ident,
fields: Fields,
}
#[proc_macro_error]
#[proc_macro_derive(
VariantsStruct,
attributes(struct_bounds, struct_derive, struct_name, field_name, struct_attr)
)]
pub fn variants_struct(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as ItemEnum);
let enum_ident = input.ident.clone();
let mut struct_ident = format_ident!("{}Struct", input.ident);
let visibility = input.vis.clone();
let mut bounds = quote! {};
let mut derives = vec![];
let mut attrs = vec![];
for attr in input.clone().attrs {
if attr.path().is_ident("struct_bounds") {
let syn::Meta::List(l) = attr.meta else {
emit_error!(
attr,
"struct_bounds must be of the form #[struct_bounds(Bound)]"
);
return quote! {}.into();
};
bounds = l.tokens;
} else if attr.path().is_ident("struct_derive") {
attr.parse_nested_meta(|meta| {
derives.push(meta.path);
Ok(())
})
.unwrap();
} else if attr.path().is_ident("struct_name") {
if let syn::Meta::NameValue(syn::MetaNameValue { value, .. }) = attr.meta {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = value
{
struct_ident = format_ident!("{}", lit_str.value());
} else {
emit_error!(value, "must be a str literal");
}
}
} else if attr.path().is_ident("struct_attr") {
let syn::Meta::List(l) = attr.meta else {
emit_error!(attr, "struct_attr must be of the form #[struct_attr(attr)]");
return quote! {}.into();
};
attrs.push(l.tokens);
}
}
if input.variants.is_empty() {
return (quote! {
#[derive(#(#derives),*)]
#visibility struct #struct_ident;
})
.into();
}
let vars: Vec<_> = input
.clone()
.variants
.iter()
.map(|var| {
let mut names = vec![];
for attr in &var.attrs {
if attr.path().is_ident("field_name") {
if let syn::Meta::NameValue(syn::MetaNameValue { value, .. }) = &attr.meta {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = value
{
names.push(lit_str.value());
} else {
emit_error!(value, "must be a str literal");
}
}
}
}
let snake = if names.is_empty() {
format_ident!("{}", var.ident.to_string().to_snek_case().into_safe())
} else {
format_ident!("{}", names.first().unwrap().into_safe())
};
VariantInfo {
normal: var.ident.clone(),
snake,
fields: var.fields.clone(),
}
})
.collect();
let mut field_idents = vec![];
let mut field_names = vec![];
let mut struct_fields = vec![];
let mut get_uncheckeds = vec![];
let mut get_mut_uncheckeds = vec![];
let mut gets = vec![];
let mut get_muts = vec![];
let mut new_args = vec![];
let mut new_fields = vec![];
for VariantInfo {
normal,
snake,
fields,
} in &vars
{
field_idents.push(snake.clone());
field_names.push(snake.to_string());
match fields {
Fields::Unit => {
struct_fields.push(quote! { pub #snake: T });
gets.push(quote! { &#enum_ident::#normal => Some(&self.#snake) });
get_muts.push(quote! { &#enum_ident::#normal => Some(&mut self.#snake) });
get_uncheckeds.push(quote! { &#enum_ident::#normal => &self.#snake });
get_mut_uncheckeds.push(quote! { &#enum_ident::#normal => &mut self.#snake });
new_args.push(quote! {#snake: T});
new_fields.push(quote! {#snake});
}
Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => {
if unnamed.len() == 1 {
let ty = unnamed.first().unwrap().clone().ty;
struct_fields.push(quote! {
pub #snake: std::collections::HashMap<#ty, T>
});
gets.push(quote! {
&#enum_ident::#normal(key) => self.#snake.get(&key)
});
get_muts.push(quote! {
&#enum_ident::#normal(key) => self.#snake.get_mut(&key)
});
get_uncheckeds.push(quote! {
&#enum_ident::#normal(key) => self.#snake.get(&key)
.expect("tuple variant key not found in hashmap")
});
get_mut_uncheckeds.push(quote! {
&#enum_ident::#normal(key) => self.#snake.get_mut(&key)
.expect("tuple variant key not found in hashmap")
});
new_fields.push(quote! {#snake: std::collections::HashMap::new()});
} else {
emit_error!(unnamed, "only tuples with one value are allowed");
}
}
Fields::Named(syn::FieldsNamed { named, .. }) => {
if named.len() == 1 {
let ty = named.first().unwrap().clone().ty;
let ident = named.first().unwrap().ident.clone().unwrap();
struct_fields.push(quote! {
pub #snake: std::collections::HashMap<#ty, T>
});
gets.push(quote! {
&#enum_ident::#normal {#ident} => self.#snake.get(&#ident)
});
get_muts.push(quote! {
&#enum_ident::#normal {#ident} => self.#snake.get_mut(&#ident)
});
get_uncheckeds.push(quote! {
&#enum_ident::#normal {#ident} => self.#snake.get(&#ident)
.expect("tuple variant key not found in hashmap")
});
get_mut_uncheckeds.push(quote! {
&#enum_ident::#normal {#ident} => self.#snake.get_mut(&#ident)
.expect("tuple variant key not found in hashmap")
});
new_fields.push(quote! {#snake: std::collections::HashMap::new()});
} else {
emit_error!(named, "only structs with one field are allowed");
}
}
}
}
(quote! {
#[derive(#(#derives),*)]
#(#[#attrs])*
#visibility struct #struct_ident<T: #bounds> {
#(#struct_fields),*
}
impl<T: #bounds> #struct_ident<T> {
pub fn new(#(#new_args),*) -> #struct_ident<T> {
#struct_ident {
#(#new_fields),*
}
}
pub fn get_unchecked(&self, var: &#enum_ident) -> &T {
match var {
#(#get_uncheckeds),*
}
}
pub fn get_mut_unchecked(&mut self, var: &#enum_ident) -> &mut T {
match var {
#(#get_mut_uncheckeds),*
}
}
pub fn get(&self, var: &#enum_ident) -> Option<&T> {
match var {
#(#gets),*
}
}
pub fn get_mut(&mut self, var: &#enum_ident) -> Option<&mut T> {
match var {
#(#get_muts),*
}
}
}
impl<T: #bounds> std::ops::Index<#enum_ident> for #struct_ident<T> {
type Output = T;
fn index(&self, var: #enum_ident) -> &T {
self.get_unchecked(&var)
}
}
impl<T: #bounds> std::ops::IndexMut<#enum_ident> for #struct_ident<T> {
fn index_mut(&mut self, var: #enum_ident) -> &mut T {
self.get_mut_unchecked(&var)
}
}
impl<T: #bounds> std::ops::Index<&#enum_ident> for #struct_ident<T> {
type Output = T;
fn index(&self, var: &#enum_ident) -> &T {
self.get_unchecked(var)
}
}
impl<T: #bounds> std::ops::IndexMut<&#enum_ident> for #struct_ident<T> {
fn index_mut(&mut self, var: &#enum_ident) -> &mut T {
self.get_mut_unchecked(var)
}
}
})
.into()
}