use std::collections::{HashMap, HashSet};
use std::fmt;
use std::hash::Hash;
pub trait StateMachine<S>: Send + Sync
where
S: Clone + PartialEq + Eq + Hash + Send + Sync,
{
fn can_transition(&self, from: &S, to: &S) -> bool;
fn next_states(&self, from: &S) -> Vec<S>;
fn transition(&self, from: &S, to: &S) -> Result<(), StateMachineError<S>>
where
S: fmt::Debug,
{
if self.can_transition(from, to) {
Ok(())
} else {
Err(StateMachineError::IllegalTransition {
from: from.clone(),
to: to.clone(),
})
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum StateMachineError<S: fmt::Debug> {
#[error("非法的状态转换: {from:?} → {to:?}")]
IllegalTransition { from: S, to: S },
}
pub struct SimpleStateMachine<S: Clone + PartialEq + Eq + Hash> {
transitions: HashMap<S, HashSet<S>>,
}
impl<S> SimpleStateMachine<S>
where
S: Clone + PartialEq + Eq + Hash,
{
pub fn new(pairs: Vec<(S, S)>) -> Self {
let mut transitions: HashMap<S, HashSet<S>> = HashMap::new();
for (from, to) in pairs {
transitions
.entry(from)
.or_insert_with(HashSet::new)
.insert(to);
}
Self { transitions }
}
}
impl<S> StateMachine<S> for SimpleStateMachine<S>
where
S: Clone + PartialEq + Eq + Hash + Send + Sync,
{
fn can_transition(&self, from: &S, to: &S) -> bool {
self.transitions
.get(from)
.map(|targets| targets.contains(to))
.unwrap_or(false)
}
fn next_states(&self, from: &S) -> Vec<S> {
self.transitions
.get(from)
.map(|targets| targets.iter().cloned().collect())
.unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum TestStatus {
Init,
Processing,
Done,
Failed,
}
fn create_test_sm() -> SimpleStateMachine<TestStatus> {
SimpleStateMachine::new(vec![
(TestStatus::Init, TestStatus::Processing),
(TestStatus::Processing, TestStatus::Done),
(TestStatus::Processing, TestStatus::Failed),
(TestStatus::Failed, TestStatus::Init), ])
}
#[test]
fn test_valid_transitions() {
let sm = create_test_sm();
assert!(sm.can_transition(&TestStatus::Init, &TestStatus::Processing));
assert!(sm.can_transition(&TestStatus::Processing, &TestStatus::Done));
assert!(sm.can_transition(&TestStatus::Processing, &TestStatus::Failed));
assert!(sm.can_transition(&TestStatus::Failed, &TestStatus::Init));
}
#[test]
fn test_invalid_transitions() {
let sm = create_test_sm();
assert!(!sm.can_transition(&TestStatus::Init, &TestStatus::Done));
assert!(!sm.can_transition(&TestStatus::Done, &TestStatus::Init));
assert!(!sm.can_transition(&TestStatus::Done, &TestStatus::Processing));
assert!(!sm.can_transition(&TestStatus::Init, &TestStatus::Failed));
}
#[test]
fn test_next_states() {
let sm = create_test_sm();
let next = sm.next_states(&TestStatus::Processing);
assert_eq!(next.len(), 2);
assert!(next.contains(&TestStatus::Done));
assert!(next.contains(&TestStatus::Failed));
}
#[test]
fn test_transition_ok() {
let sm = create_test_sm();
assert!(sm.transition(&TestStatus::Init, &TestStatus::Processing).is_ok());
}
#[test]
fn test_transition_err() {
let sm = create_test_sm();
let result = sm.transition(&TestStatus::Done, &TestStatus::Init);
assert!(result.is_err());
}
}