Skip to main content

atomr_persistence/
eventsourced.rs

1//! `Eventsourced` — the modern command/event/state trait.
2//!
3//! Phase 11 of `docs/full-port-plan.md`. Akka.NET parity:
4//! `Akka.Persistence.Eventsourced`. Improves on the legacy
5//! [`crate::PersistentActor`] in three ways:
6//!
7//! 1. **Typed errors via `thiserror`** — handlers return
8//!    `Result<Vec<Event>, Self::Error>` so domain rejections
9//!    short-circuit without panicking.
10//! 2. **`recovery_completed` lifecycle hook** so actors can warm
11//!    caches / register subscriptions once replay is done.
12//! 3. **Pluggable codec via the trait** — `encode_event` /
13//!    `decode_event` return `Result` and use a configurable codec
14//!    name baked into each `PersistentRepr.manifest`.
15//!
16//! `PersistentActor` remains in place for back-compat; new actors
17//! should target this trait.
18
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use thiserror::Error;
23
24use crate::journal::{Journal, JournalError, PersistentRepr};
25use crate::recovery_permitter::RecoveryPermitter;
26use crate::snapshot::{SnapshotMetadata, SnapshotStore};
27
28/// Recovery / handler errors that propagate out of [`Eventsourced`].
29#[derive(Debug, Error)]
30#[non_exhaustive]
31pub enum EventsourcedError<DomainErr> {
32    #[error("journal error: {0}")]
33    Journal(#[from] JournalError),
34    #[error("codec error: {0}")]
35    Codec(String),
36    #[error("recovery permit acquire failed")]
37    PermitDenied,
38    #[error(transparent)]
39    Domain(DomainErr),
40}
41
42/// Modern event-sourced actor.
43#[async_trait]
44pub trait Eventsourced: Send + 'static {
45    /// User commands received via `handle_command`.
46    type Command: Send + 'static;
47    /// Persisted events derived from commands by `command_to_events`.
48    type Event: Send + Clone + 'static;
49    /// In-memory state mutated by `apply_event`.
50    type State: Default + Send + 'static;
51    /// Domain-level errors a command handler can return.
52    type Error: std::error::Error + Send + 'static;
53
54    /// Stable journal key for this actor instance.
55    fn persistence_id(&self) -> String;
56
57    /// Manifest tag baked into each `PersistentRepr` so cross-version
58    /// replay can dispatch to the right decoder. Defaults to `"evt"`.
59    fn event_manifest(&self) -> &'static str {
60        "evt"
61    }
62
63    /// Pure projection of a command into 0..N events. Validation /
64    /// rejection lives here (`Err(_)` aborts the persist).
65    fn command_to_events(
66        &self,
67        state: &Self::State,
68        cmd: Self::Command,
69    ) -> Result<Vec<Self::Event>, Self::Error>;
70
71    /// Apply a persisted event to in-memory state. Called both during
72    /// recovery (per replayed event) and during normal operation
73    /// (after each successful persist).
74    fn apply_event(state: &mut Self::State, event: &Self::Event);
75
76    /// Encode an event for the journal. Errors short-circuit the
77    /// persist with [`EventsourcedError::Codec`].
78    fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String>;
79
80    /// Decode an event from journal bytes. Symmetric with
81    /// [`Self::encode_event`].
82    fn decode_event(bytes: &[u8]) -> Result<Self::Event, String>;
83
84    /// Lifecycle hook fired after recovery completes. Default no-op.
85    async fn recovery_completed(&mut self, _state: &Self::State, _highest_seq: u64) {}
86
87    // ---- Driver methods (default implementations) -----------------
88
89    /// Replay the journal under a [`RecoveryPermitter`], applying
90    /// each event to `state`. Returns the highest replayed
91    /// sequence number.
92    async fn recover<J: Journal>(
93        &mut self,
94        journal: Arc<J>,
95        state: &mut Self::State,
96        permitter: &RecoveryPermitter,
97    ) -> Result<u64, EventsourcedError<Self::Error>> {
98        let _permit = permitter.acquire().await.ok_or(EventsourcedError::PermitDenied)?;
99        let pid = self.persistence_id();
100        let highest = journal.highest_sequence_nr(&pid, 0).await?;
101        let events = journal.replay_messages(&pid, 1, highest, u64::MAX).await?;
102        for e in &events {
103            let evt = Self::decode_event(&e.payload).map_err(EventsourcedError::Codec)?;
104            Self::apply_event(state, &evt);
105        }
106        // Permit dropped here, freeing a slot for the next recovering
107        // actor before we run the user-facing hook.
108        drop(_permit);
109        self.recovery_completed(state, highest).await;
110        Ok(highest)
111    }
112
113    /// Run a single command — derive events, persist, apply.
114    async fn handle_command<J: Journal>(
115        &self,
116        journal: Arc<J>,
117        state: &mut Self::State,
118        next_seq: &mut u64,
119        writer_uuid: &str,
120        cmd: Self::Command,
121    ) -> Result<(), EventsourcedError<Self::Error>> {
122        let events = self.command_to_events(state, cmd).map_err(EventsourcedError::Domain)?;
123        if events.is_empty() {
124            return Ok(());
125        }
126        let mut reprs = Vec::with_capacity(events.len());
127        for e in &events {
128            *next_seq += 1;
129            let payload = Self::encode_event(e).map_err(EventsourcedError::Codec)?;
130            reprs.push(PersistentRepr {
131                persistence_id: self.persistence_id(),
132                sequence_nr: *next_seq,
133                payload,
134                manifest: self.event_manifest().to_string(),
135                writer_uuid: writer_uuid.into(),
136                deleted: false,
137                tags: Vec::new(),
138            });
139        }
140        journal.write_messages(reprs).await?;
141        for e in &events {
142            Self::apply_event(state, e);
143        }
144        Ok(())
145    }
146
147    /// Save a snapshot of current state under `sequence_nr`.
148    async fn save_snapshot<S: SnapshotStore>(&self, store: Arc<S>, sequence_nr: u64, payload: Vec<u8>) {
149        store
150            .save(
151                SnapshotMetadata { persistence_id: self.persistence_id(), sequence_nr, timestamp: 0 },
152                payload,
153            )
154            .await;
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::{InMemoryJournal, Journal};
162
163    #[derive(Default, Debug, PartialEq)]
164    struct CounterState {
165        n: i64,
166    }
167
168    #[derive(Clone, Debug)]
169    enum CounterEvent {
170        Adjusted(i64),
171    }
172
173    enum CounterCmd {
174        Add(i64),
175        Sub(i64),
176    }
177
178    #[derive(Debug, Error)]
179    enum CounterErr {
180        #[error("would underflow below 0")]
181        Underflow,
182    }
183
184    struct Counter {
185        id: String,
186    }
187
188    #[async_trait]
189    impl Eventsourced for Counter {
190        type Command = CounterCmd;
191        type Event = CounterEvent;
192        type State = CounterState;
193        type Error = CounterErr;
194
195        fn persistence_id(&self) -> String {
196            self.id.clone()
197        }
198
199        fn command_to_events(
200            &self,
201            state: &Self::State,
202            cmd: Self::Command,
203        ) -> Result<Vec<Self::Event>, Self::Error> {
204            let delta = match cmd {
205                CounterCmd::Add(n) => n,
206                CounterCmd::Sub(n) => -n,
207            };
208            if state.n + delta < 0 {
209                return Err(CounterErr::Underflow);
210            }
211            Ok(vec![CounterEvent::Adjusted(delta)])
212        }
213
214        fn apply_event(state: &mut Self::State, event: &Self::Event) {
215            match event {
216                CounterEvent::Adjusted(d) => state.n += d,
217            }
218        }
219
220        fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String> {
221            match event {
222                CounterEvent::Adjusted(d) => Ok(d.to_le_bytes().to_vec()),
223            }
224        }
225
226        fn decode_event(bytes: &[u8]) -> Result<Self::Event, String> {
227            if bytes.len() != 8 {
228                return Err(format!("bad len: {}", bytes.len()));
229            }
230            let mut buf = [0u8; 8];
231            buf.copy_from_slice(bytes);
232            Ok(CounterEvent::Adjusted(i64::from_le_bytes(buf)))
233        }
234    }
235
236    #[tokio::test]
237    async fn happy_path_persist_and_recover() {
238        let journal = Arc::new(InMemoryJournal::default());
239        let permitter = RecoveryPermitter::new(2);
240
241        // First incarnation: persist three commands.
242        let c = Counter { id: "c-1".into() };
243        let mut state = CounterState::default();
244        let mut seq = 0u64;
245        c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Add(5)).await.unwrap();
246        c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Add(3)).await.unwrap();
247        c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Sub(2)).await.unwrap();
248        assert_eq!(state.n, 6);
249        assert_eq!(seq, 3);
250        let highest = journal.highest_sequence_nr("c-1", 0).await.unwrap();
251        assert_eq!(highest, 3);
252
253        // Second incarnation: replay → state should match.
254        let mut c2 = Counter { id: "c-1".into() };
255        let mut state2 = CounterState::default();
256        let h = c2.recover(journal.clone(), &mut state2, &permitter).await.unwrap();
257        assert_eq!(h, 3);
258        assert_eq!(state2.n, 6);
259    }
260
261    #[tokio::test]
262    async fn domain_error_aborts_persist() {
263        let journal = Arc::new(InMemoryJournal::default());
264        let c = Counter { id: "c-2".into() };
265        let mut state = CounterState::default();
266        let mut seq = 0u64;
267        let r = c.handle_command(journal.clone(), &mut state, &mut seq, "w", CounterCmd::Sub(5)).await;
268        assert!(matches!(r, Err(EventsourcedError::Domain(CounterErr::Underflow))));
269        assert_eq!(seq, 0);
270        assert_eq!(journal.highest_sequence_nr("c-2", 0).await.unwrap(), 0);
271    }
272
273    #[tokio::test]
274    async fn recovery_completed_called_once() {
275        struct HookCounter {
276            id: String,
277            hook_calls: Arc<std::sync::atomic::AtomicU32>,
278        }
279        #[async_trait]
280        impl Eventsourced for HookCounter {
281            type Command = ();
282            type Event = ();
283            type State = ();
284            type Error = std::io::Error;
285            fn persistence_id(&self) -> String {
286                self.id.clone()
287            }
288            fn command_to_events(&self, _: &(), _: ()) -> Result<Vec<()>, Self::Error> {
289                Ok(vec![])
290            }
291            fn apply_event(_: &mut (), _: &()) {}
292            fn encode_event(_: &()) -> Result<Vec<u8>, String> {
293                Ok(vec![])
294            }
295            fn decode_event(_: &[u8]) -> Result<(), String> {
296                Ok(())
297            }
298            async fn recovery_completed(&mut self, _: &(), _: u64) {
299                self.hook_calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
300            }
301        }
302        let journal = Arc::new(InMemoryJournal::default());
303        let permitter = RecoveryPermitter::new(1);
304        let calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
305        let mut a = HookCounter { id: "h".into(), hook_calls: calls.clone() };
306        let _ = a.recover(journal, &mut (), &permitter).await.unwrap();
307        assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 1);
308    }
309}