rustfsm_procmacro/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{quote, quote_spanned};
5use std::collections::{hash_map::Entry, HashMap, HashSet};
6use syn::{
7    parenthesized,
8    parse::{Parse, ParseStream, Result},
9    parse_macro_input,
10    punctuated::Punctuated,
11    spanned::Spanned,
12    Error, Fields, Ident, Token, Type, Variant, Visibility,
13};
14
15/// Parses a DSL for defining finite state machines, and produces code implementing the
16/// [StateMachine](trait.StateMachine.html) trait.
17///
18/// An example state machine definition of a card reader for unlocking a door:
19/// ```
20/// # extern crate rustfsm_trait as rustfsm;
21/// use rustfsm_procmacro::fsm;
22/// use std::convert::Infallible;
23/// use rustfsm_trait::{StateMachine, TransitionResult};
24///
25/// fsm! {
26///     name CardReader; command Commands; error Infallible; shared_state SharedState;
27///
28///     Locked --(CardReadable(CardData), shared on_card_readable) --> ReadingCard;
29///     Locked --(CardReadable(CardData), shared on_card_readable) --> Locked;
30///     ReadingCard --(CardAccepted, on_card_accepted) --> DoorOpen;
31///     ReadingCard --(CardRejected, on_card_rejected) --> Locked;
32///     DoorOpen --(DoorClosed, on_door_closed) --> Locked;
33/// }
34///
35/// #[derive(Clone)]
36/// pub struct SharedState {
37///     last_id: Option<String>
38/// }
39///
40/// #[derive(Debug, Clone, Eq, PartialEq, Hash)]
41/// pub enum Commands {
42///     StartBlinkingLight,
43///     StopBlinkingLight,
44///     ProcessData(CardData),
45/// }
46///
47/// type CardData = String;
48///
49/// /// Door is locked / idle / we are ready to read
50/// #[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
51/// pub struct Locked {}
52///
53/// /// Actively reading the card
54/// #[derive(Debug, Clone, Eq, PartialEq, Hash)]
55/// pub struct ReadingCard {
56///     card_data: CardData,
57/// }
58///
59/// /// The door is open, we shouldn't be accepting cards and should be blinking the light
60/// #[derive(Debug, Clone, Eq, PartialEq, Hash)]
61/// pub struct DoorOpen {}
62/// impl DoorOpen {
63///     fn on_door_closed(&self) -> CardReaderTransition<Locked> {
64///         TransitionResult::ok(vec![], Locked {})
65///     }
66/// }
67///
68/// impl Locked {
69///     fn on_card_readable(&self, shared_dat: SharedState, data: CardData)
70///       -> CardReaderTransition<ReadingCardOrLocked> {
71///         match shared_dat.last_id {
72///             // Arbitrarily deny the same person entering twice in a row
73///             Some(d) if d == data => TransitionResult::ok(vec![], Locked {}.into()),
74///             _ => {
75///                 // Otherwise issue a processing command. This illustrates using the same handler
76///                 // for different destinations
77///                 TransitionResult::ok_shared(
78///                     vec![
79///                         Commands::ProcessData(data.clone()),
80///                         Commands::StartBlinkingLight,
81///                     ],
82///                     ReadingCard { card_data: data.clone() }.into(),
83///                     SharedState { last_id: Some(data) }
84///                 )
85///             }   
86///         }
87///     }
88/// }
89///
90/// impl ReadingCard {
91///     fn on_card_accepted(&self) -> CardReaderTransition<DoorOpen> {
92///         TransitionResult::ok(vec![Commands::StopBlinkingLight], DoorOpen {})
93///     }
94///     fn on_card_rejected(&self) -> CardReaderTransition<Locked> {
95///         TransitionResult::ok(vec![Commands::StopBlinkingLight], Locked {})
96///     }
97/// }
98///
99/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
100/// let crs = CardReaderState::Locked(Locked {});
101/// let mut cr = CardReader { state: crs, shared_state: SharedState { last_id: None } };
102/// let cmds = cr.on_event_mut(CardReaderEvents::CardReadable("badguy".to_string()))?;
103/// assert_eq!(cmds[0], Commands::ProcessData("badguy".to_string()));
104/// assert_eq!(cmds[1], Commands::StartBlinkingLight);
105///
106/// let cmds = cr.on_event_mut(CardReaderEvents::CardRejected)?;
107/// assert_eq!(cmds[0], Commands::StopBlinkingLight);
108///
109/// let cmds = cr.on_event_mut(CardReaderEvents::CardReadable("goodguy".to_string()))?;
110/// assert_eq!(cmds[0], Commands::ProcessData("goodguy".to_string()));
111/// assert_eq!(cmds[1], Commands::StartBlinkingLight);
112///
113/// let cmds = cr.on_event_mut(CardReaderEvents::CardAccepted)?;
114/// assert_eq!(cmds[0], Commands::StopBlinkingLight);
115/// # Ok(())
116/// # }
117/// ```
118///
119/// In the above example the first word is the name of the state machine, then after the comma the
120/// type (which you must define separately) of commands produced by the machine.
121///
122/// then each line represents a transition, where the first word is the initial state, the tuple
123/// inside the arrow is `(eventtype[, event handler])`, and the word after the arrow is the
124/// destination state. here `eventtype` is an enum variant , and `event_handler` is a function you
125/// must define outside the enum whose form depends on the event variant. the only variant types
126/// allowed are unit and one-item tuple variants. For unit variants, the function takes no
127/// parameters. For the tuple variants, the function takes the variant data as its parameter. In
128/// either case the function is expected to return a `TransitionResult` to the appropriate state.
129///
130/// The first transition can be interpreted as "If the machine is in the locked state, when a
131/// `CardReadable` event is seen, call `on_card_readable` (pasing in `CardData`) and transition to
132/// the `ReadingCard` state.
133///
134/// The macro will generate a few things:
135/// * A struct for the overall state machine, named with the provided name. Here:
136///   ```ignore
137///   struct CardMachine {
138///       state: CardMachineState,
139///       shared_state: CardId,
140///   }
141///   ```
142/// * An enum with a variant for each state, named with the provided name + "State".
143///   ```ignore
144///   enum CardMachineState {
145///       Locked(Locked),
146///       ReadingCard(ReadingCard),
147///       Unlocked(Unlocked),
148///   }
149///   ```
150///
151///   You are expected to define a type for each state, to contain that state's data. If there is
152///   no data, you can simply: `type StateName = ()`
153/// * For any instance of transitions with the same event/handler which transition to different
154///   destination states (dynamic destinations), an enum named like `DestAOrDestBOrDestC` is
155///   generated. This enum must be used as the destination "state" from those handlers.
156/// * An enum with a variant for each event. You are expected to define the type (if any) contained
157///   in the event variant.
158///   ```ignore
159///   enum CardMachineEvents {
160///     CardReadable(CardData)
161///   }
162///   ```
163/// * An implementation of the [StateMachine](trait.StateMachine.html) trait for the generated state
164///   machine enum (in this case, `CardMachine`)
165/// * A type alias for a [TransitionResult](enum.TransitionResult.html) with the appropriate generic
166///   parameters set for your machine. It is named as your machine with `Transition` appended. In
167///   this case, `CardMachineTransition`.
168#[proc_macro]
169pub fn fsm(input: TokenStream) -> TokenStream {
170    let def: StateMachineDefinition = parse_macro_input!(input as StateMachineDefinition);
171    def.codegen()
172}
173
174mod kw {
175    syn::custom_keyword!(name);
176    syn::custom_keyword!(command);
177    syn::custom_keyword!(error);
178    syn::custom_keyword!(shared);
179    syn::custom_keyword!(shared_state);
180}
181
182struct StateMachineDefinition {
183    visibility: Visibility,
184    name: Ident,
185    shared_state_type: Option<Type>,
186    command_type: Ident,
187    error_type: Ident,
188    transitions: Vec<Transition>,
189}
190
191impl StateMachineDefinition {
192    fn is_final_state(&self, state: &Ident) -> bool {
193        // If no transitions go from this state, it's a final state.
194        self.transitions.iter().find(|t| t.from == *state).is_none()
195    }
196}
197
198impl Parse for StateMachineDefinition {
199    fn parse(input: ParseStream) -> Result<Self> {
200        // Parse visibility if present
201        let visibility = input.parse()?;
202        // parse the state machine name, command type, and error type
203        let (name, command_type, error_type, shared_state_type) = parse_machine_types(&input)
204            .map_err(|mut e| {
205                e.combine(Error::new(
206                    e.span(),
207                    "The fsm definition should begin with `name MachineName; command CommandType; \
208                    error ErrorType;` optionally followed by `shared_state SharedStateType;`",
209                ));
210                e
211            })?;
212        // Then the state machine definition is simply a sequence of transitions separated by
213        // semicolons
214        let transitions: Punctuated<Transition, Token![;]> =
215            input.parse_terminated(Transition::parse)?;
216        let transitions: Vec<_> = transitions.into_iter().collect();
217        // Check for and whine about any identical transitions. We do this here because preserving
218        // the order transitions were defined in is important, so simply collecting to a set is
219        // not ideal.
220        let trans_set: HashSet<_> = transitions.iter().collect();
221        if trans_set.len() != transitions.len() {
222            return Err(syn::Error::new(
223                input.span(),
224                "Duplicate transitions are not allowed!",
225            ));
226        }
227        Ok(Self {
228            visibility,
229            name,
230            shared_state_type,
231            transitions,
232            command_type,
233            error_type,
234        })
235    }
236}
237
238fn parse_machine_types(input: &ParseStream) -> Result<(Ident, Ident, Ident, Option<Type>)> {
239    let _: kw::name = input.parse()?;
240    let name: Ident = input.parse()?;
241    input.parse::<Token![;]>()?;
242
243    let _: kw::command = input.parse()?;
244    let command_type: Ident = input.parse()?;
245    input.parse::<Token![;]>()?;
246
247    let _: kw::error = input.parse()?;
248    let error_type: Ident = input.parse()?;
249    input.parse::<Token![;]>()?;
250
251    let shared_state_type: Option<Type> = if input.peek(kw::shared_state) {
252        let _: kw::shared_state = input.parse()?;
253        let typep = input.parse()?;
254        input.parse::<Token![;]>()?;
255        Some(typep)
256    } else {
257        None
258    };
259    Ok((name, command_type, error_type, shared_state_type))
260}
261
262#[derive(Debug, Clone, Eq, PartialEq, Hash)]
263struct Transition {
264    from: Ident,
265    to: Vec<Ident>,
266    event: Variant,
267    handler: Option<Ident>,
268    mutates_shared: bool,
269}
270
271impl Parse for Transition {
272    fn parse(input: ParseStream) -> Result<Self> {
273        // Parse the initial state name
274        let from: Ident = input.parse()?;
275        // Parse at least one dash
276        input.parse::<Token![-]>()?;
277        while input.peek(Token![-]) {
278            input.parse::<Token![-]>()?;
279        }
280        // Parse transition information inside parens
281        let transition_info;
282        parenthesized!(transition_info in input);
283        // Get the event variant definition
284        let event: Variant = transition_info.parse()?;
285        // Reject non-unit or single-item-tuple variants
286        match &event.fields {
287            Fields::Named(_) => {
288                return Err(Error::new(
289                    event.span(),
290                    "Struct variants are not supported for events",
291                ))
292            }
293            Fields::Unnamed(uf) => {
294                if uf.unnamed.len() != 1 {
295                    return Err(Error::new(
296                        event.span(),
297                        "Only tuple variants with exactly one item are supported for events",
298                    ));
299                }
300            }
301            Fields::Unit => {}
302        }
303        // Check if there is an event handler, and parse it
304        let (mutates_shared, handler) = if transition_info.peek(Token![,]) {
305            transition_info.parse::<Token![,]>()?;
306            // Check for mut keyword signifying handler wants to mutate shared state
307            let mutates = if transition_info.peek(kw::shared) {
308                transition_info.parse::<kw::shared>()?;
309                true
310            } else {
311                false
312            };
313            (mutates, Some(transition_info.parse()?))
314        } else {
315            (false, None)
316        };
317        // Parse at least one dash followed by the "arrow"
318        input.parse::<Token![-]>()?;
319        while input.peek(Token![-]) {
320            input.parse::<Token![-]>()?;
321        }
322        input.parse::<Token![>]>()?;
323        // Parse the destination state
324        let to: Ident = input.parse()?;
325
326        Ok(Self {
327            from,
328            event,
329            handler,
330            to: vec![to],
331            mutates_shared,
332        })
333    }
334}
335
336impl StateMachineDefinition {
337    fn codegen(&self) -> TokenStream {
338        let visibility = self.visibility.clone();
339        // First extract all of the states into a set, and build the enum's insides
340        let states = self.all_states();
341        let state_variants = states.iter().map(|s| {
342            let statestr = s.to_string();
343            quote! {
344                #[display(fmt=#statestr)]
345                #s(#s)
346            }
347        });
348        let name = &self.name;
349        let name_str = &self.name.to_string();
350
351        let transition_result_name = Ident::new(&format!("{}Transition", name), name.span());
352        let transition_type_alias = quote! {
353            type #transition_result_name<Ds, Sm = #name> = TransitionResult<Sm, Ds>;
354        };
355
356        let state_enum_name = Ident::new(&format!("{}State", name), name.span());
357        // If user has not defined any shared state, use the unit type.
358        let shared_state_type = self
359            .shared_state_type
360            .clone()
361            .unwrap_or_else(|| syn::parse_str("()").unwrap());
362        let machine_struct = quote! {
363            #[derive(Clone)]
364            #visibility struct #name {
365                state: #state_enum_name,
366                shared_state: #shared_state_type
367            }
368        };
369        let states_enum = quote! {
370            #[derive(::derive_more::From, Clone, ::derive_more::Display)]
371            #visibility enum #state_enum_name {
372                #(#state_variants),*
373            }
374        };
375        let state_is_final_match_arms = states.iter().map(|s| {
376            let val = if self.is_final_state(s) {
377                quote! { true }
378            } else {
379                quote! { false }
380            };
381            quote! { #state_enum_name::#s(_) => #val }
382        });
383        let states_enum_impl = quote! {
384            impl #state_enum_name {
385                fn is_final(&self) -> bool {
386                    match self {
387                        #(#state_is_final_match_arms),*
388                    }
389                }
390            }
391        };
392
393        // Build the events enum
394        let events: HashSet<Variant> = self.transitions.iter().map(|t| t.event.clone()).collect();
395        let events_enum_name = Ident::new(&format!("{}Events", name), name.span());
396        let events: Vec<_> = events
397            .into_iter()
398            .map(|v| {
399                let vname = v.ident.to_string();
400                quote! {
401                    #[display(fmt=#vname)]
402                    #v
403                }
404            })
405            .collect();
406        let events_enum = quote! {
407            #[derive(::derive_more::Display)]
408            #visibility enum #events_enum_name {
409                #(#events),*
410            }
411        };
412
413        // Construct the trait implementation
414        let cmd_type = &self.command_type;
415        let err_type = &self.error_type;
416        let mut statemap: HashMap<Ident, Vec<Transition>> = HashMap::new();
417        for t in &self.transitions {
418            statemap
419                .entry(t.from.clone())
420                .and_modify(|v| v.push(t.clone()))
421                .or_insert_with(|| vec![t.clone()]);
422        }
423        // Add any states without any transitions to the map
424        for s in &states {
425            if !statemap.contains_key(s) {
426                statemap.insert(s.clone(), vec![]);
427            }
428        }
429        let mut multi_dest_enums = vec![];
430        let state_branches: Vec<_> = statemap.into_iter().map(|(from, transitions)| {
431            // Merge transition dest states with the same handler
432            let transitions = merge_transition_dests(transitions);
433            let event_branches = transitions
434                .into_iter()
435                .map(|ts| {
436                    let ev_variant = &ts.event.ident;
437                    if let Some(ts_fn) = ts.handler.clone() {
438                        let span = ts_fn.span();
439                        let trans_type = match ts.to.as_slice() {
440                            [] => unreachable!("There will be at least one dest state in transitions"),
441                            [one_to] => quote! {
442                                            #transition_result_name<#one_to>
443                                        },
444                            multi_dests => {
445                                let string_dests: Vec<_> = multi_dests.iter()
446                                    .map(|i| i.to_string()).collect();
447                                let enum_ident = Ident::new(&string_dests.join("Or"),
448                                                            multi_dests[0].span());
449                                let multi_dest_enum = quote! {
450                                    #[derive(::derive_more::From)]
451                                    #visibility enum #enum_ident {
452                                        #(#multi_dests(#multi_dests)),*
453                                    }
454                                    impl ::core::convert::From<#enum_ident> for #state_enum_name {
455                                        fn from(v: #enum_ident) -> Self {
456                                            match v {
457                                                #( #enum_ident::#multi_dests(sv) =>
458                                                    Self::#multi_dests(sv) ),*
459                                            }
460                                        }
461                                    }
462                                };
463                                multi_dest_enums.push(multi_dest_enum);
464                                quote! {
465                                    #transition_result_name<#enum_ident>
466                                }
467                            }
468                        };
469                        match ts.event.fields {
470                            Fields::Unnamed(_) => {
471                                let arglist = if ts.mutates_shared {
472                                    quote! {self.shared_state, val}
473                                } else {
474                                    quote! {val}
475                                };
476                                quote_spanned! {span=>
477                                    #events_enum_name::#ev_variant(val) => {
478                                        let res: #trans_type = state_data.#ts_fn(#arglist);
479                                        res.into_general()
480                                    }
481                                }
482                            }
483                            Fields::Unit => {
484                                let arglist = if ts.mutates_shared {
485                                    quote! {self.shared_state}
486                                } else {
487                                    quote! {}
488                                };
489                                quote_spanned! {span=>
490                                    #events_enum_name::#ev_variant => {
491                                        let res: #trans_type = state_data.#ts_fn(#arglist);
492                                        res.into_general()
493                                    }
494                                }
495                            }
496                            Fields::Named(_) => unreachable!(),
497                        }
498                    } else {
499                        // If events do not have a handler, attempt to construct the next state
500                        // using `Default`.
501                        if let [new_state] = ts.to.as_slice() {
502                            let span = new_state.span();
503                            let default_trans = quote_spanned! {span=>
504                            TransitionResult::<_, #new_state>::from::<#from>(state_data).into_general()
505                        };
506                            let span = ts.event.span();
507                            match ts.event.fields {
508                                Fields::Unnamed(_) => quote_spanned! {span=>
509                                #events_enum_name::#ev_variant(_val) => {
510                                    #default_trans
511                                }
512                            },
513                                Fields::Unit => quote_spanned! {span=>
514                                #events_enum_name::#ev_variant => {
515                                    #default_trans
516                                }
517                            },
518                                Fields::Named(_) => unreachable!(),
519                            }
520
521                        } else {
522                            unreachable!("It should be impossible to have more than one dest state in no-handler transitions")
523                        }
524                    }
525                })
526                // Since most states won't handle every possible event, return an error to that effect
527                .chain(std::iter::once(
528                    quote! { _ => { return TransitionResult::InvalidTransition } },
529                ));
530            quote! {
531                #state_enum_name::#from(state_data) => match event {
532                    #(#event_branches),*
533                }
534            }
535        }).collect();
536
537        let viz_str = self.visualize();
538
539        let trait_impl = quote! {
540            impl ::rustfsm::StateMachine for #name {
541                type Error = #err_type;
542                type State = #state_enum_name;
543                type SharedState = #shared_state_type;
544                type Event = #events_enum_name;
545                type Command = #cmd_type;
546
547                fn name(&self) -> &str {
548                  #name_str
549                }
550
551                fn on_event(self, event: #events_enum_name)
552                  -> ::rustfsm::TransitionResult<Self, Self::State> {
553                    match self.state {
554                        #(#state_branches),*
555                    }
556                }
557
558                fn state(&self) -> &Self::State {
559                    &self.state
560                }
561                fn set_state(&mut self, new: Self::State) {
562                    self.state = new
563                }
564
565                fn shared_state(&self) -> &Self::SharedState{
566                    &self.shared_state
567                }
568
569                fn on_final_state(&self) -> bool {
570                    self.state.is_final()
571                }
572
573                fn from_parts(shared: Self::SharedState, state: Self::State) -> Self {
574                    Self { shared_state: shared, state }
575                }
576
577                fn visualizer() -> &'static str {
578                    #viz_str
579                }
580            }
581        };
582
583        let output = quote! {
584            #transition_type_alias
585            #machine_struct
586            #states_enum
587            #(#multi_dest_enums)*
588            #states_enum_impl
589            #events_enum
590            #trait_impl
591        };
592
593        output.into()
594    }
595
596    fn all_states(&self) -> HashSet<Ident> {
597        self.transitions
598            .iter()
599            .flat_map(|t| {
600                let mut states = t.to.clone();
601                states.push(t.from.clone());
602                states
603            })
604            .collect()
605    }
606
607    fn visualize(&self) -> String {
608        let transitions: Vec<String> = self
609            .transitions
610            .iter()
611            .flat_map(|t| {
612                t.to.iter()
613                    .map(move |d| format!("{} --> {}: {}", t.from, d, t.event.ident))
614            })
615            // Add all final state transitions
616            .chain(
617                self.all_states()
618                    .iter()
619                    .filter(|s| self.is_final_state(s))
620                    .map(|s| format!("{} --> [*]", s)),
621            )
622            .collect();
623        let transitions = transitions.join("\n");
624        format!("@startuml\n{}\n@enduml", transitions)
625    }
626}
627
628/// Merge transition's dest state lists for those with the same from state & handler
629fn merge_transition_dests(transitions: Vec<Transition>) -> Vec<Transition> {
630    let mut map = HashMap::<_, Transition>::new();
631    transitions.into_iter().for_each(|t| {
632        // We want to use the transition sans-destinations as the key
633        let without_dests = {
634            let mut wd = t.clone();
635            wd.to = vec![];
636            wd
637        };
638        match map.entry(without_dests) {
639            Entry::Occupied(mut e) => {
640                e.get_mut().to.extend(t.to.into_iter());
641            }
642            Entry::Vacant(v) => {
643                v.insert(t);
644            }
645        }
646    });
647    map.into_iter().map(|(_, v)| v).collect()
648}