1use std::sync::Arc;
21
22use crate::eventsourced::EventsourcedError;
23use crate::journal::{Journal, PersistentRepr};
24use crate::recovery_permitter::RecoveryPermitter;
25
26type CmdFn<S, D, C, E, Err> = Box<dyn FnMut(&S, &D, C) -> Result<Vec<E>, Err> + Send + 'static>;
27type EvtFn<S, D, E> = Box<dyn FnMut(&mut S, &mut D, &E) + Send + 'static>;
28type EncodeFn<E> = Box<dyn Fn(&E) -> Result<Vec<u8>, String> + Send + Sync>;
29type DecodeFn<E> = Box<dyn Fn(&[u8]) -> Result<E, String> + Send + Sync>;
30
31pub struct PersistentFSM<S, D, C, E, Err>
32where
33 S: Clone + Send + 'static,
34 D: Send + 'static,
35 C: Send + 'static,
36 E: Clone + Send + 'static,
37 Err: std::error::Error + Send + 'static,
38{
39 persistence_id: String,
40 state: Option<S>,
41 data: Option<D>,
42 next_seq: u64,
43 on_command: Option<CmdFn<S, D, C, E, Err>>,
44 on_event: Option<EvtFn<S, D, E>>,
45 encode: Option<EncodeFn<E>>,
46 decode: Option<DecodeFn<E>>,
47 transitions: Vec<(S, S)>,
48}
49
50impl<S, D, C, E, Err> PersistentFSM<S, D, C, E, Err>
51where
52 S: Clone + PartialEq + std::fmt::Debug + Send + 'static,
53 D: Send + 'static,
54 C: Send + 'static,
55 E: Clone + Send + 'static,
56 Err: std::error::Error + Send + 'static,
57{
58 pub fn new(persistence_id: impl Into<String>) -> Self {
59 Self {
60 persistence_id: persistence_id.into(),
61 state: None,
62 data: None,
63 next_seq: 0,
64 on_command: None,
65 on_event: None,
66 encode: None,
67 decode: None,
68 transitions: Vec::new(),
69 }
70 }
71
72 pub fn with_initial(mut self, s: S, d: D) -> Self {
73 self.state = Some(s);
74 self.data = Some(d);
75 self
76 }
77
78 pub fn on_command<F>(mut self, f: F) -> Self
79 where
80 F: FnMut(&S, &D, C) -> Result<Vec<E>, Err> + Send + 'static,
81 {
82 self.on_command = Some(Box::new(f));
83 self
84 }
85
86 pub fn on_event<F>(mut self, f: F) -> Self
87 where
88 F: FnMut(&mut S, &mut D, &E) + Send + 'static,
89 {
90 self.on_event = Some(Box::new(f));
91 self
92 }
93
94 pub fn with_codec<EncF, DecF>(mut self, encode: EncF, decode: DecF) -> Self
95 where
96 EncF: Fn(&E) -> Result<Vec<u8>, String> + Send + Sync + 'static,
97 DecF: Fn(&[u8]) -> Result<E, String> + Send + Sync + 'static,
98 {
99 self.encode = Some(Box::new(encode));
100 self.decode = Some(Box::new(decode));
101 self
102 }
103
104 pub fn state(&self) -> Option<&S> {
105 self.state.as_ref()
106 }
107
108 pub fn data(&self) -> Option<&D> {
109 self.data.as_ref()
110 }
111
112 pub fn transitions(&self) -> &[(S, S)] {
114 &self.transitions
115 }
116
117 pub async fn recover<J: Journal>(
118 &mut self,
119 journal: Arc<J>,
120 permitter: &RecoveryPermitter,
121 ) -> Result<u64, EventsourcedError<Err>> {
122 let _permit = permitter.acquire().await.ok_or(EventsourcedError::PermitDenied)?;
123 let on_event = self
124 .on_event
125 .as_mut()
126 .ok_or_else(|| EventsourcedError::Codec("on_event not registered".into()))?;
127 let decode =
128 self.decode.as_ref().ok_or_else(|| EventsourcedError::Codec("decoder not registered".into()))?;
129 let highest = journal.highest_sequence_nr(&self.persistence_id, 0).await?;
130 let events = journal.replay_messages(&self.persistence_id, 1, highest, u64::MAX).await?;
131 for e in &events {
132 let evt = decode(&e.payload).map_err(EventsourcedError::Codec)?;
133 let prev = self.state.clone();
134 let (s, d) = (
135 self.state.as_mut().expect("initial state required before recover"),
136 self.data.as_mut().expect("initial data required before recover"),
137 );
138 on_event(s, d, &evt);
139 if let (Some(p), Some(now)) = (prev, self.state.as_ref()) {
140 if &p != now {
141 self.transitions.push((p, now.clone()));
142 }
143 }
144 }
145 self.next_seq = highest;
146 Ok(highest)
147 }
148
149 pub async fn handle<J: Journal>(
150 &mut self,
151 journal: Arc<J>,
152 cmd: C,
153 ) -> Result<(), EventsourcedError<Err>> {
154 let on_cmd = self
155 .on_command
156 .as_mut()
157 .ok_or_else(|| EventsourcedError::Codec("on_command not registered".into()))?;
158 let on_event = self
159 .on_event
160 .as_mut()
161 .ok_or_else(|| EventsourcedError::Codec("on_event not registered".into()))?;
162 let encode =
163 self.encode.as_ref().ok_or_else(|| EventsourcedError::Codec("encoder not registered".into()))?;
164 let s =
165 self.state.as_ref().ok_or_else(|| EventsourcedError::Codec("initial state not set".into()))?;
166 let d = self.data.as_ref().ok_or_else(|| EventsourcedError::Codec("initial data not set".into()))?;
167 let events = on_cmd(s, d, cmd).map_err(EventsourcedError::Domain)?;
168 if events.is_empty() {
169 return Ok(());
170 }
171 let mut reprs = Vec::with_capacity(events.len());
172 for e in &events {
173 self.next_seq += 1;
174 let payload = encode(e).map_err(EventsourcedError::Codec)?;
175 reprs.push(PersistentRepr {
176 persistence_id: self.persistence_id.clone(),
177 sequence_nr: self.next_seq,
178 payload,
179 manifest: "fsm".into(),
180 writer_uuid: "fsm".into(),
181 deleted: false,
182 tags: Vec::new(),
183 });
184 }
185 journal.write_messages(reprs).await?;
186 for e in &events {
187 let prev = self.state.clone();
188 let s_mut = self.state.as_mut().expect("state present");
189 let d_mut = self.data.as_mut().expect("data present");
190 on_event(s_mut, d_mut, e);
191 if let (Some(p), Some(now)) = (prev, self.state.as_ref()) {
192 if &p != now {
193 self.transitions.push((p, now.clone()));
194 }
195 }
196 }
197 Ok(())
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::InMemoryJournal;
205
206 #[derive(Clone, Debug, PartialEq)]
207 enum DoorState {
208 Closed,
209 Open,
210 }
211
212 #[derive(Default)]
213 struct DoorData {
214 opens: u32,
215 }
216
217 #[derive(Clone, Debug)]
218 enum DoorCmd {
219 Toggle,
220 }
221
222 #[derive(Clone, Debug)]
223 enum DoorEvent {
224 Toggled,
225 }
226
227 #[derive(Debug, thiserror::Error)]
228 #[error("dummy")]
229 struct E;
230
231 fn make_fsm(id: &str) -> PersistentFSM<DoorState, DoorData, DoorCmd, DoorEvent, E> {
232 PersistentFSM::new(id)
233 .with_initial(DoorState::Closed, DoorData::default())
234 .on_command(|_s, _d, _c: DoorCmd| Ok(vec![DoorEvent::Toggled]))
235 .on_event(|s, d, _evt: &DoorEvent| match s {
236 DoorState::Closed => {
237 *s = DoorState::Open;
238 d.opens += 1;
239 }
240 DoorState::Open => {
241 *s = DoorState::Closed;
242 }
243 })
244 .with_codec(|_| Ok(vec![0u8]), |_| Ok(DoorEvent::Toggled))
245 }
246
247 #[tokio::test]
248 async fn fsm_transitions_and_recovers() {
249 let journal = Arc::new(InMemoryJournal::default());
250 let permits = RecoveryPermitter::new(1);
251
252 let mut fsm = make_fsm("door-1");
253 fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
254 fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
255 fsm.handle(journal.clone(), DoorCmd::Toggle).await.unwrap();
256 assert_eq!(fsm.state(), Some(&DoorState::Open));
257 assert_eq!(fsm.data().unwrap().opens, 2);
258 assert_eq!(fsm.transitions().len(), 3);
259
260 let mut fsm2 = make_fsm("door-1");
262 fsm2.recover(journal.clone(), &permits).await.unwrap();
263 assert_eq!(fsm2.state(), Some(&DoorState::Open));
264 assert_eq!(fsm2.data().unwrap().opens, 2);
265 }
266
267 #[tokio::test]
268 async fn missing_initial_state_is_typed_error() {
269 let journal = Arc::new(InMemoryJournal::default());
270 let mut fsm: PersistentFSM<DoorState, DoorData, DoorCmd, DoorEvent, E> = PersistentFSM::new("door-2")
271 .on_command(|_, _, _| Ok(vec![DoorEvent::Toggled]))
272 .on_event(|_, _, _| {})
273 .with_codec(|_| Ok(vec![]), |_| Ok(DoorEvent::Toggled));
274 let r = fsm.handle(journal, DoorCmd::Toggle).await;
275 assert!(matches!(r, Err(EventsourcedError::Codec(_))));
276 }
277}