postgres_es/
event_repository.rs

1use async_trait::async_trait;
2use cqrs_es::persist::{
3    PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
4};
5use cqrs_es::Aggregate;
6use futures::TryStreamExt;
7use serde_json::Value;
8use sqlx::postgres::PgRow;
9use sqlx::{Pool, Postgres, Row, Transaction};
10
11use crate::error::PostgresAggregateError;
12use crate::sql_query::SqlQueryFactory;
13
14const DEFAULT_EVENT_TABLE: &str = "events";
15const DEFAULT_SNAPSHOT_TABLE: &str = "snapshots";
16
17const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
18
19/// An event repository relying on a Postgres database for persistence.
20pub struct PostgresEventRepository {
21    pool: Pool<Postgres>,
22    query_factory: SqlQueryFactory,
23    stream_channel_size: usize,
24}
25
26#[async_trait]
27impl PersistedEventRepository for PostgresEventRepository {
28    async fn get_events<A: Aggregate>(
29        &self,
30        aggregate_id: &str,
31    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
32        self.select_events::<A>(aggregate_id, self.query_factory.select_events())
33            .await
34    }
35
36    async fn get_last_events<A: Aggregate>(
37        &self,
38        aggregate_id: &str,
39        last_sequence: usize,
40    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
41        let query = self.query_factory.get_last_events(last_sequence);
42        self.select_events::<A>(aggregate_id, &query).await
43    }
44
45    async fn get_snapshot<A: Aggregate>(
46        &self,
47        aggregate_id: &str,
48    ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
49        let row: PgRow = match sqlx::query(self.query_factory.select_snapshot())
50            .bind(A::aggregate_type())
51            .bind(aggregate_id)
52            .fetch_optional(&self.pool)
53            .await
54            .map_err(PostgresAggregateError::from)?
55        {
56            Some(row) => row,
57            None => {
58                return Ok(None);
59            }
60        };
61        Ok(Some(self.deser_snapshot(row)?))
62    }
63
64    async fn persist<A: Aggregate>(
65        &self,
66        events: &[SerializedEvent],
67        snapshot_update: Option<(String, Value, usize)>,
68    ) -> Result<(), PersistenceError> {
69        match snapshot_update {
70            None => {
71                self.insert_events::<A>(events).await?;
72            }
73            Some((aggregate_id, aggregate, current_snapshot)) => {
74                if current_snapshot == 1 {
75                    self.insert::<A>(aggregate, aggregate_id, current_snapshot, events)
76                        .await?;
77                } else {
78                    self.update::<A>(aggregate, aggregate_id, current_snapshot, events)
79                        .await?;
80                }
81            }
82        };
83        Ok(())
84    }
85
86    async fn stream_events<A: Aggregate>(
87        &self,
88        aggregate_id: &str,
89    ) -> Result<ReplayStream, PersistenceError> {
90        Ok(stream_events(
91            self.query_factory.select_events().to_string(),
92            A::aggregate_type(),
93            aggregate_id.to_string(),
94            self.pool.clone(),
95            self.stream_channel_size,
96        ))
97    }
98
99    // TODO: aggregate id is unused here, `stream_events` function needs to be broken up
100    async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
101        Ok(stream_events(
102            self.query_factory.all_events().to_string(),
103            A::aggregate_type(),
104            "".to_string(),
105            self.pool.clone(),
106            self.stream_channel_size,
107        ))
108    }
109}
110
111fn stream_events(
112    query: String,
113    aggregate_type: String,
114    aggregate_id: String,
115    pool: Pool<Postgres>,
116    channel_size: usize,
117) -> ReplayStream {
118    let (mut feed, stream) = ReplayStream::new(channel_size);
119    tokio::spawn(async move {
120        let query = sqlx::query(&query)
121            .bind(&aggregate_type)
122            .bind(&aggregate_id);
123        let mut rows = query.fetch(&pool);
124        while let Some(row) = rows.try_next().await.unwrap() {
125            let event_result: Result<SerializedEvent, PersistenceError> =
126                PostgresEventRepository::deser_event(row).map_err(Into::into);
127            if feed.push(event_result).await.is_err() {
128                // TODO: in the unlikely event of a broken channel this error should be reported.
129                return;
130            };
131        }
132    });
133    stream
134}
135
136impl PostgresEventRepository {
137    async fn select_events<A: Aggregate>(
138        &self,
139        aggregate_id: &str,
140        query: &str,
141    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
142        let mut rows = sqlx::query(query)
143            .bind(A::aggregate_type())
144            .bind(aggregate_id)
145            .fetch(&self.pool);
146        let mut result: Vec<SerializedEvent> = Default::default();
147        while let Some(row) = rows
148            .try_next()
149            .await
150            .map_err(PostgresAggregateError::from)?
151        {
152            result.push(PostgresEventRepository::deser_event(row)?);
153        }
154        Ok(result)
155    }
156}
157
158impl PostgresEventRepository {
159    /// Creates a new `PostgresEventRepository` from the provided database connection.
160    /// This uses the default tables 'events' and 'snapshots'.
161    ///
162    /// ```
163    /// use sqlx::{Pool, Postgres};
164    /// use postgres_es::PostgresEventRepository;
165    ///
166    /// fn configure_repo(pool: Pool<Postgres>) -> PostgresEventRepository {
167    ///     PostgresEventRepository::new(pool)
168    /// }
169    /// ```
170    pub fn new(pool: Pool<Postgres>) -> Self {
171        Self::use_tables(pool, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
172    }
173
174    /// Configures a `PostgresEventRepository` to use a streaming queue of the provided size.
175    ///
176    /// _Example: configure the repository to stream with a 1000 event buffer._
177    /// ```
178    /// use sqlx::{Pool, Postgres};
179    /// use postgres_es::PostgresEventRepository;
180    ///
181    /// fn configure_repo(pool: Pool<Postgres>) -> PostgresEventRepository {
182    ///     let store = PostgresEventRepository::new(pool);
183    ///     store.with_streaming_channel_size(1000)
184    /// }
185    /// ```
186    pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
187        Self {
188            pool: self.pool,
189            query_factory: self.query_factory,
190            stream_channel_size,
191        }
192    }
193
194    /// Configures a `PostgresEventRepository` to use the provided table names.
195    ///
196    /// _Example: configure the repository to use "my_event_table" and "my_snapshot_table"
197    /// for the event and snapshot table names._
198    /// ```
199    /// use sqlx::{Pool, Postgres};
200    /// use postgres_es::PostgresEventRepository;
201    ///
202    /// fn configure_repo(pool: Pool<Postgres>) -> PostgresEventRepository {
203    ///     let store = PostgresEventRepository::new(pool);
204    ///     store.with_tables("my_event_table", "my_snapshot_table")
205    /// }
206    /// ```
207    pub fn with_tables(self, events_table: &str, snapshots_table: &str) -> Self {
208        Self::use_tables(self.pool, events_table, snapshots_table)
209    }
210
211    fn use_tables(pool: Pool<Postgres>, events_table: &str, snapshots_table: &str) -> Self {
212        Self {
213            pool,
214            query_factory: SqlQueryFactory::new(events_table, snapshots_table),
215            stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
216        }
217    }
218
219    pub(crate) async fn insert_events<A: Aggregate>(
220        &self,
221        events: &[SerializedEvent],
222    ) -> Result<(), PostgresAggregateError> {
223        let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
224        self.persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
225            .await?;
226        tx.commit().await?;
227        Ok(())
228    }
229
230    pub(crate) async fn insert<A: Aggregate>(
231        &self,
232        aggregate_payload: Value,
233        aggregate_id: String,
234        current_snapshot: usize,
235        events: &[SerializedEvent],
236    ) -> Result<(), PostgresAggregateError> {
237        let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
238        let current_sequence = self
239            .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
240            .await?;
241        sqlx::query(self.query_factory.insert_snapshot())
242            .bind(A::aggregate_type())
243            .bind(aggregate_id.as_str())
244            .bind(current_sequence as i32)
245            .bind(current_snapshot as i32)
246            .bind(&aggregate_payload)
247            .execute(&mut *tx)
248            .await?;
249        tx.commit().await?;
250        Ok(())
251    }
252
253    pub(crate) async fn update<A: Aggregate>(
254        &self,
255        aggregate: Value,
256        aggregate_id: String,
257        current_snapshot: usize,
258        events: &[SerializedEvent],
259    ) -> Result<(), PostgresAggregateError> {
260        let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
261        let current_sequence = self
262            .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
263            .await?;
264
265        let aggregate_payload = serde_json::to_value(&aggregate)?;
266        let result = sqlx::query(self.query_factory.update_snapshot())
267            .bind(A::aggregate_type())
268            .bind(aggregate_id.as_str())
269            .bind(current_sequence as i32)
270            .bind(current_snapshot as i32)
271            .bind((current_snapshot - 1) as i32)
272            .bind(&aggregate_payload)
273            .execute(&mut *tx)
274            .await?;
275        tx.commit().await?;
276        match result.rows_affected() {
277            1 => Ok(()),
278            _ => Err(PostgresAggregateError::OptimisticLock),
279        }
280    }
281
282    fn deser_event(row: PgRow) -> Result<SerializedEvent, PostgresAggregateError> {
283        let aggregate_type: String = row.get("aggregate_type");
284        let aggregate_id: String = row.get("aggregate_id");
285        let sequence = {
286            let s: i64 = row.get("sequence");
287            s as usize
288        };
289        let event_type: String = row.get("event_type");
290        let event_version: String = row.get("event_version");
291        let payload: Value = row.get("payload");
292        let metadata: Value = row.get("metadata");
293        Ok(SerializedEvent::new(
294            aggregate_id,
295            sequence,
296            aggregate_type,
297            event_type,
298            event_version,
299            payload,
300            metadata,
301        ))
302    }
303
304    fn deser_snapshot(&self, row: PgRow) -> Result<SerializedSnapshot, PostgresAggregateError> {
305        let aggregate_id = row.get("aggregate_id");
306        let s: i64 = row.get("last_sequence");
307        let current_sequence = s as usize;
308        let s: i64 = row.get("current_snapshot");
309        let current_snapshot = s as usize;
310        let aggregate: Value = row.get("payload");
311        Ok(SerializedSnapshot {
312            aggregate_id,
313            aggregate,
314            current_sequence,
315            current_snapshot,
316        })
317    }
318
319    pub(crate) async fn persist_events<A: Aggregate>(
320        &self,
321        inser_event_query: &str,
322        tx: &mut Transaction<'_, Postgres>,
323        events: &[SerializedEvent],
324    ) -> Result<usize, PostgresAggregateError> {
325        let mut current_sequence: usize = 0;
326        for event in events {
327            current_sequence = event.sequence;
328            let event_type = &event.event_type;
329            let event_version = &event.event_version;
330            let payload = serde_json::to_value(&event.payload)?;
331            let metadata = serde_json::to_value(&event.metadata)?;
332            sqlx::query(inser_event_query)
333                .bind(A::aggregate_type())
334                .bind(event.aggregate_id.as_str())
335                .bind(event.sequence as i32)
336                .bind(event_type)
337                .bind(event_version)
338                .bind(&payload)
339                .bind(&metadata)
340                .execute(&mut **tx)
341                .await?;
342        }
343        Ok(current_sequence)
344    }
345}
346
347#[cfg(test)]
348mod test {
349    use cqrs_es::persist::PersistedEventRepository;
350
351    use crate::error::PostgresAggregateError;
352    use crate::testing::tests::{
353        snapshot_context, test_event_envelope, Created, SomethingElse, TestAggregate, TestEvent,
354        Tested, TEST_CONNECTION_STRING,
355    };
356    use crate::{default_postgress_pool, PostgresEventRepository};
357
358    #[tokio::test]
359    async fn event_repositories() {
360        let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
361        let id = uuid::Uuid::new_v4().to_string();
362        let event_repo: PostgresEventRepository =
363            PostgresEventRepository::new(pool.clone()).with_streaming_channel_size(1);
364        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
365        assert!(events.is_empty());
366
367        event_repo
368            .insert_events::<TestAggregate>(&[
369                test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
370                test_event_envelope(
371                    &id,
372                    2,
373                    TestEvent::Tested(Tested {
374                        test_name: "a test was run".to_string(),
375                    }),
376                ),
377            ])
378            .await
379            .unwrap();
380        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
381        assert_eq!(2, events.len());
382        events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
383
384        // Optimistic lock error
385        let result = event_repo
386            .insert_events::<TestAggregate>(&[
387                test_event_envelope(
388                    &id,
389                    3,
390                    TestEvent::SomethingElse(SomethingElse {
391                        description: "this should not persist".to_string(),
392                    }),
393                ),
394                test_event_envelope(
395                    &id,
396                    2,
397                    TestEvent::SomethingElse(SomethingElse {
398                        description: "bad sequence number".to_string(),
399                    }),
400                ),
401            ])
402            .await
403            .unwrap_err();
404        match result {
405            PostgresAggregateError::OptimisticLock => {}
406            _ => panic!("invalid error result found during insert: {}", result),
407        };
408
409        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
410        assert_eq!(2, events.len());
411
412        verify_replay_stream(&id, event_repo).await;
413    }
414
415    async fn verify_replay_stream(id: &str, event_repo: PostgresEventRepository) {
416        let mut stream = event_repo
417            .stream_events::<TestAggregate>(&id)
418            .await
419            .unwrap();
420        let mut found_in_stream = 0;
421        while let Some(_) = stream.next::<TestAggregate>(&None).await {
422            found_in_stream += 1;
423        }
424        assert_eq!(found_in_stream, 2);
425
426        let mut stream = event_repo
427            .stream_all_events::<TestAggregate>()
428            .await
429            .unwrap();
430        let mut found_in_stream = 0;
431        while let Some(_) = stream.next::<TestAggregate>(&None).await {
432            found_in_stream += 1;
433        }
434        assert!(found_in_stream >= 2);
435    }
436
437    #[tokio::test]
438    async fn snapshot_repositories() {
439        let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
440        let id = uuid::Uuid::new_v4().to_string();
441        let event_repo: PostgresEventRepository = PostgresEventRepository::new(pool.clone());
442        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
443        assert_eq!(None, snapshot);
444
445        let test_description = "some test snapshot here".to_string();
446        let test_tests = vec!["testA".to_string(), "testB".to_string()];
447        event_repo
448            .insert::<TestAggregate>(
449                serde_json::to_value(TestAggregate {
450                    id: id.clone(),
451                    description: test_description.clone(),
452                    tests: test_tests.clone(),
453                })
454                .unwrap(),
455                id.clone(),
456                1,
457                &vec![],
458            )
459            .await
460            .unwrap();
461
462        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
463        assert_eq!(
464            Some(snapshot_context(
465                id.clone(),
466                0,
467                1,
468                serde_json::to_value(TestAggregate {
469                    id: id.clone(),
470                    description: test_description.clone(),
471                    tests: test_tests.clone(),
472                })
473                .unwrap()
474            )),
475            snapshot
476        );
477
478        // sequence iterated, does update
479        event_repo
480            .update::<TestAggregate>(
481                serde_json::to_value(TestAggregate {
482                    id: id.clone(),
483                    description: "a test description that should be saved".to_string(),
484                    tests: test_tests.clone(),
485                })
486                .unwrap(),
487                id.clone(),
488                2,
489                &vec![],
490            )
491            .await
492            .unwrap();
493
494        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
495        assert_eq!(
496            Some(snapshot_context(
497                id.clone(),
498                0,
499                2,
500                serde_json::to_value(TestAggregate {
501                    id: id.clone(),
502                    description: "a test description that should be saved".to_string(),
503                    tests: test_tests.clone(),
504                })
505                .unwrap()
506            )),
507            snapshot
508        );
509
510        // sequence out of order or not iterated, does not update
511        let result = event_repo
512            .update::<TestAggregate>(
513                serde_json::to_value(TestAggregate {
514                    id: id.clone(),
515                    description: "a test description that should not be saved".to_string(),
516                    tests: test_tests.clone(),
517                })
518                .unwrap(),
519                id.clone(),
520                2,
521                &vec![],
522            )
523            .await
524            .unwrap_err();
525        match result {
526            PostgresAggregateError::OptimisticLock => {}
527            _ => panic!("invalid error result found during insert: {}", result),
528        };
529
530        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
531        assert_eq!(
532            Some(snapshot_context(
533                id.clone(),
534                0,
535                2,
536                serde_json::to_value(TestAggregate {
537                    id: id.clone(),
538                    description: "a test description that should be saved".to_string(),
539                    tests: test_tests.clone(),
540                })
541                .unwrap()
542            )),
543            snapshot
544        );
545    }
546}