tui_dispatch_macros/
lib.rs

1//! Procedural macros for tui-dispatch
2
3use darling::{FromDeriveInput, FromField, FromVariant};
4use proc_macro::TokenStream;
5use proc_macro2::Ident;
6use quote::{format_ident, quote};
7use std::collections::HashMap;
8use syn::{parse_macro_input, DeriveInput};
9
10/// Container-level attributes for #[derive(Action)]
11#[derive(Debug, FromDeriveInput)]
12#[darling(attributes(action), supports(enum_any))]
13struct ActionOpts {
14    ident: syn::Ident,
15    data: darling::ast::Data<ActionVariant, ()>,
16
17    /// Enable automatic category inference from variant name prefixes
18    #[darling(default)]
19    infer_categories: bool,
20
21    /// Generate dispatcher trait
22    #[darling(default)]
23    generate_dispatcher: bool,
24}
25
26/// Variant-level attributes
27#[derive(Debug, FromVariant)]
28#[darling(attributes(action))]
29struct ActionVariant {
30    ident: syn::Ident,
31    fields: darling::ast::Fields<()>,
32
33    /// Explicit category override
34    #[darling(default)]
35    category: Option<String>,
36
37    /// Exclude from category inference
38    #[darling(default)]
39    skip_category: bool,
40}
41
42/// Common action verbs that typically appear as the last part of a variant name
43// Action verbs that typically END an action name (the actual verb part)
44// Things like "Form", "Panel", "Field" are nouns, not verbs - they should NOT be here
45const ACTION_VERBS: &[&str] = &[
46    // State transitions
47    "Start", "End", "Open", "Close", "Submit", "Confirm", "Cancel", // Navigation
48    "Next", "Prev", "Up", "Down", "Left", "Right", "Enter", "Exit", "Escape",
49    // CRUD operations
50    "Add", "Remove", "Clear", "Update", "Set", "Get", "Load", "Save", "Delete", "Create",
51    // Visibility
52    "Show", "Hide", "Enable", "Disable", "Toggle", // Focus
53    "Focus", "Blur", "Select", // Movement
54    "Move", "Copy", "Cycle", "Reset", "Scroll",
55];
56
57/// Split a PascalCase string into parts
58fn split_pascal_case(s: &str) -> Vec<String> {
59    let mut parts = Vec::new();
60    let mut current = String::new();
61
62    for ch in s.chars() {
63        if ch.is_uppercase() && !current.is_empty() {
64            parts.push(current);
65            current = String::new();
66        }
67        current.push(ch);
68    }
69    if !current.is_empty() {
70        parts.push(current);
71    }
72    parts
73}
74
75/// Convert PascalCase to snake_case
76fn to_snake_case(s: &str) -> String {
77    let mut result = String::new();
78    for (i, ch) in s.chars().enumerate() {
79        if ch.is_uppercase() {
80            if i > 0 {
81                result.push('_');
82            }
83            result.push(ch.to_lowercase().next().unwrap());
84        } else {
85            result.push(ch);
86        }
87    }
88    result
89}
90
91/// Convert snake_case to PascalCase
92fn to_pascal_case(s: &str) -> String {
93    s.split('_')
94        .map(|part| {
95            let mut chars = part.chars();
96            match chars.next() {
97                None => String::new(),
98                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
99            }
100        })
101        .collect()
102}
103
104/// Infer category from a variant name using naming patterns
105fn infer_category(name: &str) -> Option<String> {
106    let parts = split_pascal_case(name);
107    if parts.is_empty() {
108        return None;
109    }
110
111    // Check for "Did" prefix (async results)
112    if parts[0] == "Did" {
113        return Some("async_result".to_string());
114    }
115
116    // If only one part, no category
117    if parts.len() < 2 {
118        return None;
119    }
120
121    // Find the longest prefix that ends before an action verb
122    // e.g., ["Connection", "Form", "Submit"] -> "connection_form"
123    // e.g., ["Search", "Add", "Char"] -> "search"
124    // e.g., ["Value", "Viewer", "Scroll", "Up"] -> "value_viewer"
125
126    let first_is_verb = ACTION_VERBS.contains(&parts[0].as_str());
127
128    let mut prefix_end = parts.len();
129    let mut found_verb = false;
130    for (i, part) in parts.iter().enumerate().skip(1) {
131        if ACTION_VERBS.contains(&part.as_str()) {
132            prefix_end = i;
133            found_verb = true;
134            break;
135        }
136    }
137
138    // Skip if first part is an action verb - these are primary actions, not categorized
139    // e.g., "OpenConnectionForm" → "Open" is the verb, "ConnectionForm" is the object
140    // e.g., "NextItem" → "Next" is the verb, "Item" is the object
141    if first_is_verb {
142        return None;
143    }
144
145    // Skip if no verb found in the name - can't determine meaningful category
146    if !found_verb {
147        return None;
148    }
149
150    if prefix_end == 0 {
151        return None;
152    }
153
154    let prefix_parts: Vec<&str> = parts[..prefix_end].iter().map(|s| s.as_str()).collect();
155    let prefix = prefix_parts.join("");
156
157    Some(to_snake_case(&prefix))
158}
159
160/// Derive macro for the Action trait
161///
162/// Generates a `name()` method that returns the variant name as a static string.
163///
164/// With `#[action(infer_categories)]`, also generates:
165/// - `category() -> Option<&'static str>` - Get action's category
166/// - `category_enum() -> {Name}Category` - Get category as enum
167/// - `is_{category}()` predicates for each category
168/// - `{Name}Category` enum with all discovered categories
169///
170/// With `#[action(generate_dispatcher)]`, also generates:
171/// - `{Name}Dispatcher` trait with category-based dispatch methods
172///
173/// # Example
174/// ```ignore
175/// #[derive(Action, Clone, Debug)]
176/// #[action(infer_categories, generate_dispatcher)]
177/// enum MyAction {
178///     SearchStart,
179///     SearchClear,
180///     ConnectionFormOpen,
181///     ConnectionFormSubmit,
182///     DidConnect,
183///     Tick,  // uncategorized
184/// }
185///
186/// let action = MyAction::SearchStart;
187/// assert_eq!(action.name(), "SearchStart");
188/// assert_eq!(action.category(), Some("search"));
189/// assert!(action.is_search());
190/// ```
191#[proc_macro_derive(Action, attributes(action))]
192pub fn derive_action(input: TokenStream) -> TokenStream {
193    let input = parse_macro_input!(input as DeriveInput);
194
195    // Try to parse with darling for attributes
196    let opts = match ActionOpts::from_derive_input(&input) {
197        Ok(opts) => opts,
198        Err(e) => return e.write_errors().into(),
199    };
200
201    let name = &opts.ident;
202
203    let variants = match &opts.data {
204        darling::ast::Data::Enum(variants) => variants,
205        _ => {
206            return syn::Error::new_spanned(&input, "Action can only be derived for enums")
207                .to_compile_error()
208                .into();
209        }
210    };
211
212    // Generate basic name() implementation
213    let name_arms = variants.iter().map(|v| {
214        let variant_name = &v.ident;
215        let variant_str = variant_name.to_string();
216
217        match &v.fields.style {
218            darling::ast::Style::Unit => quote! {
219                #name::#variant_name => #variant_str
220            },
221            darling::ast::Style::Tuple => quote! {
222                #name::#variant_name(..) => #variant_str
223            },
224            darling::ast::Style::Struct => quote! {
225                #name::#variant_name { .. } => #variant_str
226            },
227        }
228    });
229
230    let mut expanded = quote! {
231        impl tui_dispatch::Action for #name {
232            fn name(&self) -> &'static str {
233                match self {
234                    #(#name_arms),*
235                }
236            }
237        }
238    };
239
240    // If category inference is enabled, generate category-related code
241    if opts.infer_categories {
242        // Collect categories and their variants
243        let mut categories: HashMap<String, Vec<&Ident>> = HashMap::new();
244        let mut variant_categories: Vec<(&Ident, Option<String>)> = Vec::new();
245
246        for v in variants.iter() {
247            let cat = if v.skip_category {
248                None
249            } else if let Some(ref explicit_cat) = v.category {
250                Some(explicit_cat.clone())
251            } else {
252                infer_category(&v.ident.to_string())
253            };
254
255            variant_categories.push((&v.ident, cat.clone()));
256
257            if let Some(ref category) = cat {
258                categories
259                    .entry(category.clone())
260                    .or_default()
261                    .push(&v.ident);
262            }
263        }
264
265        // Sort categories for deterministic output
266        let mut sorted_categories: Vec<_> = categories.keys().cloned().collect();
267        sorted_categories.sort();
268
269        // Create deduplicated category match arms
270        let category_arms_dedup: Vec<_> = variant_categories
271            .iter()
272            .map(|(variant, cat)| {
273                let cat_expr = match cat {
274                    Some(c) => quote! { ::core::option::Option::Some(#c) },
275                    None => quote! { ::core::option::Option::None },
276                };
277                // Use wildcard pattern to handle all field types
278                quote! { #name::#variant { .. } => #cat_expr }
279            })
280            .collect();
281
282        // Generate category enum
283        let category_enum_name = format_ident!("{}Category", name);
284        let category_variants: Vec<_> = sorted_categories
285            .iter()
286            .map(|c| format_ident!("{}", to_pascal_case(c)))
287            .collect();
288        let category_variant_names: Vec<_> = sorted_categories.clone();
289
290        // Generate category_enum() method arms
291        let category_enum_arms: Vec<_> = variant_categories
292            .iter()
293            .map(|(variant, cat)| {
294                let cat_variant = match cat {
295                    Some(c) => format_ident!("{}", to_pascal_case(c)),
296                    None => format_ident!("Uncategorized"),
297                };
298                quote! { #name::#variant { .. } => #category_enum_name::#cat_variant }
299            })
300            .collect();
301
302        // Generate is_* predicates
303        let predicates: Vec<_> = sorted_categories
304            .iter()
305            .map(|cat| {
306                let predicate_name = format_ident!("is_{}", cat);
307                let cat_variants = categories.get(cat).unwrap();
308                let patterns: Vec<_> = cat_variants
309                    .iter()
310                    .map(|v| quote! { #name::#v { .. } })
311                    .collect();
312                let doc = format!(
313                    "Returns true if this action belongs to the `{}` category.",
314                    cat
315                );
316
317                quote! {
318                    #[doc = #doc]
319                    pub fn #predicate_name(&self) -> bool {
320                        matches!(self, #(#patterns)|*)
321                    }
322                }
323            })
324            .collect();
325
326        // Add category-related implementations
327        let category_enum_doc = format!(
328            "Action categories for [`{}`].\n\n\
329             Use [`{}::category_enum()`] to get the category of an action.",
330            name, name
331        );
332
333        expanded = quote! {
334            #expanded
335
336            #[doc = #category_enum_doc]
337            #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
338            pub enum #category_enum_name {
339                #(#category_variants,)*
340                /// Actions that don't belong to any specific category.
341                Uncategorized,
342            }
343
344            impl #category_enum_name {
345                /// Get all category values
346                pub fn all() -> &'static [Self] {
347                    &[#(Self::#category_variants,)* Self::Uncategorized]
348                }
349
350                /// Get category name as string
351                pub fn name(&self) -> &'static str {
352                    match self {
353                        #(Self::#category_variants => #category_variant_names,)*
354                        Self::Uncategorized => "uncategorized",
355                    }
356                }
357            }
358
359            impl #name {
360                /// Get the action's category (if categorized)
361                pub fn category(&self) -> ::core::option::Option<&'static str> {
362                    match self {
363                        #(#category_arms_dedup,)*
364                    }
365                }
366
367                /// Get the category as an enum value
368                pub fn category_enum(&self) -> #category_enum_name {
369                    match self {
370                        #(#category_enum_arms,)*
371                    }
372                }
373
374                #(#predicates)*
375            }
376
377            impl tui_dispatch::ActionCategory for #name {
378                type Category = #category_enum_name;
379
380                fn category(&self) -> ::core::option::Option<&'static str> {
381                    #name::category(self)
382                }
383
384                fn category_enum(&self) -> Self::Category {
385                    #name::category_enum(self)
386                }
387            }
388        };
389
390        // Generate dispatcher trait if requested
391        if opts.generate_dispatcher {
392            let dispatcher_trait_name = format_ident!("{}Dispatcher", name);
393
394            let dispatch_methods: Vec<_> = sorted_categories
395                .iter()
396                .map(|cat| {
397                    let method_name = format_ident!("dispatch_{}", cat);
398                    let doc = format!("Handle actions in the `{}` category.", cat);
399                    quote! {
400                        #[doc = #doc]
401                        fn #method_name(&mut self, action: &#name) -> bool {
402                            false
403                        }
404                    }
405                })
406                .collect();
407
408            let dispatch_arms: Vec<_> = sorted_categories
409                .iter()
410                .map(|cat| {
411                    let method_name = format_ident!("dispatch_{}", cat);
412                    let cat_variant = format_ident!("{}", to_pascal_case(cat));
413                    quote! {
414                        #category_enum_name::#cat_variant => self.#method_name(action)
415                    }
416                })
417                .collect();
418
419            let dispatcher_doc = format!(
420                "Dispatcher trait for [`{}`].\n\n\
421                 Implement the `dispatch_*` methods for each category you want to handle.\n\
422                 The [`dispatch()`](Self::dispatch) method automatically routes to the correct handler.",
423                name
424            );
425
426            expanded = quote! {
427                #expanded
428
429                #[doc = #dispatcher_doc]
430                pub trait #dispatcher_trait_name {
431                    #(#dispatch_methods)*
432
433                    /// Handle uncategorized actions.
434                    fn dispatch_uncategorized(&mut self, action: &#name) -> bool {
435                        false
436                    }
437
438                    /// Main dispatch entry point - routes to category-specific handlers.
439                    fn dispatch(&mut self, action: &#name) -> bool {
440                        match action.category_enum() {
441                            #(#dispatch_arms,)*
442                            #category_enum_name::Uncategorized => self.dispatch_uncategorized(action),
443                        }
444                    }
445                }
446            };
447        }
448    }
449
450    TokenStream::from(expanded)
451}
452
453/// Derive macro for the BindingContext trait
454///
455/// Generates implementations for `name()`, `from_name()`, and `all()` methods.
456/// The context name is derived from the variant name converted to snake_case.
457///
458/// # Example
459/// ```ignore
460/// #[derive(BindingContext, Clone, Copy, PartialEq, Eq, Hash)]
461/// enum MyContext {
462///     Default,
463///     Search,
464///     ConnectionForm,
465/// }
466///
467/// // Generated names: "default", "search", "connection_form"
468/// assert_eq!(MyContext::Default.name(), "default");
469/// assert_eq!(MyContext::from_name("search"), Some(MyContext::Search));
470/// ```
471#[proc_macro_derive(BindingContext)]
472pub fn derive_binding_context(input: TokenStream) -> TokenStream {
473    let input = parse_macro_input!(input as DeriveInput);
474    let name = &input.ident;
475
476    let expanded = match &input.data {
477        syn::Data::Enum(data) => {
478            // Check that all variants are unit variants
479            for variant in &data.variants {
480                if !matches!(variant.fields, syn::Fields::Unit) {
481                    return syn::Error::new_spanned(
482                        variant,
483                        "BindingContext can only be derived for enums with unit variants",
484                    )
485                    .to_compile_error()
486                    .into();
487                }
488            }
489
490            let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
491            let variant_strings: Vec<_> = variant_names
492                .iter()
493                .map(|v| to_snake_case(&v.to_string()))
494                .collect();
495
496            let name_arms = variant_names
497                .iter()
498                .zip(variant_strings.iter())
499                .map(|(v, s)| {
500                    quote! { #name::#v => #s }
501                });
502
503            let from_name_arms = variant_names
504                .iter()
505                .zip(variant_strings.iter())
506                .map(|(v, s)| {
507                    quote! { #s => ::core::option::Option::Some(#name::#v) }
508                });
509
510            let all_variants = variant_names.iter().map(|v| quote! { #name::#v });
511
512            quote! {
513                impl tui_dispatch::BindingContext for #name {
514                    fn name(&self) -> &'static str {
515                        match self {
516                            #(#name_arms),*
517                        }
518                    }
519
520                    fn from_name(name: &str) -> ::core::option::Option<Self> {
521                        match name {
522                            #(#from_name_arms,)*
523                            _ => ::core::option::Option::None,
524                        }
525                    }
526
527                    fn all() -> &'static [Self] {
528                        static ALL: &[#name] = &[#(#all_variants),*];
529                        ALL
530                    }
531                }
532            }
533        }
534        _ => {
535            return syn::Error::new_spanned(input, "BindingContext can only be derived for enums")
536                .to_compile_error()
537                .into();
538        }
539    };
540
541    TokenStream::from(expanded)
542}
543
544/// Derive macro for the ComponentId trait
545///
546/// Generates implementations for `name()` method that returns the variant name.
547///
548/// # Example
549/// ```ignore
550/// #[derive(ComponentId, Clone, Copy, PartialEq, Eq, Hash, Debug)]
551/// enum MyComponentId {
552///     Sidebar,
553///     MainContent,
554///     StatusBar,
555/// }
556///
557/// assert_eq!(MyComponentId::Sidebar.name(), "Sidebar");
558/// ```
559#[proc_macro_derive(ComponentId)]
560pub fn derive_component_id(input: TokenStream) -> TokenStream {
561    let input = parse_macro_input!(input as DeriveInput);
562    let name = &input.ident;
563
564    let expanded = match &input.data {
565        syn::Data::Enum(data) => {
566            // Check that all variants are unit variants
567            for variant in &data.variants {
568                if !matches!(variant.fields, syn::Fields::Unit) {
569                    return syn::Error::new_spanned(
570                        variant,
571                        "ComponentId can only be derived for enums with unit variants",
572                    )
573                    .to_compile_error()
574                    .into();
575                }
576            }
577
578            let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
579            let variant_strings: Vec<_> = variant_names.iter().map(|v| v.to_string()).collect();
580
581            let name_arms = variant_names
582                .iter()
583                .zip(variant_strings.iter())
584                .map(|(v, s)| {
585                    quote! { #name::#v => #s }
586                });
587
588            quote! {
589                impl tui_dispatch::ComponentId for #name {
590                    fn name(&self) -> &'static str {
591                        match self {
592                            #(#name_arms),*
593                        }
594                    }
595                }
596            }
597        }
598        _ => {
599            return syn::Error::new_spanned(input, "ComponentId can only be derived for enums")
600                .to_compile_error()
601                .into();
602        }
603    };
604
605    TokenStream::from(expanded)
606}
607
608// ============================================================================
609// DebugState derive macro
610// ============================================================================
611
612/// Container-level attributes for #[derive(DebugState)]
613#[derive(Debug, FromDeriveInput)]
614#[darling(attributes(debug_state), supports(struct_named))]
615struct DebugStateOpts {
616    ident: syn::Ident,
617    data: darling::ast::Data<(), DebugStateField>,
618}
619
620/// Field-level attributes for DebugState
621#[derive(Debug, FromField)]
622#[darling(attributes(debug))]
623struct DebugStateField {
624    ident: Option<syn::Ident>,
625
626    /// Section name for this field (groups fields together)
627    #[darling(default)]
628    section: Option<String>,
629
630    /// Skip this field in debug output
631    #[darling(default)]
632    skip: bool,
633
634    /// Custom display format (e.g., "{:?}" for Debug, "{:#?}" for pretty Debug)
635    #[darling(default)]
636    format: Option<String>,
637
638    /// Custom label for this field (defaults to field name)
639    #[darling(default)]
640    label: Option<String>,
641
642    /// Use Debug trait instead of Display
643    #[darling(default)]
644    debug_fmt: bool,
645}
646
647/// Derive macro for the DebugState trait
648///
649/// Automatically generates `debug_sections()` implementation from struct fields.
650///
651/// # Attributes
652///
653/// - `#[debug(section = "Name")]` - Group field under a section
654/// - `#[debug(skip)]` - Exclude field from debug output
655/// - `#[debug(label = "Custom Label")]` - Use custom label instead of field name
656/// - `#[debug(debug_fmt)]` - Use `{:?}` format instead of `Display`
657/// - `#[debug(format = "{:#?}")]` - Use custom format string
658///
659/// # Example
660///
661/// ```ignore
662/// use tui_dispatch::DebugState;
663///
664/// #[derive(DebugState)]
665/// struct AppState {
666///     #[debug(section = "Connection")]
667///     host: String,
668///     #[debug(section = "Connection")]
669///     port: u16,
670///
671///     #[debug(section = "UI")]
672///     scroll_offset: usize,
673///
674///     #[debug(skip)]
675///     internal_cache: HashMap<String, Data>,
676///
677///     #[debug(section = "Stats", debug_fmt)]
678///     status: ConnectionStatus,
679/// }
680/// ```
681///
682/// Fields without a section attribute are grouped under a section named after
683/// the struct (e.g., "AppState").
684#[proc_macro_derive(DebugState, attributes(debug, debug_state))]
685pub fn derive_debug_state(input: TokenStream) -> TokenStream {
686    let input = parse_macro_input!(input as DeriveInput);
687
688    let opts = match DebugStateOpts::from_derive_input(&input) {
689        Ok(opts) => opts,
690        Err(e) => return e.write_errors().into(),
691    };
692
693    let name = &opts.ident;
694    let default_section = name.to_string();
695
696    let fields = match &opts.data {
697        darling::ast::Data::Struct(fields) => fields,
698        _ => {
699            return syn::Error::new_spanned(&input, "DebugState can only be derived for structs")
700                .to_compile_error()
701                .into();
702        }
703    };
704
705    // Group fields by section
706    let mut sections: HashMap<String, Vec<&DebugStateField>> = HashMap::new();
707    let mut section_order: Vec<String> = Vec::new();
708
709    for field in fields.iter() {
710        if field.skip {
711            continue;
712        }
713
714        let section_name = field
715            .section
716            .clone()
717            .unwrap_or_else(|| default_section.clone());
718
719        if !section_order.contains(&section_name) {
720            section_order.push(section_name.clone());
721        }
722
723        sections.entry(section_name).or_default().push(field);
724    }
725
726    // Generate code for each section
727    let section_code: Vec<_> = section_order
728        .iter()
729        .map(|section_name| {
730            let fields_in_section = sections.get(section_name).unwrap();
731
732            let entry_calls: Vec<_> = fields_in_section
733                .iter()
734                .filter_map(|field| {
735                    let field_ident = field.ident.as_ref()?;
736                    let label = field
737                        .label
738                        .clone()
739                        .unwrap_or_else(|| field_ident.to_string());
740
741                    let value_expr = if let Some(ref fmt) = field.format {
742                        quote! { format!(#fmt, self.#field_ident) }
743                    } else if field.debug_fmt {
744                        quote! { format!("{:?}", self.#field_ident) }
745                    } else {
746                        quote! { self.#field_ident.to_string() }
747                    };
748
749                    Some(quote! {
750                        .entry(#label, #value_expr)
751                    })
752                })
753                .collect();
754
755            quote! {
756                tui_dispatch::debug::DebugSection::new(#section_name)
757                    #(#entry_calls)*
758            }
759        })
760        .collect();
761
762    let expanded = quote! {
763        impl tui_dispatch::debug::DebugState for #name {
764            fn debug_sections(&self) -> ::std::vec::Vec<tui_dispatch::debug::DebugSection> {
765                ::std::vec![
766                    #(#section_code),*
767                ]
768            }
769        }
770    };
771
772    TokenStream::from(expanded)
773}