Skip to main content

atomr_persistence/
persistent_fsm.rs

1//! `PersistentFSM` — event-sourced state machine on top of [`Eventsourced`].
2//!
3//! Two-shape model:
4//!
5//! * `S` — finite state (typically an enum: `Idle`, `Active`, …).
6//! * `D` — state-data carried alongside `S`.
7//!
8//! The actor receives commands, the registered transition function
9//! decides `(next_state, persisted_event)` per command, and recovery
10//! re-applies events to rebuild the `(S, D)` pair.
11//!
12//! ```ignore
13//! let fsm = PersistentFSM::<DoorState, DoorData, DoorCmd, DoorEvent>::new("door-1")
14//!     .with_initial(DoorState::Closed, DoorData::default())
15//!     .on_command(|state, data, cmd| { … })
16//!     .on_event(|state, data, evt| { … });
17//! ```
18
19use std::sync::Arc;
20
21use crate::eventsourced::EventsourcedError;
22use crate::journal::{Journal, PersistentRepr};
23use crate::recovery_permitter::RecoveryPermitter;
24
25type CmdFn<S, D, C, E, Err> = Box<dyn FnMut(&S, &D, C) -> Result<Vec<E>, Err> + Send + 'static>;
26type EvtFn<S, D, E> = Box<dyn FnMut(&mut S, &mut D, &E) + Send + 'static>;
27type EncodeFn<E> = Box<dyn Fn(&E) -> Result<Vec<u8>, String> + Send + Sync>;
28type DecodeFn<E> = Box<dyn Fn(&[u8]) -> Result<E, String> + Send + Sync>;
29
30pub struct PersistentFSM<S, D, C, E, Err>
31where
32    S: Clone + Send + 'static,
33    D: Send + 'static,
34    C: Send + 'static,
35    E: Clone + Send + 'static,
36    Err: std::error::Error + Send + 'static,
37{
38    persistence_id: String,
39    state: Option<S>,
40    data: Option<D>,
41    next_seq: u64,
42    on_command: Option<CmdFn<S, D, C, E, Err>>,
43    on_event: Option<EvtFn<S, D, E>>,
44    encode: Option<EncodeFn<E>>,
45    decode: Option<DecodeFn<E>>,
46    transitions: Vec<(S, S)>,
47}
48
49impl<S, D, C, E, Err> PersistentFSM<S, D, C, E, Err>
50where
51    S: Clone + PartialEq + std::fmt::Debug + Send + 'static,
52    D: Send + 'static,
53    C: Send + 'static,
54    E: Clone + Send + 'static,
55    Err: std::error::Error + Send + 'static,
56{
57    pub fn new(persistence_id: impl Into<String>) -> Self {
58        Self {
59            persistence_id: persistence_id.into(),
60            state: None,
61            data: None,
62            next_seq: 0,
63            on_command: None,
64            on_event: None,
65            encode: None,
66            decode: None,
67            transitions: Vec::new(),
68        }
69    }
70
71    pub fn with_initial(mut self, s: S, d: D) -> Self {
72        self.state = Some(s);
73        self.data = Some(d);
74        self
75    }
76
77    pub fn on_command<F>(mut self, f: F) -> Self
78    where
79        F: FnMut(&S, &D, C) -> Result<Vec<E>, Err> + Send + 'static,
80    {
81        self.on_command = Some(Box::new(f));
82        self
83    }
84
85    pub fn on_event<F>(mut self, f: F) -> Self
86    where
87        F: FnMut(&mut S, &mut D, &E) + Send + 'static,
88    {
89        self.on_event = Some(Box::new(f));
90        self
91    }
92
93    pub fn with_codec<EncF, DecF>(mut self, encode: EncF, decode: DecF) -> Self
94    where
95        EncF: Fn(&E) -> Result<Vec<u8>, String> + Send + Sync + 'static,
96        DecF: Fn(&[u8]) -> Result<E, String> + Send + Sync + 'static,
97    {
98        self.encode = Some(Box::new(encode));
99        self.decode = Some(Box::new(decode));
100        self
101    }
102
103    pub fn state(&self) -> Option<&S> {
104        self.state.as_ref()
105    }
106
107    pub fn data(&self) -> Option<&D> {
108        self.data.as_ref()
109    }
110
111    /// History of state transitions seen since boot. Useful for tests.
112    pub fn transitions(&self) -> &[(S, S)] {
113        &self.transitions
114    }
115
116    pub async fn recover<J: Journal>(
117        &mut self,
118        journal: Arc<J>,
119        permitter: &RecoveryPermitter,
120    ) -> Result<u64, EventsourcedError<Err>> {
121        let _permit = permitter.acquire().await.ok_or(EventsourcedError::PermitDenied)?;
122        let on_event = self
123            .on_event
124            .as_mut()
125            .ok_or_else(|| EventsourcedError::Codec("on_event not registered".into()))?;
126        let decode =
127            self.decode.as_ref().ok_or_else(|| EventsourcedError::Codec("decoder not registered".into()))?;
128        let highest = journal.highest_sequence_nr(&self.persistence_id, 0).await?;
129        let events = journal.replay_messages(&self.persistence_id, 1, highest, u64::MAX).await?;
130        for e in &events {
131            let evt = decode(&e.payload).map_err(EventsourcedError::Codec)?;
132            let prev = self.state.clone();
133            let (s, d) = (
134                self.state.as_mut().expect("initial state required before recover"),
135                self.data.as_mut().expect("initial data required before recover"),
136            );
137            on_event(s, d, &evt);
138            if let (Some(p), Some(now)) = (prev, self.state.as_ref()) {
139                if &p != now {
140                    self.transitions.push((p, now.clone()));
141                }
142            }
143        }
144        self.next_seq = highest;
145        Ok(highest)
146    }
147
148    pub async fn handle<J: Journal>(
149        &mut self,
150        journal: Arc<J>,
151        cmd: C,
152    ) -> Result<(), EventsourcedError<Err>> {
153        let on_cmd = self
154            .on_command
155            .as_mut()
156            .ok_or_else(|| EventsourcedError::Codec("on_command not registered".into()))?;
157        let on_event = self
158            .on_event
159            .as_mut()
160            .ok_or_else(|| EventsourcedError::Codec("on_event not registered".into()))?;
161        let encode =
162            self.encode.as_ref().ok_or_else(|| EventsourcedError::Codec("encoder not registered".into()))?;
163        let s =
164            self.state.as_ref().ok_or_else(|| EventsourcedError::Codec("initial state not set".into()))?;
165        let d = self.data.as_ref().ok_or_else(|| EventsourcedError::Codec("initial data not set".into()))?;
166        let events = on_cmd(s, d, cmd).map_err(EventsourcedError::Domain)?;
167        if events.is_empty() {
168            return Ok(());
169        }
170        let mut reprs = Vec::with_capacity(events.len());
171        for e in &events {
172            self.next_seq += 1;
173            let payload = encode(e).map_err(EventsourcedError::Codec)?;
174            reprs.push(PersistentRepr {
175                persistence_id: self.persistence_id.clone(),
176                sequence_nr: self.next_seq,
177                payload,
178                manifest: "fsm".into(),
179                writer_uuid: "fsm".into(),
180                deleted: false,
181                tags: Vec::new(),
182            });
183        }
184        journal.write_messages(reprs).await?;
185        for e in &events {
186            let prev = self.state.clone();
187            let s_mut = self.state.as_mut().expect("state present");
188            let d_mut = self.data.as_mut().expect("data present");
189            on_event(s_mut, d_mut, e);
190            if let (Some(p), Some(now)) = (prev, self.state.as_ref()) {
191                if &p != now {
192                    self.transitions.push((p, now.clone()));
193                }
194            }
195        }
196        Ok(())
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::InMemoryJournal;
204
205    #[derive(Clone, Debug, PartialEq)]
206    enum DoorState {
207        Closed,
208        Open,
209    }
210
211    #[derive(Default)]
212    struct DoorData {
213        opens: u32,
214    }
215
216    #[derive(Clone, Debug)]
217    enum DoorCmd {
218        Toggle,
219    }
220
221    #[derive(Clone, Debug)]
222    enum DoorEvent {
223        Toggled,
224    }
225
226    #[derive(Debug, thiserror::Error)]
227    #[error("dummy")]
228    struct E;
229
230    fn make_fsm(id: &str) -> PersistentFSM<DoorState, DoorData, DoorCmd, DoorEvent, E> {
231        PersistentFSM::new(id)
232            .with_initial(DoorState::Closed, DoorData::default())
233            .on_command(|_s, _d, _c: DoorCmd| Ok(vec![DoorEvent::Toggled]))
234            .on_event(|s, d, _evt: &DoorEvent| match s {
235                DoorState::Closed => {
236                    *s = DoorState::Open;
237                    d.opens += 1;
238                }
239                DoorState::Open => {
240                    *s = DoorState::Closed;
241                }
242            })
243            .with_codec(|_| Ok(vec![0u8]), |_| Ok(DoorEvent::Toggled))
244    }
245
246    #[tokio::test]
247    async fn fsm_transitions_and_recovers() {
248        let journal = Arc::new(InMemoryJournal::default());
249        let permits = RecoveryPermitter::new(1);
250
251        let mut fsm = make_fsm("door-1");
252        fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
253        fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
254        fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
255        assert_eq!(fsm.state(), Some(&DoorState::Open));
256        assert_eq!(fsm.data().unwrap().opens, 2);
257        assert_eq!(fsm.transitions().len(), 3);
258
259        // Replay -> same final state.
260        let mut fsm2 = make_fsm("door-1");
261        fsm2.recover(journal.clone(), &permits).await.unwrap();
262        assert_eq!(fsm2.state(), Some(&DoorState::Open));
263        assert_eq!(fsm2.data().unwrap().opens, 2);
264    }
265
266    #[tokio::test]
267    async fn missing_initial_state_is_typed_error() {
268        let journal = Arc::new(InMemoryJournal::default());
269        let mut fsm: PersistentFSM<DoorState, DoorData, DoorCmd, DoorEvent, E> = PersistentFSM::new("door-2")
270            .on_command(|_, _, _| Ok(vec![DoorEvent::Toggled]))
271            .on_event(|_, _, _| {})
272            .with_codec(|_| Ok(vec![]), |_| Ok(DoorEvent::Toggled));
273        let r = fsm.handle(journal, DoorCmd::Toggle).await;
274        assert!(matches!(r, Err(EventsourcedError::Codec(_))));
275    }
276}