use std::{ffi::OsStr, fs};
use proc_macro::{Span, TokenStream};
use quote::ToTokens;
use syn::{
Attribute, Ident,
Item::{Enum, Struct, Type, Union},
ItemEnum, Meta, MetaNameValue, Token, parse_file, parse_macro_input,
punctuated::Punctuated,
};
use walkdir::WalkDir;
const ENUM_DISPATCH: &'static str = "enum_dispatch";
fn valid_variant(enum_name: &Ident, attrs: Vec<Attribute>) -> bool {
for attr in attrs {
if let Meta::List(list) = attr.meta {
if list.path.to_token_stream().to_string() != "enum_builder_variant" {
continue;
}
if list.tokens.to_token_stream().to_string() == enum_name.to_string() {
return true;
}
}
}
false
}
fn remove_enum_dispatch(mut item: ItemEnum) -> TokenStream {
let mut enum_dispatch_index = None;
for (index, attr) in item.attrs.iter().enumerate() {
match &attr.meta {
Meta::Path(path) => {
if path.to_token_stream().to_string() == ENUM_DISPATCH {
enum_dispatch_index = Some(index);
}
}
Meta::List(list) => {
if list.path.to_token_stream().to_string() == ENUM_DISPATCH {
enum_dispatch_index = Some(index);
}
}
_ => {}
}
}
if let Some(index) = enum_dispatch_index {
item.attrs.remove(index);
}
item.to_token_stream().into()
}
#[proc_macro_attribute]
pub fn enum_builder(attrs: TokenStream, item: TokenStream) -> TokenStream {
let parsed_item = parse_macro_input!(item);
let Enum(item_enum) = parsed_item else {
return parsed_item.to_token_stream().into();
};
let Some(dir) = Span::call_site().local_file() else {
return remove_enum_dispatch(item_enum);
};
let mut dir = dir.parent().unwrap().to_owned();
let mut enum_variants: Vec<String> = vec![];
let attrs =
parse_macro_input!(attrs with Punctuated::<MetaNameValue, Token![,]>::parse_terminated);
for attr in attrs {
let name = attr.path.to_token_stream().to_string();
if name != "path" {
continue;
}
dir = dir.join(
attr.value
.to_token_stream()
.to_string()
.trim_matches('"')
.to_owned(),
);
}
for entry in WalkDir::new(dir) {
let Ok(entry) = entry else { continue };
let path = entry.path();
if path.is_dir() {
continue;
}
if path.extension() != Some(OsStr::new("rs")) {
continue;
};
let src = fs::read_to_string(path)
.expect(format!("unable to read file {}", path.to_string_lossy()).as_str());
let syntax = parse_file(&src)
.expect(format!("unable to parse file {}", path.to_string_lossy()).as_str());
for item in syntax.items {
let ident;
let generics;
match item {
Struct(item) => {
if !valid_variant(&item_enum.ident, item.attrs) {
continue;
}
ident = item.ident;
generics = item.generics.to_token_stream().to_string();
}
Type(item) => {
if !valid_variant(&item_enum.ident, item.attrs) {
continue;
}
ident = item.ident;
generics = item.generics.to_token_stream().to_string();
}
Enum(item) => {
if !valid_variant(&item_enum.ident, item.attrs) {
continue;
}
ident = item.ident;
generics = item.generics.to_token_stream().to_string();
}
Union(item) => {
if !valid_variant(&item_enum.ident, item.attrs) {
continue;
}
ident = item.ident;
generics = item.generics.to_token_stream().to_string();
}
_ => continue,
}
enum_variants.push(format!("{}({}{})", ident, ident, generics));
}
}
if enum_variants.is_empty() {
return remove_enum_dispatch(item_enum);
}
format!(
"#[enum_dispatch]\nenum {}<'a> {{ {} }}",
item_enum.ident,
enum_variants.join(",\n")
)
.parse()
.unwrap()
}
#[proc_macro_attribute]
pub fn enum_builder_variant(_: TokenStream, item: TokenStream) -> TokenStream {
item
}