use std::{collections::HashMap, fmt::Debug};
#[cfg(feature = "events")]
use crossbeam_channel::{Receiver, Sender};
use derive_builder::Builder;
use crate::{
edge::{Edge, EdgeBuilder},
prelude::*,
stack::Stack,
};
#[derive(Debug)]
pub struct FiniteStateMachine<'a, S, I, O = (), C = ()> {
states: HashMap<S, StateDefWrapper<'a, S, C>>,
parent_refs: HashMap<S, S>,
edges: HashMap<S, Vec<Edge<'a, S, I, O, C>>>,
initial_state: Option<S>,
}
impl<'a, S, I, O, C> Default for FiniteStateMachine<'a, S, I, O, C> {
fn default() -> Self {
Self {
states: Default::default(),
parent_refs: Default::default(),
edges: Default::default(),
initial_state: Default::default(),
}
}
}
impl<'a, S, I, O, C> FiniteStateMachine<'a, S, I, O, C>
where
S: StateLike,
{
fn get_state_definition(&self, state: &S) -> Option<&StateDefWrapper<'_, S, C>> {
self.states.get(state)
}
}
impl<'a, S, I, O, C> FiniteStateMachine<'a, S, I, O, C>
where
S: StateLike,
I: PartialEq,
{
fn inner_add_transition(
&mut self,
source_state: S,
event: I,
target_state: S,
output: Option<O>,
mutation: Option<fn(&mut C)>,
) {
if !self.states.contains_key(&source_state) {
self.states.insert(
source_state.clone(),
StateDefWrapper::new(source_state.clone()),
);
}
if !self.states.contains_key(&target_state) {
self.states.insert(
target_state.clone(),
StateDefWrapper::new(target_state.clone()),
);
}
if !self.edges.contains_key(&source_state) {
self.edges.insert(source_state.clone(), vec![]);
}
if let Some(source_state_edges) = self.edges.get_mut(&source_state) {
for edge in source_state_edges.iter_mut() {
if *edge.target_state() == target_state && *edge.message() == event {
*edge = EdgeBuilder::default()
.message(event)
.source_state(source_state)
.target_state(target_state)
.output(output)
.mutation(mutation)
.build()
.unwrap();
return;
}
}
source_state_edges.push(
EdgeBuilder::default()
.message(event)
.source_state(source_state)
.target_state(target_state)
.output(output)
.mutation(mutation)
.build()
.unwrap(),
);
}
}
}
impl<'a, S, I, O, C> ConfigurableStateMachine<'a> for FiniteStateMachine<'a, S, I, O, C>
where
S: StateLike,
I: PartialEq,
C: Default,
{
type Input = I;
type State = S;
type Context = C;
type Output = O;
type StateMachine = RunningFiniteStateMachine<'a, S, I, O, C>;
fn set_initial_state(&mut self, initial_state: Self::State) {
self.initial_state = Some(initial_state);
}
fn define_state(
&mut self,
state: Self::State,
options: state::Options<'a, Self>,
) -> Option<Self::State> {
if let Some(parent) = options.parent {
if self.get_state_definition(&parent).is_none() {
self.define_state(
parent.clone(),
state::OptionsBuilder::default().build().unwrap(),
);
}
self.parent_refs.insert(state.clone(), parent);
}
let old_value = self.states.remove(&state);
self.states.insert(
state.clone(),
StateDefWrapper::with_mutations(state, options.on_enter, options.on_leave),
);
old_value.map(|w| w.value().clone())
}
fn add_transition(
&mut self,
source_state: Self::State,
event: Self::Input,
target_state: Self::State,
output: Option<Self::Output>,
mutation: Option<fn(&mut Self::Context)>,
) {
self.inner_add_transition(source_state, event, target_state, output, mutation);
}
fn start(
self,
) -> Result<RunningFiniteStateMachine<'a, S, I, O, C>, Error<'a, Self::State, Self::Input>>
{
#[cfg(feature = "events")]
let (sender, receiver) = crossbeam_channel::unbounded();
Ok(RunningFiniteStateMachine {
context: C::default(),
state: Stack::new(self.initial_state.unwrap()),
edges: self.edges,
parent_refs: self.parent_refs,
states: self.states,
#[cfg(feature = "events")]
receiver,
#[cfg(feature = "events")]
sender,
})
}
}
#[derive(Debug, Clone, Builder)]
pub struct RunningFiniteStateMachine<'a, S, I, O = (), C = ()> {
context: C,
states: HashMap<S, StateDefWrapper<'a, S, C>>,
parent_refs: HashMap<S, S>,
state: Stack<S>,
edges: HashMap<S, Vec<Edge<'a, S, I, O, C>>>,
#[cfg(feature = "events")]
sender: Sender<StateMachineEvent<S>>,
#[cfg(feature = "events")]
receiver: Receiver<StateMachineEvent<S>>,
}
impl<'a, S, I, O, C> RunningStateMachine<'a> for RunningFiniteStateMachine<'a, S, I, O, C>
where
S: StateLike,
I: PartialEq,
{
type Input = I;
type State = S;
type Context = C;
type Output = O;
fn states(&self) -> Vec<&Self::State> {
self.states
.keys()
.collect::<Vec<&Self::State>>()
}
fn state(&self) -> &Self::State {
self.inner_state()
}
fn context(&self) -> &Self::Context {
&self.context
}
fn is_final(&self) -> bool {
self.edges[self.state()].is_empty()
}
fn is_active(&self, state: Self::State) -> bool {
self.state.contains(&state) || self.parent_refs.get(self.state()) == Some(&state)
}
fn consume(
&mut self,
event: Self::Input,
) -> Result<Option<&Self::Output>, Error<'_, Self::State, Self::Input>> {
self.inner_consume(event)
}
}
impl<'a, S, I, O, C> StateMachineMut<'a> for RunningFiniteStateMachine<'a, S, I, O, C>
where
S: StateLike,
I: PartialEq,
{
fn set_state(&mut self, value: Self::State) {
let target_state_parents = self.get_parents(&value);
let target_state_iter = [value];
let state_items = target_state_parents
.iter()
.chain(target_state_iter.iter())
.cloned();
self.state = Stack::from_iter(state_items);
}
fn context_mut(&mut self) -> &mut Self::Context {
&mut self.context
}
}
impl<S, I, C> Default for RunningFiniteStateMachine<'_, S, I, C>
where
S: StateLike + Default,
C: Default,
{
fn default() -> Self {
Self::new(S::default())
}
}
impl<'a, S, I, O, C> RunningFiniteStateMachine<'a, S, I, O, C> {
fn inner_state(&self) -> &S {
self.state
.top()
.expect("Tried to access current state but not current state is set")
}
#[cfg(feature = "events")]
pub fn events(&self) -> crossbeam_channel::Receiver<StateMachineEvent<S>> {
self.receiver.clone()
}
}
impl<'a, S, I, O, C> RunningFiniteStateMachine<'a, S, I, O, C> {
fn inner_new(initial_context: C, initial_state: S) -> Self {
#[cfg(feature = "events")]
let (sender, receiver) = crossbeam_channel::unbounded();
Self {
context: initial_context,
states: Default::default(),
state: Stack::new(initial_state),
parent_refs: Default::default(),
edges: Default::default(),
#[cfg(feature = "events")]
sender,
#[cfg(feature = "events")]
receiver,
}
}
pub fn with_context(initial_state: S, initial_context: C) -> Self {
Self::inner_new(initial_context, initial_state)
}
}
impl<'a, S, I, O, C> RunningFiniteStateMachine<'a, S, I, O, C>
where
S: StateLike,
C: Default,
{
pub fn new(initial_state: S) -> Self {
Self::inner_new(Default::default(), initial_state)
}
}
impl<'a, S, I, O, C> RunningFiniteStateMachine<'a, S, I, O, C>
where
C: 'a,
S: StateLike,
I: PartialEq,
{
fn get_state_definition(&self, state: &S) -> Option<&StateDefWrapper<'_, S, C>> {
self.states.get(state)
}
fn get_parents(&self, state: &S) -> Vec<S> {
let mut parents = vec![];
let mut root = state;
while let Some(parent) = self.parent_refs.get(root) {
parents.push(parent.clone());
root = parent;
}
parents
}
fn inner_consume(&mut self, event: I) -> Result<Option<&O>, Error<'_, S, I>> {
let edges = self.edges.get(self.inner_state());
if let Some(edges) = edges {
let edge = edges.iter().find(|e| e.message() == &event);
if let Some(edge) = edge {
let source_state_definition = self.get_state_definition(edge.source_state());
let target_state_definition = self.get_state_definition(edge.target_state());
if let Some(source_state_definition) = source_state_definition
&& let Some(target_state_definition) = target_state_definition
{
if *target_state_definition.is_virtual() {
return Err(Error::TransitionImpossible((edge.source_state(), event)));
}
if let Some(on_source_state_leave) = source_state_definition.on_leave() {
on_source_state_leave(&mut self.context);
}
let source_state_parents = self.get_parents(edge.source_state());
for parent in source_state_parents {
let def = self.get_state_definition(&parent);
if let Some(parent_definition) = def
&& let Some(on_parent_state_leave) = parent_definition.on_enter()
{
on_parent_state_leave(&mut self.context);
}
}
if let Some(mutation) = edge.mutation() {
(*mutation)(&mut self.context);
}
let mut target_state_parents = self.get_parents(edge.target_state());
target_state_parents.reverse();
for parent in target_state_parents {
let def = self.get_state_definition(&parent);
if let Some(parent_definition) = def
&& let Some(on_parent_state_enter) = parent_definition.on_enter()
{
on_parent_state_enter(&mut self.context);
}
}
if let Some(target_state_definition) =
self.get_state_definition(edge.target_state())
{
if let Some(on_target_state_enter) = target_state_definition.on_enter() {
on_target_state_enter(&mut self.context);
}
} else {
return Err(Error::UnknownState);
}
let target_state = edge.target_state().clone();
let target_state_parents = self.get_parents(&target_state);
let target_state_iter = [target_state];
let state_items = target_state_parents
.iter()
.chain(target_state_iter.iter())
.cloned();
self.state = Stack::from_iter(state_items);
#[cfg(feature = "events")]
let _ = self.emit(StateMachineEvent::StateChanged((
edge.source_state().clone(),
edge.target_state().clone(),
)));
Ok(edge.output().as_ref())
} else {
Err(Error::UnknownState)
}
} else {
Err(Error::TransitionImpossible((self.inner_state(), event)))
}
} else {
Err(Error::TransitionImpossible((self.inner_state(), event)))
}
}
}
#[cfg(feature = "events")]
impl<'a, S, I, O, C> EventProducer<S> for RunningFiniteStateMachine<'a, S, I, O, C> {
fn emit(
&self,
event: StateMachineEvent<S>,
) -> Result<(), crossbeam_channel::SendError<StateMachineEvent<S>>> {
self.sender.send(event)
}
}
#[cfg(test)]
mod test {
use fsmy_dsl::Primitive;
use super::*;
#[derive(Debug, Primitive)]
enum State {
Open,
Closed,
Broken,
Offline,
Repairing,
}
#[derive(Debug, PartialEq)]
enum Message {
Push,
Pull,
Kick,
StartRepairing,
}
#[derive(Default, Debug, Clone)]
pub struct Context {
counter: usize,
message: &'static str,
}
fn machine() -> FiniteStateMachine<'static, State, Message, (), Context> {
let mut machine = FiniteStateMachine::<State, Message, (), Context>::default();
machine.add_transition(State::Open, Message::Push, State::Closed, None, None);
machine.add_transition_with_mutation(
State::Closed,
Message::Pull,
State::Open,
None,
|c| c.counter += 1,
);
machine.define_state(
State::Broken,
state::OptionsBuilder::default()
.parent(Some(State::Open))
.build()
.unwrap(),
);
machine.add_transition(State::Closed, Message::Kick, State::Broken, None, None);
machine.define_state(
State::Offline,
state::OptionsBuilder::default()
.is_virtual(true)
.on_enter(Some(|c: &mut Context| {
c.message = "Offline";
}))
.on_leave(Some(|c: &mut Context| {
c.message = "Online";
}))
.build()
.unwrap(),
);
machine.define_state(
State::Repairing,
state::OptionsBuilder::default()
.is_virtual(true)
.parent(Some(State::Offline))
.on_enter(Some(|c: &mut Context| {
c.counter = 100;
}))
.on_leave(Some(|c: &mut Context| {
c.counter = 0;
}))
.build()
.unwrap(),
);
machine.add_transition(
State::Broken,
Message::StartRepairing,
State::Repairing,
None,
None,
);
machine.set_initial_state(State::Closed);
machine
}
#[test]
fn transition() {
let mut machine = machine().start().unwrap();
assert!(machine.consume(Message::Pull).is_ok());
assert_eq!(*machine.state(), State::Open);
assert!(machine.consume(Message::Push).is_ok());
assert_eq!(*machine.state(), State::Closed);
assert!(machine.consume(Message::Push).is_err());
assert_eq!(*machine.state(), State::Closed);
}
#[test]
fn define_state_with_mutations() {
let mut machine = machine();
machine.define_state(
State::Open,
state::Options {
is_virtual: false,
parent: None,
on_enter: Some(|c| c.message = "Now in opened state"),
on_leave: Some(|c| c.message = "Leaving opened state"),
},
);
let mut machine = machine.start().unwrap();
assert!(machine.consume(Message::Pull).is_ok());
assert_eq!(machine.context().message, "Now in opened state");
assert!(machine.consume(Message::Push).is_ok());
assert_eq!(machine.context().message, "Leaving opened state");
}
#[test]
fn transition_with_mutation() {
let mut machine = machine().start().unwrap();
assert!(machine.consume(Message::Pull).is_ok());
assert_eq!(*machine.state(), State::Open);
assert_eq!(machine.context().counter, 1);
assert!(machine.consume(Message::Push).is_ok());
assert_eq!(*machine.state(), State::Closed);
assert_eq!(machine.context().counter, 1);
assert!(machine.consume(Message::Pull).is_ok());
assert_eq!(*machine.state(), State::Open);
assert_eq!(machine.context().counter, 2);
}
#[test]
fn parents() {
let mut machine = machine().start().unwrap();
assert!(machine.consume(Message::Kick).is_ok());
assert_eq!(*machine.state(), State::Broken);
assert!(machine.is_active(State::Broken));
assert!(machine.is_active(State::Open)); assert!(!machine.is_active(State::Closed));
assert!(machine.consume(Message::StartRepairing).is_ok());
assert_eq!(machine.context().message, "Offline");
}
#[test]
fn virtual_states() {
let mut machine = machine();
machine.define_state(
State::Broken,
state::OptionsBuilder::default()
.parent(Some(State::Open))
.build()
.unwrap(),
);
machine.add_transition(State::Closed, Message::Kick, State::Broken, None, None);
let mut machine = machine.start().unwrap();
assert!(machine.consume(Message::Kick).is_ok());
assert_eq!(*machine.state(), State::Broken);
assert!(machine.is_active(State::Broken));
assert!(machine.consume(Message::StartRepairing).is_ok());
assert_eq!(*machine.state(), State::Repairing);
}
#[test]
#[cfg(feature = "events")]
fn events() {
let mut machine = machine().start().unwrap();
let events = machine.events();
std::thread::spawn(move || {
let event = events.recv();
assert!(event.is_ok());
let event = event.unwrap();
match event {
StateMachineEvent::StateChanged((from, to)) => {
assert_eq!(from, State::Closed);
assert_eq!(to, State::Open);
}
}
});
assert!(machine.consume(Message::Pull).is_ok());
}
}