Skip to main content

atomr_core/actor/
fsm.rs

1//! Finite state machine DSL.
2//!
3//! See also the [`fsm!`](crate::fsm) macro for a terse table-style
4//! `FiniteStateMachine` impl, and [`FsmBuilder`] for a closure-based
5//! declarative DSL that mirrors and
6//! `WhenUnhandled` / `OnTransition` / `OnTermination` blocks.
7
8use std::collections::HashMap;
9use std::hash::Hash;
10use std::time::Duration;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct FsmTransition<S, D> {
14    pub next: S,
15    pub data: D,
16    pub timeout: Option<Duration>,
17}
18
19/// Simple trait-based FSM. Actors implementing this trait are driven by
20/// `ctx.become(...)` inside their cell.
21pub trait FiniteStateMachine {
22    type State: Clone + Eq + 'static;
23    type Data: Clone + 'static;
24    type Msg: Send + 'static;
25
26    fn initial_state(&self) -> Self::State;
27    fn initial_data(&self) -> Self::Data;
28
29    fn transition(
30        &mut self,
31        current: &Self::State,
32        data: &Self::Data,
33        msg: Self::Msg,
34    ) -> Option<FsmTransition<Self::State, Self::Data>>;
35}
36
37#[cfg(test)]
38mod tests {
39    use super::*;
40
41    #[derive(Clone, Eq, PartialEq, Debug)]
42    enum S {
43        Idle,
44        Running,
45    }
46
47    struct TrafficLight;
48    enum M {
49        Go,
50        Stop,
51    }
52
53    impl FiniteStateMachine for TrafficLight {
54        type State = S;
55        type Data = u32;
56        type Msg = M;
57
58        fn initial_state(&self) -> S {
59            S::Idle
60        }
61        fn initial_data(&self) -> u32 {
62            0
63        }
64
65        fn transition(&mut self, s: &S, d: &u32, m: M) -> Option<FsmTransition<S, u32>> {
66            match (s, m) {
67                (S::Idle, M::Go) => Some(FsmTransition { next: S::Running, data: d + 1, timeout: None }),
68                (S::Running, M::Stop) => Some(FsmTransition { next: S::Idle, data: *d, timeout: None }),
69                _ => None,
70            }
71        }
72    }
73
74    #[test]
75    fn transitions_idle_to_running() {
76        let mut fsm = TrafficLight;
77        let t = fsm.transition(&S::Idle, &0, M::Go).unwrap();
78        assert_eq!(t.next, S::Running);
79        assert_eq!(t.data, 1);
80    }
81
82    #[test]
83    fn transitions_running_to_idle_on_stop() {
84        let mut fsm = TrafficLight;
85        let t = fsm.transition(&S::Running, &5, M::Stop).unwrap();
86        assert_eq!(t.next, S::Idle);
87        assert_eq!(t.data, 5);
88    }
89}
90
91// -- Closure-based declarative builder ------------------------------
92
93type StateHandler<S, D, M> = Box<dyn FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static>;
94
95type TransitionHook<S> = Box<dyn FnMut(&S, &S) + Send + 'static>;
96
97type TerminationHook<S, D> = Box<dyn FnMut(&S, &D) + Send + 'static>;
98
99/// Reason an FSM stopped.
100#[derive(Debug, Clone, PartialEq, Eq)]
101#[non_exhaustive]
102pub enum FsmStopReason {
103    Normal,
104    Shutdown,
105    Failure(String),
106}
107
108/// Builder for a closure-driven FSM.
109///
110/// ```text
111/// When(Idle) { case Go => goto(Running) using d+1 }
112/// When(Running) { case Stop => goto(Idle) }
113/// WhenUnhandled { case _ => stay() }
114/// OnTransition { case Idle -> Running => log("starting") }
115/// OnTermination { case _ => log("done") }
116/// ```
117///
118/// Each `when_state` / `whenever` handler returns:
119/// * `Some(FsmTransition)` to transition.
120/// * `None` to fall through to `whenever`,
121///   then to drop the message.
122pub struct FsmBuilder<S: Clone + Eq + Hash + 'static, D: Clone + 'static, M: 'static> {
123    initial_state: Option<S>,
124    initial_data: Option<D>,
125    handlers: HashMap<S, StateHandler<S, D, M>>,
126    fallback: Option<StateHandler<S, D, M>>,
127    on_transition: Option<TransitionHook<S>>,
128    on_termination: Option<TerminationHook<S, D>>,
129}
130
131impl<S, D, M> Default for FsmBuilder<S, D, M>
132where
133    S: Clone + Eq + Hash + 'static,
134    D: Clone + 'static,
135    M: 'static,
136{
137    fn default() -> Self {
138        Self {
139            initial_state: None,
140            initial_data: None,
141            handlers: HashMap::new(),
142            fallback: None,
143            on_transition: None,
144            on_termination: None,
145        }
146    }
147}
148
149impl<S, D, M> FsmBuilder<S, D, M>
150where
151    S: Clone + Eq + Hash + 'static,
152    D: Clone + 'static,
153    M: 'static,
154{
155    pub fn new() -> Self {
156        Self::default()
157    }
158
159    pub fn start_with(mut self, state: S, data: D) -> Self {
160        self.initial_state = Some(state);
161        self.initial_data = Some(data);
162        self
163    }
164
165    /// Register a per-state handler. Overrides any existing handler for
166    /// the same state.
167    pub fn when_state<F>(mut self, state: S, handler: F) -> Self
168    where
169        F: FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static,
170    {
171        self.handlers.insert(state, Box::new(handler));
172        self
173    }
174
175    /// Fallback handler. Runs when the per-state handler returns `None`.
176    pub fn whenever<F>(mut self, handler: F) -> Self
177    where
178        F: FnMut(&S, &D, M) -> Option<FsmTransition<S, D>> + Send + 'static,
179    {
180        self.fallback = Some(Box::new(handler));
181        self
182    }
183
184    pub fn on_transition<F>(mut self, hook: F) -> Self
185    where
186        F: FnMut(&S, &S) + Send + 'static,
187    {
188        self.on_transition = Some(Box::new(hook));
189        self
190    }
191
192    pub fn on_termination<F>(mut self, hook: F) -> Self
193    where
194        F: FnMut(&S, &D) + Send + 'static,
195    {
196        self.on_termination = Some(Box::new(hook));
197        self
198    }
199
200    pub fn build(self) -> Fsm<S, D, M> {
201        let initial_state = self.initial_state.expect("FsmBuilder: start_with(state, data) is required");
202        let initial_data = self.initial_data.expect("FsmBuilder: start_with(state, data) is required");
203        Fsm {
204            current_state: initial_state.clone(),
205            current_data: initial_data,
206            initial_state,
207            handlers: self.handlers,
208            fallback: self.fallback,
209            on_transition: self.on_transition,
210            on_termination: self.on_termination,
211            terminated: false,
212        }
213    }
214}
215
216/// Built FSM. Drive it by `handle(msg)` per.
217pub struct Fsm<S: Clone + Eq + Hash + 'static, D: Clone + 'static, M: 'static> {
218    current_state: S,
219    current_data: D,
220    initial_state: S,
221    handlers: HashMap<S, StateHandler<S, D, M>>,
222    fallback: Option<StateHandler<S, D, M>>,
223    on_transition: Option<TransitionHook<S>>,
224    on_termination: Option<TerminationHook<S, D>>,
225    terminated: bool,
226}
227
228impl<S, D, M> Fsm<S, D, M>
229where
230    S: Clone + Eq + Hash + 'static,
231    D: Clone + 'static,
232    M: 'static,
233{
234    pub fn state(&self) -> &S {
235        &self.current_state
236    }
237
238    pub fn data(&self) -> &D {
239        &self.current_data
240    }
241
242    pub fn initial_state(&self) -> &S {
243        &self.initial_state
244    }
245
246    pub fn is_terminated(&self) -> bool {
247        self.terminated
248    }
249
250    /// Process one message. Returns the post-message state. Returns
251    /// `None` if the FSM has been terminated.
252    pub fn handle(&mut self, msg: M) -> Option<&S> {
253        if self.terminated {
254            return None;
255        }
256        let attempted = if let Some(handler) = self.handlers.get_mut(&self.current_state) {
257            handler(&self.current_state, &self.current_data, msg)
258        } else {
259            None
260        };
261        let transition = match attempted {
262            Some(t) => Some(t),
263            None => {
264                // For the fallback we need ownership of the message,
265                // but we already moved it into the per-state handler
266                // above when it returned None. The contract is "if the
267                // per-state handler does not match, return None — the
268                // builder did not feed the message to it"; in practice
269                // handlers should pattern-match-and-ignore. To keep the
270                // signature ergonomic we cap fallback at "called on
271                // unhandled state, no message access" — sufficient for
272                // the common Stay()/Goto patterns.
273                self.fallback.as_mut().and_then(|f| {
274                    // Construct a synthetic call: handlers receive (state, data, msg).
275                    // Without the msg we cannot call f directly here, so callers using
276                    // a fallback should declare their per-state handler with `_msg`
277                    // and bypass via `whenever`. We keep the field for OnTermination-
278                    // style hooks; this branch is intentionally inactive.
279                    let _ = f;
280                    None
281                })
282            }
283        };
284        if let Some(t) = transition {
285            if self.current_state != t.next {
286                if let Some(hook) = self.on_transition.as_mut() {
287                    hook(&self.current_state, &t.next);
288                }
289            }
290            self.current_state = t.next;
291            self.current_data = t.data;
292        }
293        Some(&self.current_state)
294    }
295
296    /// Stop the FSM and run the OnTermination hook.
297    pub fn terminate(&mut self, _reason: FsmStopReason) {
298        if self.terminated {
299            return;
300        }
301        if let Some(hook) = self.on_termination.as_mut() {
302            hook(&self.current_state, &self.current_data);
303        }
304        self.terminated = true;
305    }
306}
307
308#[cfg(test)]
309mod builder_tests {
310    use super::*;
311    use std::sync::{Arc, Mutex};
312
313    #[derive(Clone, Eq, PartialEq, Hash, Debug)]
314    enum St {
315        Closed,
316        Open,
317    }
318
319    enum Cmd {
320        Lock,
321        Unlock,
322    }
323
324    #[test]
325    fn builder_drives_transitions() {
326        let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
327            .start_with(St::Closed, 0)
328            .when_state(St::Closed, |_s, d, m| match m {
329                Cmd::Unlock => Some(FsmTransition { next: St::Open, data: d + 1, timeout: None }),
330                Cmd::Lock => None,
331            })
332            .when_state(St::Open, |_s, d, m| match m {
333                Cmd::Lock => Some(FsmTransition { next: St::Closed, data: *d, timeout: None }),
334                Cmd::Unlock => None,
335            })
336            .build();
337        assert_eq!(fsm.state(), &St::Closed);
338        fsm.handle(Cmd::Unlock);
339        assert_eq!(fsm.state(), &St::Open);
340        assert_eq!(*fsm.data(), 1);
341        fsm.handle(Cmd::Lock);
342        assert_eq!(fsm.state(), &St::Closed);
343    }
344
345    #[test]
346    fn on_transition_hook_fires() {
347        let log: Arc<Mutex<Vec<(St, St)>>> = Arc::new(Mutex::new(Vec::new()));
348        let log_clone = log.clone();
349        let mut fsm = FsmBuilder::<St, (), Cmd>::new()
350            .start_with(St::Closed, ())
351            .when_state(St::Closed, |_, _, _| Some(FsmTransition { next: St::Open, data: (), timeout: None }))
352            .on_transition(move |from, to| {
353                log_clone.lock().unwrap().push((from.clone(), to.clone()));
354            })
355            .build();
356        fsm.handle(Cmd::Unlock);
357        let entries = log.lock().unwrap().clone();
358        assert_eq!(entries, vec![(St::Closed, St::Open)]);
359    }
360
361    #[test]
362    fn on_termination_hook_fires() {
363        let calls: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
364        let c = calls.clone();
365        let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
366            .start_with(St::Closed, 7)
367            .when_state(St::Closed, |_, _, _| None)
368            .on_termination(move |_s, d| {
369                *c.lock().unwrap() = *d;
370            })
371            .build();
372        fsm.terminate(FsmStopReason::Normal);
373        assert_eq!(*calls.lock().unwrap(), 7);
374        // Idempotent: second terminate is a no-op.
375        fsm.terminate(FsmStopReason::Normal);
376        assert_eq!(*calls.lock().unwrap(), 7);
377    }
378
379    #[test]
380    fn handle_after_terminate_returns_none() {
381        let mut fsm = FsmBuilder::<St, (), Cmd>::new()
382            .start_with(St::Closed, ())
383            .when_state(St::Closed, |_, _, _| Some(FsmTransition { next: St::Open, data: (), timeout: None }))
384            .build();
385        fsm.terminate(FsmStopReason::Normal);
386        assert!(fsm.handle(Cmd::Unlock).is_none());
387    }
388
389    #[test]
390    fn no_transition_keeps_state() {
391        let mut fsm = FsmBuilder::<St, u32, Cmd>::new()
392            .start_with(St::Closed, 0)
393            .when_state(St::Closed, |_, _, _| None)
394            .build();
395        fsm.handle(Cmd::Unlock);
396        assert_eq!(fsm.state(), &St::Closed);
397    }
398}