Skip to main content

postgres_es/
event_repository.rs

1use cqrs_es::persist::{
2    PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
3};
4use cqrs_es::Aggregate;
5use futures::TryStreamExt;
6use serde_json::Value;
7use sqlx::postgres::PgRow;
8use sqlx::{Pool, Postgres, Row, Transaction};
9
10use crate::error::PostgresAggregateError;
11use crate::sql_query::SqlQueryFactory;
12
13const DEFAULT_EVENT_TABLE: &str = "events";
14const DEFAULT_SNAPSHOT_TABLE: &str = "snapshots";
15
16const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
17
18/// An event repository relying on a Postgres database for persistence.
19pub struct PostgresEventRepository {
20    pool: Pool<Postgres>,
21    query_factory: SqlQueryFactory,
22    stream_channel_size: usize,
23}
24
25impl PersistedEventRepository for PostgresEventRepository {
26    async fn get_events<A: Aggregate>(
27        &self,
28        aggregate_id: &str,
29    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
30        self.select_events::<A>(aggregate_id, self.query_factory.select_events())
31            .await
32    }
33
34    async fn get_last_events<A: Aggregate>(
35        &self,
36        aggregate_id: &str,
37        last_sequence: usize,
38    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
39        let query = self.query_factory.get_last_events(last_sequence);
40        self.select_events::<A>(aggregate_id, &query).await
41    }
42
43    async fn get_snapshot<A: Aggregate>(
44        &self,
45        aggregate_id: &str,
46    ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
47        let Some(row) = sqlx::query(self.query_factory.select_snapshot())
48            .bind(A::TYPE)
49            .bind(aggregate_id)
50            .fetch_optional(&self.pool)
51            .await
52            .map_err(PostgresAggregateError::from)?
53        else {
54            return Ok(None);
55        };
56        Ok(Some(Self::deser_snapshot(&row)))
57    }
58
59    async fn persist<A: Aggregate>(
60        &self,
61        events: &[SerializedEvent],
62        snapshot_update: Option<(String, Value, usize)>,
63    ) -> Result<(), PersistenceError> {
64        match snapshot_update {
65            None => {
66                self.insert_events::<A>(events).await?;
67            }
68            Some((aggregate_id, aggregate, current_snapshot)) => {
69                if current_snapshot == 1 {
70                    self.insert::<A>(aggregate, aggregate_id, current_snapshot, events)
71                        .await?;
72                } else {
73                    self.update::<A>(aggregate, aggregate_id, current_snapshot, events)
74                        .await?;
75                }
76            }
77        }
78        Ok(())
79    }
80
81    async fn stream_events<A: Aggregate>(
82        &self,
83        aggregate_id: &str,
84    ) -> Result<ReplayStream, PersistenceError> {
85        Ok(stream_events(
86            self.query_factory.select_events().to_string(),
87            A::TYPE.to_string(),
88            aggregate_id.to_string(),
89            self.pool.clone(),
90            self.stream_channel_size,
91        ))
92    }
93
94    // TODO: aggregate id is unused here, `stream_events` function needs to be broken up
95    async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
96        Ok(stream_events(
97            self.query_factory.all_events().to_string(),
98            A::TYPE.to_string(),
99            String::new(),
100            self.pool.clone(),
101            self.stream_channel_size,
102        ))
103    }
104}
105
106fn stream_events(
107    query: String,
108    aggregate_type: String,
109    aggregate_id: String,
110    pool: Pool<Postgres>,
111    channel_size: usize,
112) -> ReplayStream {
113    let (mut feed, stream) = ReplayStream::new(channel_size);
114    tokio::spawn(async move {
115        let query = sqlx::query(&query)
116            .bind(&aggregate_type)
117            .bind(&aggregate_id);
118        let mut rows = query.fetch(&pool);
119        while let Some(row) = rows.try_next().await.unwrap() {
120            let event = PostgresEventRepository::deser_event(&row);
121            if feed.push(Ok(event)).await.is_err() {
122                // TODO: in the unlikely event of a broken channel this error should be reported.
123                return;
124            }
125        }
126    });
127    stream
128}
129
130impl PostgresEventRepository {
131    async fn select_events<A: Aggregate>(
132        &self,
133        aggregate_id: &str,
134        query: &str,
135    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
136        let mut rows = sqlx::query(query)
137            .bind(A::TYPE)
138            .bind(aggregate_id)
139            .fetch(&self.pool);
140        let mut result = Vec::default();
141        while let Some(row) = rows
142            .try_next()
143            .await
144            .map_err(PostgresAggregateError::from)?
145        {
146            result.push(Self::deser_event(&row));
147        }
148        Ok(result)
149    }
150}
151
152impl PostgresEventRepository {
153    /// Creates a new `PostgresEventRepository` from the provided database connection.
154    /// This uses the default tables 'events' and 'snapshots'.
155    ///
156    /// ```
157    /// use sqlx::{Pool, Postgres};
158    /// use postgres_es::PostgresEventRepository;
159    ///
160    /// fn configure_repo(pool: Pool<Postgres>) -> PostgresEventRepository {
161    ///     PostgresEventRepository::new(pool)
162    /// }
163    /// ```
164    pub fn new(pool: Pool<Postgres>) -> Self {
165        Self::use_tables(pool, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
166    }
167
168    /// Configures a `PostgresEventRepository` to use a streaming queue of the provided size.
169    ///
170    /// _Example: configure the repository to stream with a 1000 event buffer._
171    /// ```
172    /// use sqlx::{Pool, Postgres};
173    /// use postgres_es::PostgresEventRepository;
174    ///
175    /// fn configure_repo(pool: Pool<Postgres>) -> PostgresEventRepository {
176    ///     let store = PostgresEventRepository::new(pool);
177    ///     store.with_streaming_channel_size(1000)
178    /// }
179    /// ```
180    pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
181        Self {
182            pool: self.pool,
183            query_factory: self.query_factory,
184            stream_channel_size,
185        }
186    }
187
188    /// Configures a `PostgresEventRepository` to use the provided table names.
189    ///
190    /// _Example: configure the repository to use "my_event_table" and "my_snapshot_table"
191    /// for the event and snapshot table names._
192    /// ```
193    /// use sqlx::{Pool, Postgres};
194    /// use postgres_es::PostgresEventRepository;
195    ///
196    /// fn configure_repo(pool: Pool<Postgres>) -> PostgresEventRepository {
197    ///     let store = PostgresEventRepository::new(pool);
198    ///     store.with_tables("my_event_table", "my_snapshot_table")
199    /// }
200    /// ```
201    pub fn with_tables(self, events_table: &str, snapshots_table: &str) -> Self {
202        Self::use_tables(self.pool, events_table, snapshots_table)
203    }
204
205    fn use_tables(pool: Pool<Postgres>, events_table: &str, snapshots_table: &str) -> Self {
206        Self {
207            pool,
208            query_factory: SqlQueryFactory::new(events_table, snapshots_table),
209            stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
210        }
211    }
212
213    pub(crate) async fn insert_events<A: Aggregate>(
214        &self,
215        events: &[SerializedEvent],
216    ) -> Result<(), PostgresAggregateError> {
217        let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
218        self.persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
219            .await?;
220        tx.commit().await?;
221        Ok(())
222    }
223
224    pub(crate) async fn insert<A: Aggregate>(
225        &self,
226        aggregate_payload: Value,
227        aggregate_id: String,
228        current_snapshot: usize,
229        events: &[SerializedEvent],
230    ) -> Result<(), PostgresAggregateError> {
231        let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
232        let current_sequence = self
233            .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
234            .await?;
235        sqlx::query(self.query_factory.insert_snapshot())
236            .bind(A::TYPE)
237            .bind(aggregate_id.as_str())
238            .bind(current_sequence as i32)
239            .bind(current_snapshot as i32)
240            .bind(&aggregate_payload)
241            .execute(&mut *tx)
242            .await?;
243        tx.commit().await?;
244        Ok(())
245    }
246
247    pub(crate) async fn update<A: Aggregate>(
248        &self,
249        aggregate: Value,
250        aggregate_id: String,
251        current_snapshot: usize,
252        events: &[SerializedEvent],
253    ) -> Result<(), PostgresAggregateError> {
254        let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
255        let current_sequence = self
256            .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
257            .await?;
258
259        let aggregate_payload = serde_json::to_value(&aggregate)?;
260        let result = sqlx::query(self.query_factory.update_snapshot())
261            .bind(A::TYPE)
262            .bind(aggregate_id.as_str())
263            .bind(current_sequence as i32)
264            .bind(current_snapshot as i32)
265            .bind((current_snapshot - 1) as i32)
266            .bind(&aggregate_payload)
267            .execute(&mut *tx)
268            .await?;
269        tx.commit().await?;
270        match result.rows_affected() {
271            1 => Ok(()),
272            _ => Err(PostgresAggregateError::OptimisticLock),
273        }
274    }
275
276    fn deser_event(row: &PgRow) -> SerializedEvent {
277        let aggregate_type: String = row.get("aggregate_type");
278        let aggregate_id: String = row.get("aggregate_id");
279        let sequence = {
280            let s: i64 = row.get("sequence");
281            s as usize
282        };
283        let event_type: String = row.get("event_type");
284        let event_version: String = row.get("event_version");
285        let payload: Value = row.get("payload");
286        let metadata: Value = row.get("metadata");
287        SerializedEvent::new(
288            aggregate_id,
289            sequence,
290            aggregate_type,
291            event_type,
292            event_version,
293            payload,
294            metadata,
295        )
296    }
297
298    fn deser_snapshot(row: &PgRow) -> SerializedSnapshot {
299        let aggregate_id = row.get("aggregate_id");
300        let s: i64 = row.get("last_sequence");
301        let current_sequence = s as usize;
302        let s: i64 = row.get("current_snapshot");
303        let current_snapshot = s as usize;
304        let aggregate: Value = row.get("payload");
305        SerializedSnapshot {
306            aggregate_id,
307            aggregate,
308            current_sequence,
309            current_snapshot,
310        }
311    }
312
313    pub(crate) async fn persist_events<A: Aggregate>(
314        &self,
315        inser_event_query: &str,
316        tx: &mut Transaction<'_, Postgres>,
317        events: &[SerializedEvent],
318    ) -> Result<usize, PostgresAggregateError> {
319        let mut current_sequence: usize = 0;
320        for event in events {
321            current_sequence = event.sequence;
322            let event_type = &event.event_type;
323            let event_version = &event.event_version;
324            let payload = serde_json::to_value(&event.payload)?;
325            let metadata = serde_json::to_value(&event.metadata)?;
326            sqlx::query(inser_event_query)
327                .bind(A::TYPE)
328                .bind(event.aggregate_id.as_str())
329                .bind(event.sequence as i32)
330                .bind(event_type)
331                .bind(event_version)
332                .bind(&payload)
333                .bind(&metadata)
334                .execute(&mut **tx)
335                .await?;
336        }
337        Ok(current_sequence)
338    }
339}
340
341#[cfg(test)]
342mod test {
343    use cqrs_es::persist::PersistedEventRepository;
344
345    use crate::error::PostgresAggregateError;
346    use crate::testing::tests::{
347        snapshot_context, test_event_envelope, Created, SomethingElse, TestAggregate, TestEvent,
348        Tested, TEST_CONNECTION_STRING,
349    };
350    use crate::{default_postgress_pool, PostgresEventRepository};
351
352    #[tokio::test]
353    async fn event_repositories() {
354        let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
355        let id = uuid::Uuid::new_v4().to_string();
356        let event_repo: PostgresEventRepository =
357            PostgresEventRepository::new(pool.clone()).with_streaming_channel_size(1);
358        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
359        assert!(events.is_empty());
360
361        event_repo
362            .insert_events::<TestAggregate>(&[
363                test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
364                test_event_envelope(
365                    &id,
366                    2,
367                    TestEvent::Tested(Tested {
368                        test_name: "a test was run".to_string(),
369                    }),
370                ),
371            ])
372            .await
373            .unwrap();
374        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
375        assert_eq!(2, events.len());
376        events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
377
378        // Optimistic lock error
379        let result = event_repo
380            .insert_events::<TestAggregate>(&[
381                test_event_envelope(
382                    &id,
383                    3,
384                    TestEvent::SomethingElse(SomethingElse {
385                        description: "this should not persist".to_string(),
386                    }),
387                ),
388                test_event_envelope(
389                    &id,
390                    2,
391                    TestEvent::SomethingElse(SomethingElse {
392                        description: "bad sequence number".to_string(),
393                    }),
394                ),
395            ])
396            .await
397            .unwrap_err();
398        assert!(
399            matches!(result, PostgresAggregateError::OptimisticLock),
400            "invalid error result found during insert: {result}"
401        );
402
403        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
404        assert_eq!(2, events.len());
405
406        verify_replay_stream(&id, event_repo).await;
407    }
408
409    async fn verify_replay_stream(id: &str, event_repo: PostgresEventRepository) {
410        let mut stream = event_repo.stream_events::<TestAggregate>(id).await.unwrap();
411        let mut found_in_stream = 0;
412        while (stream.next::<TestAggregate>(&[]).await).is_some() {
413            found_in_stream += 1;
414        }
415        assert_eq!(found_in_stream, 2);
416
417        let mut stream = event_repo
418            .stream_all_events::<TestAggregate>()
419            .await
420            .unwrap();
421        let mut found_in_stream = 0;
422        while (stream.next::<TestAggregate>(&[]).await).is_some() {
423            found_in_stream += 1;
424        }
425        assert!(found_in_stream >= 2);
426    }
427
428    #[tokio::test]
429    async fn snapshot_repositories() {
430        let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
431        let id = uuid::Uuid::new_v4().to_string();
432        let event_repo: PostgresEventRepository = PostgresEventRepository::new(pool.clone());
433        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
434        assert_eq!(None, snapshot);
435
436        let test_description = "some test snapshot here".to_string();
437        let test_tests = vec!["testA".to_string(), "testB".to_string()];
438        event_repo
439            .insert::<TestAggregate>(
440                serde_json::to_value(TestAggregate {
441                    id: id.clone(),
442                    description: test_description.clone(),
443                    tests: test_tests.clone(),
444                })
445                .unwrap(),
446                id.clone(),
447                1,
448                &[],
449            )
450            .await
451            .unwrap();
452
453        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
454        assert_eq!(
455            Some(snapshot_context(
456                id.clone(),
457                0,
458                1,
459                serde_json::to_value(TestAggregate {
460                    id: id.clone(),
461                    description: test_description.clone(),
462                    tests: test_tests.clone(),
463                })
464                .unwrap()
465            )),
466            snapshot
467        );
468
469        // sequence iterated, does update
470        event_repo
471            .update::<TestAggregate>(
472                serde_json::to_value(TestAggregate {
473                    id: id.clone(),
474                    description: "a test description that should be saved".to_string(),
475                    tests: test_tests.clone(),
476                })
477                .unwrap(),
478                id.clone(),
479                2,
480                &[],
481            )
482            .await
483            .unwrap();
484
485        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
486        assert_eq!(
487            Some(snapshot_context(
488                id.clone(),
489                0,
490                2,
491                serde_json::to_value(TestAggregate {
492                    id: id.clone(),
493                    description: "a test description that should be saved".to_string(),
494                    tests: test_tests.clone(),
495                })
496                .unwrap()
497            )),
498            snapshot
499        );
500
501        // sequence out of order or not iterated, does not update
502        let result = event_repo
503            .update::<TestAggregate>(
504                serde_json::to_value(TestAggregate {
505                    id: id.clone(),
506                    description: "a test description that should not be saved".to_string(),
507                    tests: test_tests.clone(),
508                })
509                .unwrap(),
510                id.clone(),
511                2,
512                &[],
513            )
514            .await
515            .unwrap_err();
516        assert!(
517            matches!(result, PostgresAggregateError::OptimisticLock),
518            "invalid error result found during insert: {result}"
519        );
520
521        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
522        assert_eq!(
523            Some(snapshot_context(
524                id.clone(),
525                0,
526                2,
527                serde_json::to_value(TestAggregate {
528                    id: id.clone(),
529                    description: "a test description that should be saved".to_string(),
530                    tests: test_tests.clone(),
531                })
532                .unwrap()
533            )),
534            snapshot
535        );
536    }
537}