1use std::collections::HashMap;
9use std::hash::Hash;
10use std::time::Duration;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct FsmTransition<S, D> {
14 pub next: S,
15 pub data: D,
16 pub timeout: Option<Duration>,
17}
18
19pub trait FiniteStateMachine {
22 type State: Clone + Eq + 'static;
23 type Data: Clone + 'static;
24 type Msg: Send + 'static;
25
26 fn initial_state(&self) -> Self::State;
27 fn initial_data(&self) -> Self::Data;
28
29 fn transition(
30 &mut self,
31 current: &Self::State,
32 data: &Self::Data,
33 msg: Self::Msg,
34 ) -> Option<FsmTransition<Self::State, Self::Data>>;
35}
36
37#[cfg(test)]
38mod tests {
39 use super::*;
40
41 #[derive(Clone, Eq, PartialEq, Debug)]
42 enum S {
43 Idle,
44 Running,
45 }
46
47 struct TrafficLight;
48 enum M {
49 Go,
50 Stop,
51 }
52
53 impl FiniteStateMachine for TrafficLight {
54 type State = S;
55 type Data = u32;
56 type Msg = M;
57
58 fn initial_state(&self) -> S {
59 S::Idle
60 }
61 fn initial_data(&self) -> u32 {
62 0
63 }
64
65 fn transition(&mut self, s: &S, d: &u32, m: M) -> Option<FsmTransition<S, u32>> {
66 match (s, m) {
67 (S::Idle, M::Go) => Some(FsmTransition { next: S::Running, data: d + 1, timeout: None }),
68 (S::Running, M::Stop) => Some(FsmTransition { next: S::Idle, data: *d, timeout: None }),
69 _ => None,
70 }
71 }
72 }
73
74 #[test]
75 fn transitions_idle_to_running() {
76 let mut fsm = TrafficLight;
77 let t = fsm.transition(&S::Idle, &0, M::Go).unwrap();
78 assert_eq!(t.next, S::Running);
79 assert_eq!(t.data, 1);
80 }
81
82 #[test]
83 fn transitions_running_to_idle_on_stop() {
84 let mut fsm = TrafficLight;
85 let t = fsm.transition(&S::Running, &5, M::Stop).unwrap();
86 assert_eq!(t.next, S::Idle);
87 assert_eq!(t.data, 5);
88 }
89}
90
91type StateHandler<S, D, M> = Box<dyn FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static>;
94
95type TransitionHook<S> = Box<dyn FnMut(&S, &S) + Send + 'static>;
96
97type TerminationHook<S, D> = Box<dyn FnMut(&S, &D) + Send + 'static>;
98
99#[derive(Debug, Clone, PartialEq, Eq)]
101#[non_exhaustive]
102pub enum FsmStopReason {
103 Normal,
104 Shutdown,
105 Failure(String),
106}
107
108pub struct FsmBuilder<S: Clone + Eq + Hash + 'static, D: Clone + 'static, M: 'static> {
123 initial_state: Option<S>,
124 initial_data: Option<D>,
125 handlers: HashMap<S, StateHandler<S, D, M>>,
126 fallback: Option<StateHandler<S, D, M>>,
127 on_transition: Option<TransitionHook<S>>,
128 on_termination: Option<TerminationHook<S, D>>,
129}
130
131impl<S, D, M> Default for FsmBuilder<S, D, M>
132where
133 S: Clone + Eq + Hash + 'static,
134 D: Clone + 'static,
135 M: 'static,
136{
137 fn default() -> Self {
138 Self {
139 initial_state: None,
140 initial_data: None,
141 handlers: HashMap::new(),
142 fallback: None,
143 on_transition: None,
144 on_termination: None,
145 }
146 }
147}
148
149impl<S, D, M> FsmBuilder<S, D, M>
150where
151 S: Clone + Eq + Hash + 'static,
152 D: Clone + 'static,
153 M: 'static,
154{
155 pub fn new() -> Self {
156 Self::default()
157 }
158
159 pub fn start_with(mut self, state: S, data: D) -> Self {
160 self.initial_state = Some(state);
161 self.initial_data = Some(data);
162 self
163 }
164
165 pub fn when_state<F>(mut self, state: S, handler: F) -> Self
168 where
169 F: FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static,
170 {
171 self.handlers.insert(state, Box::new(handler));
172 self
173 }
174
175 pub fn whenever<F>(mut self, handler: F) -> Self
177 where
178 F: FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static,
179 {
180 self.fallback = Some(Box::new(handler));
181 self
182 }
183
184 pub fn on_transition<F>(mut self, hook: F) -> Self
185 where
186 F: FnMut(&S, &S) + Send + 'static,
187 {
188 self.on_transition = Some(Box::new(hook));
189 self
190 }
191
192 pub fn on_termination<F>(mut self, hook: F) -> Self
193 where
194 F: FnMut(&S, &D) + Send + 'static,
195 {
196 self.on_termination = Some(Box::new(hook));
197 self
198 }
199
200 pub fn build(self) -> Fsm<S, D, M> {
201 let initial_state = self.initial_state.expect("FsmBuilder: start_with(state, data) is required");
202 let initial_data = self.initial_data.expect("FsmBuilder: start_with(state, data) is required");
203 Fsm {
204 current_state: initial_state.clone(),
205 current_data: initial_data,
206 initial_state,
207 handlers: self.handlers,
208 fallback: self.fallback,
209 on_transition: self.on_transition,
210 on_termination: self.on_termination,
211 terminated: false,
212 }
213 }
214}
215
216pub struct Fsm<S: Clone + Eq + Hash + 'static, D: Clone + 'static, M: 'static> {
218 current_state: S,
219 current_data: D,
220 initial_state: S,
221 handlers: HashMap<S, StateHandler<S, D, M>>,
222 fallback: Option<StateHandler<S, D, M>>,
223 on_transition: Option<TransitionHook<S>>,
224 on_termination: Option<TerminationHook<S, D>>,
225 terminated: bool,
226}
227
228impl<S, D, M> Fsm<S, D, M>
229where
230 S: Clone + Eq + Hash + 'static,
231 D: Clone + 'static,
232 M: 'static,
233{
234 pub fn state(&self) -> &S {
235 &self.current_state
236 }
237
238 pub fn data(&self) -> &D {
239 &self.current_data
240 }
241
242 pub fn initial_state(&self) -> &S {
243 &self.initial_state
244 }
245
246 pub fn is_terminated(&self) -> bool {
247 self.terminated
248 }
249
250 pub fn handle(&mut self, msg: M) -> Option<&S> {
253 if self.terminated {
254 return None;
255 }
256 let attempted = if let Some(handler) = self.handlers.get_mut(&self.current_state) {
257 handler(&self.current_state, &self.current_data, msg)
258 } else {
259 None
260 };
261 let transition = match attempted {
262 Some(t) => Some(t),
263 None => {
264 self.fallback.as_mut().and_then(|f| {
274 let _ = f;
280 None
281 })
282 }
283 };
284 if let Some(t) = transition {
285 if self.current_state != t.next {
286 if let Some(hook) = self.on_transition.as_mut() {
287 hook(&self.current_state, &t.next);
288 }
289 }
290 self.current_state = t.next;
291 self.current_data = t.data;
292 }
293 Some(&self.current_state)
294 }
295
296 pub fn terminate(&mut self, _reason: FsmStopReason) {
298 if self.terminated {
299 return;
300 }
301 if let Some(hook) = self.on_termination.as_mut() {
302 hook(&self.current_state, &self.current_data);
303 }
304 self.terminated = true;
305 }
306}
307
308#[cfg(test)]
309mod builder_tests {
310 use super::*;
311 use std::sync::{Arc, Mutex};
312
313 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
314 enum St {
315 Closed,
316 Open,
317 }
318
319 enum Cmd {
320 Lock,
321 Unlock,
322 }
323
324 #[test]
325 fn builder_drives_transitions() {
326 let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
327 .start_with(St::Closed, 0)
328 .when_state(St::Closed, |_s, d, m| match m {
329 Cmd::Unlock => Some(FsmTransition { next: St::Open, data: d + 1, timeout: None }),
330 Cmd::Lock => None,
331 })
332 .when_state(St::Open, |_s, d, m| match m {
333 Cmd::Lock => Some(FsmTransition { next: St::Closed, data: *d, timeout: None }),
334 Cmd::Unlock => None,
335 })
336 .build();
337 assert_eq!(fsm.state(), &St::Closed);
338 fsm.handle(Cmd::Unlock);
339 assert_eq!(fsm.state(), &St::Open);
340 assert_eq!(*fsm.data(), 1);
341 fsm.handle(Cmd::Lock);
342 assert_eq!(fsm.state(), &St::Closed);
343 }
344
345 #[test]
346 fn on_transition_hook_fires() {
347 let log: Arc<Mutex<Vec<(St, St)>>> = Arc::new(Mutex::new(Vec::new()));
348 let log_clone = log.clone();
349 let mut fsm = FsmBuilder::<St, (), Cmd>::new()
350 .start_with(St::Closed, ())
351 .when_state(St::Closed, |_, _, _| Some(FsmTransition { next: St::Open, data: (), timeout: None }))
352 .on_transition(move |from, to| {
353 log_clone.lock().unwrap().push((from.clone(), to.clone()));
354 })
355 .build();
356 fsm.handle(Cmd::Unlock);
357 let entries = log.lock().unwrap().clone();
358 assert_eq!(entries, vec![(St::Closed, St::Open)]);
359 }
360
361 #[test]
362 fn on_termination_hook_fires() {
363 let calls: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
364 let c = calls.clone();
365 let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
366 .start_with(St::Closed, 7)
367 .when_state(St::Closed, |_, _, _| None)
368 .on_termination(move |_s, d| {
369 *c.lock().unwrap() = *d;
370 })
371 .build();
372 fsm.terminate(FsmStopReason::Normal);
373 assert_eq!(*calls.lock().unwrap(), 7);
374 fsm.terminate(FsmStopReason::Normal);
376 assert_eq!(*calls.lock().unwrap(), 7);
377 }
378
379 #[test]
380 fn handle_after_terminate_returns_none() {
381 let mut fsm = FsmBuilder::<St, (), Cmd>::new()
382 .start_with(St::Closed, ())
383 .when_state(St::Closed, |_, _, _| Some(FsmTransition { next: St::Open, data: (), timeout: None }))
384 .build();
385 fsm.terminate(FsmStopReason::Normal);
386 assert!(fsm.handle(Cmd::Unlock).is_none());
387 }
388
389 #[test]
390 fn no_transition_keeps_state() {
391 let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
392 .start_with(St::Closed, 0)
393 .when_state(St::Closed, |_, _, _| None)
394 .build();
395 fsm.handle(Cmd::Unlock);
396 assert_eq!(fsm.state(), &St::Closed);
397 }
398}