use std::collections::HashMap;
use std::hash::Hash;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FsmTransition<S, D> {
pub next: S,
pub data: D,
pub timeout: Option<Duration>,
}
pub trait FiniteStateMachine {
type State: Clone + Eq + 'static;
type Data: Clone + 'static;
type Msg: Send + 'static;
fn initial_state(&self) -> Self::State;
fn initial_data(&self) -> Self::Data;
fn transition(
&mut self,
current: &Self::State,
data: &Self::Data,
msg: Self::Msg,
) -> Option<FsmTransition<Self::State, Self::Data>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Eq, PartialEq, Debug)]
enum S {
Idle,
Running,
}
struct TrafficLight;
enum M {
Go,
Stop,
}
impl FiniteStateMachine for TrafficLight {
type State = S;
type Data = u32;
type Msg = M;
fn initial_state(&self) -> S {
S::Idle
}
fn initial_data(&self) -> u32 {
0
}
fn transition(&mut self, s: &S, d: &u32, m: M) -> Option<FsmTransition<S, u32>> {
match (s, m) {
(S::Idle, M::Go) => Some(FsmTransition { next: S::Running, data: d + 1, timeout: None }),
(S::Running, M::Stop) => Some(FsmTransition { next: S::Idle, data: *d, timeout: None }),
_ => None,
}
}
}
#[test]
fn transitions_idle_to_running() {
let mut fsm = TrafficLight;
let t = fsm.transition(&S::Idle, &0, M::Go).unwrap();
assert_eq!(t.next, S::Running);
assert_eq!(t.data, 1);
}
#[test]
fn transitions_running_to_idle_on_stop() {
let mut fsm = TrafficLight;
let t = fsm.transition(&S::Running, &5, M::Stop).unwrap();
assert_eq!(t.next, S::Idle);
assert_eq!(t.data, 5);
}
}
type StateHandler<S, D, M> = Box<dyn FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static>;
type TransitionHook<S> = Box<dyn FnMut(&S, &S) + Send + 'static>;
type TerminationHook<S, D> = Box<dyn FnMut(&S, &D) + Send + 'static>;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum FsmStopReason {
Normal,
Shutdown,
Failure(String),
}
pub struct FsmBuilder<S: Clone + Eq + Hash + 'static, D: Clone + 'static, M: 'static> {
initial_state: Option<S>,
initial_data: Option<D>,
handlers: HashMap<S, StateHandler<S, D, M>>,
fallback: Option<StateHandler<S, D, M>>,
on_transition: Option<TransitionHook<S>>,
on_termination: Option<TerminationHook<S, D>>,
}
impl<S, D, M> Default for FsmBuilder<S, D, M>
where
S: Clone + Eq + Hash + 'static,
D: Clone + 'static,
M: 'static,
{
fn default() -> Self {
Self {
initial_state: None,
initial_data: None,
handlers: HashMap::new(),
fallback: None,
on_transition: None,
on_termination: None,
}
}
}
impl<S, D, M> FsmBuilder<S, D, M>
where
S: Clone + Eq + Hash + 'static,
D: Clone + 'static,
M: 'static,
{
pub fn new() -> Self {
Self::default()
}
pub fn start_with(mut self, state: S, data: D) -> Self {
self.initial_state = Some(state);
self.initial_data = Some(data);
self
}
pub fn when_state<F>(mut self, state: S, handler: F) -> Self
where
F: FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static,
{
self.handlers.insert(state, Box::new(handler));
self
}
pub fn whenever<F>(mut self, handler: F) -> Self
where
F: FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static,
{
self.fallback = Some(Box::new(handler));
self
}
pub fn on_transition<F>(mut self, hook: F) -> Self
where
F: FnMut(&S, &S) + Send + 'static,
{
self.on_transition = Some(Box::new(hook));
self
}
pub fn on_termination<F>(mut self, hook: F) -> Self
where
F: FnMut(&S, &D) + Send + 'static,
{
self.on_termination = Some(Box::new(hook));
self
}
pub fn build(self) -> Fsm<S, D, M> {
let initial_state = self.initial_state.expect("FsmBuilder: start_with(state, data) is required");
let initial_data = self.initial_data.expect("FsmBuilder: start_with(state, data) is required");
Fsm {
current_state: initial_state.clone(),
current_data: initial_data,
initial_state,
handlers: self.handlers,
fallback: self.fallback,
on_transition: self.on_transition,
on_termination: self.on_termination,
terminated: false,
}
}
}
pub struct Fsm<S: Clone + Eq + Hash + 'static, D: Clone + 'static, M: 'static> {
current_state: S,
current_data: D,
initial_state: S,
handlers: HashMap<S, StateHandler<S, D, M>>,
fallback: Option<StateHandler<S, D, M>>,
on_transition: Option<TransitionHook<S>>,
on_termination: Option<TerminationHook<S, D>>,
terminated: bool,
}
impl<S, D, M> Fsm<S, D, M>
where
S: Clone + Eq + Hash + 'static,
D: Clone + 'static,
M: 'static,
{
pub fn state(&self) -> &S {
&self.current_state
}
pub fn data(&self) -> &D {
&self.current_data
}
pub fn initial_state(&self) -> &S {
&self.initial_state
}
pub fn is_terminated(&self) -> bool {
self.terminated
}
pub fn handle(&mut self, msg: M) -> Option<&S> {
if self.terminated {
return None;
}
let attempted = if let Some(handler) = self.handlers.get_mut(&self.current_state) {
handler(&self.current_state, &self.current_data, msg)
} else {
None
};
let transition = match attempted {
Some(t) => Some(t),
None => {
self.fallback.as_mut().and_then(|f| {
let _ = f;
None
})
}
};
if let Some(t) = transition {
if self.current_state != t.next {
if let Some(hook) = self.on_transition.as_mut() {
hook(&self.current_state, &t.next);
}
}
self.current_state = t.next;
self.current_data = t.data;
}
Some(&self.current_state)
}
pub fn terminate(&mut self, _reason: FsmStopReason) {
if self.terminated {
return;
}
if let Some(hook) = self.on_termination.as_mut() {
hook(&self.current_state, &self.current_data);
}
self.terminated = true;
}
}
#[cfg(test)]
mod builder_tests {
use super::*;
use std::sync::{Arc, Mutex};
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
enum St {
Closed,
Open,
}
enum Cmd {
Lock,
Unlock,
}
#[test]
fn builder_drives_transitions() {
let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
.start_with(St::Closed, 0)
.when_state(St::Closed, |_s, d, m| match m {
Cmd::Unlock => Some(FsmTransition { next: St::Open, data: d + 1, timeout: None }),
Cmd::Lock => None,
})
.when_state(St::Open, |_s, d, m| match m {
Cmd::Lock => Some(FsmTransition { next: St::Closed, data: *d, timeout: None }),
Cmd::Unlock => None,
})
.build();
assert_eq!(fsm.state(), &St::Closed);
fsm.handle(Cmd::Unlock);
assert_eq!(fsm.state(), &St::Open);
assert_eq!(*fsm.data(), 1);
fsm.handle(Cmd::Lock);
assert_eq!(fsm.state(), &St::Closed);
}
#[test]
fn on_transition_hook_fires() {
let log: Arc<Mutex<Vec<(St, St)>>> = Arc::new(Mutex::new(Vec::new()));
let log_clone = log.clone();
let mut fsm = FsmBuilder::<St, (), Cmd>::new()
.start_with(St::Closed, ())
.when_state(St::Closed, |_, _, _| Some(FsmTransition { next: St::Open, data: (), timeout: None }))
.on_transition(move |from, to| {
log_clone.lock().unwrap().push((from.clone(), to.clone()));
})
.build();
fsm.handle(Cmd::Unlock);
let entries = log.lock().unwrap().clone();
assert_eq!(entries, vec![(St::Closed, St::Open)]);
}
#[test]
fn on_termination_hook_fires() {
let calls: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
let c = calls.clone();
let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
.start_with(St::Closed, 7)
.when_state(St::Closed, |_, _, _| None)
.on_termination(move |_s, d| {
*c.lock().unwrap() = *d;
})
.build();
fsm.terminate(FsmStopReason::Normal);
assert_eq!(*calls.lock().unwrap(), 7);
fsm.terminate(FsmStopReason::Normal);
assert_eq!(*calls.lock().unwrap(), 7);
}
#[test]
fn handle_after_terminate_returns_none() {
let mut fsm = FsmBuilder::<St, (), Cmd>::new()
.start_with(St::Closed, ())
.when_state(St::Closed, |_, _, _| Some(FsmTransition { next: St::Open, data: (), timeout: None }))
.build();
fsm.terminate(FsmStopReason::Normal);
assert!(fsm.handle(Cmd::Unlock).is_none());
}
#[test]
fn no_transition_keeps_state() {
let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
.start_with(St::Closed, 0)
.when_state(St::Closed, |_, _, _| None)
.build();
fsm.handle(Cmd::Unlock);
assert_eq!(fsm.state(), &St::Closed);
}
}