#![doc = include_str!("../README.md")]
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::parse::{Parse, Parser};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use crate::case::RenameRule;
mod case;
struct Argument {
ident: syn::Ident,
expr: Option<(syn::Token![=], syn::Expr)>,
}
impl Parse for Argument {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let name = input.parse()?;
let expr = if input.peek(syn::Token![=]) {
let eq_token: syn::Token![=] = input.parse()?;
let expr: syn::Expr = input.parse()?;
Some((eq_token, expr))
} else {
None
};
Ok(Argument { ident: name, expr })
}
}
#[proc_macro_derive(DictKey, attributes(enum_dict))]
pub fn derive_dict_key(input: TokenStream) -> TokenStream {
derive_dict_key_inner(input.into()).into()
}
pub(crate) fn derive_dict_key_inner(input: TokenStream2) -> TokenStream2 {
let input: syn::DeriveInput = match syn::parse2(input) {
Ok(input) => input,
Err(err) => return err.to_compile_error(),
};
let syn::Data::Enum(data) = input.data else {
return syn::Error::new(input.span(), "DictKey can only be derived for enums").to_compile_error();
};
let mut rename_all = RenameRule::None;
let mut errors = TokenStream2::new();
for attr in input.attrs {
if !attr.path().is_ident("enum_dict") {
continue;
}
let syn::Meta::List(meta_list) = attr.meta else {
errors.extend(syn::Error::new(attr.span(), "expected #[enum_dict(...)]").to_compile_error());
continue;
};
let args = match Punctuated::<Argument, syn::Token![,]>::parse_terminated.parse2(meta_list.tokens) {
Ok(args) => args,
Err(err) => {
errors.extend(err.to_compile_error());
continue;
}
};
for arg in args {
if arg.ident == "rename_all" {
let Some((
_,
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}),
)) = arg.expr
else {
errors
.extend(syn::Error::new(arg.ident.span(), "expected rename_all = \"...\"").to_compile_error());
continue;
};
match RenameRule::from_str(&lit_str.value()) {
Ok(rule) => rename_all = rule,
Err(err) => errors.extend(syn::Error::new(lit_str.span(), err.to_string()).to_compile_error()),
};
} else {
errors.extend(
syn::Error::new(arg.ident.span(), "unknown attribute for enum_dict derive").to_compile_error(),
);
}
}
}
let mut ident_names = TokenStream2::new();
let mut from_index_arms = TokenStream2::new();
let mut as_index_arms = TokenStream2::new();
let variant_count = data.variants.len();
for (index, variant) in data.variants.into_iter().enumerate() {
let syn::Fields::Unit = &variant.fields else {
errors.extend(
syn::Error::new(variant.span(), "DictKey can only be derived for unit variants").to_compile_error(),
);
continue;
};
let ident = &variant.ident;
let mut name = rename_all.apply(&ident.to_string());
for attr in variant.attrs {
if !attr.path().is_ident("enum_dict") {
continue;
}
let syn::Meta::List(meta_list) = attr.meta else {
errors.extend(syn::Error::new(attr.span(), "expected #[enum_dict(...)]").to_compile_error());
continue;
};
let args = match Punctuated::<Argument, syn::Token![,]>::parse_terminated.parse2(meta_list.tokens) {
Ok(args) => args,
Err(err) => {
errors.extend(err.to_compile_error());
continue;
}
};
for arg in args {
if arg.ident == "rename" {
let Some((
_,
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}),
)) = arg.expr
else {
errors
.extend(syn::Error::new(arg.ident.span(), "expected rename = \"...\"").to_compile_error());
continue;
};
name = lit_str.value();
} else {
errors.extend(
syn::Error::new(arg.ident.span(), "unknown attribute for enum_dict derive").to_compile_error(),
);
}
}
}
from_index_arms.extend(quote! { #index => Self::#ident, });
as_index_arms.extend(quote! { Self::#ident => #index, });
ident_names.extend(quote! { #name, });
}
if !errors.is_empty() {
return errors;
}
let ident = &input.ident;
quote! {
#[automatically_derived]
impl ::enum_dict::DictKey for #ident {
type Array<T> = [T; #variant_count];
const VARIANTS: &'static [&'static str] = &[#ident_names];
fn from_index(index: usize) -> Self {
match index {
#from_index_arms
_ => panic!("invalid index for DictKey"),
}
}
#[inline]
fn as_index(&self) -> usize {
match self {
#as_index_arms
}
}
}
}
}
#[cfg(test)]
mod test {
use std::env::var;
use std::fs::{create_dir_all, read_to_string, write};
use std::path::{Path, PathBuf};
use macro_expand::Context;
use pretty_assertions::StrComparison;
use prettyplease::unparse;
use walkdir::WalkDir;
use super::*;
struct TestDiff {
path: PathBuf,
expect: String,
actual: String,
}
#[test]
fn fixtures() {
let input_dir = "fixtures/input";
let output_dir = "fixtures/output";
let mut diffs = vec![];
let will_emit = var("EMIT").is_ok_and(|v| !v.is_empty());
for entry in WalkDir::new(input_dir).into_iter().filter_map(Result::ok) {
let input_path = entry.path();
if !input_path.is_file() || input_path.extension() != Some("rs".as_ref()) {
continue;
}
let path = input_path.strip_prefix(input_dir).unwrap();
let output_path = Path::new(output_dir).join(path);
let input = read_to_string(input_path).unwrap().parse().unwrap();
let mut ctx = Context::new();
ctx.register_proc_macro_derive(
"DictKey".to_string(),
derive_dict_key_inner,
vec!["enum_dict".to_string()],
);
let actual = unparse(&syn::parse2(ctx.transform(input)).unwrap());
let expect_result = read_to_string(&output_path);
if let Ok(expect) = &expect_result
&& expect == &actual
{
continue;
}
if will_emit {
create_dir_all(output_path.parent().unwrap()).unwrap();
write(output_path, &actual).unwrap();
}
if let Ok(expect) = expect_result {
diffs.push(TestDiff {
path: path.to_path_buf(),
expect,
actual,
});
}
}
let len = diffs.len();
for diff in diffs {
eprintln!("diff {}", diff.path.display());
eprintln!("{}", StrComparison::new(&diff.expect, &diff.actual));
}
if len > 0 && !will_emit {
panic!("Some tests failed");
}
}
}