#![deny(missing_docs)]
use std::fmt::{Debug};
use std::ops::Deref;
use std::sync::Arc;
use thiserror::Error;
use crate::ToState::{Calc, Same, To};
#[derive(Default, Clone)]
pub struct StateMachine<'a, TEvent, TState: PartialEq<TState> + Clone + Send + 'a, TData, TErr = Box<dyn std::error::Error>>
{
pub state: TState,
pub transitions: Arc<Vec<StateMachineTransition<'a, TEvent, TState, TData, TErr>>>,
pub data: TData,
pub cycle: bool,
}
impl <'a, TEvent, TState: PartialEq<TState> + Clone + Send + Eq + PartialEq + 'a, TData, TErr> StateMachine<'a, TEvent, TState, TData, TErr>
{
fn new(cycle: bool, initial_state: TState, initial_data: TData) -> Self {
Self {
cycle,
state: initial_state,
data: initial_data,
transitions: Arc::new(Vec::new()),
}
}
pub fn with_transitions(mut self, transitions: Arc<Vec<StateMachineTransition<'a, TEvent, TState, TData, TErr>>>) -> Self {
self.transitions = transitions.clone();
self
}
pub fn handle_event(&mut self, event: TEvent) -> Result<&TState, StateMachineError<TState, TErr>> {
loop {
let mut transition_occurred = false;
for transition in self.transitions.deref() {
let from_state_matches = match &transition.from_state {
FromState::Any => true,
FromState::AnyOf(states) => states.iter().any(|s|s == &self.state),
FromState::From(state) => state == &self.state
};
if from_state_matches {
let to_state = match &transition.get_to_state {
To(to_state) => to_state.clone(),
Calc(get_to_state) => {
let data = StateTransitionToStateData {
data: &mut self.data,
event: &event,
from: &self.state,
};
get_to_state.deref()(data)
},
Same => self.state.clone()
};
let transition_effect_data = StateTransitionEffectData {
name: &transition.name,
data: &mut self.data,
event: &event,
from: &self.state,
to: &to_state
};
if let Some(predicate) = &transition.event_predicate {
if !predicate(&transition_effect_data) {
continue;
}
}
if let Some(effect) = &transition.effect {
effect(transition_effect_data)
.map_err(|e| StateMachineError::EffectError(self.state.clone(), to_state.clone(), e))?;
}
if &self.state != &to_state {
self.state = to_state;
transition_occurred = true;
}
}
}
if !self.cycle || !transition_occurred {
break;
}
}
Ok(&self.state)
}
}
pub struct LockedStateMachineFactory<'a, TEvent, TState: PartialEq<TState> + Clone + Send + 'a, TData = (), TErr = Box<dyn std::error::Error>> {
transitions: Arc<Vec<StateMachineTransition<'a, TEvent, TState, TData, TErr>>>,
cycle: bool,
}
impl <'a, TEvent, TState: PartialEq<TState> + Clone + Send + Eq + PartialEq + 'a, TData, TErr> LockedStateMachineFactory<'a, TEvent, TState, TData, TErr> {
pub fn build(&self, initial_state: TState, initial_data: TData) -> StateMachine<'a, TEvent, TState, TData, TErr> {
StateMachine::new(self.cycle, initial_state, initial_data).with_transitions(self.transitions.clone())
}
}
#[derive(Default)]
pub struct StateMachineFactory<'a, TEvent, TState: PartialEq<TState> + Clone + Send + 'a, TData, TErr = Box<dyn std::error::Error>> {
cycle: bool,
transitions: Vec<StateMachineTransition<'a, TEvent, TState, TData, TErr>>,
}
impl <'a, TEvent, TState: PartialEq<TState> + Clone + Send + Eq + PartialEq + 'a, TData, TErr> StateMachineFactory<'a, TEvent, TState, TData, TErr> {
pub fn new() -> Self {
Self {
cycle: false,
transitions: Vec::new(),
}
}
pub fn cycle(self, cycle: bool) -> Self {
Self {
cycle,
transitions: self.transitions
}
}
pub fn lock(self) -> LockedStateMachineFactory<'a, TEvent, TState, TData, TErr> {
LockedStateMachineFactory {
cycle: self.cycle,
transitions: Arc::new(self.transitions)
}
}
pub fn with_custom_transition(mut self, transition: StateMachineTransition<'a, TEvent, TState, TData, TErr>) -> Self
{
self.transitions.push(transition);
self
}
pub fn with_named_auto_transition(mut self, name: impl Into<String>, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>) -> Self
{
self.transitions.push(StateMachineTransition::new(Some(name.into()), None, from_state.into(), get_to_state.into(), None));
self
}
pub fn with_named_transition_effect(mut self, name: impl Into<String>, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>, effect: impl Fn(StateTransitionEffectData<TEvent, TState, TData>) -> Result<(), TErr> + Send + 'a) -> Self
{
self.transitions.push(StateMachineTransition::new(Some(name.into()), None, from_state.into(), get_to_state.into(), Some(Box::new(effect))));
self
}
pub fn with_named_predicated_transition(mut self, name: impl Into<String>, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>, event_predicate: impl Fn(&StateTransitionEffectData<TEvent, TState, TData>) -> bool + Send + 'a) -> Self
{
self.transitions.push(StateMachineTransition::new(Some(name.into()), Some(Box::new(event_predicate)), from_state.into(), get_to_state.into(), None));
self
}
pub fn with_named_predicated_transition_effect(mut self, name: impl Into<String>, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>, event_predicate: impl Fn(&StateTransitionEffectData<TEvent, TState, TData>) -> bool + Send + 'a, effect: impl Fn(StateTransitionEffectData<TEvent, TState, TData>) -> Result<(), TErr> + Send + 'a) -> Self
{
self.transitions.push(StateMachineTransition::new(Some(name.into()), Some(Box::new(event_predicate)), from_state.into(), get_to_state.into(), Some(Box::new(effect))));
self
}
pub fn with_auto_transition(mut self, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>) -> Self
{
self.transitions.push(StateMachineTransition::new(None, None, from_state.into(), get_to_state.into(), None));
self
}
pub fn with_transition_effect(mut self, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>, effect: impl Fn(StateTransitionEffectData<TEvent, TState, TData>) -> Result<(), TErr> + Send + 'a) -> Self
{
self.transitions.push(StateMachineTransition::new(None, None, from_state.into(), get_to_state.into(), Some(Box::new(effect))));
self
}
pub fn with_predicated_transition(mut self, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>, event_predicate: impl Fn(&StateTransitionEffectData<TEvent, TState, TData>) -> bool + Send + 'a) -> Self
{
self.transitions.push(StateMachineTransition::new(None, Some(Box::new(event_predicate)), from_state.into(), get_to_state.into(), None));
self
}
pub fn with_predicated_transition_effect(mut self, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>, event_predicate: impl Fn(&StateTransitionEffectData<TEvent, TState, TData>) -> bool + Send + 'a, effect: impl Fn(StateTransitionEffectData<TEvent, TState, TData>) -> Result<(), TErr> + Send + 'a) -> Self
{
self.transitions.push(StateMachineTransition::new(None, Some(Box::new(event_predicate)), from_state.into(), get_to_state.into(), Some(Box::new(effect))));
self
}
}
impl <'a, TEvent, TState: PartialEq<TState> + Clone + Send + Eq + PartialEq + 'a, TData, TErr> StateMachineFactory<'a, TEvent, TState, TData, TErr>
where TEvent: PartialEq<TEvent> + Sync
{
pub fn with_named_event_transition<'b>(mut self, name: impl Into<String>, event: &'a TEvent, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>) -> Self
{
self.transitions.push(
StateMachineTransition::new(
Some(name.into()),
Some(Box::new(|e| *event == *e.event)),
from_state.into(),
get_to_state.into(),
None
)
);
self
}
pub fn with_named_event_transition_effect(mut self, name: impl Into<String>, event: &'a TEvent, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>, effect: impl Fn(StateTransitionEffectData<TEvent, TState, TData>) -> Result<(), TErr> + Send + 'a) -> Self
{
self.transitions.push(
StateMachineTransition::new(
Some(name.into()),
Some(Box::new(|e| *event == *e.event)),
from_state.into(),
get_to_state.into(),
Some(Box::new(effect))
)
);
self
}
pub fn with_event_transition<'b>(mut self, event: &'a TEvent, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>) -> Self
{
self.transitions.push(
StateMachineTransition::new(
None,
Some(Box::new(|e| *event == *e.event)),
from_state.into(),
get_to_state.into(),
None
)
);
self
}
pub fn with_event_transition_effect(mut self, event: &'a TEvent, from_state: impl Into<FromState<TState>>, get_to_state: impl Into<ToState<TEvent, TState, TData>>, effect: impl Fn(StateTransitionEffectData<TEvent, TState, TData>) -> Result<(), TErr> + Send + 'a) -> Self
{
self.transitions.push(
StateMachineTransition::new(
None,
Some(Box::new(|e| *event == *e.event)),
from_state.into(),
get_to_state.into(),
Some(Box::new(effect))
)
);
self
}
}
#[derive(Error, Debug)]
pub enum StateMachineError<TState: Send + Clone + Eq + PartialEq, TErr = Box<dyn std::error::Error>> {
#[error("error running effect moving from state {0:?} to {1:?}: {2:?}")]
EffectError(TState, TState, TErr)
}
pub struct StateMachineTransition<'a, TEvent, TState: PartialEq<TState> + Clone + Send + 'a, TData, TErr = Box<dyn std::error::Error>>
{
name: Option<String>,
from_state: FromState<TState>,
get_to_state: ToState<TEvent, TState, TData>,
event_predicate: Option<Box<dyn Fn(&StateTransitionEffectData<TEvent, TState, TData>) -> bool + Send + 'a>>,
effect: Option<Box<dyn Fn(StateTransitionEffectData<TEvent, TState, TData>) -> Result<(), TErr> + Send + 'a>>
}
impl <'a, TEvent, TState: PartialEq<TState> + Clone + Send + 'a, TData, TErr> StateMachineTransition<'a, TEvent, TState, TData, TErr> {
fn new(
name: Option<String>,
event_predicate: Option<Box<dyn Fn(&StateTransitionEffectData<TEvent, TState, TData>) -> bool + Send + 'a>>,
from_state: FromState<TState>,
get_to_state: ToState<TEvent, TState, TData>,
effect: Option<Box<dyn Fn(StateTransitionEffectData<TEvent, TState, TData>) -> Result<(), TErr> + Send + 'a>>,
) -> Self
{
Self {
name,
event_predicate,
from_state,
get_to_state,
effect
}
}
}
#[derive(Clone, Eq, PartialEq)]
pub enum FromState<TState: PartialEq<TState> + Clone> {
Any,
AnyOf(Vec<TState>),
From(TState)
}
impl <TState: PartialEq<TState> + Clone> From<TState> for FromState<TState> {
fn from(value: TState) -> Self {
FromState::From(value)
}
}
pub enum ToState<TEvent, TState: PartialEq<TState> + Clone + Send, TData> {
Same,
To(TState),
Calc(Box<dyn Fn(StateTransitionToStateData<TEvent, TState, TData>) -> TState>)
}
impl <TEvent, TState: PartialEq<TState> + Clone + Send, TData> From<TState> for ToState<TEvent, TState, TData> {
fn from(value: TState) -> Self {
ToState::<TEvent, TState, TData>::To(value)
}
}
#[derive(Clone)]
pub struct StateTransitionEffectData<'a, TEvent, TState, TData> {
pub name: &'a Option<String>,
pub event: &'a TEvent,
pub data: &'a TData,
pub from: &'a TState,
pub to: &'a TState
}
#[derive(Clone)]
pub struct StateTransitionToStateData<'a, TEvent, TState, TData> {
pub event: &'a TEvent,
pub data: &'a TData,
pub from: &'a TState,
}
#[cfg(test)]
mod unit_tests {
use std::sync::atomic::{AtomicBool, Ordering};
use anyhow::{anyhow};
use thiserror::Error;
use crate::{StateMachineFactory, StateMachineError};
use crate::FromState::From;
use crate::ToState::To;
#[test]
fn test_state_machine() {
#[derive(Eq, PartialEq)]
enum StateMachineMessage {
GoToTwo,
GoToThree
}
let go_to_two_happened = AtomicBool::new(false);
let go_to_three_happened = AtomicBool::new(false);
let mut sm = StateMachineFactory::<StateMachineMessage, u32, ()>::new()
.with_event_transition_effect(
&StateMachineMessage::GoToTwo,
1,
2,
|_| {
go_to_two_happened.store(true, Ordering::SeqCst);
Ok(())
}
)
.with_event_transition_effect(
&StateMachineMessage::GoToThree,
2,
3,
|_| {
go_to_three_happened.store(true, Ordering::SeqCst);
Ok(())
}
).lock().build(1, ());
assert_eq!(1, sm.state);
assert_eq!(&1, sm.handle_event(StateMachineMessage::GoToThree).expect("unexpected error"));
assert_eq!(&2, sm.handle_event(StateMachineMessage::GoToTwo).expect("unexpected error"));
assert!(go_to_two_happened.load(Ordering::SeqCst), "effect from GoToTwo did not happen when expected");
assert_eq!(2, sm.state);
assert_eq!(&2, sm.handle_event(StateMachineMessage::GoToTwo).expect("unexpected error"));
assert_eq!(&3, sm.handle_event(StateMachineMessage::GoToThree).expect("unexpected error"));
assert!(go_to_three_happened.load(Ordering::SeqCst), "effect from GoToThree did not happen when expected");
assert_eq!(3, sm.state);
}
#[test]
fn test_double_transition<'a>() -> anyhow::Result<()> {
#[derive(Eq, PartialEq)]
enum StateMachineMessage {
GoToTwo
}
let factory = StateMachineFactory::<StateMachineMessage, u32, ()>::new()
.cycle(true)
.with_event_transition(
&StateMachineMessage::GoToTwo,
1,
2
)
.with_auto_transition(
2,
3
)
.lock();
let mut sm = factory.build(1, ());
assert_eq!(1, sm.state);
match sm.handle_event(StateMachineMessage::GoToTwo) {
Ok(state) => {
assert_eq!(3, *state);
}
Err(StateMachineError::EffectError(from, to, e)) => {
return Err(anyhow!("error changing state from {} to {}: {}", from, to, e));
}
};
assert_eq!(3, sm.state);
Ok(())
}
#[test]
fn test_effect_error() -> anyhow::Result<()> {
#[derive(Eq, PartialEq, Debug)]
enum StateMachineMessage {
GoToTwo
}
#[derive(Error, Debug, Eq, PartialEq)]
enum TestError {
#[error("test error")]
TestError
}
let mut sm = StateMachineFactory::new()
.with_event_transition_effect(
&StateMachineMessage::GoToTwo,
From(1),
To(2),
|_| {
Err(TestError::TestError)
}
).lock().build(1, ());
match sm.handle_event(StateMachineMessage::GoToTwo) {
Ok(_) => {
Err(anyhow!("expected an error"))
},
Err(StateMachineError::EffectError(from, to, cause)) => {
assert_eq!(1, from);
assert_eq!(2, to);
assert_eq!(cause, TestError::TestError);
Ok(())
}
}
}
}