1use std::sync::Arc;
20
21use crate::eventsourced::EventsourcedError;
22use crate::journal::{Journal, PersistentRepr};
23use crate::recovery_permitter::RecoveryPermitter;
24
25type CmdFn<S, D, C, E, Err> = Box<dyn FnMut(&S, &D, C) -> Result<Vec<E>, Err> + Send + 'static>;
26type EvtFn<S, D, E> = Box<dyn FnMut(&mut S, &mut D, &E) + Send + 'static>;
27type EncodeFn<E> = Box<dyn Fn(&E) -> Result<Vec<u8>, String> + Send + Sync>;
28type DecodeFn<E> = Box<dyn Fn(&[u8]) -> Result<E, String> + Send + Sync>;
29
30pub struct PersistentFSM<S, D, C, E, Err>
31where
32 S: Clone + Send + 'static,
33 D: Send + 'static,
34 C: Send + 'static,
35 E: Clone + Send + 'static,
36 Err: std::error::Error + Send + 'static,
37{
38 persistence_id: String,
39 state: Option<S>,
40 data: Option<D>,
41 next_seq: u64,
42 on_command: Option<CmdFn<S, D, C, E, Err>>,
43 on_event: Option<EvtFn<S, D, E>>,
44 encode: Option<EncodeFn<E>>,
45 decode: Option<DecodeFn<E>>,
46 transitions: Vec<(S, S)>,
47}
48
49impl<S, D, C, E, Err> PersistentFSM<S, D, C, E, Err>
50where
51 S: Clone + PartialEq + std::fmt::Debug + Send + 'static,
52 D: Send + 'static,
53 C: Send + 'static,
54 E: Clone + Send + 'static,
55 Err: std::error::Error + Send + 'static,
56{
57 pub fn new(persistence_id: impl Into<String>) -> Self {
58 Self {
59 persistence_id: persistence_id.into(),
60 state: None,
61 data: None,
62 next_seq: 0,
63 on_command: None,
64 on_event: None,
65 encode: None,
66 decode: None,
67 transitions: Vec::new(),
68 }
69 }
70
71 pub fn with_initial(mut self, s: S, d: D) -> Self {
72 self.state = Some(s);
73 self.data = Some(d);
74 self
75 }
76
77 pub fn on_command<F>(mut self, f: F) -> Self
78 where
79 F: FnMut(&S, &D, C) -> Result<Vec<E>, Err> + Send + 'static,
80 {
81 self.on_command = Some(Box::new(f));
82 self
83 }
84
85 pub fn on_event<F>(mut self, f: F) -> Self
86 where
87 F: FnMut(&mut S, &mut D, &E) + Send + 'static,
88 {
89 self.on_event = Some(Box::new(f));
90 self
91 }
92
93 pub fn with_codec<EncF, DecF>(mut self, encode: EncF, decode: DecF) -> Self
94 where
95 EncF: Fn(&E) -> Result<Vec<u8>, String> + Send + Sync + 'static,
96 DecF: Fn(&[u8]) -> Result<E, String> + Send + Sync + 'static,
97 {
98 self.encode = Some(Box::new(encode));
99 self.decode = Some(Box::new(decode));
100 self
101 }
102
103 pub fn state(&self) -> Option<&S> {
104 self.state.as_ref()
105 }
106
107 pub fn data(&self) -> Option<&D> {
108 self.data.as_ref()
109 }
110
111 pub fn transitions(&self) -> &[(S, S)] {
113 &self.transitions
114 }
115
116 pub async fn recover<J: Journal>(
117 &mut self,
118 journal: Arc<J>,
119 permitter: &RecoveryPermitter,
120 ) -> Result<u64, EventsourcedError<Err>> {
121 let _permit = permitter.acquire().await.ok_or(EventsourcedError::PermitDenied)?;
122 let on_event = self
123 .on_event
124 .as_mut()
125 .ok_or_else(|| EventsourcedError::Codec("on_event not registered".into()))?;
126 let decode =
127 self.decode.as_ref().ok_or_else(|| EventsourcedError::Codec("decoder not registered".into()))?;
128 let highest = journal.highest_sequence_nr(&self.persistence_id, 0).await?;
129 let events = journal.replay_messages(&self.persistence_id, 1, highest, u64::MAX).await?;
130 for e in &events {
131 let evt = decode(&e.payload).map_err(EventsourcedError::Codec)?;
132 let prev = self.state.clone();
133 let (s, d) = (
134 self.state.as_mut().expect("initial state required before recover"),
135 self.data.as_mut().expect("initial data required before recover"),
136 );
137 on_event(s, d, &evt);
138 if let (Some(p), Some(now)) = (prev, self.state.as_ref()) {
139 if &p != now {
140 self.transitions.push((p, now.clone()));
141 }
142 }
143 }
144 self.next_seq = highest;
145 Ok(highest)
146 }
147
148 pub async fn handle<J: Journal>(
149 &mut self,
150 journal: Arc<J>,
151 cmd: C,
152 ) -> Result<(), EventsourcedError<Err>> {
153 let on_cmd = self
154 .on_command
155 .as_mut()
156 .ok_or_else(|| EventsourcedError::Codec("on_command not registered".into()))?;
157 let on_event = self
158 .on_event
159 .as_mut()
160 .ok_or_else(|| EventsourcedError::Codec("on_event not registered".into()))?;
161 let encode =
162 self.encode.as_ref().ok_or_else(|| EventsourcedError::Codec("encoder not registered".into()))?;
163 let s =
164 self.state.as_ref().ok_or_else(|| EventsourcedError::Codec("initial state not set".into()))?;
165 let d = self.data.as_ref().ok_or_else(|| EventsourcedError::Codec("initial data not set".into()))?;
166 let events = on_cmd(s, d, cmd).map_err(EventsourcedError::Domain)?;
167 if events.is_empty() {
168 return Ok(());
169 }
170 let mut reprs = Vec::with_capacity(events.len());
171 for e in &events {
172 self.next_seq += 1;
173 let payload = encode(e).map_err(EventsourcedError::Codec)?;
174 reprs.push(PersistentRepr {
175 persistence_id: self.persistence_id.clone(),
176 sequence_nr: self.next_seq,
177 payload,
178 manifest: "fsm".into(),
179 writer_uuid: "fsm".into(),
180 deleted: false,
181 tags: Vec::new(),
182 });
183 }
184 journal.write_messages(reprs).await?;
185 for e in &events {
186 let prev = self.state.clone();
187 let s_mut = self.state.as_mut().expect("state present");
188 let d_mut = self.data.as_mut().expect("data present");
189 on_event(s_mut, d_mut, e);
190 if let (Some(p), Some(now)) = (prev, self.state.as_ref()) {
191 if &p != now {
192 self.transitions.push((p, now.clone()));
193 }
194 }
195 }
196 Ok(())
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::InMemoryJournal;
204
205 #[derive(Clone, Debug, PartialEq)]
206 enum DoorState {
207 Closed,
208 Open,
209 }
210
211 #[derive(Default)]
212 struct DoorData {
213 opens: u32,
214 }
215
216 #[derive(Clone, Debug)]
217 enum DoorCmd {
218 Toggle,
219 }
220
221 #[derive(Clone, Debug)]
222 enum DoorEvent {
223 Toggled,
224 }
225
226 #[derive(Debug, thiserror::Error)]
227 #[error("dummy")]
228 struct E;
229
230 fn make_fsm(id: &str) -> PersistentFSM<DoorState, DoorData, DoorCmd, DoorEvent, E> {
231 PersistentFSM::new(id)
232 .with_initial(DoorState::Closed, DoorData::default())
233 .on_command(|_s, _d, _c: DoorCmd| Ok(vec![DoorEvent::Toggled]))
234 .on_event(|s, d, _evt: &DoorEvent| match s {
235 DoorState::Closed => {
236 *s = DoorState::Open;
237 d.opens += 1;
238 }
239 DoorState::Open => {
240 *s = DoorState::Closed;
241 }
242 })
243 .with_codec(|_| Ok(vec![0u8]), |_| Ok(DoorEvent::Toggled))
244 }
245
246 #[tokio::test]
247 async fn fsm_transitions_and_recovers() {
248 let journal = Arc::new(InMemoryJournal::default());
249 let permits = RecoveryPermitter::new(1);
250
251 let mut fsm = make_fsm("door-1");
252 fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
253 fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
254 fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
255 assert_eq!(fsm.state(), Some(&DoorState::Open));
256 assert_eq!(fsm.data().unwrap().opens, 2);
257 assert_eq!(fsm.transitions().len(), 3);
258
259 let mut fsm2 = make_fsm("door-1");
261 fsm2.recover(journal.clone(), &permits).await.unwrap();
262 assert_eq!(fsm2.state(), Some(&DoorState::Open));
263 assert_eq!(fsm2.data().unwrap().opens, 2);
264 }
265
266 #[tokio::test]
267 async fn missing_initial_state_is_typed_error() {
268 let journal = Arc::new(InMemoryJournal::default());
269 let mut fsm: PersistentFSM<DoorState, DoorData, DoorCmd, DoorEvent, E> = PersistentFSM::new("door-2")
270 .on_command(|_, _, _| Ok(vec![DoorEvent::Toggled]))
271 .on_event(|_, _, _| {})
272 .with_codec(|_| Ok(vec![]), |_| Ok(DoorEvent::Toggled));
273 let r = fsm.handle(journal, DoorCmd::Toggle).await;
274 assert!(matches!(r, Err(EventsourcedError::Codec(_))));
275 }
276}