use darling::{FromDeriveInput, FromField, FromVariant};
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::{format_ident, quote};
use std::collections::HashMap;
use syn::{parse_macro_input, DeriveInput};
use tui_dispatch_shared::{infer_action_category, pascal_to_snake_case};
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(action), supports(enum_any))]
struct ActionOpts {
ident: syn::Ident,
data: darling::ast::Data<ActionVariant, ()>,
#[darling(default)]
infer_categories: bool,
#[darling(default)]
generate_dispatcher: bool,
}
#[derive(Debug, FromVariant)]
#[darling(attributes(action))]
struct ActionVariant {
ident: syn::Ident,
fields: darling::ast::Fields<()>,
#[darling(default)]
category: Option<String>,
#[darling(default)]
skip_category: bool,
}
fn to_snake_case(s: &str) -> String {
pascal_to_snake_case(s)
}
fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|part| {
let mut chars = part.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
}
})
.collect()
}
fn infer_category(name: &str) -> Option<String> {
infer_action_category(name)
}
#[proc_macro_derive(Action, attributes(action))]
pub fn derive_action(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let opts = match ActionOpts::from_derive_input(&input) {
Ok(opts) => opts,
Err(e) => return e.write_errors().into(),
};
let name = &opts.ident;
let variants = match &opts.data {
darling::ast::Data::Enum(variants) => variants,
_ => {
return syn::Error::new_spanned(&input, "Action can only be derived for enums")
.to_compile_error()
.into();
}
};
let syn_variants = match &input.data {
syn::Data::Enum(data) => &data.variants,
_ => unreachable!(), };
let name_arms = variants.iter().map(|v| {
let variant_name = &v.ident;
let variant_str = variant_name.to_string();
match &v.fields.style {
darling::ast::Style::Unit => quote! {
#name::#variant_name => #variant_str
},
darling::ast::Style::Tuple => quote! {
#name::#variant_name(..) => #variant_str
},
darling::ast::Style::Struct => quote! {
#name::#variant_name { .. } => #variant_str
},
}
});
let params_arms = syn_variants.iter().map(|v| {
let variant_name = &v.ident;
match &v.fields {
syn::Fields::Unit => quote! {
#name::#variant_name => ::std::string::String::new()
},
syn::Fields::Unnamed(fields) => {
let field_count = fields.unnamed.len();
let field_names: Vec<_> =
(0..field_count).map(|i| format_ident!("_{}", i)).collect();
if field_count == 1 {
quote! {
#name::#variant_name(#(#field_names),*) => {
tui_dispatch::debug::debug_string(&#(#field_names),*)
}
}
} else {
let parts = field_names.iter().map(|field| {
quote! { tui_dispatch::debug::debug_string(&#field) }
});
quote! {
#name::#variant_name(#(#field_names),*) => {
let values = ::std::vec![#(#parts),*];
format!("({})", values.join(", "))
}
}
}
}
syn::Fields::Named(fields) => {
let field_names: Vec<_> = fields
.named
.iter()
.filter_map(|f| f.ident.as_ref())
.collect();
if field_names.is_empty() {
quote! {
#name::#variant_name { .. } => ::std::string::String::new()
}
} else {
let parts = field_names.iter().map(|field| {
let label = field.to_string();
quote! {
format!("{}: {}", #label, tui_dispatch::debug::debug_string(&#field))
}
});
quote! {
#name::#variant_name { #(#field_names),*, .. } => {
let values = ::std::vec![#(#parts),*];
format!("{{{}}}", values.join(", "))
}
}
}
}
}
});
let params_pretty_arms = syn_variants.iter().map(|v| {
let variant_name = &v.ident;
match &v.fields {
syn::Fields::Unit => quote! {
#name::#variant_name => ::std::string::String::new()
},
syn::Fields::Unnamed(fields) => {
let field_count = fields.unnamed.len();
let field_names: Vec<_> =
(0..field_count).map(|i| format_ident!("_{}", i)).collect();
if field_count == 1 {
quote! {
#name::#variant_name(#(#field_names),*) => {
tui_dispatch::debug::debug_string_pretty(&#(#field_names),*)
}
}
} else {
let parts = field_names.iter().map(|field| {
quote! { tui_dispatch::debug::debug_string_pretty(&#field) }
});
quote! {
#name::#variant_name(#(#field_names),*) => {
let values = ::std::vec![#(#parts),*];
format!("({})", values.join(", "))
}
}
}
}
syn::Fields::Named(fields) => {
let field_names: Vec<_> = fields
.named
.iter()
.filter_map(|f| f.ident.as_ref())
.collect();
if field_names.is_empty() {
quote! {
#name::#variant_name { .. } => ::std::string::String::new()
}
} else {
let parts = field_names.iter().map(|field| {
let label = field.to_string();
quote! {
format!("{}: {}", #label, tui_dispatch::debug::debug_string_pretty(&#field))
}
});
quote! {
#name::#variant_name { #(#field_names),*, .. } => {
let values = ::std::vec![#(#parts),*];
format!("{{{}}}", values.join(", "))
}
}
}
}
}
});
let mut expanded = quote! {
impl tui_dispatch::Action for #name {
fn name(&self) -> &'static str {
match self {
#(#name_arms),*
}
}
}
impl tui_dispatch::ActionParams for #name {
fn params(&self) -> ::std::string::String {
match self {
#(#params_arms),*
}
}
fn params_pretty(&self) -> ::std::string::String {
match self {
#(#params_pretty_arms),*
}
}
}
};
if opts.infer_categories {
let mut categories: HashMap<String, Vec<&Ident>> = HashMap::new();
let mut variant_categories: Vec<(&Ident, Option<String>)> = Vec::new();
for v in variants.iter() {
let cat = if v.skip_category {
None
} else if let Some(ref explicit_cat) = v.category {
Some(explicit_cat.clone())
} else {
infer_category(&v.ident.to_string())
};
variant_categories.push((&v.ident, cat.clone()));
if let Some(ref category) = cat {
categories
.entry(category.clone())
.or_default()
.push(&v.ident);
}
}
let mut sorted_categories: Vec<_> = categories.keys().cloned().collect();
sorted_categories.sort();
let category_arms_dedup: Vec<_> = variant_categories
.iter()
.map(|(variant, cat)| {
let cat_expr = match cat {
Some(c) => quote! { ::core::option::Option::Some(#c) },
None => quote! { ::core::option::Option::None },
};
quote! { #name::#variant { .. } => #cat_expr }
})
.collect();
let category_enum_name = format_ident!("{}Category", name);
let category_variants: Vec<_> = sorted_categories
.iter()
.map(|c| format_ident!("{}", to_pascal_case(c)))
.collect();
let category_variant_names: Vec<_> = sorted_categories.clone();
let category_enum_arms: Vec<_> = variant_categories
.iter()
.map(|(variant, cat)| {
let cat_variant = match cat {
Some(c) => format_ident!("{}", to_pascal_case(c)),
None => format_ident!("Uncategorized"),
};
quote! { #name::#variant { .. } => #category_enum_name::#cat_variant }
})
.collect();
let predicates: Vec<_> = sorted_categories
.iter()
.map(|cat| {
let predicate_name = format_ident!("is_{}", cat);
let cat_variants = categories.get(cat).unwrap();
let patterns: Vec<_> = cat_variants
.iter()
.map(|v| quote! { #name::#v { .. } })
.collect();
let doc = format!(
"Returns true if this action belongs to the `{}` category.",
cat
);
quote! {
#[doc = #doc]
pub fn #predicate_name(&self) -> bool {
matches!(self, #(#patterns)|*)
}
}
})
.collect();
let category_enum_doc = format!(
"Action categories for [`{}`].\n\n\
Use [`{}::category_enum()`] to get the category of an action.",
name, name
);
expanded = quote! {
#expanded
#[doc = #category_enum_doc]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum #category_enum_name {
#(#category_variants,)*
Uncategorized,
}
impl #category_enum_name {
pub fn all() -> &'static [Self] {
&[#(Self::#category_variants,)* Self::Uncategorized]
}
pub fn name(&self) -> &'static str {
match self {
#(Self::#category_variants => #category_variant_names,)*
Self::Uncategorized => "uncategorized",
}
}
}
impl #name {
pub fn category(&self) -> ::core::option::Option<&'static str> {
match self {
#(#category_arms_dedup,)*
}
}
pub fn category_enum(&self) -> #category_enum_name {
match self {
#(#category_enum_arms,)*
}
}
#(#predicates)*
}
impl tui_dispatch::ActionCategory for #name {
type Category = #category_enum_name;
fn category(&self) -> ::core::option::Option<&'static str> {
#name::category(self)
}
fn category_enum(&self) -> Self::Category {
#name::category_enum(self)
}
}
};
if opts.generate_dispatcher {
let dispatcher_trait_name = format_ident!("{}Dispatcher", name);
let dispatch_methods: Vec<_> = sorted_categories
.iter()
.map(|cat| {
let method_name = format_ident!("dispatch_{}", cat);
let doc = format!("Handle actions in the `{}` category.", cat);
quote! {
#[doc = #doc]
fn #method_name(&mut self, action: &#name) -> bool {
false
}
}
})
.collect();
let dispatch_arms: Vec<_> = sorted_categories
.iter()
.map(|cat| {
let method_name = format_ident!("dispatch_{}", cat);
let cat_variant = format_ident!("{}", to_pascal_case(cat));
quote! {
#category_enum_name::#cat_variant => self.#method_name(action)
}
})
.collect();
let dispatcher_doc = format!(
"Dispatcher trait for [`{}`].\n\n\
Implement the `dispatch_*` methods for each category you want to handle.\n\
The [`dispatch()`](Self::dispatch) method automatically routes to the correct handler.",
name
);
expanded = quote! {
#expanded
#[doc = #dispatcher_doc]
pub trait #dispatcher_trait_name {
#(#dispatch_methods)*
fn dispatch_uncategorized(&mut self, action: &#name) -> bool {
false
}
fn dispatch(&mut self, action: &#name) -> bool {
match action.category_enum() {
#(#dispatch_arms,)*
#category_enum_name::Uncategorized => self.dispatch_uncategorized(action),
}
}
}
};
}
}
TokenStream::from(expanded)
}
#[proc_macro_derive(BindingContext)]
pub fn derive_binding_context(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let expanded = match &input.data {
syn::Data::Enum(data) => {
for variant in &data.variants {
if !matches!(variant.fields, syn::Fields::Unit) {
return syn::Error::new_spanned(
variant,
"BindingContext can only be derived for enums with unit variants",
)
.to_compile_error()
.into();
}
}
let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
let variant_strings: Vec<_> = variant_names
.iter()
.map(|v| to_snake_case(&v.to_string()))
.collect();
let name_arms = variant_names
.iter()
.zip(variant_strings.iter())
.map(|(v, s)| {
quote! { #name::#v => #s }
});
let from_name_arms = variant_names
.iter()
.zip(variant_strings.iter())
.map(|(v, s)| {
quote! { #s => ::core::option::Option::Some(#name::#v) }
});
let all_variants = variant_names.iter().map(|v| quote! { #name::#v });
quote! {
impl tui_dispatch::BindingContext for #name {
fn name(&self) -> &'static str {
match self {
#(#name_arms),*
}
}
fn from_name(name: &str) -> ::core::option::Option<Self> {
match name {
#(#from_name_arms,)*
_ => ::core::option::Option::None,
}
}
fn all() -> &'static [Self] {
static ALL: &[#name] = &[#(#all_variants),*];
ALL
}
}
}
}
_ => {
return syn::Error::new_spanned(input, "BindingContext can only be derived for enums")
.to_compile_error()
.into();
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(ComponentId)]
pub fn derive_component_id(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let expanded = match &input.data {
syn::Data::Enum(data) => {
for variant in &data.variants {
if !matches!(variant.fields, syn::Fields::Unit) {
return syn::Error::new_spanned(
variant,
"ComponentId can only be derived for enums with unit variants",
)
.to_compile_error()
.into();
}
}
let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
let variant_strings: Vec<_> = variant_names.iter().map(|v| v.to_string()).collect();
let name_arms = variant_names
.iter()
.zip(variant_strings.iter())
.map(|(v, s)| {
quote! { #name::#v => #s }
});
quote! {
impl tui_dispatch::ComponentId for #name {
fn name(&self) -> &'static str {
match self {
#(#name_arms),*
}
}
}
}
}
_ => {
return syn::Error::new_spanned(input, "ComponentId can only be derived for enums")
.to_compile_error()
.into();
}
};
TokenStream::from(expanded)
}
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(debug_state), supports(struct_named))]
struct DebugStateOpts {
ident: syn::Ident,
data: darling::ast::Data<(), DebugStateField>,
}
#[derive(Debug, FromField)]
#[darling(attributes(debug))]
struct DebugStateField {
ident: Option<syn::Ident>,
#[darling(default)]
section: Option<String>,
#[darling(default)]
skip: bool,
#[darling(default)]
format: Option<String>,
#[darling(default)]
label: Option<String>,
#[darling(default)]
debug_fmt: bool,
}
#[proc_macro_derive(DebugState, attributes(debug, debug_state))]
pub fn derive_debug_state(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let opts = match DebugStateOpts::from_derive_input(&input) {
Ok(opts) => opts,
Err(e) => return e.write_errors().into(),
};
let name = &opts.ident;
let default_section = name.to_string();
let fields = match &opts.data {
darling::ast::Data::Struct(fields) => fields,
_ => {
return syn::Error::new_spanned(&input, "DebugState can only be derived for structs")
.to_compile_error()
.into();
}
};
let mut sections: HashMap<String, Vec<&DebugStateField>> = HashMap::new();
let mut section_order: Vec<String> = Vec::new();
for field in fields.iter() {
if field.skip {
continue;
}
let section_name = field
.section
.clone()
.unwrap_or_else(|| default_section.clone());
if !section_order.contains(§ion_name) {
section_order.push(section_name.clone());
}
sections.entry(section_name).or_default().push(field);
}
let section_code: Vec<_> = section_order
.iter()
.map(|section_name| {
let fields_in_section = sections.get(section_name).unwrap();
let entry_calls: Vec<_> = fields_in_section
.iter()
.filter_map(|field| {
let field_ident = field.ident.as_ref()?;
let label = field
.label
.clone()
.unwrap_or_else(|| field_ident.to_string());
let value_expr = if let Some(ref fmt) = field.format {
quote! { format!(#fmt, self.#field_ident) }
} else if field.debug_fmt {
quote! { format!("{:?}", self.#field_ident) }
} else {
quote! { tui_dispatch::debug::debug_string(&self.#field_ident) }
};
Some(quote! {
.entry(#label, #value_expr)
})
})
.collect();
quote! {
tui_dispatch::debug::DebugSection::new(#section_name)
#(#entry_calls)*
}
})
.collect();
let expanded = quote! {
impl tui_dispatch::debug::DebugState for #name {
fn debug_sections(&self) -> ::std::vec::Vec<tui_dispatch::debug::DebugSection> {
::std::vec![
#(#section_code),*
]
}
}
};
TokenStream::from(expanded)
}
#[derive(Debug, FromField)]
#[darling(attributes(flag))]
struct FeatureFlagsField {
ident: Option<syn::Ident>,
ty: syn::Type,
#[darling(default)]
default: Option<bool>,
}
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(feature_flags), supports(struct_named))]
struct FeatureFlagsOpts {
ident: syn::Ident,
data: darling::ast::Data<(), FeatureFlagsField>,
}
#[proc_macro_derive(FeatureFlags, attributes(flag, feature_flags))]
pub fn derive_feature_flags(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let opts = match FeatureFlagsOpts::from_derive_input(&input) {
Ok(opts) => opts,
Err(e) => return e.write_errors().into(),
};
let name = &opts.ident;
let fields = match &opts.data {
darling::ast::Data::Struct(fields) => fields,
_ => {
return syn::Error::new_spanned(
&input,
"FeatureFlags can only be derived for structs with named fields",
)
.to_compile_error()
.into();
}
};
let bool_fields: Vec<_> = fields
.iter()
.filter_map(|f| {
let ident = f.ident.as_ref()?;
if let syn::Type::Path(type_path) = &f.ty {
if type_path.path.is_ident("bool") {
return Some((ident.clone(), f.default.unwrap_or(false)));
}
}
None
})
.collect();
if bool_fields.is_empty() {
return syn::Error::new_spanned(
&input,
"FeatureFlags struct must have at least one bool field",
)
.to_compile_error()
.into();
}
let is_enabled_arms: Vec<_> = bool_fields
.iter()
.map(|(ident, _)| {
let name_str = ident.to_string();
quote! { #name_str => ::core::option::Option::Some(self.#ident) }
})
.collect();
let set_arms: Vec<_> = bool_fields
.iter()
.map(|(ident, _)| {
let name_str = ident.to_string();
quote! {
#name_str => {
self.#ident = enabled;
true
}
}
})
.collect();
let flag_names: Vec<_> = bool_fields
.iter()
.map(|(ident, _)| ident.to_string())
.collect();
let default_fields: Vec<_> = bool_fields
.iter()
.map(|(ident, default)| {
quote! { #ident: #default }
})
.collect();
let expanded = quote! {
impl tui_dispatch::FeatureFlags for #name {
fn is_enabled(&self, name: &str) -> ::core::option::Option<bool> {
match name {
#(#is_enabled_arms,)*
_ => ::core::option::Option::None,
}
}
fn set(&mut self, name: &str, enabled: bool) -> bool {
match name {
#(#set_arms)*
_ => false,
}
}
fn all_flags() -> &'static [&'static str] {
&[#(#flag_names),*]
}
}
impl ::core::default::Default for #name {
fn default() -> Self {
Self {
#(#default_fields,)*
}
}
}
};
TokenStream::from(expanded)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_snake_case_handles_acronyms() {
assert_eq!(to_snake_case("APIFetch"), "api_fetch");
assert_eq!(to_snake_case("HTTPResult"), "http_result");
}
#[test]
fn test_infer_category_handles_acronyms() {
assert_eq!(infer_category("APIFetchStart"), Some("api".to_string()));
assert_eq!(
infer_category("SearchHTTPStart"),
Some("search_http".to_string())
);
}
}