Skip to main content

eventsourced_projection/
postgres.rs

1use error_ext::StdErrorExt;
2use eventsourced::{binarize, event_log::EventLog};
3use futures::StreamExt;
4use serde::{Deserialize, Serialize};
5use sqlx::{Pool, Postgres, Row, Transaction};
6use std::{
7    error::Error as StdError,
8    fmt::Debug,
9    num::{NonZeroU64, TryFromIntError},
10    pin::pin,
11    sync::Arc,
12    time::Duration,
13};
14use thiserror::Error;
15use tokio::{
16    sync::{mpsc, oneshot, RwLock},
17    task,
18    time::sleep,
19};
20use tracing::{debug, error, info};
21
22/// A projection of events of an event sourced entity to a Postgres database.
23#[derive(Debug, Clone)]
24pub struct Projection {
25    name: String,
26    command_in: mpsc::Sender<(Command, oneshot::Sender<State>)>,
27}
28
29impl Projection {
30    pub async fn new<E, L, H>(
31        type_name: &'static str,
32        name: String,
33        event_log: L,
34        event_handler: H,
35        error_strategy: ErrorStrategy,
36        pool: Pool<Postgres>,
37    ) -> Result<Self, Error>
38    where
39        E: for<'de> Deserialize<'de> + Send + 'static,
40        L: EventLog + Sync,
41        H: EventHandler<E> + Clone + Send + Sync + 'static,
42    {
43        sqlx::query(include_str!("create_projection.sql"))
44            .execute(&pool)
45            .await
46            .expect("create projection table");
47
48        let seq_no = load_seq_no(&name, &pool).await?;
49
50        let state = Arc::new(RwLock::new(State {
51            seq_no,
52            running: false,
53            error: None,
54        }));
55
56        let (command_in, mut command_out) = mpsc::channel::<(Command, oneshot::Sender<State>)>(1);
57
58        task::spawn({
59            let name = name.clone();
60            let state = state.clone();
61
62            async move {
63                while let Some((command, reply_in)) = command_out.recv().await {
64                    match command {
65                        Command::Run => {
66                            // Do not remove braces, dead-lock is waiting for you!
67                            let running = { state.read().await.running };
68                            if running {
69                                info!(type_name, name, "projection already running");
70                            } else {
71                                info!(type_name, name, "running projection");
72
73                                // Do not remove braces, dead-lock is waiting for you!
74                                {
75                                    let mut state = state.write().await;
76                                    state.running = true;
77                                    state.error = None;
78                                }
79
80                                run_projection_loop(
81                                    type_name,
82                                    name.clone(),
83                                    state.clone(),
84                                    event_log.clone(),
85                                    event_handler.clone(),
86                                    pool.clone(),
87                                    error_strategy,
88                                )
89                                .await;
90                            }
91
92                            if reply_in.send(state.read().await.clone()).is_err() {
93                                error!(type_name, name, "cannot send state");
94                            }
95                        }
96
97                        Command::Stop => {
98                            // Do not remove braces, dead-lock is waiting for you!
99                            let running = { state.read().await.running };
100                            if running {
101                                info!(type_name, name, "stopping projection");
102                                let mut state = state.write().await;
103                                state.running = false;
104                            } else {
105                                info!(type_name, name, "projection already stopped");
106                            }
107
108                            if reply_in.send(state.read().await.clone()).is_err() {
109                                error!(type_name, name, "cannot send state");
110                            }
111                        }
112
113                        Command::GetState => {
114                            if reply_in.send(state.read().await.clone()).is_err() {
115                                error!(type_name, name, "cannot send state");
116                            }
117                        }
118                    }
119                }
120            }
121        });
122
123        Ok(Projection { name, command_in })
124    }
125
126    pub async fn run(&self) -> Result<State, CommandError> {
127        self.dispatch_command(Command::Run).await
128    }
129
130    pub async fn stop(&self) -> Result<State, CommandError> {
131        self.dispatch_command(Command::Stop).await
132    }
133
134    pub async fn get_state(&self) -> Result<State, CommandError> {
135        self.dispatch_command(Command::GetState).await
136    }
137
138    async fn dispatch_command(&self, command: Command) -> Result<State, CommandError> {
139        let (reply_in, reply_out) = oneshot::channel();
140        self.command_in
141            .send((command, reply_in))
142            .await
143            .map_err(|_| CommandError::SendCommand(command, self.name.clone()))?;
144        let state = reply_out
145            .await
146            .map_err(|_| CommandError::ReceiveResponse(command, self.name.clone()))?;
147        Ok(state)
148    }
149}
150
151#[trait_variant::make(Send)]
152pub trait EventHandler<E> {
153    type Error: StdError + Send + Sync + 'static;
154
155    async fn handle_event(
156        &self,
157        event: E,
158        tx: &mut Transaction<'static, Postgres>,
159    ) -> Result<(), Self::Error>;
160}
161
162#[derive(Debug, Error)]
163pub enum Error {
164    #[error("cannot create Projection, b/c cannot load state from database")]
165    Sqlx(#[from] sqlx::Error),
166
167    #[error("cannot create Projection, b/c cannot convert loaded seq_no into non zero value")]
168    TryFromInt(#[from] TryFromIntError),
169}
170
171#[derive(Debug, Error, Serialize, Deserialize)]
172pub enum CommandError {
173    /// The command cannot be sent from this [Projection] to its projection.
174    #[error("cannot send command {0:?} to projection {1}")]
175    SendCommand(Command, String),
176
177    /// A response for the command cannot be received from this [Projection]'s projection.
178    #[error("cannot receive reply for command {0:?} from projection {1}")]
179    ReceiveResponse(Command, String),
180}
181
182#[derive(Debug, Clone, Copy)]
183pub enum ErrorStrategy {
184    Retry(Duration),
185    Stop,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct State {
190    pub seq_no: Option<NonZeroU64>,
191    pub running: bool,
192    pub error: Option<String>,
193}
194
195#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
196pub enum Command {
197    Run,
198    Stop,
199    GetState,
200}
201
202#[derive(Debug, Error)]
203enum IntenalRunError<E, H> {
204    #[error(transparent)]
205    Events(E),
206
207    #[error(transparent)]
208    Handler(H),
209
210    #[error(transparent)]
211    Sqlx(#[from] sqlx::Error),
212
213    #[error(transparent)]
214    LoadStateError(#[from] Error),
215}
216
217async fn load_seq_no(name: &str, pool: &Pool<Postgres>) -> Result<Option<NonZeroU64>, Error> {
218    let seq_no = sqlx::query("SELECT seq_no FROM projection WHERE name=$1")
219        .bind(name)
220        .fetch_optional(pool)
221        .await?
222        .map(|row| row.try_get::<i64, _>(0))
223        .transpose()?
224        .map(|seq_no| (seq_no as u64).try_into())
225        .transpose()?;
226    Ok(seq_no)
227}
228
229async fn run_projection_loop<E, L, H>(
230    type_name: &'static str,
231    name: String,
232    state: Arc<RwLock<State>>,
233    event_log: L,
234    event_handler: H,
235    pool: Pool<Postgres>,
236    error_strategy: ErrorStrategy,
237) where
238    E: for<'de> Deserialize<'de> + Send + 'static,
239    L: EventLog + Sync,
240    H: EventHandler<E> + Sync + 'static,
241{
242    task::spawn({
243        async move {
244            loop {
245                let result =
246                    run_projection(type_name, &name, &event_log, &event_handler, &pool, &state)
247                        .await;
248                match result {
249                    Ok(_) => {
250                        info!(type_name, name, "projection stopped");
251                        {
252                            let mut state = state.write().await;
253                            state.running = false;
254                        }
255                        break;
256                    }
257
258                    Err(error) => {
259                        error!(
260                            error = error.as_chain(),
261                            type_name, name, "projection error"
262                        );
263
264                        match error_strategy {
265                            ErrorStrategy::Retry(delay) => {
266                                info!(type_name, name, ?delay, "projection retrying after error");
267                                {
268                                    let mut state = state.write().await;
269                                    state.error = Some(error.to_string());
270                                }
271                                sleep(delay).await
272                            }
273
274                            ErrorStrategy::Stop => {
275                                info!(type_name, name, "projection stopped after error");
276                                {
277                                    let mut state = state.write().await;
278                                    state.running = false;
279                                    state.error = Some(error.to_string());
280                                }
281                                break;
282                            }
283                        }
284                    }
285                }
286            }
287        }
288    });
289}
290
291async fn run_projection<E, L, H>(
292    type_name: &'static str,
293    name: &str,
294    event_log: &L,
295    handler: &H,
296    pool: &Pool<Postgres>,
297    state: &Arc<RwLock<State>>,
298) -> Result<(), IntenalRunError<L::Error, H::Error>>
299where
300    E: for<'de> Deserialize<'de> + Send + 'static,
301    L: EventLog,
302    H: EventHandler<E>,
303{
304    let seq_no = load_seq_no(name, pool)
305        .await?
306        .map(|n| n.saturating_add(1))
307        .unwrap_or(NonZeroU64::MIN);
308    let events = event_log
309        .events_by_type::<E, _, _>(type_name, seq_no, binarize::serde_json::from_bytes)
310        .await
311        .map_err(IntenalRunError::Events)?;
312    let mut events = pin!(events);
313
314    while let Some(event) = events.next().await {
315        if !state.read().await.running {
316            break;
317        };
318
319        let (seq_no, event) = event.map_err(IntenalRunError::Events)?;
320
321        let mut tx = pool.begin().await?;
322        handler
323            .handle_event(event, &mut tx)
324            .await
325            .map_err(IntenalRunError::Handler)?;
326        debug!(type_name, name, seq_no, "projection handled event");
327        save_seq_no(seq_no, name, &mut tx).await?;
328        tx.commit().await?;
329
330        state.write().await.seq_no = Some(seq_no);
331    }
332
333    Ok(())
334}
335
336async fn save_seq_no(
337    seq_no: NonZeroU64,
338    name: &str,
339    tx: &mut Transaction<'_, Postgres>,
340) -> Result<(), sqlx::Error> {
341    let query = r#"INSERT INTO projection (name, seq_no)
342                   VALUES ($1, $2)
343                   ON CONFLICT (name) DO UPDATE SET seq_no = $2"#;
344    sqlx::query(query)
345        .bind(name)
346        .bind(seq_no.get() as i64)
347        .execute(&mut **tx)
348        .await?;
349    Ok(())
350}
351
352#[cfg(test)]
353mod tests {
354    use crate::postgres::{ErrorStrategy, EventHandler, Projection};
355    use error_ext::BoxError;
356    use eventsourced::{
357        binarize::serde_json::to_bytes,
358        event_log::{test::TestEventLog, EventLog},
359    };
360    use sqlx::{
361        postgres::{PgConnectOptions, PgPoolOptions},
362        Postgres, QueryBuilder, Row, Transaction,
363    };
364    use std::{iter::once, time::Duration};
365    use testcontainers::{clients::Cli, RunnableImage};
366    use testcontainers_modules::postgres::Postgres as TCPostgres;
367    use tokio::time::sleep;
368
369    #[derive(Clone)]
370    struct TestHandler;
371
372    impl EventHandler<i32> for TestHandler {
373        type Error = sqlx::Error;
374
375        async fn handle_event(
376            &self,
377            event: i32,
378            tx: &mut Transaction<'static, Postgres>,
379        ) -> Result<(), Self::Error> {
380            QueryBuilder::new("INSERT INTO test (n) ")
381                .push_values(once(event), |mut q, event| {
382                    q.push_bind(event);
383                })
384                .build()
385                .execute(&mut **tx)
386                .await?;
387            Ok(())
388        }
389    }
390
391    #[tokio::test]
392    async fn test() -> Result<(), BoxError> {
393        let containers = Cli::default();
394
395        let container =
396            containers.run(RunnableImage::from(TCPostgres::default()).with_tag("16-alpine"));
397        let port = container.get_host_port_ipv4(5432);
398
399        let cnn_url = format!("postgresql://postgres:postgres@localhost:{port}");
400        let cnn_options = cnn_url.parse::<PgConnectOptions>()?;
401        let pool = PgPoolOptions::new().connect_with(cnn_options).await?;
402
403        let mut event_log = TestEventLog::<u64>::default();
404        for n in 1..=100 {
405            event_log.persist("test", &0, None, &n, &to_bytes).await?;
406        }
407
408        sqlx::query("CREATE TABLE test (n bigint);")
409            .execute(&pool)
410            .await?;
411
412        let projection = Projection::new(
413            "test",
414            "test-projection".to_string(),
415            event_log.clone(),
416            TestHandler,
417            ErrorStrategy::Stop,
418            pool.clone(),
419        )
420        .await?;
421
422        QueryBuilder::new("INSERT INTO projection ")
423            .push_values(once(("test-projection", 10)), |mut q, (name, seq_no)| {
424                q.push_bind(name).push_bind(seq_no);
425            })
426            .build()
427            .execute(&pool)
428            .await?;
429
430        projection.run().await?;
431
432        let mut state = projection.get_state().await?;
433        let max = Some(100.try_into()?);
434        while state.seq_no < max {
435            sleep(Duration::from_millis(100)).await;
436            state = projection.get_state().await?;
437        }
438        assert_eq!(state.seq_no, max);
439
440        let sum = sqlx::query("SELECT * FROM test;")
441            .fetch_all(&pool)
442            .await?
443            .into_iter()
444            .map(|row| row.try_get::<i64, _>(0))
445            .try_fold(0i64, |acc, n| n.map(|n| acc + n))?;
446        assert_eq!(sum, 4_995); // sum(1..100) - sum(1..10)
447
448        projection.stop().await?;
449        sleep(Duration::from_millis(100)).await;
450        let state = projection.get_state().await?;
451        sleep(Duration::from_millis(100)).await;
452        let state_2 = projection.get_state().await?;
453        assert_eq!(state.seq_no, state_2.seq_no);
454
455        Ok(())
456    }
457}