1use async_trait::async_trait;
2use cqrs_es::persist::{
3 PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
4};
5use cqrs_es::Aggregate;
6use futures::TryStreamExt;
7use serde_json::Value;
8use sqlx::postgres::PgRow;
9use sqlx::{Pool, Postgres, Row, Transaction};
10
11use crate::error::PostgresAggregateError;
12use crate::sql_query::SqlQueryFactory;
13
14const DEFAULT_EVENT_TABLE: &str = "events";
15const DEFAULT_SNAPSHOT_TABLE: &str = "snapshots";
16
17const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
18
19pub struct PostgresEventRepository {
21 pool: Pool<Postgres>,
22 query_factory: SqlQueryFactory,
23 stream_channel_size: usize,
24}
25
26#[async_trait]
27impl PersistedEventRepository for PostgresEventRepository {
28 async fn get_events<A: Aggregate>(
29 &self,
30 aggregate_id: &str,
31 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
32 self.select_events::<A>(aggregate_id, self.query_factory.select_events())
33 .await
34 }
35
36 async fn get_last_events<A: Aggregate>(
37 &self,
38 aggregate_id: &str,
39 last_sequence: usize,
40 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
41 let query = self.query_factory.get_last_events(last_sequence);
42 self.select_events::<A>(aggregate_id, &query).await
43 }
44
45 async fn get_snapshot<A: Aggregate>(
46 &self,
47 aggregate_id: &str,
48 ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
49 let row: PgRow = match sqlx::query(self.query_factory.select_snapshot())
50 .bind(A::aggregate_type())
51 .bind(aggregate_id)
52 .fetch_optional(&self.pool)
53 .await
54 .map_err(PostgresAggregateError::from)?
55 {
56 Some(row) => row,
57 None => {
58 return Ok(None);
59 }
60 };
61 Ok(Some(self.deser_snapshot(row)?))
62 }
63
64 async fn persist<A: Aggregate>(
65 &self,
66 events: &[SerializedEvent],
67 snapshot_update: Option<(String, Value, usize)>,
68 ) -> Result<(), PersistenceError> {
69 match snapshot_update {
70 None => {
71 self.insert_events::<A>(events).await?;
72 }
73 Some((aggregate_id, aggregate, current_snapshot)) => {
74 if current_snapshot == 1 {
75 self.insert::<A>(aggregate, aggregate_id, current_snapshot, events)
76 .await?;
77 } else {
78 self.update::<A>(aggregate, aggregate_id, current_snapshot, events)
79 .await?;
80 }
81 }
82 };
83 Ok(())
84 }
85
86 async fn stream_events<A: Aggregate>(
87 &self,
88 aggregate_id: &str,
89 ) -> Result<ReplayStream, PersistenceError> {
90 Ok(stream_events(
91 self.query_factory.select_events().to_string(),
92 A::aggregate_type(),
93 aggregate_id.to_string(),
94 self.pool.clone(),
95 self.stream_channel_size,
96 ))
97 }
98
99 async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
101 Ok(stream_events(
102 self.query_factory.all_events().to_string(),
103 A::aggregate_type(),
104 "".to_string(),
105 self.pool.clone(),
106 self.stream_channel_size,
107 ))
108 }
109}
110
111fn stream_events(
112 query: String,
113 aggregate_type: String,
114 aggregate_id: String,
115 pool: Pool<Postgres>,
116 channel_size: usize,
117) -> ReplayStream {
118 let (mut feed, stream) = ReplayStream::new(channel_size);
119 tokio::spawn(async move {
120 let query = sqlx::query(&query)
121 .bind(&aggregate_type)
122 .bind(&aggregate_id);
123 let mut rows = query.fetch(&pool);
124 while let Some(row) = rows.try_next().await.unwrap() {
125 let event_result: Result<SerializedEvent, PersistenceError> =
126 PostgresEventRepository::deser_event(row).map_err(Into::into);
127 if feed.push(event_result).await.is_err() {
128 return;
130 };
131 }
132 });
133 stream
134}
135
136impl PostgresEventRepository {
137 async fn select_events<A: Aggregate>(
138 &self,
139 aggregate_id: &str,
140 query: &str,
141 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
142 let mut rows = sqlx::query(query)
143 .bind(A::aggregate_type())
144 .bind(aggregate_id)
145 .fetch(&self.pool);
146 let mut result: Vec<SerializedEvent> = Default::default();
147 while let Some(row) = rows
148 .try_next()
149 .await
150 .map_err(PostgresAggregateError::from)?
151 {
152 result.push(PostgresEventRepository::deser_event(row)?);
153 }
154 Ok(result)
155 }
156}
157
158impl PostgresEventRepository {
159 pub fn new(pool: Pool<Postgres>) -> Self {
171 Self::use_tables(pool, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
172 }
173
174 pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
187 Self {
188 pool: self.pool,
189 query_factory: self.query_factory,
190 stream_channel_size,
191 }
192 }
193
194 pub fn with_tables(self, events_table: &str, snapshots_table: &str) -> Self {
208 Self::use_tables(self.pool, events_table, snapshots_table)
209 }
210
211 fn use_tables(pool: Pool<Postgres>, events_table: &str, snapshots_table: &str) -> Self {
212 Self {
213 pool,
214 query_factory: SqlQueryFactory::new(events_table, snapshots_table),
215 stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
216 }
217 }
218
219 pub(crate) async fn insert_events<A: Aggregate>(
220 &self,
221 events: &[SerializedEvent],
222 ) -> Result<(), PostgresAggregateError> {
223 let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
224 self.persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
225 .await?;
226 tx.commit().await?;
227 Ok(())
228 }
229
230 pub(crate) async fn insert<A: Aggregate>(
231 &self,
232 aggregate_payload: Value,
233 aggregate_id: String,
234 current_snapshot: usize,
235 events: &[SerializedEvent],
236 ) -> Result<(), PostgresAggregateError> {
237 let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
238 let current_sequence = self
239 .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
240 .await?;
241 sqlx::query(self.query_factory.insert_snapshot())
242 .bind(A::aggregate_type())
243 .bind(aggregate_id.as_str())
244 .bind(current_sequence as i32)
245 .bind(current_snapshot as i32)
246 .bind(&aggregate_payload)
247 .execute(&mut *tx)
248 .await?;
249 tx.commit().await?;
250 Ok(())
251 }
252
253 pub(crate) async fn update<A: Aggregate>(
254 &self,
255 aggregate: Value,
256 aggregate_id: String,
257 current_snapshot: usize,
258 events: &[SerializedEvent],
259 ) -> Result<(), PostgresAggregateError> {
260 let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
261 let current_sequence = self
262 .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
263 .await?;
264
265 let aggregate_payload = serde_json::to_value(&aggregate)?;
266 let result = sqlx::query(self.query_factory.update_snapshot())
267 .bind(A::aggregate_type())
268 .bind(aggregate_id.as_str())
269 .bind(current_sequence as i32)
270 .bind(current_snapshot as i32)
271 .bind((current_snapshot - 1) as i32)
272 .bind(&aggregate_payload)
273 .execute(&mut *tx)
274 .await?;
275 tx.commit().await?;
276 match result.rows_affected() {
277 1 => Ok(()),
278 _ => Err(PostgresAggregateError::OptimisticLock),
279 }
280 }
281
282 fn deser_event(row: PgRow) -> Result<SerializedEvent, PostgresAggregateError> {
283 let aggregate_type: String = row.get("aggregate_type");
284 let aggregate_id: String = row.get("aggregate_id");
285 let sequence = {
286 let s: i64 = row.get("sequence");
287 s as usize
288 };
289 let event_type: String = row.get("event_type");
290 let event_version: String = row.get("event_version");
291 let payload: Value = row.get("payload");
292 let metadata: Value = row.get("metadata");
293 Ok(SerializedEvent::new(
294 aggregate_id,
295 sequence,
296 aggregate_type,
297 event_type,
298 event_version,
299 payload,
300 metadata,
301 ))
302 }
303
304 fn deser_snapshot(&self, row: PgRow) -> Result<SerializedSnapshot, PostgresAggregateError> {
305 let aggregate_id = row.get("aggregate_id");
306 let s: i64 = row.get("last_sequence");
307 let current_sequence = s as usize;
308 let s: i64 = row.get("current_snapshot");
309 let current_snapshot = s as usize;
310 let aggregate: Value = row.get("payload");
311 Ok(SerializedSnapshot {
312 aggregate_id,
313 aggregate,
314 current_sequence,
315 current_snapshot,
316 })
317 }
318
319 pub(crate) async fn persist_events<A: Aggregate>(
320 &self,
321 inser_event_query: &str,
322 tx: &mut Transaction<'_, Postgres>,
323 events: &[SerializedEvent],
324 ) -> Result<usize, PostgresAggregateError> {
325 let mut current_sequence: usize = 0;
326 for event in events {
327 current_sequence = event.sequence;
328 let event_type = &event.event_type;
329 let event_version = &event.event_version;
330 let payload = serde_json::to_value(&event.payload)?;
331 let metadata = serde_json::to_value(&event.metadata)?;
332 sqlx::query(inser_event_query)
333 .bind(A::aggregate_type())
334 .bind(event.aggregate_id.as_str())
335 .bind(event.sequence as i32)
336 .bind(event_type)
337 .bind(event_version)
338 .bind(&payload)
339 .bind(&metadata)
340 .execute(&mut **tx)
341 .await?;
342 }
343 Ok(current_sequence)
344 }
345}
346
347#[cfg(test)]
348mod test {
349 use cqrs_es::persist::PersistedEventRepository;
350
351 use crate::error::PostgresAggregateError;
352 use crate::testing::tests::{
353 snapshot_context, test_event_envelope, Created, SomethingElse, TestAggregate, TestEvent,
354 Tested, TEST_CONNECTION_STRING,
355 };
356 use crate::{default_postgress_pool, PostgresEventRepository};
357
358 #[tokio::test]
359 async fn event_repositories() {
360 let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
361 let id = uuid::Uuid::new_v4().to_string();
362 let event_repo: PostgresEventRepository =
363 PostgresEventRepository::new(pool.clone()).with_streaming_channel_size(1);
364 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
365 assert!(events.is_empty());
366
367 event_repo
368 .insert_events::<TestAggregate>(&[
369 test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
370 test_event_envelope(
371 &id,
372 2,
373 TestEvent::Tested(Tested {
374 test_name: "a test was run".to_string(),
375 }),
376 ),
377 ])
378 .await
379 .unwrap();
380 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
381 assert_eq!(2, events.len());
382 events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
383
384 let result = event_repo
386 .insert_events::<TestAggregate>(&[
387 test_event_envelope(
388 &id,
389 3,
390 TestEvent::SomethingElse(SomethingElse {
391 description: "this should not persist".to_string(),
392 }),
393 ),
394 test_event_envelope(
395 &id,
396 2,
397 TestEvent::SomethingElse(SomethingElse {
398 description: "bad sequence number".to_string(),
399 }),
400 ),
401 ])
402 .await
403 .unwrap_err();
404 match result {
405 PostgresAggregateError::OptimisticLock => {}
406 _ => panic!("invalid error result found during insert: {}", result),
407 };
408
409 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
410 assert_eq!(2, events.len());
411
412 verify_replay_stream(&id, event_repo).await;
413 }
414
415 async fn verify_replay_stream(id: &str, event_repo: PostgresEventRepository) {
416 let mut stream = event_repo
417 .stream_events::<TestAggregate>(&id)
418 .await
419 .unwrap();
420 let mut found_in_stream = 0;
421 while let Some(_) = stream.next::<TestAggregate>(&None).await {
422 found_in_stream += 1;
423 }
424 assert_eq!(found_in_stream, 2);
425
426 let mut stream = event_repo
427 .stream_all_events::<TestAggregate>()
428 .await
429 .unwrap();
430 let mut found_in_stream = 0;
431 while let Some(_) = stream.next::<TestAggregate>(&None).await {
432 found_in_stream += 1;
433 }
434 assert!(found_in_stream >= 2);
435 }
436
437 #[tokio::test]
438 async fn snapshot_repositories() {
439 let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
440 let id = uuid::Uuid::new_v4().to_string();
441 let event_repo: PostgresEventRepository = PostgresEventRepository::new(pool.clone());
442 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
443 assert_eq!(None, snapshot);
444
445 let test_description = "some test snapshot here".to_string();
446 let test_tests = vec!["testA".to_string(), "testB".to_string()];
447 event_repo
448 .insert::<TestAggregate>(
449 serde_json::to_value(TestAggregate {
450 id: id.clone(),
451 description: test_description.clone(),
452 tests: test_tests.clone(),
453 })
454 .unwrap(),
455 id.clone(),
456 1,
457 &vec![],
458 )
459 .await
460 .unwrap();
461
462 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
463 assert_eq!(
464 Some(snapshot_context(
465 id.clone(),
466 0,
467 1,
468 serde_json::to_value(TestAggregate {
469 id: id.clone(),
470 description: test_description.clone(),
471 tests: test_tests.clone(),
472 })
473 .unwrap()
474 )),
475 snapshot
476 );
477
478 event_repo
480 .update::<TestAggregate>(
481 serde_json::to_value(TestAggregate {
482 id: id.clone(),
483 description: "a test description that should be saved".to_string(),
484 tests: test_tests.clone(),
485 })
486 .unwrap(),
487 id.clone(),
488 2,
489 &vec![],
490 )
491 .await
492 .unwrap();
493
494 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
495 assert_eq!(
496 Some(snapshot_context(
497 id.clone(),
498 0,
499 2,
500 serde_json::to_value(TestAggregate {
501 id: id.clone(),
502 description: "a test description that should be saved".to_string(),
503 tests: test_tests.clone(),
504 })
505 .unwrap()
506 )),
507 snapshot
508 );
509
510 let result = event_repo
512 .update::<TestAggregate>(
513 serde_json::to_value(TestAggregate {
514 id: id.clone(),
515 description: "a test description that should not be saved".to_string(),
516 tests: test_tests.clone(),
517 })
518 .unwrap(),
519 id.clone(),
520 2,
521 &vec![],
522 )
523 .await
524 .unwrap_err();
525 match result {
526 PostgresAggregateError::OptimisticLock => {}
527 _ => panic!("invalid error result found during insert: {}", result),
528 };
529
530 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
531 assert_eq!(
532 Some(snapshot_context(
533 id.clone(),
534 0,
535 2,
536 serde_json::to_value(TestAggregate {
537 id: id.clone(),
538 description: "a test description that should be saved".to_string(),
539 tests: test_tests.clone(),
540 })
541 .unwrap()
542 )),
543 snapshot
544 );
545 }
546}