rustfsm_procmacro/lib.rs
1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{quote, quote_spanned};
5use std::collections::{hash_map::Entry, HashMap, HashSet};
6use syn::{
7 parenthesized,
8 parse::{Parse, ParseStream, Result},
9 parse_macro_input,
10 punctuated::Punctuated,
11 spanned::Spanned,
12 Error, Fields, Ident, Token, Type, Variant, Visibility,
13};
14
15/// Parses a DSL for defining finite state machines, and produces code implementing the
16/// [StateMachine](trait.StateMachine.html) trait.
17///
18/// An example state machine definition of a card reader for unlocking a door:
19/// ```
20/// # extern crate rustfsm_trait as rustfsm;
21/// use rustfsm_procmacro::fsm;
22/// use std::convert::Infallible;
23/// use rustfsm_trait::{StateMachine, TransitionResult};
24///
25/// fsm! {
26/// name CardReader; command Commands; error Infallible; shared_state SharedState;
27///
28/// Locked --(CardReadable(CardData), shared on_card_readable) --> ReadingCard;
29/// Locked --(CardReadable(CardData), shared on_card_readable) --> Locked;
30/// ReadingCard --(CardAccepted, on_card_accepted) --> DoorOpen;
31/// ReadingCard --(CardRejected, on_card_rejected) --> Locked;
32/// DoorOpen --(DoorClosed, on_door_closed) --> Locked;
33/// }
34///
35/// #[derive(Clone)]
36/// pub struct SharedState {
37/// last_id: Option<String>
38/// }
39///
40/// #[derive(Debug, Clone, Eq, PartialEq, Hash)]
41/// pub enum Commands {
42/// StartBlinkingLight,
43/// StopBlinkingLight,
44/// ProcessData(CardData),
45/// }
46///
47/// type CardData = String;
48///
49/// /// Door is locked / idle / we are ready to read
50/// #[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
51/// pub struct Locked {}
52///
53/// /// Actively reading the card
54/// #[derive(Debug, Clone, Eq, PartialEq, Hash)]
55/// pub struct ReadingCard {
56/// card_data: CardData,
57/// }
58///
59/// /// The door is open, we shouldn't be accepting cards and should be blinking the light
60/// #[derive(Debug, Clone, Eq, PartialEq, Hash)]
61/// pub struct DoorOpen {}
62/// impl DoorOpen {
63/// fn on_door_closed(&self) -> CardReaderTransition<Locked> {
64/// TransitionResult::ok(vec![], Locked {})
65/// }
66/// }
67///
68/// impl Locked {
69/// fn on_card_readable(&self, shared_dat: SharedState, data: CardData)
70/// -> CardReaderTransition<ReadingCardOrLocked> {
71/// match shared_dat.last_id {
72/// // Arbitrarily deny the same person entering twice in a row
73/// Some(d) if d == data => TransitionResult::ok(vec![], Locked {}.into()),
74/// _ => {
75/// // Otherwise issue a processing command. This illustrates using the same handler
76/// // for different destinations
77/// TransitionResult::ok_shared(
78/// vec![
79/// Commands::ProcessData(data.clone()),
80/// Commands::StartBlinkingLight,
81/// ],
82/// ReadingCard { card_data: data.clone() }.into(),
83/// SharedState { last_id: Some(data) }
84/// )
85/// }
86/// }
87/// }
88/// }
89///
90/// impl ReadingCard {
91/// fn on_card_accepted(&self) -> CardReaderTransition<DoorOpen> {
92/// TransitionResult::ok(vec![Commands::StopBlinkingLight], DoorOpen {})
93/// }
94/// fn on_card_rejected(&self) -> CardReaderTransition<Locked> {
95/// TransitionResult::ok(vec![Commands::StopBlinkingLight], Locked {})
96/// }
97/// }
98///
99/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
100/// let crs = CardReaderState::Locked(Locked {});
101/// let mut cr = CardReader { state: crs, shared_state: SharedState { last_id: None } };
102/// let cmds = cr.on_event_mut(CardReaderEvents::CardReadable("badguy".to_string()))?;
103/// assert_eq!(cmds[0], Commands::ProcessData("badguy".to_string()));
104/// assert_eq!(cmds[1], Commands::StartBlinkingLight);
105///
106/// let cmds = cr.on_event_mut(CardReaderEvents::CardRejected)?;
107/// assert_eq!(cmds[0], Commands::StopBlinkingLight);
108///
109/// let cmds = cr.on_event_mut(CardReaderEvents::CardReadable("goodguy".to_string()))?;
110/// assert_eq!(cmds[0], Commands::ProcessData("goodguy".to_string()));
111/// assert_eq!(cmds[1], Commands::StartBlinkingLight);
112///
113/// let cmds = cr.on_event_mut(CardReaderEvents::CardAccepted)?;
114/// assert_eq!(cmds[0], Commands::StopBlinkingLight);
115/// # Ok(())
116/// # }
117/// ```
118///
119/// In the above example the first word is the name of the state machine, then after the comma the
120/// type (which you must define separately) of commands produced by the machine.
121///
122/// then each line represents a transition, where the first word is the initial state, the tuple
123/// inside the arrow is `(eventtype[, event handler])`, and the word after the arrow is the
124/// destination state. here `eventtype` is an enum variant , and `event_handler` is a function you
125/// must define outside the enum whose form depends on the event variant. the only variant types
126/// allowed are unit and one-item tuple variants. For unit variants, the function takes no
127/// parameters. For the tuple variants, the function takes the variant data as its parameter. In
128/// either case the function is expected to return a `TransitionResult` to the appropriate state.
129///
130/// The first transition can be interpreted as "If the machine is in the locked state, when a
131/// `CardReadable` event is seen, call `on_card_readable` (pasing in `CardData`) and transition to
132/// the `ReadingCard` state.
133///
134/// The macro will generate a few things:
135/// * A struct for the overall state machine, named with the provided name. Here:
136/// ```ignore
137/// struct CardMachine {
138/// state: CardMachineState,
139/// shared_state: CardId,
140/// }
141/// ```
142/// * An enum with a variant for each state, named with the provided name + "State".
143/// ```ignore
144/// enum CardMachineState {
145/// Locked(Locked),
146/// ReadingCard(ReadingCard),
147/// Unlocked(Unlocked),
148/// }
149/// ```
150///
151/// You are expected to define a type for each state, to contain that state's data. If there is
152/// no data, you can simply: `type StateName = ()`
153/// * For any instance of transitions with the same event/handler which transition to different
154/// destination states (dynamic destinations), an enum named like `DestAOrDestBOrDestC` is
155/// generated. This enum must be used as the destination "state" from those handlers.
156/// * An enum with a variant for each event. You are expected to define the type (if any) contained
157/// in the event variant.
158/// ```ignore
159/// enum CardMachineEvents {
160/// CardReadable(CardData)
161/// }
162/// ```
163/// * An implementation of the [StateMachine](trait.StateMachine.html) trait for the generated state
164/// machine enum (in this case, `CardMachine`)
165/// * A type alias for a [TransitionResult](enum.TransitionResult.html) with the appropriate generic
166/// parameters set for your machine. It is named as your machine with `Transition` appended. In
167/// this case, `CardMachineTransition`.
168#[proc_macro]
169pub fn fsm(input: TokenStream) -> TokenStream {
170 let def: StateMachineDefinition = parse_macro_input!(input as StateMachineDefinition);
171 def.codegen()
172}
173
174mod kw {
175 syn::custom_keyword!(name);
176 syn::custom_keyword!(command);
177 syn::custom_keyword!(error);
178 syn::custom_keyword!(shared);
179 syn::custom_keyword!(shared_state);
180}
181
182struct StateMachineDefinition {
183 visibility: Visibility,
184 name: Ident,
185 shared_state_type: Option<Type>,
186 command_type: Ident,
187 error_type: Ident,
188 transitions: Vec<Transition>,
189}
190
191impl StateMachineDefinition {
192 fn is_final_state(&self, state: &Ident) -> bool {
193 // If no transitions go from this state, it's a final state.
194 self.transitions.iter().find(|t| t.from == *state).is_none()
195 }
196}
197
198impl Parse for StateMachineDefinition {
199 fn parse(input: ParseStream) -> Result<Self> {
200 // Parse visibility if present
201 let visibility = input.parse()?;
202 // parse the state machine name, command type, and error type
203 let (name, command_type, error_type, shared_state_type) = parse_machine_types(&input)
204 .map_err(|mut e| {
205 e.combine(Error::new(
206 e.span(),
207 "The fsm definition should begin with `name MachineName; command CommandType; \
208 error ErrorType;` optionally followed by `shared_state SharedStateType;`",
209 ));
210 e
211 })?;
212 // Then the state machine definition is simply a sequence of transitions separated by
213 // semicolons
214 let transitions: Punctuated<Transition, Token![;]> =
215 input.parse_terminated(Transition::parse)?;
216 let transitions: Vec<_> = transitions.into_iter().collect();
217 // Check for and whine about any identical transitions. We do this here because preserving
218 // the order transitions were defined in is important, so simply collecting to a set is
219 // not ideal.
220 let trans_set: HashSet<_> = transitions.iter().collect();
221 if trans_set.len() != transitions.len() {
222 return Err(syn::Error::new(
223 input.span(),
224 "Duplicate transitions are not allowed!",
225 ));
226 }
227 Ok(Self {
228 visibility,
229 name,
230 shared_state_type,
231 transitions,
232 command_type,
233 error_type,
234 })
235 }
236}
237
238fn parse_machine_types(input: &ParseStream) -> Result<(Ident, Ident, Ident, Option<Type>)> {
239 let _: kw::name = input.parse()?;
240 let name: Ident = input.parse()?;
241 input.parse::<Token![;]>()?;
242
243 let _: kw::command = input.parse()?;
244 let command_type: Ident = input.parse()?;
245 input.parse::<Token![;]>()?;
246
247 let _: kw::error = input.parse()?;
248 let error_type: Ident = input.parse()?;
249 input.parse::<Token![;]>()?;
250
251 let shared_state_type: Option<Type> = if input.peek(kw::shared_state) {
252 let _: kw::shared_state = input.parse()?;
253 let typep = input.parse()?;
254 input.parse::<Token![;]>()?;
255 Some(typep)
256 } else {
257 None
258 };
259 Ok((name, command_type, error_type, shared_state_type))
260}
261
262#[derive(Debug, Clone, Eq, PartialEq, Hash)]
263struct Transition {
264 from: Ident,
265 to: Vec<Ident>,
266 event: Variant,
267 handler: Option<Ident>,
268 mutates_shared: bool,
269}
270
271impl Parse for Transition {
272 fn parse(input: ParseStream) -> Result<Self> {
273 // Parse the initial state name
274 let from: Ident = input.parse()?;
275 // Parse at least one dash
276 input.parse::<Token![-]>()?;
277 while input.peek(Token![-]) {
278 input.parse::<Token![-]>()?;
279 }
280 // Parse transition information inside parens
281 let transition_info;
282 parenthesized!(transition_info in input);
283 // Get the event variant definition
284 let event: Variant = transition_info.parse()?;
285 // Reject non-unit or single-item-tuple variants
286 match &event.fields {
287 Fields::Named(_) => {
288 return Err(Error::new(
289 event.span(),
290 "Struct variants are not supported for events",
291 ))
292 }
293 Fields::Unnamed(uf) => {
294 if uf.unnamed.len() != 1 {
295 return Err(Error::new(
296 event.span(),
297 "Only tuple variants with exactly one item are supported for events",
298 ));
299 }
300 }
301 Fields::Unit => {}
302 }
303 // Check if there is an event handler, and parse it
304 let (mutates_shared, handler) = if transition_info.peek(Token![,]) {
305 transition_info.parse::<Token![,]>()?;
306 // Check for mut keyword signifying handler wants to mutate shared state
307 let mutates = if transition_info.peek(kw::shared) {
308 transition_info.parse::<kw::shared>()?;
309 true
310 } else {
311 false
312 };
313 (mutates, Some(transition_info.parse()?))
314 } else {
315 (false, None)
316 };
317 // Parse at least one dash followed by the "arrow"
318 input.parse::<Token![-]>()?;
319 while input.peek(Token![-]) {
320 input.parse::<Token![-]>()?;
321 }
322 input.parse::<Token![>]>()?;
323 // Parse the destination state
324 let to: Ident = input.parse()?;
325
326 Ok(Self {
327 from,
328 event,
329 handler,
330 to: vec![to],
331 mutates_shared,
332 })
333 }
334}
335
336impl StateMachineDefinition {
337 fn codegen(&self) -> TokenStream {
338 let visibility = self.visibility.clone();
339 // First extract all of the states into a set, and build the enum's insides
340 let states = self.all_states();
341 let state_variants = states.iter().map(|s| {
342 let statestr = s.to_string();
343 quote! {
344 #[display(fmt=#statestr)]
345 #s(#s)
346 }
347 });
348 let name = &self.name;
349 let name_str = &self.name.to_string();
350
351 let transition_result_name = Ident::new(&format!("{}Transition", name), name.span());
352 let transition_type_alias = quote! {
353 type #transition_result_name<Ds, Sm = #name> = TransitionResult<Sm, Ds>;
354 };
355
356 let state_enum_name = Ident::new(&format!("{}State", name), name.span());
357 // If user has not defined any shared state, use the unit type.
358 let shared_state_type = self
359 .shared_state_type
360 .clone()
361 .unwrap_or_else(|| syn::parse_str("()").unwrap());
362 let machine_struct = quote! {
363 #[derive(Clone)]
364 #visibility struct #name {
365 state: #state_enum_name,
366 shared_state: #shared_state_type
367 }
368 };
369 let states_enum = quote! {
370 #[derive(::derive_more::From, Clone, ::derive_more::Display)]
371 #visibility enum #state_enum_name {
372 #(#state_variants),*
373 }
374 };
375 let state_is_final_match_arms = states.iter().map(|s| {
376 let val = if self.is_final_state(s) {
377 quote! { true }
378 } else {
379 quote! { false }
380 };
381 quote! { #state_enum_name::#s(_) => #val }
382 });
383 let states_enum_impl = quote! {
384 impl #state_enum_name {
385 fn is_final(&self) -> bool {
386 match self {
387 #(#state_is_final_match_arms),*
388 }
389 }
390 }
391 };
392
393 // Build the events enum
394 let events: HashSet<Variant> = self.transitions.iter().map(|t| t.event.clone()).collect();
395 let events_enum_name = Ident::new(&format!("{}Events", name), name.span());
396 let events: Vec<_> = events
397 .into_iter()
398 .map(|v| {
399 let vname = v.ident.to_string();
400 quote! {
401 #[display(fmt=#vname)]
402 #v
403 }
404 })
405 .collect();
406 let events_enum = quote! {
407 #[derive(::derive_more::Display)]
408 #visibility enum #events_enum_name {
409 #(#events),*
410 }
411 };
412
413 // Construct the trait implementation
414 let cmd_type = &self.command_type;
415 let err_type = &self.error_type;
416 let mut statemap: HashMap<Ident, Vec<Transition>> = HashMap::new();
417 for t in &self.transitions {
418 statemap
419 .entry(t.from.clone())
420 .and_modify(|v| v.push(t.clone()))
421 .or_insert_with(|| vec![t.clone()]);
422 }
423 // Add any states without any transitions to the map
424 for s in &states {
425 if !statemap.contains_key(s) {
426 statemap.insert(s.clone(), vec![]);
427 }
428 }
429 let mut multi_dest_enums = vec![];
430 let state_branches: Vec<_> = statemap.into_iter().map(|(from, transitions)| {
431 // Merge transition dest states with the same handler
432 let transitions = merge_transition_dests(transitions);
433 let event_branches = transitions
434 .into_iter()
435 .map(|ts| {
436 let ev_variant = &ts.event.ident;
437 if let Some(ts_fn) = ts.handler.clone() {
438 let span = ts_fn.span();
439 let trans_type = match ts.to.as_slice() {
440 [] => unreachable!("There will be at least one dest state in transitions"),
441 [one_to] => quote! {
442 #transition_result_name<#one_to>
443 },
444 multi_dests => {
445 let string_dests: Vec<_> = multi_dests.iter()
446 .map(|i| i.to_string()).collect();
447 let enum_ident = Ident::new(&string_dests.join("Or"),
448 multi_dests[0].span());
449 let multi_dest_enum = quote! {
450 #[derive(::derive_more::From)]
451 #visibility enum #enum_ident {
452 #(#multi_dests(#multi_dests)),*
453 }
454 impl ::core::convert::From<#enum_ident> for #state_enum_name {
455 fn from(v: #enum_ident) -> Self {
456 match v {
457 #( #enum_ident::#multi_dests(sv) =>
458 Self::#multi_dests(sv) ),*
459 }
460 }
461 }
462 };
463 multi_dest_enums.push(multi_dest_enum);
464 quote! {
465 #transition_result_name<#enum_ident>
466 }
467 }
468 };
469 match ts.event.fields {
470 Fields::Unnamed(_) => {
471 let arglist = if ts.mutates_shared {
472 quote! {self.shared_state, val}
473 } else {
474 quote! {val}
475 };
476 quote_spanned! {span=>
477 #events_enum_name::#ev_variant(val) => {
478 let res: #trans_type = state_data.#ts_fn(#arglist);
479 res.into_general()
480 }
481 }
482 }
483 Fields::Unit => {
484 let arglist = if ts.mutates_shared {
485 quote! {self.shared_state}
486 } else {
487 quote! {}
488 };
489 quote_spanned! {span=>
490 #events_enum_name::#ev_variant => {
491 let res: #trans_type = state_data.#ts_fn(#arglist);
492 res.into_general()
493 }
494 }
495 }
496 Fields::Named(_) => unreachable!(),
497 }
498 } else {
499 // If events do not have a handler, attempt to construct the next state
500 // using `Default`.
501 if let [new_state] = ts.to.as_slice() {
502 let span = new_state.span();
503 let default_trans = quote_spanned! {span=>
504 TransitionResult::<_, #new_state>::from::<#from>(state_data).into_general()
505 };
506 let span = ts.event.span();
507 match ts.event.fields {
508 Fields::Unnamed(_) => quote_spanned! {span=>
509 #events_enum_name::#ev_variant(_val) => {
510 #default_trans
511 }
512 },
513 Fields::Unit => quote_spanned! {span=>
514 #events_enum_name::#ev_variant => {
515 #default_trans
516 }
517 },
518 Fields::Named(_) => unreachable!(),
519 }
520
521 } else {
522 unreachable!("It should be impossible to have more than one dest state in no-handler transitions")
523 }
524 }
525 })
526 // Since most states won't handle every possible event, return an error to that effect
527 .chain(std::iter::once(
528 quote! { _ => { return TransitionResult::InvalidTransition } },
529 ));
530 quote! {
531 #state_enum_name::#from(state_data) => match event {
532 #(#event_branches),*
533 }
534 }
535 }).collect();
536
537 let viz_str = self.visualize();
538
539 let trait_impl = quote! {
540 impl ::rustfsm::StateMachine for #name {
541 type Error = #err_type;
542 type State = #state_enum_name;
543 type SharedState = #shared_state_type;
544 type Event = #events_enum_name;
545 type Command = #cmd_type;
546
547 fn name(&self) -> &str {
548 #name_str
549 }
550
551 fn on_event(self, event: #events_enum_name)
552 -> ::rustfsm::TransitionResult<Self, Self::State> {
553 match self.state {
554 #(#state_branches),*
555 }
556 }
557
558 fn state(&self) -> &Self::State {
559 &self.state
560 }
561 fn set_state(&mut self, new: Self::State) {
562 self.state = new
563 }
564
565 fn shared_state(&self) -> &Self::SharedState{
566 &self.shared_state
567 }
568
569 fn on_final_state(&self) -> bool {
570 self.state.is_final()
571 }
572
573 fn from_parts(shared: Self::SharedState, state: Self::State) -> Self {
574 Self { shared_state: shared, state }
575 }
576
577 fn visualizer() -> &'static str {
578 #viz_str
579 }
580 }
581 };
582
583 let output = quote! {
584 #transition_type_alias
585 #machine_struct
586 #states_enum
587 #(#multi_dest_enums)*
588 #states_enum_impl
589 #events_enum
590 #trait_impl
591 };
592
593 output.into()
594 }
595
596 fn all_states(&self) -> HashSet<Ident> {
597 self.transitions
598 .iter()
599 .flat_map(|t| {
600 let mut states = t.to.clone();
601 states.push(t.from.clone());
602 states
603 })
604 .collect()
605 }
606
607 fn visualize(&self) -> String {
608 let transitions: Vec<String> = self
609 .transitions
610 .iter()
611 .flat_map(|t| {
612 t.to.iter()
613 .map(move |d| format!("{} --> {}: {}", t.from, d, t.event.ident))
614 })
615 // Add all final state transitions
616 .chain(
617 self.all_states()
618 .iter()
619 .filter(|s| self.is_final_state(s))
620 .map(|s| format!("{} --> [*]", s)),
621 )
622 .collect();
623 let transitions = transitions.join("\n");
624 format!("@startuml\n{}\n@enduml", transitions)
625 }
626}
627
628/// Merge transition's dest state lists for those with the same from state & handler
629fn merge_transition_dests(transitions: Vec<Transition>) -> Vec<Transition> {
630 let mut map = HashMap::<_, Transition>::new();
631 transitions.into_iter().for_each(|t| {
632 // We want to use the transition sans-destinations as the key
633 let without_dests = {
634 let mut wd = t.clone();
635 wd.to = vec![];
636 wd
637 };
638 match map.entry(without_dests) {
639 Entry::Occupied(mut e) => {
640 e.get_mut().to.extend(t.to.into_iter());
641 }
642 Entry::Vacant(v) => {
643 v.insert(t);
644 }
645 }
646 });
647 map.into_iter().map(|(_, v)| v).collect()
648}