eventsourced_postgres/
event_log.rs

1//! An [EventLog] implementation based on [PostgreSQL](https://www.postgresql.org/).
2
3use crate::{Cnn, CnnPool, Error};
4use async_stream::stream;
5use bb8_postgres::{bb8::Pool, PostgresConnectionManager};
6use bytes::Bytes;
7use eventsourced::event_log::EventLog;
8use futures::{Stream, StreamExt, TryStreamExt};
9use serde::{Deserialize, Serialize};
10use std::{
11    error::Error as StdError,
12    fmt::{self, Debug, Formatter},
13    marker::PhantomData,
14    num::{NonZeroU64, NonZeroUsize},
15    time::Duration,
16};
17use tokio::time::sleep;
18use tokio_postgres::{types::ToSql, NoTls};
19use tracing::{debug, instrument};
20
21/// An [EventLog] implementation based on [PostgreSQL](https://www.postgresql.org/).
22#[derive(Clone)]
23pub struct PostgresEventLog<I> {
24    poll_interval: Duration,
25    cnn_pool: CnnPool<NoTls>,
26    _id: PhantomData<I>,
27}
28
29impl<I> PostgresEventLog<I>
30where
31    I: ToSql + Sync,
32{
33    #[allow(missing_docs)]
34    pub async fn new(config: Config) -> Result<Self, Error> {
35        debug!(?config, "creating PostgresEventLog");
36
37        // Create connection pool.
38        let tls = NoTls;
39        let cnn_manager = PostgresConnectionManager::new_from_stringlike(config.cnn_config(), tls)
40            .map_err(|error| {
41                Error::Postgres("cannot create connection manager".to_string(), error)
42            })?;
43        let cnn_pool = Pool::builder()
44            .build(cnn_manager)
45            .await
46            .map_err(|error| Error::Postgres("cannot create connection pool".to_string(), error))?;
47
48        // Setup tables.
49        if config.setup {
50            cnn_pool
51                .get()
52                .await
53                .map_err(Error::GetConnection)?
54                .batch_execute(
55                    &include_str!("create_event_log.sql").replace("events", &config.events_table),
56                )
57                .await
58                .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?;
59        }
60
61        Ok(Self {
62            poll_interval: config.poll_interval,
63            cnn_pool,
64            _id: PhantomData,
65        })
66    }
67
68    async fn cnn(&self) -> Result<Cnn<NoTls>, Error> {
69        self.cnn_pool.get().await.map_err(Error::GetConnection)
70    }
71
72    async fn next_events_by_id<E, FromBytes, FromBytesError>(
73        &self,
74        id: &I,
75        seq_no: i64,
76        from_bytes: FromBytes,
77    ) -> Result<impl Stream<Item = Result<(NonZeroU64, E), Error>> + Send, Error>
78    where
79        E: Send,
80        FromBytes: Fn(Bytes) -> Result<E, FromBytesError> + Send,
81        FromBytesError: StdError + Send + Sync + 'static,
82    {
83        debug!(?id, ?seq_no, "querying events");
84        let params: [&(dyn ToSql + Sync); 2] = [&id, &seq_no];
85        let events = self
86            .cnn()
87            .await?
88            .query_raw(
89                "SELECT seq_no, event FROM events WHERE id = $1 AND seq_no >= $2",
90                params,
91            )
92            .await
93            .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?
94            .map_err(|error| Error::Postgres("cannot get next row".to_string(), error))
95            .map(move |row| {
96                row.and_then(|row| {
97                    let seq_no = (row.get::<_, i64>(0) as u64)
98                        .try_into()
99                        .map_err(|_| Error::ZeroNonZeroU64)?;
100                    let bytes = row.get::<_, &[u8]>(1);
101                    let bytes = Bytes::copy_from_slice(bytes);
102                    from_bytes(bytes)
103                        .map_err(|source| Error::FromBytes(Box::new(source)))
104                        .map(|event| (seq_no, event))
105                })
106            });
107
108        Ok(events)
109    }
110
111    async fn next_events_by_type<E, FromBytes, FromBytesError>(
112        &self,
113        type_name: &str,
114        seq_no: i64,
115        from_bytes: FromBytes,
116    ) -> Result<impl Stream<Item = Result<(NonZeroU64, E), Error>> + Send, Error>
117    where
118        E: Send,
119        FromBytes: Fn(Bytes) -> Result<E, FromBytesError> + Send,
120        FromBytesError: StdError + Send + Sync + 'static,
121    {
122        debug!(%type_name, seq_no, "querying events");
123
124        let params: [&(dyn ToSql + Sync); 2] = [&type_name, &seq_no];
125        let events = self
126            .cnn()
127            .await?
128            .query_raw(
129                "SELECT seq_no, event FROM events WHERE type = $1 AND seq_no >= $2",
130                params,
131            )
132            .await
133            .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?
134            .map_err(|error| Error::Postgres("cannot get next row".to_string(), error))
135            .map(move |row| {
136                row.and_then(|row| {
137                    let seq_no = (row.get::<_, i64>(0) as u64)
138                        .try_into()
139                        .map_err(|_| Error::ZeroNonZeroU64)?;
140                    let bytes = row.get::<_, &[u8]>(1);
141                    let bytes = Bytes::copy_from_slice(bytes);
142                    from_bytes(bytes)
143                        .map_err(|source| Error::FromBytes(Box::new(source)))
144                        .map(|event| (seq_no, event))
145                })
146            });
147
148        Ok(events)
149    }
150
151    async fn last_seq_no_by_type(&self, type_name: &str) -> Result<Option<NonZeroU64>, Error> {
152        self.cnn()
153            .await?
154            .query_one(
155                "SELECT MAX(seq_no) FROM events WHERE type = $1",
156                &[&type_name],
157            )
158            .await
159            .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))
160            .and_then(|row| {
161                // If there is no seq_no there is one row with a NULL column, hence use `try_get`.
162                row.try_get::<_, i64>(0)
163                    .ok()
164                    .map(|seq_no| {
165                        (seq_no as u64)
166                            .try_into()
167                            .map_err(|_| Error::ZeroNonZeroU64)
168                    })
169                    .transpose()
170            })
171    }
172}
173
174impl<I> Debug for PostgresEventLog<I> {
175    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
176        f.debug_struct("PostgresEventLog").finish()
177    }
178}
179
180impl<I> EventLog for PostgresEventLog<I>
181where
182    I: Clone + ToSql + Send + Sync + 'static,
183{
184    type Id = I;
185
186    type Error = Error;
187
188    /// The maximum value for sequence numbers. As PostgreSQL does not support unsigned integers,
189    /// this is `i64::MAX` or `9_223_372_036_854_775_807`.
190    const MAX_SEQ_NO: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(i64::MAX as u64) };
191
192    #[instrument(skip(self, event, to_bytes))]
193    async fn persist<E, ToBytes, ToBytesError>(
194        &mut self,
195        type_name: &'static str,
196        id: &Self::Id,
197        last_seq_no: Option<NonZeroU64>,
198        event: &E,
199        to_bytes: &ToBytes,
200    ) -> Result<NonZeroU64, Self::Error>
201    where
202        ToBytes: Fn(&E) -> Result<Bytes, ToBytesError> + Sync,
203        ToBytesError: StdError + Send + Sync + 'static,
204    {
205        let seq_no = last_seq_no.map(|n| n.get() as i64).unwrap_or_default() + 1;
206
207        let bytes = to_bytes(event).map_err(|error| Error::ToBytes(Box::new(error)))?;
208
209        self.cnn()
210            .await?
211            .query_one(
212                "INSERT INTO events (seq_no, type, id, event) VALUES ($1, $2, $3, $4) RETURNING seq_no",
213                &[&seq_no, &type_name, &id, &bytes.as_ref()],
214            )
215            .await
216            .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))
217            .and_then(|row| {
218                (row.get::<_, i64>(0) as u64)
219                    .try_into()
220                    .map_err(|_| Error::ZeroNonZeroU64)
221            })
222    }
223
224    #[instrument(skip(self))]
225    async fn last_seq_no(
226        &self,
227        type_name: &'static str,
228        id: &Self::Id,
229    ) -> Result<Option<NonZeroU64>, Self::Error> {
230        self.cnn()
231            .await?
232            .query_one("SELECT MAX(seq_no) FROM events WHERE id = $1", &[&id])
233            .await
234            .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))
235            .and_then(|row| {
236                // If there is no seq_no there is one row with a NULL column, hence use `try_get`.
237                row.try_get::<_, i64>(0)
238                    .ok()
239                    .map(|seq_no| {
240                        (seq_no as u64)
241                            .try_into()
242                            .map_err(|_| Error::ZeroNonZeroU64)
243                    })
244                    .transpose()
245            })
246    }
247
248    #[instrument(skip(self, from_bytes))]
249    async fn events_by_id<E, FromBytes, FromBytesError>(
250        &self,
251        type_name: &'static str,
252        id: &Self::Id,
253        seq_no: NonZeroU64,
254        from_bytes: FromBytes,
255    ) -> Result<impl Stream<Item = Result<(NonZeroU64, E), Self::Error>> + Send, Self::Error>
256    where
257        E: Send,
258        FromBytes: Fn(Bytes) -> Result<E, FromBytesError> + Copy + Send + Sync + 'static,
259        FromBytesError: StdError + Send + Sync + 'static,
260    {
261        let last_seq_no = self
262            .last_seq_no(type_name, id)
263            .await?
264            .map(|n| n.get() as i64)
265            .unwrap_or_default();
266
267        let mut current_seq_no = seq_no.get() as i64;
268        let events = stream! {
269            'outer: loop {
270                let events = self
271                    .next_events_by_id(id, current_seq_no, from_bytes)
272                    .await?;
273
274                for await event in events {
275                    match event {
276                        Ok(event @ (seq_no, _)) => {
277                            current_seq_no += seq_no.get() as i64 + 1;
278                            yield Ok(event);
279                        }
280
281                        Err(error) => {
282                            yield Err(error);
283                            break 'outer;
284                        }
285                    }
286                }
287
288                // Only sleep if requesting future events.
289                if current_seq_no >= last_seq_no {
290                    sleep(self.poll_interval).await;
291                }
292            }
293        };
294
295        Ok(events)
296    }
297
298    #[instrument(skip(self, from_bytes))]
299    async fn events_by_type<E, FromBytes, FromBytesError>(
300        &self,
301        type_name: &'static str,
302        seq_no: NonZeroU64,
303        from_bytes: FromBytes,
304    ) -> Result<impl Stream<Item = Result<(NonZeroU64, E), Self::Error>> + Send, Self::Error>
305    where
306        E: Send,
307        FromBytes: Fn(Bytes) -> Result<E, FromBytesError> + Copy + Send + Sync + 'static,
308        FromBytesError: StdError + Send + Sync + 'static,
309    {
310        debug!(type_name, seq_no, "building events by type stream");
311
312        let last_seq_no = self
313            .last_seq_no_by_type(type_name)
314            .await?
315            .map(|n| n.get() as i64)
316            .unwrap_or_default();
317
318        let mut current_seq_no = seq_no.get() as i64;
319        let events = stream! {
320            'outer: loop {
321                let events = self
322                    .next_events_by_type(type_name, current_seq_no, from_bytes)
323                    .await?;
324
325                for await event in events {
326                    match event {
327                        Ok(event @ (seq_no, _)) => {
328                            current_seq_no = seq_no.get() as i64 + 1;
329                            yield Ok(event);
330                        }
331
332                        Err(error) => {
333                            yield Err(error);
334                            break 'outer;
335                        }
336                    }
337                }
338
339                // Only sleep if requesting future events.
340                if current_seq_no >= last_seq_no {
341                    sleep(self.poll_interval).await;
342                }
343            }
344        };
345
346        Ok(events)
347    }
348}
349
350/// Configuration for the [PostgresEventLog].
351#[derive(Debug, Clone, Serialize, Deserialize)]
352#[serde(rename_all = "kebab-case")]
353pub struct Config {
354    pub host: String,
355
356    pub port: u16,
357
358    pub user: String,
359
360    pub password: String,
361
362    pub dbname: String,
363
364    pub sslmode: String,
365
366    #[serde(default = "events_table_default")]
367    pub events_table: String,
368
369    #[serde(default = "poll_interval_default", with = "humantime_serde")]
370    pub poll_interval: Duration,
371
372    #[serde(default = "id_broadcast_capacity_default")]
373    pub id_broadcast_capacity: NonZeroUsize,
374
375    #[serde(default)]
376    pub setup: bool,
377}
378
379impl Config {
380    fn cnn_config(&self) -> String {
381        format!(
382            "host={} port={} user={} password={} dbname={} sslmode={}",
383            self.host, self.port, self.user, self.password, self.dbname, self.sslmode
384        )
385    }
386}
387
388impl Default for Config {
389    /// Default values suitable for local testing only.
390    fn default() -> Self {
391        Self {
392            host: "localhost".to_string(),
393            port: 5432,
394            user: "postgres".to_string(),
395            password: "".to_string(),
396            dbname: "postgres".to_string(),
397            sslmode: "prefer".to_string(),
398            events_table: events_table_default(),
399            poll_interval: poll_interval_default(),
400            id_broadcast_capacity: id_broadcast_capacity_default(),
401            setup: false,
402        }
403    }
404}
405
406fn events_table_default() -> String {
407    "events".to_string()
408}
409
410const fn poll_interval_default() -> Duration {
411    Duration::from_secs(2)
412}
413
414const fn id_broadcast_capacity_default() -> NonZeroUsize {
415    NonZeroUsize::MIN
416}
417
418#[cfg(test)]
419mod tests {
420    use crate::{PostgresEventLog, PostgresEventLogConfig};
421    use error_ext::BoxError;
422    use eventsourced::{binarize, event_log::EventLog};
423    use futures::{StreamExt, TryStreamExt};
424    use std::{future, num::NonZeroU64};
425    use testcontainers::clients::Cli;
426    use testcontainers_modules::postgres::Postgres;
427    use uuid::Uuid;
428
429    #[tokio::test]
430    async fn test_event_log() -> Result<(), BoxError> {
431        let client = Cli::default();
432        let container = client.run(Postgres::default().with_host_auth());
433        let port = container.get_host_port_ipv4(5432);
434
435        let config = PostgresEventLogConfig {
436            port,
437            setup: true,
438            ..Default::default()
439        };
440        let mut event_log = PostgresEventLog::<Uuid>::new(config).await?;
441
442        let id = Uuid::now_v7();
443
444        // Start testing.
445
446        let last_seq_no = event_log.last_seq_no("counter", &id).await?;
447        assert_eq!(last_seq_no, None);
448
449        let last_seq_no = event_log
450            .persist("counter", &id, None, &1, &binarize::serde_json::to_bytes)
451            .await?;
452        assert!(last_seq_no.get() == 1);
453
454        event_log
455            .persist(
456                "counter",
457                &id,
458                Some(last_seq_no),
459                &2,
460                &binarize::serde_json::to_bytes,
461            )
462            .await?;
463
464        let result = event_log
465            .persist(
466                "counter",
467                &id,
468                Some(last_seq_no),
469                &3,
470                &binarize::serde_json::to_bytes,
471            )
472            .await;
473        assert!(result.is_err());
474
475        event_log
476            .persist(
477                "counter",
478                &id,
479                Some(last_seq_no.checked_add(1).expect("overflow")),
480                &3,
481                &binarize::serde_json::to_bytes,
482            )
483            .await?;
484
485        let last_seq_no = event_log.last_seq_no("counter", &id).await?;
486        assert_eq!(last_seq_no, Some(3.try_into()?));
487
488        let events = event_log
489            .events_by_id::<u32, _, _>(
490                "counter",
491                &id,
492                2.try_into()?,
493                binarize::serde_json::from_bytes,
494            )
495            .await?;
496        let sum = events
497            .take(2)
498            .try_fold(0u32, |acc, (_, n)| future::ready(Ok(acc + n)))
499            .await?;
500        assert_eq!(sum, 5);
501
502        let events = event_log
503            .events_by_type::<u32, _, _>(
504                "counter",
505                NonZeroU64::MIN,
506                binarize::serde_json::from_bytes,
507            )
508            .await?;
509
510        let last_seq_no = event_log
511            .clone()
512            .persist(
513                "counter",
514                &id,
515                last_seq_no,
516                &4,
517                &binarize::serde_json::to_bytes,
518            )
519            .await?;
520        event_log
521            .clone()
522            .persist(
523                "counter",
524                &id,
525                Some(last_seq_no),
526                &5,
527                &binarize::serde_json::to_bytes,
528            )
529            .await?;
530        let last_seq_no = event_log.last_seq_no("counter", &id).await?;
531        assert_eq!(last_seq_no, Some(5.try_into()?));
532
533        let sum = events
534            .take(5)
535            .try_fold(0u32, |acc, (_, n)| future::ready(Ok(acc + n)))
536            .await?;
537        assert_eq!(sum, 15);
538
539        Ok(())
540    }
541}