1use crate::input::Action;
3use crate::state::{State, StateResult};
4use serde_json::Value;
5use std::collections::VecDeque;
6use std::marker::PhantomData;
7
8pub struct StateMachine<Ctx, S: ?Sized> {
14 cur_state: Option<Box<S>>,
15 queue: VecDeque<Box<S>>,
16 results: serde_json::Map<String, Value>,
17 _ctx: PhantomData<fn() -> Ctx>,
20}
21
22impl<Ctx, S: State<Ctx, S> + ?Sized> StateMachine<Ctx, S> {
24 pub fn new(queue: VecDeque<Box<S>>) -> Self {
25 Self {
26 cur_state: None,
27 queue,
28 results: serde_json::Map::new(),
29 _ctx: PhantomData,
30 }
31 }
32
33 pub fn start(&mut self, ctx: &Ctx) -> Option<Value> {
38 self.cur_state = self.queue.pop_front();
39 let Some(state) = &mut self.cur_state else {
40 return Some(self.take_results());
41 };
42 state.on_enter(ctx);
43 None
44 }
45
46 pub fn advance(&mut self, action: Option<Action>, ctx: &Ctx) -> Option<Value> {
53 let state = self
54 .cur_state
55 .as_mut()
56 .expect("advance() called on an inactive machine");
57
58 let result = match action {
59 Some(a) => state.handle_action(a, ctx),
60 None => state.tick(ctx),
61 };
62
63 match result {
64 None => None,
65 Some(done) => self.transition(done, ctx),
66 }
67 }
68
69 pub fn current_state(&self) -> Option<&S> {
70 self.cur_state.as_deref()
71 }
72
73 pub fn current_state_mut(&mut self) -> Option<&mut S> {
74 self.cur_state.as_deref_mut()
75 }
76
77 pub fn finish(&mut self) -> Value {
79 self.cur_state = None;
80 self.queue.clear();
81 self.take_results()
82 }
83
84 fn transition(&mut self, done: StateResult<S>, ctx: &Ctx) -> Option<Value> {
85 if let Some(state) = &mut self.cur_state {
86 state.on_exit(ctx);
87 }
88
89 store_output(&mut self.results, done.output);
90 prepend_states(&mut self.queue, done.then);
91
92 match self.queue.pop_front() {
93 Some(mut next) => {
94 next.on_enter(ctx);
95 self.cur_state = Some(next);
96 None
97 }
98 None => {
99 self.cur_state = None;
100 Some(self.take_results())
101 }
102 }
103 }
104
105 fn take_results(&mut self) -> Value {
106 Value::Object(std::mem::take(&mut self.results))
107 }
108}
109
110fn store_output(results: &mut serde_json::Map<String, Value>, output: Option<(String, String)>) {
111 if let Some((k, v)) = output {
112 results.insert(k, Value::String(v));
113 }
114}
115
116fn prepend_states<S: ?Sized>(queue: &mut VecDeque<Box<S>>, then: Vec<Box<S>>) {
117 for state in then.into_iter().rev() {
118 queue.push_front(state);
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 struct TestCtx;
127
128 enum TestState {
131 Tick {
132 ticks: usize,
133 done_after: usize,
134 output_key: Option<String>,
135 },
136 Spawner {
137 child_output_key: String,
138 own_output: Option<(String, String)>,
139 },
140 }
141
142 impl TestState {
143 fn tick_state(done_after: usize, output_key: Option<&str>) -> Box<Self> {
144 Box::new(Self::Tick {
145 ticks: 0,
146 done_after,
147 output_key: output_key.map(String::from),
148 })
149 }
150 }
151
152 impl State<TestCtx, TestState> for TestState {
153 fn tick(&mut self, _ctx: &TestCtx) -> Option<StateResult<TestState>> {
154 match self {
155 TestState::Tick {
156 ticks,
157 done_after,
158 output_key,
159 } => {
160 *ticks += 1;
161 if *ticks >= *done_after {
162 Some(StateResult {
163 output: output_key.as_ref().map(|k| (k.clone(), format!("val_{k}"))),
164 then: vec![],
165 })
166 } else {
167 None
168 }
169 }
170 TestState::Spawner {
171 child_output_key,
172 own_output,
173 } => {
174 let child = TestState::tick_state(1, Some(child_output_key));
175 Some(StateResult {
176 output: own_output.take(),
177 then: vec![child],
178 })
179 }
180 }
181 }
182
183 fn handle_action(&mut self, _action: Action, _ctx: &TestCtx) -> Option<StateResult<TestState>> {
184 None
185 }
186 }
187
188 fn ctx() -> TestCtx {
189 TestCtx
190 }
191
192 #[test]
193 fn start_with_empty_queue_returns_done() {
194 let mut sm = StateMachine::<TestCtx, TestState>::new(VecDeque::new());
195 match sm.start(&ctx()) {
196 Some(v) => assert_eq!(v, Value::Object(serde_json::Map::new())),
197 None => panic!("expected Some(Value)"),
198 }
199 }
200
201 #[test]
202 fn single_state_runs_to_completion() {
203 let mut sm = StateMachine::new(VecDeque::from(vec![TestState::tick_state(2, Some("name"))]));
204 let c = ctx();
205
206 assert!(sm.start(&c).is_none());
207 assert!(sm.advance(None, &c).is_none());
208
209 match sm.advance(None, &c) {
210 Some(v) => {
211 assert_eq!(v.get("name").and_then(|v| v.as_str()), Some("val_name"));
212 }
213 None => panic!("expected Some(Value)"),
214 }
215 }
216
217 #[test]
218 fn advance_with_action_dispatches_handle_action() {
219 let mut sm = StateMachine::new(VecDeque::from(vec![TestState::tick_state(1, None)]));
220 let c = ctx();
221 sm.start(&c);
222
223 let action = Some(Action::Submit("hello".into()));
224 assert!(sm.advance(action, &c).is_none());
225 }
226
227 #[test]
228 fn sequential_states_chain() {
229 let mut sm = StateMachine::new(VecDeque::from(vec![
230 TestState::tick_state(1, Some("a")),
231 TestState::tick_state(1, Some("b")),
232 ]));
233 let c = ctx();
234 sm.start(&c);
235
236 assert!(sm.advance(None, &c).is_none());
237
238 match sm.advance(None, &c) {
239 Some(v) => {
240 assert_eq!(v.get("a").and_then(|v| v.as_str()), Some("val_a"));
241 assert_eq!(v.get("b").and_then(|v| v.as_str()), Some("val_b"));
242 }
243 None => panic!("expected Some(Value)"),
244 }
245 }
246
247 #[test]
248 fn continuation_states_are_spliced_before_queue() {
249 let spawner = Box::new(TestState::Spawner {
250 child_output_key: "child".into(),
251 own_output: Some(("spawner".into(), "done".into())),
252 });
253 let tail = TestState::tick_state(1, Some("tail"));
254 let mut sm = StateMachine::new(VecDeque::from(vec![spawner, tail]));
255 let c = ctx();
256 sm.start(&c);
257
258 assert!(sm.advance(None, &c).is_none());
260
261 assert!(sm.advance(None, &c).is_none());
263
264 match sm.advance(None, &c) {
266 Some(v) => {
267 assert_eq!(v.get("spawner").and_then(|v| v.as_str()), Some("done"));
268 assert_eq!(v.get("child").and_then(|v| v.as_str()), Some("val_child"));
269 assert_eq!(v.get("tail").and_then(|v| v.as_str()), Some("val_tail"));
270 }
271 None => panic!("expected Some(Value)"),
272 }
273 }
274
275 #[test]
276 fn finish_drains_machine() {
277 let mut sm = StateMachine::new(VecDeque::from(vec![TestState::tick_state(100, None)]));
278 let c = ctx();
279 sm.start(&c);
280
281 let v = sm.finish();
282 assert_eq!(v, Value::Object(serde_json::Map::new()));
283 assert!(sm.current_state().is_none());
284 }
285
286 #[test]
287 fn current_state_accessors() {
288 let mut sm = StateMachine::new(VecDeque::from(vec![TestState::tick_state(1, None)]));
289
290 assert!(sm.current_state().is_none());
291 sm.start(&ctx());
292 assert!(sm.current_state().is_some());
293 assert!(sm.current_state_mut().is_some());
294 }
295}