1use std::sync::Arc;
26
27use crate::eventsourced::EventsourcedError;
28use crate::journal::{Journal, PersistentRepr};
29use crate::recovery_permitter::RecoveryPermitter;
30
31type CommandFn<S, C, E, Err> = Box<dyn FnMut(&S, C) -> Result<Vec<E>, Err> + Send>;
32type EventFn<S, E> = Box<dyn FnMut(&mut S, &E) + Send>;
33type EncodeFn<E> = Box<dyn Fn(&E) -> Result<Vec<u8>, String> + Send + Sync>;
34type DecodeFn<E> = Box<dyn Fn(&[u8]) -> Result<E, String> + Send + Sync>;
35
36pub struct ReceivePersistent<S, E, Err>
38where
39 S: Default + Send + 'static,
40 E: Clone + Send + 'static,
41 Err: std::error::Error + Send + 'static,
42{
43 persistence_id: String,
44 state: S,
45 next_seq: u64,
46 writer_uuid: String,
47 on_command: Option<CommandFn<S, E, E, Err>>,
48 on_event: Option<EventFn<S, E>>,
49 encode: Option<EncodeFn<E>>,
50 decode: Option<DecodeFn<E>>,
51}
52
53impl<S, E, Err> ReceivePersistent<S, E, Err>
54where
55 S: Default + Send + 'static,
56 E: Clone + Send + 'static,
57 Err: std::error::Error + Send + 'static,
58{
59 pub fn new(persistence_id: impl Into<String>) -> Self {
60 Self {
61 persistence_id: persistence_id.into(),
62 state: S::default(),
63 next_seq: 0,
64 writer_uuid: format!("{}-{}", std::process::id(), uuid_v4_simple()),
65 on_command: None,
66 on_event: None,
67 encode: None,
68 decode: None,
69 }
70 }
71
72 pub fn on_command<F>(mut self, f: F) -> Self
77 where
78 F: FnMut(&S, E) -> Result<Vec<E>, Err> + Send + 'static,
79 {
80 self.on_command = Some(Box::new(f));
81 self
82 }
83
84 pub fn on_event<F>(mut self, f: F) -> Self
86 where
87 F: FnMut(&mut S, &E) + Send + 'static,
88 {
89 self.on_event = Some(Box::new(f));
90 self
91 }
92
93 pub fn with_codec<EncF, DecF>(mut self, encode: EncF, decode: DecF) -> Self
95 where
96 EncF: Fn(&E) -> Result<Vec<u8>, String> + Send + Sync + 'static,
97 DecF: Fn(&[u8]) -> Result<E, String> + Send + Sync + 'static,
98 {
99 self.encode = Some(Box::new(encode));
100 self.decode = Some(Box::new(decode));
101 self
102 }
103
104 pub fn state(&self) -> &S {
105 &self.state
106 }
107
108 pub fn persistence_id(&self) -> &str {
109 &self.persistence_id
110 }
111
112 pub async fn recover<J: Journal>(
114 &mut self,
115 journal: Arc<J>,
116 permitter: &RecoveryPermitter,
117 ) -> Result<u64, EventsourcedError<Err>> {
118 let _permit = permitter.acquire().await.ok_or(EventsourcedError::PermitDenied)?;
119 let on_event = self
120 .on_event
121 .as_mut()
122 .ok_or_else(|| EventsourcedError::Codec("on_event handler not registered".into()))?;
123 let decode =
124 self.decode.as_ref().ok_or_else(|| EventsourcedError::Codec("decoder not registered".into()))?;
125 let highest = journal.highest_sequence_nr(&self.persistence_id, 0).await?;
126 let events = journal.replay_messages(&self.persistence_id, 1, highest, u64::MAX).await?;
127 for e in &events {
128 let evt = decode(&e.payload).map_err(EventsourcedError::Codec)?;
129 on_event(&mut self.state, &evt);
130 }
131 self.next_seq = highest;
132 Ok(highest)
133 }
134
135 pub async fn handle<J: Journal>(
137 &mut self,
138 journal: Arc<J>,
139 cmd: E,
140 ) -> Result<(), EventsourcedError<Err>> {
141 let on_cmd = self
142 .on_command
143 .as_mut()
144 .ok_or_else(|| EventsourcedError::Codec("on_command handler not registered".into()))?;
145 let events = on_cmd(&self.state, cmd).map_err(EventsourcedError::Domain)?;
146 if events.is_empty() {
147 return Ok(());
148 }
149 let on_event = self
150 .on_event
151 .as_mut()
152 .ok_or_else(|| EventsourcedError::Codec("on_event handler not registered".into()))?;
153 let encode =
154 self.encode.as_ref().ok_or_else(|| EventsourcedError::Codec("encoder not registered".into()))?;
155 let mut reprs = Vec::with_capacity(events.len());
156 for e in &events {
157 self.next_seq += 1;
158 let payload = encode(e).map_err(EventsourcedError::Codec)?;
159 reprs.push(PersistentRepr {
160 persistence_id: self.persistence_id.clone(),
161 sequence_nr: self.next_seq,
162 payload,
163 manifest: "evt".into(),
164 writer_uuid: self.writer_uuid.clone(),
165 deleted: false,
166 tags: Vec::new(),
167 });
168 }
169 journal.write_messages(reprs).await?;
170 for e in &events {
171 on_event(&mut self.state, e);
172 }
173 Ok(())
174 }
175}
176
177fn uuid_v4_simple() -> String {
178 use std::time::{SystemTime, UNIX_EPOCH};
182 let nanos = SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_nanos()).unwrap_or(0);
183 format!("{nanos:x}")
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::InMemoryJournal;
190
191 #[derive(Debug, thiserror::Error)]
192 #[error("dummy")]
193 struct DummyErr;
194
195 #[tokio::test]
196 async fn closure_actor_persists_and_recovers() {
197 let journal = Arc::new(InMemoryJournal::default());
198 let permits = RecoveryPermitter::new(1);
199
200 let mut rp: ReceivePersistent<i64, i64, DummyErr> = ReceivePersistent::new("pid-1")
201 .on_command(|_state, cmd| Ok(vec![cmd]))
202 .on_event(|state, evt| {
203 *state += evt;
204 })
205 .with_codec(
206 |e: &i64| Ok(e.to_le_bytes().to_vec()),
207 |b: &[u8]| {
208 let arr: [u8; 8] = b.try_into().map_err(|_| "len".to_string())?;
209 Ok(i64::from_le_bytes(arr))
210 },
211 );
212
213 rp.handle(journal.clone(), 5).await.unwrap();
214 rp.handle(journal.clone(), 3).await.unwrap();
215 rp.handle(journal.clone(), -2).await.unwrap();
216 assert_eq!(rp.state(), &6);
217
218 let mut rp2: ReceivePersistent<i64, i64, DummyErr> = ReceivePersistent::new("pid-1")
220 .on_command(|_state, cmd| Ok(vec![cmd]))
221 .on_event(|state, evt| {
222 *state += evt;
223 })
224 .with_codec(
225 |e: &i64| Ok(e.to_le_bytes().to_vec()),
226 |b: &[u8]| {
227 let arr: [u8; 8] = b.try_into().map_err(|_| "len".to_string())?;
228 Ok(i64::from_le_bytes(arr))
229 },
230 );
231 rp2.recover(journal.clone(), &permits).await.unwrap();
232 assert_eq!(rp2.state(), &6);
233 }
234
235 #[tokio::test]
236 async fn missing_codec_is_a_typed_error() {
237 let journal = Arc::new(InMemoryJournal::default());
238 let mut rp: ReceivePersistent<i64, i64, DummyErr> =
239 ReceivePersistent::new("pid-2").on_command(|_, c| Ok(vec![c])).on_event(|s, e| {
240 *s += e;
241 });
242 let r = rp.handle(journal, 1).await;
243 assert!(matches!(r, Err(EventsourcedError::Codec(_))));
244 }
245}