1use std::sync::Arc;
18
19use async_trait::async_trait;
20use thiserror::Error;
21
22use crate::journal::{Journal, JournalError, PersistentRepr};
23use crate::recovery_permitter::RecoveryPermitter;
24use crate::snapshot::{SnapshotMetadata, SnapshotStore};
25
26#[derive(Debug, Error)]
28#[non_exhaustive]
29pub enum EventsourcedError<DomainErr> {
30 #[error("journal error: {0}")]
31 Journal(#[from] JournalError),
32 #[error("codec error: {0}")]
33 Codec(String),
34 #[error("recovery permit acquire failed")]
35 PermitDenied,
36 #[error(transparent)]
37 Domain(DomainErr),
38}
39
40#[async_trait]
42pub trait Eventsourced: Send + 'static {
43 type Command: Send + 'static;
45 type Event: Send + Clone + 'static;
47 type State: Default + Send + 'static;
49 type Error: std::error::Error + Send + 'static;
51
52 fn persistence_id(&self) -> String;
54
55 fn event_manifest(&self) -> &'static str {
58 "evt"
59 }
60
61 fn command_to_events(
64 &self,
65 state: &Self::State,
66 cmd: Self::Command,
67 ) -> Result<Vec<Self::Event>, Self::Error>;
68
69 fn apply_event(state: &mut Self::State, event: &Self::Event);
73
74 fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String>;
77
78 fn decode_event(bytes: &[u8]) -> Result<Self::Event, String>;
81
82 async fn recovery_completed(&mut self, _state: &Self::State, _highest_seq: u64) {}
84
85 async fn recover<J: Journal>(
91 &mut self,
92 journal: Arc<J>,
93 state: &mut Self::State,
94 permitter: &RecoveryPermitter,
95 ) -> Result<u64, EventsourcedError<Self::Error>> {
96 let _permit = permitter.acquire().await.ok_or(EventsourcedError::PermitDenied)?;
97 let pid = self.persistence_id();
98 let highest = journal.highest_sequence_nr(&pid, 0).await?;
99 let events = journal.replay_messages(&pid, 1, highest, u64::MAX).await?;
100 for e in &events {
101 let evt = Self::decode_event(&e.payload).map_err(EventsourcedError::Codec)?;
102 Self::apply_event(state, &evt);
103 }
104 drop(_permit);
107 self.recovery_completed(state, highest).await;
108 Ok(highest)
109 }
110
111 async fn handle_command<J: Journal>(
113 &self,
114 journal: Arc<J>,
115 state: &mut Self::State,
116 next_seq: &mut u64,
117 writer_uuid: &str,
118 cmd: Self::Command,
119 ) -> Result<(), EventsourcedError<Self::Error>> {
120 let events = self.command_to_events(state, cmd).map_err(EventsourcedError::Domain)?;
121 if events.is_empty() {
122 return Ok(());
123 }
124 let mut reprs = Vec::with_capacity(events.len());
125 for e in &events {
126 *next_seq += 1;
127 let payload = Self::encode_event(e).map_err(EventsourcedError::Codec)?;
128 reprs.push(PersistentRepr {
129 persistence_id: self.persistence_id(),
130 sequence_nr: *next_seq,
131 payload,
132 manifest: self.event_manifest().to_string(),
133 writer_uuid: writer_uuid.into(),
134 deleted: false,
135 tags: Vec::new(),
136 });
137 }
138 journal.write_messages(reprs).await?;
139 for e in &events {
140 Self::apply_event(state, e);
141 }
142 Ok(())
143 }
144
145 async fn save_snapshot<S: SnapshotStore>(&self, store: Arc<S>, sequence_nr: u64, payload: Vec<u8>) {
147 store
148 .save(
149 SnapshotMetadata { persistence_id: self.persistence_id(), sequence_nr, timestamp: 0 },
150 payload,
151 )
152 .await;
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::{InMemoryJournal, Journal};
160
161 #[derive(Default, Debug, PartialEq)]
162 struct CounterState {
163 n: i64,
164 }
165
166 #[derive(Clone, Debug)]
167 enum CounterEvent {
168 Adjusted(i64),
169 }
170
171 enum CounterCmd {
172 Add(i64),
173 Sub(i64),
174 }
175
176 #[derive(Debug, Error)]
177 enum CounterErr {
178 #[error("would underflow below 0")]
179 Underflow,
180 }
181
182 struct Counter {
183 id: String,
184 }
185
186 #[async_trait]
187 impl Eventsourced for Counter {
188 type Command = CounterCmd;
189 type Event = CounterEvent;
190 type State = CounterState;
191 type Error = CounterErr;
192
193 fn persistence_id(&self) -> String {
194 self.id.clone()
195 }
196
197 fn command_to_events(
198 &self,
199 state: &Self::State,
200 cmd: Self::Command,
201 ) -> Result<Vec<Self::Event>, Self::Error> {
202 let delta = match cmd {
203 CounterCmd::Add(n) => n,
204 CounterCmd::Sub(n) => -n,
205 };
206 if state.n + delta < 0 {
207 return Err(CounterErr::Underflow);
208 }
209 Ok(vec![CounterEvent::Adjusted(delta)])
210 }
211
212 fn apply_event(state: &mut Self::State, event: &Self::Event) {
213 match event {
214 CounterEvent::Adjusted(d) => state.n += d,
215 }
216 }
217
218 fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String> {
219 match event {
220 CounterEvent::Adjusted(d) => Ok(d.to_le_bytes().to_vec()),
221 }
222 }
223
224 fn decode_event(bytes: &[u8]) -> Result<Self::Event, String> {
225 if bytes.len() != 8 {
226 return Err(format!("bad len: {}", bytes.len()));
227 }
228 let mut buf = [0u8; 8];
229 buf.copy_from_slice(bytes);
230 Ok(CounterEvent::Adjusted(i64::from_le_bytes(buf)))
231 }
232 }
233
234 #[tokio::test]
235 async fn happy_path_persist_and_recover() {
236 let journal = Arc::new(InMemoryJournal::default());
237 let permitter = RecoveryPermitter::new(2);
238
239 let c = Counter { id: "c-1".into() };
241 let mut state = CounterState::default();
242 let mut seq = 0u64;
243 c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Add(5)).await.unwrap();
244 c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Add(3)).await.unwrap();
245 c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Sub(2)).await.unwrap();
246 assert_eq!(state.n, 6);
247 assert_eq!(seq, 3);
248 let highest = journal.highest_sequence_nr("c-1", 0).await.unwrap();
249 assert_eq!(highest, 3);
250
251 let mut c2 = Counter { id: "c-1".into() };
253 let mut state2 = CounterState::default();
254 let h = c2.recover(journal.clone(), &mut state2, &permitter).await.unwrap();
255 assert_eq!(h, 3);
256 assert_eq!(state2.n, 6);
257 }
258
259 #[tokio::test]
260 async fn domain_error_aborts_persist() {
261 let journal = Arc::new(InMemoryJournal::default());
262 let c = Counter { id: "c-2".into() };
263 let mut state = CounterState::default();
264 let mut seq = 0u64;
265 let r = c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Sub(5)).await;
266 assert!(matches!(r, Err(EventsourcedError::Domain(CounterErr::Underflow))));
267 assert_eq!(seq, 0);
268 assert_eq!(journal.highest_sequence_nr("c-2", 0).await.unwrap(), 0);
269 }
270
271 #[tokio::test]
272 async fn recovery_completed_called_once() {
273 struct HookCounter {
274 id: String,
275 hook_calls: Arc<std::sync::atomic::AtomicU32>,
276 }
277 #[async_trait]
278 impl Eventsourced for HookCounter {
279 type Command = ();
280 type Event = ();
281 type State = ();
282 type Error = std::io::Error;
283 fn persistence_id(&self) -> String {
284 self.id.clone()
285 }
286 fn command_to_events(&self, _: &(), _: ()) -> Result<Vec<()>, Self::Error> {
287 Ok(vec![])
288 }
289 fn apply_event(_: &mut (), _: &()) {}
290 fn encode_event(_: &()) -> Result<Vec<u8>, String> {
291 Ok(vec![])
292 }
293 fn decode_event(_: &[u8]) -> Result<(), String> {
294 Ok(())
295 }
296 async fn recovery_completed(&mut self, _: &(), _: u64) {
297 self.hook_calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
298 }
299 }
300 let journal = Arc::new(InMemoryJournal::default());
301 let permitter = RecoveryPermitter::new(1);
302 let calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
303 let mut a = HookCounter { id: "h".into(), hook_calls: calls.clone() };
304 let _ = a.recover(journal, &mut (), &permitter).await.unwrap();
305 assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 1);
306 }
307}