use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::hash::Hash;
use qubit_atomic::AtomicRef;
use qubit_cas::{CasDecision, CasError, CasExecutor, CasSuccess};
use crate::{StateMachineBuilder, StateMachineError, StateMachineResult, Transition};
#[derive(Debug, Clone)]
pub struct StateMachine<S, E>
where
S: Copy + Eq + Hash + Debug + 'static,
E: Copy + Eq + Hash + Debug + 'static,
{
states: HashSet<S>,
initial_states: HashSet<S>,
final_states: HashSet<S>,
transitions: HashSet<Transition<S, E>>,
transition_map: HashMap<(S, E), S>,
cas_executor: CasExecutor<S, StateMachineError<S, E>>,
}
impl<S, E> StateMachine<S, E>
where
S: Copy + Eq + Hash + Debug + 'static,
E: Copy + Eq + Hash + Debug + 'static,
{
pub fn builder() -> StateMachineBuilder<S, E> {
StateMachineBuilder::new()
}
pub(crate) fn new(
builder: StateMachineBuilder<S, E>,
transitions: HashSet<Transition<S, E>>,
transition_map: HashMap<(S, E), S>,
) -> Self {
Self {
states: builder.states,
initial_states: builder.initial_states,
final_states: builder.final_states,
transitions,
transition_map,
cas_executor: CasExecutor::latency_first(),
}
}
pub const fn states(&self) -> &HashSet<S> {
&self.states
}
pub const fn initial_states(&self) -> &HashSet<S> {
&self.initial_states
}
pub const fn final_states(&self) -> &HashSet<S> {
&self.final_states
}
pub const fn transitions(&self) -> &HashSet<Transition<S, E>> {
&self.transitions
}
pub fn contains_state(&self, state: S) -> bool {
self.states.contains(&state)
}
pub fn is_initial_state(&self, state: S) -> bool {
self.initial_states.contains(&state)
}
pub fn is_final_state(&self, state: S) -> bool {
self.final_states.contains(&state)
}
pub fn transition_target(&self, source: S, event: E) -> Option<S> {
self.transition_map.get(&(source, event)).copied()
}
pub fn trigger(&self, state: &AtomicRef<S>, event: E) -> StateMachineResult<S, E> {
let (_, new_state) = self.change_state(state, event)?;
Ok(new_state)
}
pub fn trigger_with<F>(
&self,
state: &AtomicRef<S>,
event: E,
on_success: F,
) -> StateMachineResult<S, E>
where
F: FnOnce(S, S),
{
let (old_state, new_state) = self.change_state(state, event)?;
on_success(old_state, new_state);
Ok(new_state)
}
pub fn try_trigger(&self, state: &AtomicRef<S>, event: E) -> bool {
self.trigger(state, event).is_ok()
}
pub fn try_trigger_with<F>(&self, state: &AtomicRef<S>, event: E, on_success: F) -> bool
where
F: FnOnce(S, S),
{
self.trigger_with(state, event, on_success).is_ok()
}
fn change_state(
&self,
state: &AtomicRef<S>,
event: E,
) -> Result<(S, S), StateMachineError<S, E>> {
let outcome = self.cas_executor.execute(state, |current_state: &S| {
match self.next_state(*current_state, event) {
Ok(new_state) => CasDecision::update(new_state, new_state),
Err(error) => CasDecision::abort(error),
}
});
match outcome.into_result() {
Ok(success) => Ok(Self::state_change_from_success(success)),
Err(error) => Err(Self::state_error_from_cas_error(error)),
}
}
fn next_state(&self, current_state: S, event: E) -> Result<S, StateMachineError<S, E>> {
if !self.contains_state(current_state) {
return Err(StateMachineError::UnknownState {
state: current_state,
});
}
self.transition_target(current_state, event)
.ok_or(StateMachineError::UnknownTransition {
source: current_state,
event,
})
}
fn state_change_from_success(success: CasSuccess<S, S>) -> (S, S) {
match success {
CasSuccess::Updated {
previous, current, ..
} => (*previous, *current),
CasSuccess::Finished { current, .. } => (*current, *current),
}
}
fn state_error_from_cas_error(
error: CasError<S, StateMachineError<S, E>>,
) -> StateMachineError<S, E> {
match error.error() {
Some(error) => *error,
None => StateMachineError::CasConflict {
attempts: error.attempts(),
},
}
}
}