use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::hash::Hash;
use crate::{StateMachine, StateMachineBuildError, Transition};
#[derive(Debug, Clone)]
pub struct StateMachineBuilder<S, E>
where
S: Copy + Eq + Hash + Debug,
E: Copy + Eq + Hash + Debug,
{
pub(crate) states: HashSet<S>,
pub(crate) initial_states: HashSet<S>,
pub(crate) final_states: HashSet<S>,
pub(crate) transitions: Vec<Transition<S, E>>,
}
impl<S, E> StateMachineBuilder<S, E>
where
S: Copy + Eq + Hash + Debug + 'static,
E: Copy + Eq + Hash + Debug + 'static,
{
pub fn new() -> Self {
Self {
states: HashSet::new(),
initial_states: HashSet::new(),
final_states: HashSet::new(),
transitions: Vec::new(),
}
}
pub fn add_state(mut self, state: S) -> Self {
self.states.insert(state);
self
}
pub fn add_states(mut self, states: &[S]) -> Self {
self.states.extend(states.iter().copied());
self
}
pub fn set_initial_state(mut self, state: S) -> Self {
self.initial_states.insert(state);
self
}
pub fn set_initial_states(mut self, states: &[S]) -> Self {
self.initial_states.extend(states.iter().copied());
self
}
pub fn set_final_state(mut self, state: S) -> Self {
self.final_states.insert(state);
self
}
pub fn set_final_states(mut self, states: &[S]) -> Self {
self.final_states.extend(states.iter().copied());
self
}
pub fn add_transition(self, source: S, event: E, target: S) -> Self {
self.add_transition_value(Transition::new(source, event, target))
}
pub fn add_transition_value(mut self, transition: Transition<S, E>) -> Self {
self.transitions.push(transition);
self
}
pub fn build(self) -> Result<StateMachine<S, E>, StateMachineBuildError<S, E>> {
self.validate_registered_states()?;
let mut transition_set = HashSet::new();
let mut transition_map = HashMap::new();
for transition in &self.transitions {
let transition = *transition;
self.validate_transition(transition)?;
Self::insert_transition(transition, &mut transition_set, &mut transition_map)?;
}
Ok(StateMachine::new(self, transition_set, transition_map))
}
fn validate_registered_states(&self) -> Result<(), StateMachineBuildError<S, E>> {
for state in &self.initial_states {
if !self.states.contains(state) {
return Err(StateMachineBuildError::InitialStateNotRegistered { state: *state });
}
}
for state in &self.final_states {
if !self.states.contains(state) {
return Err(StateMachineBuildError::FinalStateNotRegistered { state: *state });
}
}
Ok(())
}
fn validate_transition(
&self,
transition: Transition<S, E>,
) -> Result<(), StateMachineBuildError<S, E>> {
if !self.states.contains(&transition.source()) {
return Err(StateMachineBuildError::TransitionSourceNotRegistered {
source: transition.source(),
event: transition.event(),
target: transition.target(),
});
}
if !self.states.contains(&transition.target()) {
return Err(StateMachineBuildError::TransitionTargetNotRegistered {
source: transition.source(),
event: transition.event(),
target: transition.target(),
});
}
Ok(())
}
fn insert_transition(
transition: Transition<S, E>,
transition_set: &mut HashSet<Transition<S, E>>,
transition_map: &mut HashMap<(S, E), S>,
) -> Result<(), StateMachineBuildError<S, E>> {
let source = transition.source();
let event = transition.event();
let target = transition.target();
if let Some(existing_target) = transition_map.get(&(source, event))
&& *existing_target != target
{
return Err(StateMachineBuildError::DuplicateTransition {
source,
event,
existing_target: *existing_target,
new_target: target,
});
}
transition_set.insert(transition);
transition_map.insert((source, event), target);
Ok(())
}
}
impl<S, E> Default for StateMachineBuilder<S, E>
where
S: Copy + Eq + Hash + Debug + 'static,
E: Copy + Eq + Hash + Debug + 'static,
{
fn default() -> Self {
Self::new()
}
}