use crate::input::Action;
use crate::state::{State, StateResult};
use serde_json::Value;
use std::collections::VecDeque;
use std::marker::PhantomData;
pub struct StateMachine<Ctx, S: ?Sized> {
cur_state: Option<Box<S>>,
queue: VecDeque<Box<S>>,
results: serde_json::Map<String, Value>,
_ctx: PhantomData<fn() -> Ctx>,
}
impl<Ctx, S: State<Ctx, S> + ?Sized> StateMachine<Ctx, S> {
pub fn new(queue: VecDeque<Box<S>>) -> Self {
Self {
cur_state: None,
queue,
results: serde_json::Map::new(),
_ctx: PhantomData,
}
}
pub fn start(&mut self, ctx: &Ctx) -> Option<Value> {
self.cur_state = self.queue.pop_front();
let Some(state) = &mut self.cur_state else {
return Some(self.take_results());
};
state.on_enter(ctx);
None
}
pub fn advance(&mut self, action: Option<Action>, ctx: &Ctx) -> Option<Value> {
let state = self
.cur_state
.as_mut()
.expect("advance() called on an inactive machine");
let result = match action {
Some(a) => state.handle_action(a, ctx),
None => state.tick(ctx),
};
match result {
None => None,
Some(done) => self.transition(done, ctx),
}
}
pub fn current_state(&self) -> Option<&S> {
self.cur_state.as_deref()
}
pub fn current_state_mut(&mut self) -> Option<&mut S> {
self.cur_state.as_deref_mut()
}
pub fn finish(&mut self) -> Value {
self.cur_state = None;
self.queue.clear();
self.take_results()
}
fn transition(&mut self, done: StateResult<S>, ctx: &Ctx) -> Option<Value> {
if let Some(state) = &mut self.cur_state {
state.on_exit(ctx);
}
store_output(&mut self.results, done.output);
prepend_states(&mut self.queue, done.then);
match self.queue.pop_front() {
Some(mut next) => {
next.on_enter(ctx);
self.cur_state = Some(next);
None
}
None => {
self.cur_state = None;
Some(self.take_results())
}
}
}
fn take_results(&mut self) -> Value {
Value::Object(std::mem::take(&mut self.results))
}
}
fn store_output(results: &mut serde_json::Map<String, Value>, output: Option<(String, String)>) {
if let Some((k, v)) = output {
results.insert(k, Value::String(v));
}
}
fn prepend_states<S: ?Sized>(queue: &mut VecDeque<Box<S>>, then: Vec<Box<S>>) {
for state in then.into_iter().rev() {
queue.push_front(state);
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestCtx;
enum TestState {
Tick {
ticks: usize,
done_after: usize,
output_key: Option<String>,
},
Spawner {
child_output_key: String,
own_output: Option<(String, String)>,
},
}
impl TestState {
fn tick_state(done_after: usize, output_key: Option<&str>) -> Box<Self> {
Box::new(Self::Tick {
ticks: 0,
done_after,
output_key: output_key.map(String::from),
})
}
}
impl State<TestCtx, TestState> for TestState {
fn tick(&mut self, _ctx: &TestCtx) -> Option<StateResult<TestState>> {
match self {
TestState::Tick {
ticks,
done_after,
output_key,
} => {
*ticks += 1;
if *ticks >= *done_after {
Some(StateResult {
output: output_key.as_ref().map(|k| (k.clone(), format!("val_{k}"))),
then: vec![],
})
} else {
None
}
}
TestState::Spawner {
child_output_key,
own_output,
} => {
let child = TestState::tick_state(1, Some(child_output_key));
Some(StateResult {
output: own_output.take(),
then: vec![child],
})
}
}
}
fn handle_action(&mut self, _action: Action, _ctx: &TestCtx) -> Option<StateResult<TestState>> {
None
}
}
fn ctx() -> TestCtx {
TestCtx
}
#[test]
fn start_with_empty_queue_returns_done() {
let mut sm = StateMachine::<TestCtx, TestState>::new(VecDeque::new());
match sm.start(&ctx()) {
Some(v) => assert_eq!(v, Value::Object(serde_json::Map::new())),
None => panic!("expected Some(Value)"),
}
}
#[test]
fn single_state_runs_to_completion() {
let mut sm = StateMachine::new(VecDeque::from(vec![TestState::tick_state(2, Some("name"))]));
let c = ctx();
assert!(sm.start(&c).is_none());
assert!(sm.advance(None, &c).is_none());
match sm.advance(None, &c) {
Some(v) => {
assert_eq!(v.get("name").and_then(|v| v.as_str()), Some("val_name"));
}
None => panic!("expected Some(Value)"),
}
}
#[test]
fn advance_with_action_dispatches_handle_action() {
let mut sm = StateMachine::new(VecDeque::from(vec![TestState::tick_state(1, None)]));
let c = ctx();
sm.start(&c);
let action = Some(Action::Submit("hello".into()));
assert!(sm.advance(action, &c).is_none());
}
#[test]
fn sequential_states_chain() {
let mut sm = StateMachine::new(VecDeque::from(vec![
TestState::tick_state(1, Some("a")),
TestState::tick_state(1, Some("b")),
]));
let c = ctx();
sm.start(&c);
assert!(sm.advance(None, &c).is_none());
match sm.advance(None, &c) {
Some(v) => {
assert_eq!(v.get("a").and_then(|v| v.as_str()), Some("val_a"));
assert_eq!(v.get("b").and_then(|v| v.as_str()), Some("val_b"));
}
None => panic!("expected Some(Value)"),
}
}
#[test]
fn continuation_states_are_spliced_before_queue() {
let spawner = Box::new(TestState::Spawner {
child_output_key: "child".into(),
own_output: Some(("spawner".into(), "done".into())),
});
let tail = TestState::tick_state(1, Some("tail"));
let mut sm = StateMachine::new(VecDeque::from(vec![spawner, tail]));
let c = ctx();
sm.start(&c);
assert!(sm.advance(None, &c).is_none());
assert!(sm.advance(None, &c).is_none());
match sm.advance(None, &c) {
Some(v) => {
assert_eq!(v.get("spawner").and_then(|v| v.as_str()), Some("done"));
assert_eq!(v.get("child").and_then(|v| v.as_str()), Some("val_child"));
assert_eq!(v.get("tail").and_then(|v| v.as_str()), Some("val_tail"));
}
None => panic!("expected Some(Value)"),
}
}
#[test]
fn finish_drains_machine() {
let mut sm = StateMachine::new(VecDeque::from(vec![TestState::tick_state(100, None)]));
let c = ctx();
sm.start(&c);
let v = sm.finish();
assert_eq!(v, Value::Object(serde_json::Map::new()));
assert!(sm.current_state().is_none());
}
#[test]
fn current_state_accessors() {
let mut sm = StateMachine::new(VecDeque::from(vec![TestState::tick_state(1, None)]));
assert!(sm.current_state().is_none());
sm.start(&ctx());
assert!(sm.current_state().is_some());
assert!(sm.current_state_mut().is_some());
}
}