1use cqrs_es::persist::{
2 PersistedEventRepository, PersistenceError, ReplayFeed, ReplayStream, SerializedEvent,
3 SerializedSnapshot,
4};
5use cqrs_es::Aggregate;
6use futures::stream::BoxStream;
7use futures::TryStreamExt;
8use serde_json::Value;
9use sqlx::mysql::MySqlRow;
10use sqlx::{MySql, Pool, Row, Transaction};
11
12use crate::error::MysqlAggregateError;
13use crate::sql_query::SqlQueryFactory;
14
15const DEFAULT_EVENT_TABLE: &str = "events";
16const DEFAULT_SNAPSHOT_TABLE: &str = "snapshots";
17
18const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
19
20pub struct MysqlEventRepository {
22 pool: Pool<MySql>,
23 query_factory: SqlQueryFactory,
24 stream_channel_size: usize,
25}
26
27impl PersistedEventRepository for MysqlEventRepository {
28 async fn get_events<A: Aggregate>(
29 &self,
30 aggregate_id: &str,
31 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
32 self.select_events::<A>(aggregate_id, self.query_factory.select_events())
33 .await
34 }
35
36 async fn get_last_events<A: Aggregate>(
37 &self,
38 aggregate_id: &str,
39 last_sequence: usize,
40 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
41 let query = self.query_factory.get_last_events(last_sequence);
42 self.select_events::<A>(aggregate_id, &query).await
43 }
44
45 async fn get_snapshot<A: Aggregate>(
46 &self,
47 aggregate_id: &str,
48 ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
49 let Some(row) = sqlx::query(self.query_factory.select_snapshot())
50 .bind(A::TYPE)
51 .bind(aggregate_id)
52 .fetch_optional(&self.pool)
53 .await
54 .map_err(MysqlAggregateError::from)?
55 else {
56 return Ok(None);
57 };
58 Ok(Some(self.deser_snapshot(&row)))
59 }
60
61 async fn persist<A: Aggregate>(
62 &self,
63 events: &[SerializedEvent],
64 snapshot_update: Option<(String, Value, usize)>,
65 ) -> Result<(), PersistenceError> {
66 match snapshot_update {
67 None => {
68 self.insert_events::<A>(events).await?;
69 }
70 Some((aggregate_id, aggregate, current_snapshot)) => {
71 if current_snapshot == 1 {
72 self.insert::<A>(aggregate, aggregate_id, current_snapshot, events)
73 .await?;
74 } else {
75 self.update::<A>(aggregate, aggregate_id, current_snapshot, events)
76 .await?;
77 }
78 }
79 }
80 Ok(())
81 }
82
83 async fn stream_events<A: Aggregate>(
84 &self,
85 aggregate_id: &str,
86 ) -> Result<ReplayStream, PersistenceError> {
87 Ok(stream_events(
88 self.query_factory.select_events().to_string(),
89 A::TYPE.to_string(),
90 aggregate_id.to_string(),
91 self.pool.clone(),
92 self.stream_channel_size,
93 ))
94 }
95
96 async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
97 Ok(stream_all_events(
98 self.query_factory.all_events().to_string(),
99 A::TYPE.to_string(),
100 self.pool.clone(),
101 self.stream_channel_size,
102 ))
103 }
104}
105
106fn stream_events(
107 query: String,
108 aggregate_type: String,
109 aggregate_id: String,
110 pool: Pool<MySql>,
111 channel_size: usize,
112) -> ReplayStream {
113 let (feed, stream) = ReplayStream::new(channel_size);
114 tokio::spawn(async move {
115 let query = sqlx::query(&query)
116 .bind(&aggregate_type)
117 .bind(&aggregate_id);
118 let rows = query.fetch(&pool);
119 process_rows(feed, rows).await;
120 });
121 stream
122}
123fn stream_all_events(
124 query: String,
125 aggregate_type: String,
126 pool: Pool<MySql>,
127 channel_size: usize,
128) -> ReplayStream {
129 let (feed, stream) = ReplayStream::new(channel_size);
130 tokio::spawn(async move {
131 let query = sqlx::query(&query).bind(&aggregate_type);
132 let rows = query.fetch(&pool);
133 process_rows(feed, rows).await;
134 });
135 stream
136}
137
138async fn process_rows(
139 mut feed: ReplayFeed,
140 mut rows: BoxStream<'_, Result<MySqlRow, sqlx::Error>>,
141) {
142 while let Some(row) = rows.try_next().await.unwrap() {
143 let event_result: Result<SerializedEvent, PersistenceError> =
144 MysqlEventRepository::deser_event(row).map_err(Into::into);
145 if feed.push(event_result).await.is_err() {
146 break;
148 }
149 }
150}
151
152impl MysqlEventRepository {
153 async fn select_events<A: Aggregate>(
154 &self,
155 aggregate_id: &str,
156 query: &str,
157 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
158 let mut rows = sqlx::query(query)
159 .bind(A::TYPE)
160 .bind(aggregate_id)
161 .fetch(&self.pool);
162 let mut result: Vec<SerializedEvent> = Default::default();
163 while let Some(row) = rows.try_next().await.map_err(MysqlAggregateError::from)? {
164 result.push(Self::deser_event(row)?);
165 }
166 Ok(result)
167 }
168}
169
170impl MysqlEventRepository {
171 pub fn new(pool: Pool<MySql>) -> Self {
184 Self::use_tables(pool, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
185 }
186
187 pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
200 Self {
201 pool: self.pool,
202 query_factory: self.query_factory,
203 stream_channel_size,
204 }
205 }
206 pub fn with_tables(self, events_table: &str, snapshots_table: &str) -> Self {
220 Self::use_tables(self.pool, events_table, snapshots_table)
221 }
222
223 fn use_tables(pool: Pool<MySql>, events_table: &str, snapshots_table: &str) -> Self {
224 Self {
225 pool,
226 query_factory: SqlQueryFactory::new(events_table, snapshots_table),
227 stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
228 }
229 }
230
231 pub(crate) async fn insert_events<A: Aggregate>(
232 &self,
233 events: &[SerializedEvent],
234 ) -> Result<(), MysqlAggregateError> {
235 let mut tx: Transaction<'_, MySql> = sqlx::Acquire::begin(&self.pool).await?;
236 self.persist_events::<A>(&mut tx, events).await?;
237 tx.commit().await?;
238 Ok(())
239 }
240
241 pub(crate) async fn insert<A: Aggregate>(
242 &self,
243 aggregate_payload: Value,
244 aggregate_id: String,
245 current_snapshot: usize,
246 events: &[SerializedEvent],
247 ) -> Result<(), MysqlAggregateError> {
248 let mut tx: Transaction<'_, MySql> = sqlx::Acquire::begin(&self.pool).await?;
249 let current_sequence = self.persist_events::<A>(&mut tx, events).await?;
250 sqlx::query(self.query_factory.insert_snapshot())
251 .bind(A::TYPE)
252 .bind(aggregate_id.as_str())
253 .bind(current_sequence as u32)
254 .bind(current_snapshot as u32)
255 .bind(&aggregate_payload)
256 .execute(&mut *tx)
257 .await?;
258 tx.commit().await?;
259 Ok(())
260 }
261
262 pub(crate) async fn update<A: Aggregate>(
263 &self,
264 aggregate: Value,
265 aggregate_id: String,
266 current_snapshot: usize,
267 events: &[SerializedEvent],
268 ) -> Result<(), MysqlAggregateError> {
269 let mut tx: Transaction<'_, MySql> = sqlx::Acquire::begin(&self.pool).await?;
270 let current_sequence = self.persist_events::<A>(&mut tx, events).await?;
271
272 let aggregate_payload = serde_json::to_value(&aggregate)?;
273 let result = sqlx::query(self.query_factory.update_snapshot())
274 .bind(current_sequence as u32)
275 .bind(&aggregate_payload)
276 .bind(current_snapshot as u32)
277 .bind(A::TYPE)
278 .bind(aggregate_id.as_str())
279 .bind((current_snapshot - 1) as u32)
280 .execute(&mut *tx)
281 .await?;
282 tx.commit().await?;
283 match result.rows_affected() {
284 1 => Ok(()),
285 _ => Err(MysqlAggregateError::OptimisticLock),
286 }
287 }
288
289 fn deser_event(row: MySqlRow) -> Result<SerializedEvent, MysqlAggregateError> {
290 let aggregate_type: String = row.get("aggregate_type");
291 let aggregate_id: String = row.get("aggregate_id");
292 let sequence = {
293 let s: i64 = row.get("sequence");
294 s as usize
295 };
296 let event_type: String = row.get("event_type");
297 let event_version: String = row.get("event_version");
298 let payload: Value = row.get("payload");
299 let metadata: Value = row.get("metadata");
300 Ok(SerializedEvent::new(
301 aggregate_id,
302 sequence,
303 aggregate_type,
304 event_type,
305 event_version,
306 payload,
307 metadata,
308 ))
309 }
310
311 fn deser_snapshot(&self, row: &MySqlRow) -> SerializedSnapshot {
312 let aggregate_id = row.get("aggregate_id");
313 let s: i64 = row.get("last_sequence");
314 let current_sequence = s as usize;
315 let s: i64 = row.get("current_snapshot");
316 let current_snapshot = s as usize;
317 let aggregate: Value = row.get("payload");
318 SerializedSnapshot {
319 aggregate_id,
320 aggregate,
321 current_sequence,
322 current_snapshot,
323 }
324 }
325
326 pub(crate) async fn persist_events<A: Aggregate>(
327 &self,
328 tx: &mut Transaction<'_, MySql>,
329 events: &[SerializedEvent],
330 ) -> Result<usize, MysqlAggregateError> {
331 let mut current_sequence: usize = 0;
332 for event in events {
333 current_sequence = event.sequence;
334 let event_type = &event.event_type;
335 let event_version = &event.event_version;
336 let payload = serde_json::to_value(&event.payload)?;
337 let metadata = serde_json::to_value(&event.metadata)?;
338 sqlx::query(self.query_factory.insert_event())
339 .bind(A::TYPE)
340 .bind(event.aggregate_id.as_str())
341 .bind(event.sequence as u32)
342 .bind(event_type)
343 .bind(event_version)
344 .bind(&payload)
345 .bind(&metadata)
346 .execute(&mut **tx)
347 .await?;
348 }
349 Ok(current_sequence)
350 }
351}
352
353#[cfg(test)]
354mod test {
355 use cqrs_es::persist::PersistedEventRepository;
356
357 use crate::error::MysqlAggregateError;
358 use crate::testing::tests::{
359 snapshot_context, test_event_envelope, Created, SomethingElse, TestAggregate, TestEvent,
360 Tested, TEST_CONNECTION_STRING,
361 };
362 use crate::{default_mysql_pool, MysqlEventRepository};
363
364 #[tokio::test]
365 async fn event_repositories() {
366 let pool = default_mysql_pool(TEST_CONNECTION_STRING).await;
367 let id = uuid::Uuid::new_v4().to_string();
368 let event_repo = MysqlEventRepository::new(pool.clone()).with_streaming_channel_size(1);
369 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
370 assert!(events.is_empty());
371
372 event_repo
373 .insert_events::<TestAggregate>(&[
374 test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
375 test_event_envelope(
376 &id,
377 2,
378 TestEvent::Tested(Tested {
379 test_name: "a test was run".to_string(),
380 }),
381 ),
382 ])
383 .await
384 .unwrap();
385 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
386 assert_eq!(2, events.len());
387 events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
388
389 let result = event_repo
390 .insert_events::<TestAggregate>(&[
391 test_event_envelope(
392 &id,
393 3,
394 TestEvent::SomethingElse(SomethingElse {
395 description: "this should not persist".to_string(),
396 }),
397 ),
398 test_event_envelope(
399 &id,
400 2,
401 TestEvent::SomethingElse(SomethingElse {
402 description: "bad sequence number".to_string(),
403 }),
404 ),
405 ])
406 .await
407 .unwrap_err();
408 match result {
409 MysqlAggregateError::OptimisticLock => {}
410 _ => panic!("invalid error result found during insert: {result}"),
411 }
412
413 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
414 assert_eq!(2, events.len());
415
416 verify_replay_stream(&id, event_repo).await;
417 }
418
419 async fn verify_replay_stream(id: &str, event_repo: MysqlEventRepository) {
420 let mut stream = event_repo.stream_events::<TestAggregate>(id).await.unwrap();
421 let mut found_in_stream = 0;
422 while (stream.next::<TestAggregate>(&[]).await).is_some() {
423 found_in_stream += 1;
424 }
425 assert_eq!(found_in_stream, 2);
426
427 let mut stream = event_repo
428 .stream_all_events::<TestAggregate>()
429 .await
430 .unwrap();
431 let mut found_in_stream = 0;
432 while (stream.next::<TestAggregate>(&[]).await).is_some() {
433 found_in_stream += 1;
434 }
435 assert!(found_in_stream >= 2);
436 }
437
438 #[tokio::test]
439 async fn snapshot_repositories() {
440 let pool = default_mysql_pool(TEST_CONNECTION_STRING).await;
441 let id = uuid::Uuid::new_v4().to_string();
442 let repo = MysqlEventRepository::new(pool.clone());
443 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
444 assert_eq!(None, snapshot);
445
446 let test_description = "some test snapshot here".to_string();
447 let test_tests = vec!["testA".to_string(), "testB".to_string()];
448 repo.insert::<TestAggregate>(
449 serde_json::to_value(TestAggregate {
450 id: id.clone(),
451 description: test_description.clone(),
452 tests: test_tests.clone(),
453 })
454 .unwrap(),
455 id.clone(),
456 1,
457 &[],
458 )
459 .await
460 .unwrap();
461
462 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
463 assert_eq!(
464 Some(snapshot_context(
465 id.clone(),
466 0,
467 1,
468 serde_json::to_value(TestAggregate {
469 id: id.clone(),
470 description: test_description.clone(),
471 tests: test_tests.clone(),
472 })
473 .unwrap()
474 )),
475 snapshot
476 );
477
478 repo.update::<TestAggregate>(
480 serde_json::to_value(TestAggregate {
481 id: id.clone(),
482 description: "a test description that should be saved".to_string(),
483 tests: test_tests.clone(),
484 })
485 .unwrap(),
486 id.clone(),
487 2,
488 &[],
489 )
490 .await
491 .unwrap();
492
493 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
494 assert_eq!(
495 Some(snapshot_context(
496 id.clone(),
497 0,
498 2,
499 serde_json::to_value(TestAggregate {
500 id: id.clone(),
501 description: "a test description that should be saved".to_string(),
502 tests: test_tests.clone(),
503 })
504 .unwrap()
505 )),
506 snapshot
507 );
508
509 let result = repo
511 .update::<TestAggregate>(
512 serde_json::to_value(TestAggregate {
513 id: id.clone(),
514 description: "a test description that should not be saved".to_string(),
515 tests: test_tests.clone(),
516 })
517 .unwrap(),
518 id.clone(),
519 2,
520 &[],
521 )
522 .await
523 .unwrap_err();
524 assert!(
525 matches!(result, MysqlAggregateError::OptimisticLock),
526 "invalid error result found during insert: {result}"
527 );
528
529 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
530 assert_eq!(
531 Some(snapshot_context(
532 id.clone(),
533 0,
534 2,
535 serde_json::to_value(TestAggregate {
536 id: id.clone(),
537 description: "a test description that should be saved".to_string(),
538 tests: test_tests.clone(),
539 })
540 .unwrap()
541 )),
542 snapshot
543 );
544 }
545}