Skip to main content

fluxo_typestate_macros/
lib.rs

1// Copyright (c) 2026 Fluxo Labs
2// AI-generated code based on idea by alisio85
3// SPDX-License-Identifier: MIT
4
5#![allow(clippy::doc_lazy_continuation)]
6
7//! # Fluxo Typestate Macros
8//!
9//! This crate contains the procedural macros for the Fluxo Typestate library.
10//! It is not intended to be used directly; instead, use the `fluxo-typestate` crate.
11//!
12//! # Overview
13//!
14//! The `fluxo-typestate-macros` crate provides the implementation of the
15//! `#[state_machine]` procedural macro. This macro transforms enum definitions
16//! into complete type-state pattern implementations.
17//!
18//! ## What the Macro Generates
19//!
20//! Given an enum like:
21//!
22//! ```ignore
23//! #[state_machine]
24//! enum Computer {
25//!     Idle,
26//!     Running { cpu_load: f32 },
27//!     Sleeping,
28//! }
29//! ```
30//!
31//! The macro generates:
32//!
33//! 1. **State Structs**: For each variant, a struct representing that state
34//! 2. **Main Wrapper**: A generic `Computer<S>` struct
35//! 3. **State Trait Implementations**: Implementations of `State` and `Sealed`
36//! 4. **Transition Methods**: Methods for each defined transition
37//! 5. **Constructor**: A `new()` method for the initial state
38//! 6. **Visualization**: A `mermaid_diagram()` method
39//!
40//! ## Internal Structure
41//!
42//! The macro implementation is organized as follows:
43//!
44//! - `TransitionInfo`: Parsed transition attribute data
45//! - `VariantIr`: Intermediate representation of an enum variant
46//! - `FieldsIr`: Representation of variant fields
47//! - `generate_*`: Functions that generate different parts of the output
48//!
49//! ## Transition Attributes
50//!
51//! Transitions are defined using the `#[transition]` attribute:
52//!
53//! ```ignore
54//! #[transition(SourceState -> TargetState: method_name)]
55//! ```
56//!
57//! This syntax means: "from `SourceState`, you can transition to `TargetState`
58//! by calling `method_name()`".
59//!
60//! The macro supports two syntaxes:
61//!
62//! - Short form: `#[transition(Idle -> Running: start)]`
63//! - Full form: `#[transition(Computer::Idle -> Computer::Running: start)]`
64//!
65//! ## Attributes
66//!
67//! The state machine macro supports several attributes:
68//!
69//! - `#[state_machine]`: Required attribute to enable the macro
70//! - `#[transition(...)]`: Define a state transition
71//! - `#[trace]`: Enable tracing of state transitions (requires `logging` feature)
72//! - `#[visualize]`: Enable Mermaid diagram generation
73
74// Re-export necessary types for the macro implementation
75use heck::ToUpperCamelCase;
76use proc_macro::TokenStream;
77use proc_macro2::TokenStream as PsTokenStream;
78use quote::quote;
79use syn::{parse_macro_input, Attribute, Data, DataEnum, DeriveInput, Fields, Ident, Token};
80
81/// Internal representation of a transition attribute.
82///
83/// This struct holds the parsed information from a `#[transition]` attribute,
84/// including the source state, target state, and method name.
85///
86/// # Fields
87///
88/// - `from_state`: The source state identifier
89/// - `to_state`: The target state identifier
90/// - `method_name`: The method name to generate for this transition
91/// - `to_fields`: Optional fields to initialize in the target state
92///
93/// # Parsing
94///
95/// The attribute is parsed in the format: `Source -> Target: method_name`
96/// Both short form (`Idle`) and full form (`Computer::Idle`) are supported.
97#[derive(Debug, Clone)]
98#[allow(dead_code)]
99struct TransitionInfo {
100    /// The source state from which this transition originates.
101    from_state: Ident,
102    /// The target state to which this transition goes.
103    to_state: Ident,
104    /// The name of the method to generate for this transition.
105    method_name: Ident,
106    /// Optional fields to initialize in the target state.
107    to_fields: PsTokenStream,
108}
109
110/// Implementation of `Parse` for `TransitionInfo`.
111///
112/// This allows the macro to parse `#[transition(...)]` attributes from
113/// the token stream into a structured form.
114impl syn::parse::Parse for TransitionInfo {
115    /// Parses a transition attribute in the format:
116    /// `SourceState -> TargetState: method_name`
117    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
118        // Parse the source state (either simple ident or path like `Computer::Idle`)
119        let from_state = if input.peek(Ident) && input.peek2(Token![::]) {
120            // Skip the enum name prefix
121            input.parse::<syn::Path>()?;
122            input.parse::<Token![::]>()?;
123            input.parse::<Ident>()?
124        } else {
125            input.parse::<Ident>()?
126        };
127
128        // Expect and parse the arrow
129        input.parse::<Token![->]>()?;
130
131        // Parse the target state
132        let to_state = if input.peek(Ident) && input.peek2(Token![::]) {
133            input.parse::<syn::Path>()?;
134            input.parse::<Token![::]>()?;
135            input.parse::<Ident>()?
136        } else {
137            input.parse::<Ident>()?
138        };
139
140        // Expect and parse the colon
141        input.parse::<Token![:]>()?;
142
143        // Parse the method name
144        let method_name = input.parse::<Ident>()?;
145
146        Ok(TransitionInfo {
147            from_state,
148            to_state,
149            method_name,
150            to_fields: PsTokenStream::new(),
151        })
152    }
153}
154
155/// Internal representation of an enum variant during code generation.
156///
157/// This struct holds all information about a single variant of the state
158/// machine enum, including its name, fields, and associated transitions.
159struct VariantIr {
160    /// The generated struct name for this variant.
161    struct_name: syn::Ident,
162    /// The fields associated with this variant.
163    fields: FieldsIr,
164    /// All transitions defined from this state.
165    transitions: Vec<TransitionInfo>,
166}
167
168/// Representation of variant fields for code generation.
169///
170/// This enum represents the three possible field types in Rust enums:
171/// - Unit: No fields (e.g., `Idle`)
172/// - Unnamed: Tuple-style fields (e.g., `Running(f32)`)
173/// - Named: Struct-style fields (e.g., `Running { cpu_load: f32 }`)
174enum FieldsIr {
175    /// Unit variant with no fields.
176    Unit,
177    /// Tuple variant with unnamed fields.
178    Unnamed(Vec<syn::Type>),
179    /// Struct variant with named fields.
180    Named(Vec<(syn::Ident, syn::Type)>),
181}
182
183/// Extracts additional attributes from the enum attributes.
184///
185/// This function parses attributes like `#[trace]` and `#[visualize]`
186/// from the enum definition.
187///
188/// # Arguments
189///
190/// * `attrs` - Slice of attributes to parse
191///
192/// # Returns
193///
194/// A tuple of (trace_enabled, visualize_path)
195fn extract_attributes(attrs: &[Attribute]) -> (bool, Option<String>) {
196    let mut trace_enabled = false;
197    let mut visualize_path = None;
198
199    for attr in attrs {
200        let path_str = attr
201            .path()
202            .segments
203            .last()
204            .map(|s| s.ident.to_string())
205            .unwrap_or_default();
206
207        // Check for #[trace] attribute
208        if path_str == "trace" {
209            trace_enabled = true;
210        }
211        // Check for #[visualize] attribute
212        else if path_str == "visualize" {
213            visualize_path = Some("fluxo_map.mermaid".to_string());
214        }
215    }
216
217    (trace_enabled, visualize_path)
218}
219
220/// Parses an enum's variants into intermediate representations.
221///
222/// This function processes each variant of the enum, extracting the name,
223/// fields, and any transition attributes.
224///
225/// # Arguments
226///
227/// * `data` - The enum data to parse
228///
229/// # Returns
230///
231/// A vector of `VariantIr` structures representing each variant
232fn parse_variants(data: &DataEnum) -> Vec<VariantIr> {
233    data.variants
234        .iter()
235        .map(|variant| {
236            // Convert variant name to upper camel case for struct name
237            let struct_name = syn::Ident::new(
238                &variant.ident.to_string().to_upper_camel_case(),
239                variant.ident.span(),
240            );
241
242            // Parse the fields based on their style
243            let fields = match &variant.fields {
244                Fields::Unit => FieldsIr::Unit,
245                Fields::Unnamed(fields) => {
246                    FieldsIr::Unnamed(fields.unnamed.iter().map(|f| f.ty.clone()).collect())
247                }
248                Fields::Named(fields) => FieldsIr::Named(
249                    fields
250                        .named
251                        .iter()
252                        .map(|f| (f.ident.clone().unwrap(), f.ty.clone()))
253                        .collect(),
254                ),
255            };
256
257            // Parse transition attributes from this variant
258            let transitions = variant
259                .attrs
260                .iter()
261                .filter(|attr| attr.path().is_ident("transition"))
262                .filter_map(|attr| attr.parse_args::<TransitionInfo>().ok())
263                .collect();
264
265            VariantIr {
266                struct_name,
267                fields,
268                transitions,
269            }
270        })
271        .collect()
272}
273
274/// Generates the state struct definitions for each variant.
275///
276/// For each variant in the enum, this function generates a corresponding
277/// struct that implements `State` and `Sealed` traits.
278///
279/// # Arguments
280///
281/// * `variants` - The parsed variant representations
282///
283/// # Returns
284///
285/// A `TokenStream` containing all generated struct definitions
286fn generate_state_structs(variants: &[VariantIr]) -> PsTokenStream {
287    let structs: Vec<PsTokenStream> = variants
288        .iter()
289        .map(|variant| {
290            let struct_name = &variant.struct_name;
291            match &variant.fields {
292                FieldsIr::Unit => {
293                    quote! {
294                        /// State marker struct.
295                        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
296                        pub struct #struct_name;
297                        impl fluxo_typestate::Sealed for #struct_name {}
298                        impl fluxo_typestate::State for #struct_name {
299                            fn name(&self) -> &'static str { stringify!(#struct_name) }
300                        }
301                    }
302                }
303                FieldsIr::Unnamed(_types) => {
304                    quote! {
305                        /// State marker struct with data.
306                        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
307                        pub struct #struct_name;
308                        impl fluxo_typestate::Sealed for #struct_name {}
309                        impl fluxo_typestate::State for #struct_name {
310                            fn name(&self) -> &'static str { stringify!(#struct_name) }
311                        }
312                    }
313                }
314                FieldsIr::Named(fields) => {
315                    let field_defs: Vec<PsTokenStream> = fields
316                        .iter()
317                        .map(|(ident, ty)| quote! { pub #ident: #ty })
318                        .collect();
319                    quote! {
320                        /// State marker struct with named fields.
321                        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
322                        pub struct #struct_name { #(#field_defs),* }
323                        impl fluxo_typestate::Sealed for #struct_name {}
324                        impl fluxo_typestate::State for #struct_name {
325                            fn name(&self) -> &'static str { stringify!(#struct_name) }
326                        }
327                    }
328                }
329            }
330        })
331        .collect();
332    quote! { #(#structs)* }
333}
334
335/// Generates the main generic state machine wrapper struct.
336///
337/// This creates the generic `EnumName<S>` struct that wraps the state
338/// and provides the `current_state()` method.
339///
340/// # Arguments
341///
342/// * `enum_name` - The name of the original enum
343///
344/// # Returns
345///
346/// A `TokenStream` containing the generated state machine struct
347fn generate_state_machine(enum_name: &syn::Ident) -> PsTokenStream {
348    quote! {
349        /// The main state machine wrapper type.
350        ///
351        /// This struct is generic over the state type `S`, which ensures
352        /// that state transitions are checked at compile time.
353        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
354        pub struct #enum_name<S: fluxo_typestate::State> {
355            _state: std::marker::PhantomData<S>,
356        }
357
358        impl<S: fluxo_typestate::State> #enum_name<S> {
359            /// Get the name of the current state.
360            pub fn current_state(&self) -> &'static str {
361                S::name(&std::marker::PhantomData)
362            }
363        }
364    }
365}
366
367/// Generates transition methods for each state.
368///
369/// For each defined transition, this function generates a method that
370/// consumes the current state and returns the new state.
371///
372/// # Arguments
373///
374/// * `variants` - The parsed variant representations
375/// * `enum_name` - The name of the original enum
376/// * `trace_enabled` - Whether to generate tracing code
377///
378/// # Returns
379///
380/// A `TokenStream` containing all generated transition methods
381fn generate_transitions(
382    variants: &[VariantIr],
383    enum_name: &syn::Ident,
384    trace_enabled: bool,
385) -> PsTokenStream {
386    let transition_impls: Vec<PsTokenStream> = variants.iter()
387        .filter(|v| !v.transitions.is_empty())
388        .map(|variant| {
389            let from_state = &variant.struct_name;
390            let methods: Vec<PsTokenStream> = variant.transitions.iter()
391                .map(|trans| {
392                    let method_name = &trans.method_name;
393                    let to_state = &trans.to_state;
394                    if trace_enabled {
395                        quote! {
396                            /// Transition to another state.
397                            pub fn #method_name(self) -> #enum_name<#to_state> {
398                                tracing::info!(from = stringify!(#from_state), to = stringify!(#to_state), "Transition via {}()", stringify!(#method_name));
399                                #enum_name::<#to_state> { _state: std::marker::PhantomData }
400                            }
401                        }
402                    } else {
403                        quote! {
404                            /// Transition to another state.
405                            pub fn #method_name(self) -> #enum_name<#to_state> {
406                                #enum_name::<#to_state> { _state: std::marker::PhantomData }
407                            }
408                        }
409                    }
410                })
411                .collect();
412            quote! { impl #enum_name<#from_state> { #(#methods)* } }
413        })
414        .collect();
415    quote! { #(#transition_impls)* }
416}
417
418/// Generates the Default trait implementation for the initial state.
419///
420/// This creates a `Default` implementation that delegates to `new()`.
421///
422/// # Arguments
423///
424/// * `variants` - The parsed variant representations
425/// * `enum_name` - The name of the original enum
426///
427/// # Returns
428///
429/// A `TokenStream` containing the Default implementation
430fn generate_default_impl(variants: &[VariantIr], enum_name: &syn::Ident) -> PsTokenStream {
431    if let Some(first_variant) = variants.first() {
432        let first_state = &first_variant.struct_name;
433        quote! {
434            impl Default for #enum_name<#first_state> {
435                fn default() -> Self { Self::new() }
436            }
437        }
438    } else {
439        quote! {}
440    }
441}
442
443/// Generates the `new()` constructor for the initial state.
444///
445/// This creates a constructor function that returns a state machine
446/// in its initial state.
447///
448/// # Arguments
449///
450/// * `variants` - The parsed variant representations
451/// * `enum_name` - The name of the original enum
452///
453/// # Returns
454///
455/// A `TokenStream` containing the new() implementation
456fn generate_new_impl(variants: &[VariantIr], enum_name: &syn::Ident) -> PsTokenStream {
457    if let Some(first_variant) = variants.first() {
458        let first_state = &first_variant.struct_name;
459        quote! {
460            impl #enum_name<#first_state> {
461                /// Creates a new state machine in the initial state.
462                pub fn new() -> Self {
463                    #enum_name::<#first_state> { _state: std::marker::PhantomData }
464                }
465            }
466        }
467    } else {
468        quote! {}
469    }
470}
471
472/// Generates the Mermaid diagram visualization method.
473///
474/// This creates a method that generates a Mermaid.js state diagram
475/// representing the state machine.
476///
477/// # Arguments
478///
479/// * `variants` - The parsed variant representations (unused in basic version)
480/// * `enum_name` - The name of the original enum
481///
482/// # Returns
483///
484/// A `TokenStream` containing the mermaid_diagram() method
485fn generate_mermaid(_variants: &[VariantIr], enum_name: &syn::Ident) -> PsTokenStream {
486    quote! {
487        impl<S: fluxo_typestate::State> #enum_name<S> {
488            /// Generate a Mermaid state diagram.
489            ///
490            /// This method returns a string containing a Mermaid.js state diagram
491            /// that visualizes the state machine structure.
492            #[allow(dead_code)]
493            pub fn mermaid_diagram() -> String {
494                let mut diagram = String::from("```mermaid\nstateDiagram-v2\n");
495                diagram.push_str("```");
496                diagram
497            }
498        }
499    }
500}
501
502/// The main entry point for the `#[state_machine]` attribute macro.
503///
504/// This is the procedural macro that transforms enum definitions into
505/// complete type-state pattern implementations.
506///
507/// # Arguments
508///
509/// * `args`: TokenStream of attribute arguments (should be empty)
510/// * `input`: TokenStream of the item to annotate (should be an enum)
511///
512/// # Returns
513///
514/// A TokenStream containing the generated code
515///
516/// # Panics
517///
518/// This macro will panic if:
519/// - Arguments are provided to the attribute
520/// - The annotated item is not an enum
521///
522/// # Example
523///
524/// ```ignore
525/// use fluxo_typestate::state_machine;
526///
527/// #[state_machine]
528/// enum TrafficLight {
529///     #[transition(TrafficLight::Red -> TrafficLight::Green: go)]
530///     Red,
531///     Green,
532///     Yellow,
533/// }
534/// ```
535#[proc_macro_attribute]
536pub fn state_machine(args: TokenStream, input: TokenStream) -> TokenStream {
537    // Verify no arguments were provided
538    if !args.to_string().is_empty() {
539        let err = syn::Error::new(
540            proc_macro2::Span::call_site(),
541            "`#[state_machine]` does not accept any arguments.",
542        );
543        return err.into_compile_error().into();
544    }
545
546    // Parse the input as a DeriveInput (enum definition)
547    let mut input = parse_macro_input!(input as DeriveInput);
548
549    // Verify the #[state_machine] attribute is present
550    let has_state_machine_attr = input
551        .attrs
552        .iter()
553        .any(|attr| attr.path().is_ident("state_machine"));
554
555    if !has_state_machine_attr {
556        let err = syn::Error::new(
557            proc_macro2::Span::call_site(),
558            "`#[state_machine]` must be used as an attribute on an enum.",
559        );
560        return err.into_compile_error().into();
561    }
562
563    // Extract additional attributes like #[trace] and #[visualize]
564    let (trace_enabled, _visualize_path) = extract_attributes(&input.attrs);
565
566    // Remove the state_machine attribute to avoid conflicts
567    input
568        .attrs
569        .retain(|attr| !attr.path().is_ident("state_machine"));
570
571    let enum_name = input.ident.clone();
572    let data = &input.data;
573
574    match data {
575        Data::Enum(data_enum) => {
576            // Parse the variants
577            let variants = parse_variants(data_enum);
578
579            // Generate all the code pieces
580            let state_structs = generate_state_structs(&variants);
581            let state_machine = generate_state_machine(&enum_name);
582            let transitions = generate_transitions(&variants, &enum_name, trace_enabled);
583            let default_impl = generate_default_impl(&variants, &enum_name);
584            let new_impl = generate_new_impl(&variants, &enum_name);
585            let mermaid = generate_mermaid(&variants, &enum_name);
586
587            // Combine all generated code
588            let tokens = quote! {
589                #state_structs
590                #state_machine
591                #transitions
592                #default_impl
593                #new_impl
594                #mermaid
595            };
596
597            // Convert to proc_macro TokenStream
598            tokens.into()
599        }
600        _ => {
601            // Return an error if the annotated item is not an enum
602            let err = syn::Error::new(
603                input.ident.span(),
604                "`#[state_machine]` can only be applied to enums.",
605            );
606            err.into_compile_error().into()
607        }
608    }
609}