1use std::sync::Arc;
20
21use async_trait::async_trait;
22use thiserror::Error;
23
24use crate::journal::{Journal, JournalError, PersistentRepr};
25use crate::recovery_permitter::RecoveryPermitter;
26use crate::snapshot::{SnapshotMetadata, SnapshotStore};
27
28#[derive(Debug, Error)]
30#[non_exhaustive]
31pub enum EventsourcedError<DomainErr> {
32 #[error("journal error: {0}")]
33 Journal(#[from] JournalError),
34 #[error("codec error: {0}")]
35 Codec(String),
36 #[error("recovery permit acquire failed")]
37 PermitDenied,
38 #[error(transparent)]
39 Domain(DomainErr),
40}
41
42#[async_trait]
44pub trait Eventsourced: Send + 'static {
45 type Command: Send + 'static;
47 type Event: Send + Clone + 'static;
49 type State: Default + Send + 'static;
51 type Error: std::error::Error + Send + 'static;
53
54 fn persistence_id(&self) -> String;
56
57 fn event_manifest(&self) -> &'static str {
60 "evt"
61 }
62
63 fn command_to_events(
66 &self,
67 state: &Self::State,
68 cmd: Self::Command,
69 ) -> Result<Vec<Self::Event>, Self::Error>;
70
71 fn apply_event(state: &mut Self::State, event: &Self::Event);
75
76 fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String>;
79
80 fn decode_event(bytes: &[u8]) -> Result<Self::Event, String>;
83
84 async fn recovery_completed(&mut self, _state: &Self::State, _highest_seq: u64) {}
86
87 async fn recover<J: Journal>(
93 &mut self,
94 journal: Arc<J>,
95 state: &mut Self::State,
96 permitter: &RecoveryPermitter,
97 ) -> Result<u64, EventsourcedError<Self::Error>> {
98 let _permit = permitter.acquire().await.ok_or(EventsourcedError::PermitDenied)?;
99 let pid = self.persistence_id();
100 let highest = journal.highest_sequence_nr(&pid, 0).await?;
101 let events = journal.replay_messages(&pid, 1, highest, u64::MAX).await?;
102 for e in &events {
103 let evt = Self::decode_event(&e.payload).map_err(EventsourcedError::Codec)?;
104 Self::apply_event(state, &evt);
105 }
106 drop(_permit);
109 self.recovery_completed(state, highest).await;
110 Ok(highest)
111 }
112
113 async fn handle_command<J: Journal>(
115 &self,
116 journal: Arc<J>,
117 state: &mut Self::State,
118 next_seq: &mut u64,
119 writer_uuid: &str,
120 cmd: Self::Command,
121 ) -> Result<(), EventsourcedError<Self::Error>> {
122 let events = self.command_to_events(state, cmd).map_err(EventsourcedError::Domain)?;
123 if events.is_empty() {
124 return Ok(());
125 }
126 let mut reprs = Vec::with_capacity(events.len());
127 for e in &events {
128 *next_seq += 1;
129 let payload = Self::encode_event(e).map_err(EventsourcedError::Codec)?;
130 reprs.push(PersistentRepr {
131 persistence_id: self.persistence_id(),
132 sequence_nr: *next_seq,
133 payload,
134 manifest: self.event_manifest().to_string(),
135 writer_uuid: writer_uuid.into(),
136 deleted: false,
137 tags: Vec::new(),
138 });
139 }
140 journal.write_messages(reprs).await?;
141 for e in &events {
142 Self::apply_event(state, e);
143 }
144 Ok(())
145 }
146
147 async fn save_snapshot<S: SnapshotStore>(&self, store: Arc<S>, sequence_nr: u64, payload: Vec<u8>) {
149 store
150 .save(
151 SnapshotMetadata { persistence_id: self.persistence_id(), sequence_nr, timestamp: 0 },
152 payload,
153 )
154 .await;
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::{InMemoryJournal, Journal};
162
163 #[derive(Default, Debug, PartialEq)]
164 struct CounterState {
165 n: i64,
166 }
167
168 #[derive(Clone, Debug)]
169 enum CounterEvent {
170 Adjusted(i64),
171 }
172
173 enum CounterCmd {
174 Add(i64),
175 Sub(i64),
176 }
177
178 #[derive(Debug, Error)]
179 enum CounterErr {
180 #[error("would underflow below 0")]
181 Underflow,
182 }
183
184 struct Counter {
185 id: String,
186 }
187
188 #[async_trait]
189 impl Eventsourced for Counter {
190 type Command = CounterCmd;
191 type Event = CounterEvent;
192 type State = CounterState;
193 type Error = CounterErr;
194
195 fn persistence_id(&self) -> String {
196 self.id.clone()
197 }
198
199 fn command_to_events(
200 &self,
201 state: &Self::State,
202 cmd: Self::Command,
203 ) -> Result<Vec<Self::Event>, Self::Error> {
204 let delta = match cmd {
205 CounterCmd::Add(n) => n,
206 CounterCmd::Sub(n) => -n,
207 };
208 if state.n + delta < 0 {
209 return Err(CounterErr::Underflow);
210 }
211 Ok(vec![CounterEvent::Adjusted(delta)])
212 }
213
214 fn apply_event(state: &mut Self::State, event: &Self::Event) {
215 match event {
216 CounterEvent::Adjusted(d) => state.n += d,
217 }
218 }
219
220 fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String> {
221 match event {
222 CounterEvent::Adjusted(d) => Ok(d.to_le_bytes().to_vec()),
223 }
224 }
225
226 fn decode_event(bytes: &[u8]) -> Result<Self::Event, String> {
227 if bytes.len() != 8 {
228 return Err(format!("bad len: {}", bytes.len()));
229 }
230 let mut buf = [0u8; 8];
231 buf.copy_from_slice(bytes);
232 Ok(CounterEvent::Adjusted(i64::from_le_bytes(buf)))
233 }
234 }
235
236 #[tokio::test]
237 async fn happy_path_persist_and_recover() {
238 let journal = Arc::new(InMemoryJournal::default());
239 let permitter = RecoveryPermitter::new(2);
240
241 let c = Counter { id: "c-1".into() };
243 let mut state = CounterState::default();
244 let mut seq = 0u64;
245 c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Add(5)).await.unwrap();
246 c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Add(3)).await.unwrap();
247 c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Sub(2)).await.unwrap();
248 assert_eq!(state.n, 6);
249 assert_eq!(seq, 3);
250 let highest = journal.highest_sequence_nr("c-1", 0).await.unwrap();
251 assert_eq!(highest, 3);
252
253 let mut c2 = Counter { id: "c-1".into() };
255 let mut state2 = CounterState::default();
256 let h = c2.recover(journal.clone(), &mut state2, &permitter).await.unwrap();
257 assert_eq!(h, 3);
258 assert_eq!(state2.n, 6);
259 }
260
261 #[tokio::test]
262 async fn domain_error_aborts_persist() {
263 let journal = Arc::new(InMemoryJournal::default());
264 let c = Counter { id: "c-2".into() };
265 let mut state = CounterState::default();
266 let mut seq = 0u64;
267 let r = c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Sub(5)).await;
268 assert!(matches!(r, Err(EventsourcedError::Domain(CounterErr::Underflow))));
269 assert_eq!(seq, 0);
270 assert_eq!(journal.highest_sequence_nr("c-2", 0).await.unwrap(), 0);
271 }
272
273 #[tokio::test]
274 async fn recovery_completed_called_once() {
275 struct HookCounter {
276 id: String,
277 hook_calls: Arc<std::sync::atomic::AtomicU32>,
278 }
279 #[async_trait]
280 impl Eventsourced for HookCounter {
281 type Command = ();
282 type Event = ();
283 type State = ();
284 type Error = std::io::Error;
285 fn persistence_id(&self) -> String {
286 self.id.clone()
287 }
288 fn command_to_events(&self, _: &(), _: ()) -> Result<Vec<()>, Self::Error> {
289 Ok(vec![])
290 }
291 fn apply_event(_: &mut (), _: &()) {}
292 fn encode_event(_: &()) -> Result<Vec<u8>, String> {
293 Ok(vec![])
294 }
295 fn decode_event(_: &[u8]) -> Result<(), String> {
296 Ok(())
297 }
298 async fn recovery_completed(&mut self, _: &(), _: u64) {
299 self.hook_calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
300 }
301 }
302 let journal = Arc::new(InMemoryJournal::default());
303 let permitter = RecoveryPermitter::new(1);
304 let calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
305 let mut a = HookCounter { id: "h".into(), hook_calls: calls.clone() };
306 let _ = a.recover(journal, &mut (), &permitter).await.unwrap();
307 assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 1);
308 }
309}