use rustc_hash::FxHashMap;
use slotmap::{new_key_type, SlotMap};
use smallvec::SmallVec;
new_key_type! {
pub struct FsmId;
}
pub type StateId = u32;
pub type EventId = u32;
pub type Guard = Box<dyn Fn() -> bool + Send>;
pub type Action = Box<dyn FnMut() + Send>;
pub struct Transition {
pub from_state: StateId,
pub event: EventId,
pub to_state: StateId,
pub guard: Option<Guard>,
pub actions: SmallVec<[Action; 2]>,
}
impl Transition {
pub fn new(from: StateId, event: EventId, to: StateId) -> Self {
Self {
from_state: from,
event,
to_state: to,
guard: None,
actions: SmallVec::new(),
}
}
pub fn with_guard<F: Fn() -> bool + Send + 'static>(mut self, guard: F) -> Self {
self.guard = Some(Box::new(guard));
self
}
pub fn with_action<F: FnMut() + Send + 'static>(mut self, action: F) -> Self {
self.actions.push(Box::new(action));
self
}
}
pub struct StateMachineBuilder {
initial_state: StateId,
transitions: Vec<Transition>,
entry_callbacks: FxHashMap<StateId, Vec<Action>>,
exit_callbacks: FxHashMap<StateId, Vec<Action>>,
}
impl StateMachineBuilder {
pub fn new(initial_state: StateId) -> Self {
Self {
initial_state,
transitions: Vec::new(),
entry_callbacks: FxHashMap::default(),
exit_callbacks: FxHashMap::default(),
}
}
pub fn transition(mut self, transition: Transition) -> Self {
self.transitions.push(transition);
self
}
pub fn on(mut self, from: StateId, event: EventId, to: StateId) -> Self {
self.transitions.push(Transition::new(from, event, to));
self
}
pub fn on_enter<F: FnMut() + Send + 'static>(mut self, state: StateId, action: F) -> Self {
self.entry_callbacks
.entry(state)
.or_default()
.push(Box::new(action));
self
}
pub fn on_exit<F: FnMut() + Send + 'static>(mut self, state: StateId, action: F) -> Self {
self.exit_callbacks
.entry(state)
.or_default()
.push(Box::new(action));
self
}
pub fn build(self) -> StateMachine {
StateMachine {
current_state: self.initial_state,
transitions: self.transitions,
entry_callbacks: self.entry_callbacks,
exit_callbacks: self.exit_callbacks,
history: Vec::new(),
}
}
}
pub struct StateMachine {
current_state: StateId,
transitions: Vec<Transition>,
entry_callbacks: FxHashMap<StateId, Vec<Action>>,
exit_callbacks: FxHashMap<StateId, Vec<Action>>,
history: Vec<(StateId, EventId, StateId)>,
}
impl StateMachine {
pub fn new(initial_state: StateId, transitions: Vec<Transition>) -> Self {
Self {
current_state: initial_state,
transitions,
entry_callbacks: FxHashMap::default(),
exit_callbacks: FxHashMap::default(),
history: Vec::new(),
}
}
pub fn builder(initial_state: StateId) -> StateMachineBuilder {
StateMachineBuilder::new(initial_state)
}
pub fn current_state(&self) -> StateId {
self.current_state
}
pub fn is_in(&self, state: StateId) -> bool {
self.current_state == state
}
pub fn history(&self) -> &[(StateId, EventId, StateId)] {
&self.history
}
pub fn clear_history(&mut self) {
self.history.clear();
}
pub fn can_send(&self, event: EventId) -> bool {
let current = self.current_state;
self.transitions.iter().any(|t| {
t.from_state == current && t.event == event && {
match &t.guard {
Some(guard) => guard(),
None => true,
}
}
})
}
pub fn send(&mut self, event: EventId) -> StateId {
let current = self.current_state;
let transition_idx = self.transitions.iter().position(|t| {
t.from_state == current && t.event == event && {
match &t.guard {
Some(guard) => guard(),
None => true,
}
}
});
let Some(idx) = transition_idx else {
return current;
};
let to_state = self.transitions[idx].to_state;
if let Some(callbacks) = self.exit_callbacks.get_mut(¤t) {
for callback in callbacks.iter_mut() {
callback();
}
}
for action in self.transitions[idx].actions.iter_mut() {
action();
}
self.current_state = to_state;
self.history.push((current, event, to_state));
if let Some(callbacks) = self.entry_callbacks.get_mut(&to_state) {
for callback in callbacks.iter_mut() {
callback();
}
}
to_state
}
pub fn on_enter<F: FnMut() + Send + 'static>(&mut self, state: StateId, callback: F) {
self.entry_callbacks
.entry(state)
.or_default()
.push(Box::new(callback));
}
pub fn on_exit<F: FnMut() + Send + 'static>(&mut self, state: StateId, callback: F) {
self.exit_callbacks
.entry(state)
.or_default()
.push(Box::new(callback));
}
}
pub struct FsmRuntime {
machines: SlotMap<FsmId, StateMachine>,
}
impl FsmRuntime {
pub fn new() -> Self {
Self {
machines: SlotMap::with_key(),
}
}
pub fn create(&mut self, machine: StateMachine) -> FsmId {
self.machines.insert(machine)
}
pub fn create_simple(&mut self, initial_state: StateId, transitions: Vec<Transition>) -> FsmId {
self.machines
.insert(StateMachine::new(initial_state, transitions))
}
pub fn get(&self, id: FsmId) -> Option<&StateMachine> {
self.machines.get(id)
}
pub fn get_mut(&mut self, id: FsmId) -> Option<&mut StateMachine> {
self.machines.get_mut(id)
}
pub fn send(&mut self, id: FsmId, event: EventId) -> Option<StateId> {
self.machines.get_mut(id).map(|fsm| fsm.send(event))
}
pub fn current_state(&self, id: FsmId) -> Option<StateId> {
self.machines.get(id).map(|fsm| fsm.current_state())
}
pub fn remove(&mut self, id: FsmId) -> Option<StateMachine> {
self.machines.remove(id)
}
pub fn len(&self) -> usize {
self.machines.len()
}
pub fn is_empty(&self) -> bool {
self.machines.is_empty()
}
}
impl Default for FsmRuntime {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
const IDLE: StateId = 0;
const HOVERED: StateId = 1;
const PRESSED: StateId = 2;
const DISABLED: StateId = 3;
const POINTER_ENTER: EventId = 1;
const POINTER_LEAVE: EventId = 2;
const POINTER_DOWN: EventId = 3;
const POINTER_UP: EventId = 4;
const DISABLE: EventId = 5;
#[test]
fn test_simple_transitions() {
let mut fsm = StateMachine::new(
IDLE,
vec![
Transition::new(IDLE, POINTER_ENTER, HOVERED),
Transition::new(HOVERED, POINTER_LEAVE, IDLE),
Transition::new(HOVERED, POINTER_DOWN, PRESSED),
Transition::new(PRESSED, POINTER_UP, HOVERED),
],
);
assert_eq!(fsm.current_state(), IDLE);
fsm.send(POINTER_ENTER);
assert_eq!(fsm.current_state(), HOVERED);
fsm.send(POINTER_DOWN);
assert_eq!(fsm.current_state(), PRESSED);
fsm.send(POINTER_UP);
assert_eq!(fsm.current_state(), HOVERED);
fsm.send(POINTER_LEAVE);
assert_eq!(fsm.current_state(), IDLE);
}
#[test]
fn test_invalid_event_no_transition() {
let mut fsm = StateMachine::new(IDLE, vec![Transition::new(IDLE, POINTER_ENTER, HOVERED)]);
fsm.send(POINTER_DOWN);
assert_eq!(fsm.current_state(), IDLE);
}
#[test]
fn test_guard_conditions() {
let enabled = Arc::new(Mutex::new(true));
let enabled_clone = enabled.clone();
let mut fsm = StateMachine::builder(IDLE)
.transition(
Transition::new(IDLE, POINTER_ENTER, HOVERED)
.with_guard(move || *enabled_clone.lock().unwrap()),
)
.build();
fsm.send(POINTER_ENTER);
assert_eq!(fsm.current_state(), HOVERED);
fsm.current_state = IDLE;
*enabled.lock().unwrap() = false;
fsm.send(POINTER_ENTER);
assert_eq!(fsm.current_state(), IDLE);
}
#[test]
fn test_entry_exit_callbacks() {
let entry_count = Arc::new(Mutex::new(0));
let exit_count = Arc::new(Mutex::new(0));
let entry_clone = entry_count.clone();
let exit_clone = exit_count.clone();
let mut fsm = StateMachine::builder(IDLE)
.on(IDLE, POINTER_ENTER, HOVERED)
.on(HOVERED, POINTER_LEAVE, IDLE)
.on_enter(HOVERED, move || {
*entry_clone.lock().unwrap() += 1;
})
.on_exit(HOVERED, move || {
*exit_clone.lock().unwrap() += 1;
})
.build();
assert_eq!(*entry_count.lock().unwrap(), 0);
assert_eq!(*exit_count.lock().unwrap(), 0);
fsm.send(POINTER_ENTER);
assert_eq!(*entry_count.lock().unwrap(), 1);
assert_eq!(*exit_count.lock().unwrap(), 0);
fsm.send(POINTER_LEAVE);
assert_eq!(*entry_count.lock().unwrap(), 1);
assert_eq!(*exit_count.lock().unwrap(), 1);
fsm.send(POINTER_ENTER);
assert_eq!(*entry_count.lock().unwrap(), 2);
}
#[test]
fn test_transition_actions() {
let action_count = Arc::new(Mutex::new(0));
let action_clone = action_count.clone();
let mut fsm = StateMachine::builder(IDLE)
.transition(
Transition::new(IDLE, POINTER_ENTER, HOVERED).with_action(move || {
*action_clone.lock().unwrap() += 1;
}),
)
.build();
fsm.send(POINTER_ENTER);
assert_eq!(*action_count.lock().unwrap(), 1);
}
#[test]
fn test_history() {
let mut fsm = StateMachine::new(
IDLE,
vec![
Transition::new(IDLE, POINTER_ENTER, HOVERED),
Transition::new(HOVERED, POINTER_DOWN, PRESSED),
],
);
fsm.send(POINTER_ENTER);
fsm.send(POINTER_DOWN);
let history = fsm.history();
assert_eq!(history.len(), 2);
assert_eq!(history[0], (IDLE, POINTER_ENTER, HOVERED));
assert_eq!(history[1], (HOVERED, POINTER_DOWN, PRESSED));
}
#[test]
fn test_can_send() {
let fsm = StateMachine::new(IDLE, vec![Transition::new(IDLE, POINTER_ENTER, HOVERED)]);
assert!(fsm.can_send(POINTER_ENTER));
assert!(!fsm.can_send(POINTER_DOWN));
}
#[test]
fn test_fsm_runtime() {
let mut runtime = FsmRuntime::new();
let fsm1 = runtime.create_simple(IDLE, vec![Transition::new(IDLE, POINTER_ENTER, HOVERED)]);
let fsm2 = runtime.create_simple(IDLE, vec![Transition::new(IDLE, POINTER_DOWN, PRESSED)]);
assert_eq!(runtime.len(), 2);
runtime.send(fsm1, POINTER_ENTER);
assert_eq!(runtime.current_state(fsm1), Some(HOVERED));
assert_eq!(runtime.current_state(fsm2), Some(IDLE));
runtime.remove(fsm1);
assert_eq!(runtime.len(), 1);
assert_eq!(runtime.current_state(fsm1), None);
}
}