Skip to main content

mysql_es/
event_repository.rs

1use cqrs_es::persist::{
2    PersistedEventRepository, PersistenceError, ReplayFeed, ReplayStream, SerializedEvent,
3    SerializedSnapshot,
4};
5use cqrs_es::Aggregate;
6use futures::stream::BoxStream;
7use futures::TryStreamExt;
8use serde_json::Value;
9use sqlx::mysql::MySqlRow;
10use sqlx::{MySql, Pool, Row, Transaction};
11
12use crate::error::MysqlAggregateError;
13use crate::sql_query::SqlQueryFactory;
14
15const DEFAULT_EVENT_TABLE: &str = "events";
16const DEFAULT_SNAPSHOT_TABLE: &str = "snapshots";
17
18const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
19
20/// An event repository relying on a MySql database for persistence.
21pub struct MysqlEventRepository {
22    pool: Pool<MySql>,
23    query_factory: SqlQueryFactory,
24    stream_channel_size: usize,
25}
26
27impl PersistedEventRepository for MysqlEventRepository {
28    async fn get_events<A: Aggregate>(
29        &self,
30        aggregate_id: &str,
31    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
32        self.select_events::<A>(aggregate_id, self.query_factory.select_events())
33            .await
34    }
35
36    async fn get_last_events<A: Aggregate>(
37        &self,
38        aggregate_id: &str,
39        last_sequence: usize,
40    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
41        let query = self.query_factory.get_last_events(last_sequence);
42        self.select_events::<A>(aggregate_id, &query).await
43    }
44
45    async fn get_snapshot<A: Aggregate>(
46        &self,
47        aggregate_id: &str,
48    ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
49        let Some(row) = sqlx::query(self.query_factory.select_snapshot())
50            .bind(A::TYPE)
51            .bind(aggregate_id)
52            .fetch_optional(&self.pool)
53            .await
54            .map_err(MysqlAggregateError::from)?
55        else {
56            return Ok(None);
57        };
58        Ok(Some(self.deser_snapshot(&row)))
59    }
60
61    async fn persist<A: Aggregate>(
62        &self,
63        events: &[SerializedEvent],
64        snapshot_update: Option<(String, Value, usize)>,
65    ) -> Result<(), PersistenceError> {
66        match snapshot_update {
67            None => {
68                self.insert_events::<A>(events).await?;
69            }
70            Some((aggregate_id, aggregate, current_snapshot)) => {
71                if current_snapshot == 1 {
72                    self.insert::<A>(aggregate, aggregate_id, current_snapshot, events)
73                        .await?;
74                } else {
75                    self.update::<A>(aggregate, aggregate_id, current_snapshot, events)
76                        .await?;
77                }
78            }
79        }
80        Ok(())
81    }
82
83    async fn stream_events<A: Aggregate>(
84        &self,
85        aggregate_id: &str,
86    ) -> Result<ReplayStream, PersistenceError> {
87        Ok(stream_events(
88            self.query_factory.select_events().to_string(),
89            A::TYPE.to_string(),
90            aggregate_id.to_string(),
91            self.pool.clone(),
92            self.stream_channel_size,
93        ))
94    }
95
96    async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
97        Ok(stream_all_events(
98            self.query_factory.all_events().to_string(),
99            A::TYPE.to_string(),
100            self.pool.clone(),
101            self.stream_channel_size,
102        ))
103    }
104}
105
106fn stream_events(
107    query: String,
108    aggregate_type: String,
109    aggregate_id: String,
110    pool: Pool<MySql>,
111    channel_size: usize,
112) -> ReplayStream {
113    let (feed, stream) = ReplayStream::new(channel_size);
114    tokio::spawn(async move {
115        let query = sqlx::query(&query)
116            .bind(&aggregate_type)
117            .bind(&aggregate_id);
118        let rows = query.fetch(&pool);
119        process_rows(feed, rows).await;
120    });
121    stream
122}
123fn stream_all_events(
124    query: String,
125    aggregate_type: String,
126    pool: Pool<MySql>,
127    channel_size: usize,
128) -> ReplayStream {
129    let (feed, stream) = ReplayStream::new(channel_size);
130    tokio::spawn(async move {
131        let query = sqlx::query(&query).bind(&aggregate_type);
132        let rows = query.fetch(&pool);
133        process_rows(feed, rows).await;
134    });
135    stream
136}
137
138async fn process_rows(
139    mut feed: ReplayFeed,
140    mut rows: BoxStream<'_, Result<MySqlRow, sqlx::Error>>,
141) {
142    while let Some(row) = rows.try_next().await.unwrap() {
143        let event_result: Result<SerializedEvent, PersistenceError> =
144            MysqlEventRepository::deser_event(row).map_err(Into::into);
145        if feed.push(event_result).await.is_err() {
146            // TODO: in the unlikely event of a broken channel this error should be reported.
147            break;
148        }
149    }
150}
151
152impl MysqlEventRepository {
153    async fn select_events<A: Aggregate>(
154        &self,
155        aggregate_id: &str,
156        query: &str,
157    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
158        let mut rows = sqlx::query(query)
159            .bind(A::TYPE)
160            .bind(aggregate_id)
161            .fetch(&self.pool);
162        let mut result: Vec<SerializedEvent> = Default::default();
163        while let Some(row) = rows.try_next().await.map_err(MysqlAggregateError::from)? {
164            result.push(Self::deser_event(row)?);
165        }
166        Ok(result)
167    }
168}
169
170impl MysqlEventRepository {
171    /// Creates a new `MysqlEventRepository` from the provided database connection
172    /// used for backing a `MysqlSnapshotStore`. This uses the default tables 'events'
173    /// and 'snapshots'.
174    ///
175    /// ```
176    /// use sqlx::{MySql, Pool};
177    /// use mysql_es::MysqlEventRepository;
178    ///
179    /// fn configure_repo(pool: Pool<MySql>) -> MysqlEventRepository {
180    ///     MysqlEventRepository::new(pool)
181    /// }
182    /// ```
183    pub fn new(pool: Pool<MySql>) -> Self {
184        Self::use_tables(pool, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
185    }
186
187    /// Configures a `MysqlEventRepository` to use a streaming queue of the provided size.
188    ///
189    /// _Example: configure the repository to stream with a 1000 event buffer._
190    /// ```
191    /// use sqlx::{MySql, Pool};
192    /// use mysql_es::MysqlEventRepository;
193    ///
194    /// fn configure_repo(pool: Pool<MySql>) -> MysqlEventRepository {
195    ///     let store = MysqlEventRepository::new(pool);
196    ///     store.with_streaming_channel_size(1000)
197    /// }
198    /// ```
199    pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
200        Self {
201            pool: self.pool,
202            query_factory: self.query_factory,
203            stream_channel_size,
204        }
205    }
206    /// Configures a `MysqlEventRepository` to use the provided table names.
207    ///
208    /// _Example: configure the repository to use "my_event_table" and "my_snapshot_table"
209    /// for the event and snapshot table names._
210    /// ```
211    /// use sqlx::{MySql, Pool};
212    /// use mysql_es::MysqlEventRepository;
213    ///
214    /// fn configure_repo(pool: Pool<MySql>) -> MysqlEventRepository {
215    ///     let store = MysqlEventRepository::new(pool);
216    ///     store.with_tables("my_event_table", "my_snapshot_table")
217    /// }
218    /// ```
219    pub fn with_tables(self, events_table: &str, snapshots_table: &str) -> Self {
220        Self::use_tables(self.pool, events_table, snapshots_table)
221    }
222
223    fn use_tables(pool: Pool<MySql>, events_table: &str, snapshots_table: &str) -> Self {
224        Self {
225            pool,
226            query_factory: SqlQueryFactory::new(events_table, snapshots_table),
227            stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
228        }
229    }
230
231    pub(crate) async fn insert_events<A: Aggregate>(
232        &self,
233        events: &[SerializedEvent],
234    ) -> Result<(), MysqlAggregateError> {
235        let mut tx: Transaction<'_, MySql> = sqlx::Acquire::begin(&self.pool).await?;
236        self.persist_events::<A>(&mut tx, events).await?;
237        tx.commit().await?;
238        Ok(())
239    }
240
241    pub(crate) async fn insert<A: Aggregate>(
242        &self,
243        aggregate_payload: Value,
244        aggregate_id: String,
245        current_snapshot: usize,
246        events: &[SerializedEvent],
247    ) -> Result<(), MysqlAggregateError> {
248        let mut tx: Transaction<'_, MySql> = sqlx::Acquire::begin(&self.pool).await?;
249        let current_sequence = self.persist_events::<A>(&mut tx, events).await?;
250        sqlx::query(self.query_factory.insert_snapshot())
251            .bind(A::TYPE)
252            .bind(aggregate_id.as_str())
253            .bind(current_sequence as u32)
254            .bind(current_snapshot as u32)
255            .bind(&aggregate_payload)
256            .execute(&mut *tx)
257            .await?;
258        tx.commit().await?;
259        Ok(())
260    }
261
262    pub(crate) async fn update<A: Aggregate>(
263        &self,
264        aggregate: Value,
265        aggregate_id: String,
266        current_snapshot: usize,
267        events: &[SerializedEvent],
268    ) -> Result<(), MysqlAggregateError> {
269        let mut tx: Transaction<'_, MySql> = sqlx::Acquire::begin(&self.pool).await?;
270        let current_sequence = self.persist_events::<A>(&mut tx, events).await?;
271
272        let aggregate_payload = serde_json::to_value(&aggregate)?;
273        let result = sqlx::query(self.query_factory.update_snapshot())
274            .bind(current_sequence as u32)
275            .bind(&aggregate_payload)
276            .bind(current_snapshot as u32)
277            .bind(A::TYPE)
278            .bind(aggregate_id.as_str())
279            .bind((current_snapshot - 1) as u32)
280            .execute(&mut *tx)
281            .await?;
282        tx.commit().await?;
283        match result.rows_affected() {
284            1 => Ok(()),
285            _ => Err(MysqlAggregateError::OptimisticLock),
286        }
287    }
288
289    fn deser_event(row: MySqlRow) -> Result<SerializedEvent, MysqlAggregateError> {
290        let aggregate_type: String = row.get("aggregate_type");
291        let aggregate_id: String = row.get("aggregate_id");
292        let sequence = {
293            let s: i64 = row.get("sequence");
294            s as usize
295        };
296        let event_type: String = row.get("event_type");
297        let event_version: String = row.get("event_version");
298        let payload: Value = row.get("payload");
299        let metadata: Value = row.get("metadata");
300        Ok(SerializedEvent::new(
301            aggregate_id,
302            sequence,
303            aggregate_type,
304            event_type,
305            event_version,
306            payload,
307            metadata,
308        ))
309    }
310
311    fn deser_snapshot(&self, row: &MySqlRow) -> SerializedSnapshot {
312        let aggregate_id = row.get("aggregate_id");
313        let s: i64 = row.get("last_sequence");
314        let current_sequence = s as usize;
315        let s: i64 = row.get("current_snapshot");
316        let current_snapshot = s as usize;
317        let aggregate: Value = row.get("payload");
318        SerializedSnapshot {
319            aggregate_id,
320            aggregate,
321            current_sequence,
322            current_snapshot,
323        }
324    }
325
326    pub(crate) async fn persist_events<A: Aggregate>(
327        &self,
328        tx: &mut Transaction<'_, MySql>,
329        events: &[SerializedEvent],
330    ) -> Result<usize, MysqlAggregateError> {
331        let mut current_sequence: usize = 0;
332        for event in events {
333            current_sequence = event.sequence;
334            let event_type = &event.event_type;
335            let event_version = &event.event_version;
336            let payload = serde_json::to_value(&event.payload)?;
337            let metadata = serde_json::to_value(&event.metadata)?;
338            sqlx::query(self.query_factory.insert_event())
339                .bind(A::TYPE)
340                .bind(event.aggregate_id.as_str())
341                .bind(event.sequence as u32)
342                .bind(event_type)
343                .bind(event_version)
344                .bind(&payload)
345                .bind(&metadata)
346                .execute(&mut **tx)
347                .await?;
348        }
349        Ok(current_sequence)
350    }
351}
352
353#[cfg(test)]
354mod test {
355    use cqrs_es::persist::PersistedEventRepository;
356
357    use crate::error::MysqlAggregateError;
358    use crate::testing::tests::{
359        snapshot_context, test_event_envelope, Created, SomethingElse, TestAggregate, TestEvent,
360        Tested, TEST_CONNECTION_STRING,
361    };
362    use crate::{default_mysql_pool, MysqlEventRepository};
363
364    #[tokio::test]
365    async fn event_repositories() {
366        let pool = default_mysql_pool(TEST_CONNECTION_STRING).await;
367        let id = uuid::Uuid::new_v4().to_string();
368        let event_repo = MysqlEventRepository::new(pool.clone()).with_streaming_channel_size(1);
369        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
370        assert!(events.is_empty());
371
372        event_repo
373            .insert_events::<TestAggregate>(&[
374                test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
375                test_event_envelope(
376                    &id,
377                    2,
378                    TestEvent::Tested(Tested {
379                        test_name: "a test was run".to_string(),
380                    }),
381                ),
382            ])
383            .await
384            .unwrap();
385        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
386        assert_eq!(2, events.len());
387        events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
388
389        let result = event_repo
390            .insert_events::<TestAggregate>(&[
391                test_event_envelope(
392                    &id,
393                    3,
394                    TestEvent::SomethingElse(SomethingElse {
395                        description: "this should not persist".to_string(),
396                    }),
397                ),
398                test_event_envelope(
399                    &id,
400                    2,
401                    TestEvent::SomethingElse(SomethingElse {
402                        description: "bad sequence number".to_string(),
403                    }),
404                ),
405            ])
406            .await
407            .unwrap_err();
408        match result {
409            MysqlAggregateError::OptimisticLock => {}
410            _ => panic!("invalid error result found during insert: {result}"),
411        }
412
413        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
414        assert_eq!(2, events.len());
415
416        verify_replay_stream(&id, event_repo).await;
417    }
418
419    async fn verify_replay_stream(id: &str, event_repo: MysqlEventRepository) {
420        let mut stream = event_repo.stream_events::<TestAggregate>(id).await.unwrap();
421        let mut found_in_stream = 0;
422        while (stream.next::<TestAggregate>(&[]).await).is_some() {
423            found_in_stream += 1;
424        }
425        assert_eq!(found_in_stream, 2);
426
427        let mut stream = event_repo
428            .stream_all_events::<TestAggregate>()
429            .await
430            .unwrap();
431        let mut found_in_stream = 0;
432        while (stream.next::<TestAggregate>(&[]).await).is_some() {
433            found_in_stream += 1;
434        }
435        assert!(found_in_stream >= 2);
436    }
437
438    #[tokio::test]
439    async fn snapshot_repositories() {
440        let pool = default_mysql_pool(TEST_CONNECTION_STRING).await;
441        let id = uuid::Uuid::new_v4().to_string();
442        let repo = MysqlEventRepository::new(pool.clone());
443        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
444        assert_eq!(None, snapshot);
445
446        let test_description = "some test snapshot here".to_string();
447        let test_tests = vec!["testA".to_string(), "testB".to_string()];
448        repo.insert::<TestAggregate>(
449            serde_json::to_value(TestAggregate {
450                id: id.clone(),
451                description: test_description.clone(),
452                tests: test_tests.clone(),
453            })
454            .unwrap(),
455            id.clone(),
456            1,
457            &[],
458        )
459        .await
460        .unwrap();
461
462        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
463        assert_eq!(
464            Some(snapshot_context(
465                id.clone(),
466                0,
467                1,
468                serde_json::to_value(TestAggregate {
469                    id: id.clone(),
470                    description: test_description.clone(),
471                    tests: test_tests.clone(),
472                })
473                .unwrap()
474            )),
475            snapshot
476        );
477
478        // sequence iterated, does update
479        repo.update::<TestAggregate>(
480            serde_json::to_value(TestAggregate {
481                id: id.clone(),
482                description: "a test description that should be saved".to_string(),
483                tests: test_tests.clone(),
484            })
485            .unwrap(),
486            id.clone(),
487            2,
488            &[],
489        )
490        .await
491        .unwrap();
492
493        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
494        assert_eq!(
495            Some(snapshot_context(
496                id.clone(),
497                0,
498                2,
499                serde_json::to_value(TestAggregate {
500                    id: id.clone(),
501                    description: "a test description that should be saved".to_string(),
502                    tests: test_tests.clone(),
503                })
504                .unwrap()
505            )),
506            snapshot
507        );
508
509        // sequence out of order or not iterated, does not update
510        let result = repo
511            .update::<TestAggregate>(
512                serde_json::to_value(TestAggregate {
513                    id: id.clone(),
514                    description: "a test description that should not be saved".to_string(),
515                    tests: test_tests.clone(),
516                })
517                .unwrap(),
518                id.clone(),
519                2,
520                &[],
521            )
522            .await
523            .unwrap_err();
524        assert!(
525            matches!(result, MysqlAggregateError::OptimisticLock),
526            "invalid error result found during insert: {result}"
527        );
528
529        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
530        assert_eq!(
531            Some(snapshot_context(
532                id.clone(),
533                0,
534                2,
535                serde_json::to_value(TestAggregate {
536                    id: id.clone(),
537                    description: "a test description that should be saved".to_string(),
538                    tests: test_tests.clone(),
539                })
540                .unwrap()
541            )),
542            snapshot
543        );
544    }
545}