Skip to main content

atomr_persistence/
eventsourced.rs

1//! `Eventsourced` — the modern command/event/state trait.
2//!
3//! Improves on the legacy [`crate::PersistentActor`] in three ways:
4//!
5//! 1. **Typed errors via `thiserror`** — handlers return
6//!    `Result<Vec<Event>, Self::Error>` so domain rejections
7//!    short-circuit without panicking.
8//! 2. **`recovery_completed` lifecycle hook** so actors can warm
9//!    caches / register subscriptions once replay is done.
10//! 3. **Pluggable codec via the trait** — `encode_event` /
11//!    `decode_event` return `Result` and use a configurable codec
12//!    name baked into each `PersistentRepr.manifest`.
13//!
14//! `PersistentActor` remains in place for back-compat; new actors
15//! should target this trait.
16
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use thiserror::Error;
21
22use crate::journal::{Journal, JournalError, PersistentRepr};
23use crate::recovery_permitter::RecoveryPermitter;
24use crate::snapshot::{SnapshotMetadata, SnapshotStore};
25
26/// Recovery / handler errors that propagate out of [`Eventsourced`].
27#[derive(Debug, Error)]
28#[non_exhaustive]
29pub enum EventsourcedError<DomainErr> {
30    #[error("journal error: {0}")]
31    Journal(#[from] JournalError),
32    #[error("codec error: {0}")]
33    Codec(String),
34    #[error("recovery permit acquire failed")]
35    PermitDenied,
36    #[error(transparent)]
37    Domain(DomainErr),
38}
39
40/// Modern event-sourced actor.
41#[async_trait]
42pub trait Eventsourced: Send + 'static {
43    /// User commands received via `handle_command`.
44    type Command: Send + 'static;
45    /// Persisted events derived from commands by `command_to_events`.
46    type Event: Send + Clone + 'static;
47    /// In-memory state mutated by `apply_event`.
48    type State: Default + Send + 'static;
49    /// Domain-level errors a command handler can return.
50    type Error: std::error::Error + Send + 'static;
51
52    /// Stable journal key for this actor instance.
53    fn persistence_id(&self) -> String;
54
55    /// Manifest tag baked into each `PersistentRepr` so cross-version
56    /// replay can dispatch to the right decoder. Defaults to `"evt"`.
57    fn event_manifest(&self) -> &'static str {
58        "evt"
59    }
60
61    /// Pure projection of a command into 0..N events. Validation /
62    /// rejection lives here (`Err(_)` aborts the persist).
63    fn command_to_events(
64        &self,
65        state: &Self::State,
66        cmd: Self::Command,
67    ) -> Result<Vec<Self::Event>, Self::Error>;
68
69    /// Apply a persisted event to in-memory state. Called both during
70    /// recovery (per replayed event) and during normal operation
71    /// (after each successful persist).
72    fn apply_event(state: &mut Self::State, event: &Self::Event);
73
74    /// Encode an event for the journal. Errors short-circuit the
75    /// persist with [`EventsourcedError::Codec`].
76    fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String>;
77
78    /// Decode an event from journal bytes. Symmetric with
79    /// [`Self::encode_event`].
80    fn decode_event(bytes: &[u8]) -> Result<Self::Event, String>;
81
82    /// Lifecycle hook fired after recovery completes. Default no-op.
83    async fn recovery_completed(&mut self, _state: &Self::State, _highest_seq: u64) {}
84
85    // ---- Driver methods (default implementations) -----------------
86
87    /// Replay the journal under a [`RecoveryPermitter`], applying
88    /// each event to `state`. Returns the highest replayed
89    /// sequence number.
90    async fn recover<J: Journal>(
91        &mut self,
92        journal: Arc<J>,
93        state: &mut Self::State,
94        permitter: &RecoveryPermitter,
95    ) -> Result<u64, EventsourcedError<Self::Error>> {
96        let _permit = permitter.acquire().await.ok_or(EventsourcedError::PermitDenied)?;
97        let pid = self.persistence_id();
98        let highest = journal.highest_sequence_nr(&pid, 0).await?;
99        let events = journal.replay_messages(&pid, 1, highest, u64::MAX).await?;
100        for e in &events {
101            let evt = Self::decode_event(&e.payload).map_err(EventsourcedError::Codec)?;
102            Self::apply_event(state, &evt);
103        }
104        // Permit dropped here, freeing a slot for the next recovering
105        // actor before we run the user-facing hook.
106        drop(_permit);
107        self.recovery_completed(state, highest).await;
108        Ok(highest)
109    }
110
111    /// Run a single command — derive events, persist, apply.
112    async fn handle_command<J: Journal>(
113        &self,
114        journal: Arc<J>,
115        state: &mut Self::State,
116        next_seq: &mut u64,
117        writer_uuid: &str,
118        cmd: Self::Command,
119    ) -> Result<(), EventsourcedError<Self::Error>> {
120        let events = self.command_to_events(state, cmd).map_err(EventsourcedError::Domain)?;
121        if events.is_empty() {
122            return Ok(());
123        }
124        let mut reprs = Vec::with_capacity(events.len());
125        for e in &events {
126            *next_seq += 1;
127            let payload = Self::encode_event(e).map_err(EventsourcedError::Codec)?;
128            reprs.push(PersistentRepr {
129                persistence_id: self.persistence_id(),
130                sequence_nr: *next_seq,
131                payload,
132                manifest: self.event_manifest().to_string(),
133                writer_uuid: writer_uuid.into(),
134                deleted: false,
135                tags: Vec::new(),
136            });
137        }
138        journal.write_messages(reprs).await?;
139        for e in &events {
140            Self::apply_event(state, e);
141        }
142        Ok(())
143    }
144
145    /// Save a snapshot of current state under `sequence_nr`.
146    async fn save_snapshot<S: SnapshotStore>(&self, store: Arc<S>, sequence_nr: u64, payload: Vec<u8>) {
147        store
148            .save(
149                SnapshotMetadata { persistence_id: self.persistence_id(), sequence_nr, timestamp: 0 },
150                payload,
151            )
152            .await;
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::{InMemoryJournal, Journal};
160
161    #[derive(Default, Debug, PartialEq)]
162    struct CounterState {
163        n: i64,
164    }
165
166    #[derive(Clone, Debug)]
167    enum CounterEvent {
168        Adjusted(i64),
169    }
170
171    enum CounterCmd {
172        Add(i64),
173        Sub(i64),
174    }
175
176    #[derive(Debug, Error)]
177    enum CounterErr {
178        #[error("would underflow below 0")]
179        Underflow,
180    }
181
182    struct Counter {
183        id: String,
184    }
185
186    #[async_trait]
187    impl Eventsourced for Counter {
188        type Command = CounterCmd;
189        type Event = CounterEvent;
190        type State = CounterState;
191        type Error = CounterErr;
192
193        fn persistence_id(&self) -> String {
194            self.id.clone()
195        }
196
197        fn command_to_events(
198            &self,
199            state: &Self::State,
200            cmd: Self::Command,
201        ) -> Result<Vec<Self::Event>, Self::Error> {
202            let delta = match cmd {
203                CounterCmd::Add(n) => n,
204                CounterCmd::Sub(n) => -n,
205            };
206            if state.n + delta < 0 {
207                return Err(CounterErr::Underflow);
208            }
209            Ok(vec![CounterEvent::Adjusted(delta)])
210        }
211
212        fn apply_event(state: &mut Self::State, event: &Self::Event) {
213            match event {
214                CounterEvent::Adjusted(d) => state.n += d,
215            }
216        }
217
218        fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String> {
219            match event {
220                CounterEvent::Adjusted(d) => Ok(d.to_le_bytes().to_vec()),
221            }
222        }
223
224        fn decode_event(bytes: &[u8]) -> Result<Self::Event, String> {
225            if bytes.len() != 8 {
226                return Err(format!("bad len: {}", bytes.len()));
227            }
228            let mut buf = [0u8; 8];
229            buf.copy_from_slice(bytes);
230            Ok(CounterEvent::Adjusted(i64::from_le_bytes(buf)))
231        }
232    }
233
234    #[tokio::test]
235    async fn happy_path_persist_and_recover() {
236        let journal = Arc::new(InMemoryJournal::default());
237        let permitter = RecoveryPermitter::new(2);
238
239        // First incarnation: persist three commands.
240        let c = Counter { id: "c-1".into() };
241        let mut state = CounterState::default();
242        let mut seq = 0u64;
243        c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Add(5)).await.unwrap();
244        c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Add(3)).await.unwrap();
245        c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Sub(2)).await.unwrap();
246        assert_eq!(state.n, 6);
247        assert_eq!(seq, 3);
248        let highest = journal.highest_sequence_nr("c-1", 0).await.unwrap();
249        assert_eq!(highest, 3);
250
251        // Second incarnation: replay → state should match.
252        let mut c2 = Counter { id: "c-1".into() };
253        let mut state2 = CounterState::default();
254        let h = c2.recover(journal.clone(), &mut state2, &permitter).await.unwrap();
255        assert_eq!(h, 3);
256        assert_eq!(state2.n, 6);
257    }
258
259    #[tokio::test]
260    async fn domain_error_aborts_persist() {
261        let journal = Arc::new(InMemoryJournal::default());
262        let c = Counter { id: "c-2".into() };
263        let mut state = CounterState::default();
264        let mut seq = 0u64;
265        let r = c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Sub(5)).await;
266        assert!(matches!(r, Err(EventsourcedError::Domain(CounterErr::Underflow))));
267        assert_eq!(seq, 0);
268        assert_eq!(journal.highest_sequence_nr("c-2", 0).await.unwrap(), 0);
269    }
270
271    #[tokio::test]
272    async fn recovery_completed_called_once() {
273        struct HookCounter {
274            id: String,
275            hook_calls: Arc<std::sync::atomic::AtomicU32>,
276        }
277        #[async_trait]
278        impl Eventsourced for HookCounter {
279            type Command = ();
280            type Event = ();
281            type State = ();
282            type Error = std::io::Error;
283            fn persistence_id(&self) -> String {
284                self.id.clone()
285            }
286            fn command_to_events(&self, _: &(), _: ()) -> Result<Vec<()>, Self::Error> {
287                Ok(vec![])
288            }
289            fn apply_event(_: &mut (), _: &()) {}
290            fn encode_event(_: &()) -> Result<Vec<u8>, String> {
291                Ok(vec![])
292            }
293            fn decode_event(_: &[u8]) -> Result<(), String> {
294                Ok(())
295            }
296            async fn recovery_completed(&mut self, _: &(), _: u64) {
297                self.hook_calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
298            }
299        }
300        let journal = Arc::new(InMemoryJournal::default());
301        let permitter = RecoveryPermitter::new(1);
302        let calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
303        let mut a = HookCounter { id: "h".into(), hook_calls: calls.clone() };
304        let _ = a.recover(journal, &mut (), &permitter).await.unwrap();
305        assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 1);
306    }
307}