1use cqrs_es::persist::{
2 PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
3};
4use cqrs_es::Aggregate;
5use futures::TryStreamExt;
6use serde_json::Value;
7use sqlx::postgres::PgRow;
8use sqlx::{Pool, Postgres, Row, Transaction};
9
10use crate::error::PostgresAggregateError;
11use crate::sql_query::SqlQueryFactory;
12
13const DEFAULT_EVENT_TABLE: &str = "events";
14const DEFAULT_SNAPSHOT_TABLE: &str = "snapshots";
15
16const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
17
18pub struct PostgresEventRepository {
20 pool: Pool<Postgres>,
21 query_factory: SqlQueryFactory,
22 stream_channel_size: usize,
23}
24
25impl PersistedEventRepository for PostgresEventRepository {
26 async fn get_events<A: Aggregate>(
27 &self,
28 aggregate_id: &str,
29 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
30 self.select_events::<A>(aggregate_id, self.query_factory.select_events())
31 .await
32 }
33
34 async fn get_last_events<A: Aggregate>(
35 &self,
36 aggregate_id: &str,
37 last_sequence: usize,
38 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
39 let query = self.query_factory.get_last_events(last_sequence);
40 self.select_events::<A>(aggregate_id, &query).await
41 }
42
43 async fn get_snapshot<A: Aggregate>(
44 &self,
45 aggregate_id: &str,
46 ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
47 let Some(row) = sqlx::query(self.query_factory.select_snapshot())
48 .bind(A::TYPE)
49 .bind(aggregate_id)
50 .fetch_optional(&self.pool)
51 .await
52 .map_err(PostgresAggregateError::from)?
53 else {
54 return Ok(None);
55 };
56 Ok(Some(Self::deser_snapshot(&row)))
57 }
58
59 async fn persist<A: Aggregate>(
60 &self,
61 events: &[SerializedEvent],
62 snapshot_update: Option<(String, Value, usize)>,
63 ) -> Result<(), PersistenceError> {
64 match snapshot_update {
65 None => {
66 self.insert_events::<A>(events).await?;
67 }
68 Some((aggregate_id, aggregate, current_snapshot)) => {
69 if current_snapshot == 1 {
70 self.insert::<A>(aggregate, aggregate_id, current_snapshot, events)
71 .await?;
72 } else {
73 self.update::<A>(aggregate, aggregate_id, current_snapshot, events)
74 .await?;
75 }
76 }
77 }
78 Ok(())
79 }
80
81 async fn stream_events<A: Aggregate>(
82 &self,
83 aggregate_id: &str,
84 ) -> Result<ReplayStream, PersistenceError> {
85 Ok(stream_events(
86 self.query_factory.select_events().to_string(),
87 A::TYPE.to_string(),
88 aggregate_id.to_string(),
89 self.pool.clone(),
90 self.stream_channel_size,
91 ))
92 }
93
94 async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
96 Ok(stream_events(
97 self.query_factory.all_events().to_string(),
98 A::TYPE.to_string(),
99 String::new(),
100 self.pool.clone(),
101 self.stream_channel_size,
102 ))
103 }
104}
105
106fn stream_events(
107 query: String,
108 aggregate_type: String,
109 aggregate_id: String,
110 pool: Pool<Postgres>,
111 channel_size: usize,
112) -> ReplayStream {
113 let (mut feed, stream) = ReplayStream::new(channel_size);
114 tokio::spawn(async move {
115 let query = sqlx::query(&query)
116 .bind(&aggregate_type)
117 .bind(&aggregate_id);
118 let mut rows = query.fetch(&pool);
119 while let Some(row) = rows.try_next().await.unwrap() {
120 let event = PostgresEventRepository::deser_event(&row);
121 if feed.push(Ok(event)).await.is_err() {
122 return;
124 }
125 }
126 });
127 stream
128}
129
130impl PostgresEventRepository {
131 async fn select_events<A: Aggregate>(
132 &self,
133 aggregate_id: &str,
134 query: &str,
135 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
136 let mut rows = sqlx::query(query)
137 .bind(A::TYPE)
138 .bind(aggregate_id)
139 .fetch(&self.pool);
140 let mut result = Vec::default();
141 while let Some(row) = rows
142 .try_next()
143 .await
144 .map_err(PostgresAggregateError::from)?
145 {
146 result.push(Self::deser_event(&row));
147 }
148 Ok(result)
149 }
150}
151
152impl PostgresEventRepository {
153 pub fn new(pool: Pool<Postgres>) -> Self {
165 Self::use_tables(pool, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
166 }
167
168 pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
181 Self {
182 pool: self.pool,
183 query_factory: self.query_factory,
184 stream_channel_size,
185 }
186 }
187
188 pub fn with_tables(self, events_table: &str, snapshots_table: &str) -> Self {
202 Self::use_tables(self.pool, events_table, snapshots_table)
203 }
204
205 fn use_tables(pool: Pool<Postgres>, events_table: &str, snapshots_table: &str) -> Self {
206 Self {
207 pool,
208 query_factory: SqlQueryFactory::new(events_table, snapshots_table),
209 stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
210 }
211 }
212
213 pub(crate) async fn insert_events<A: Aggregate>(
214 &self,
215 events: &[SerializedEvent],
216 ) -> Result<(), PostgresAggregateError> {
217 let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
218 self.persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
219 .await?;
220 tx.commit().await?;
221 Ok(())
222 }
223
224 pub(crate) async fn insert<A: Aggregate>(
225 &self,
226 aggregate_payload: Value,
227 aggregate_id: String,
228 current_snapshot: usize,
229 events: &[SerializedEvent],
230 ) -> Result<(), PostgresAggregateError> {
231 let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
232 let current_sequence = self
233 .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
234 .await?;
235 sqlx::query(self.query_factory.insert_snapshot())
236 .bind(A::TYPE)
237 .bind(aggregate_id.as_str())
238 .bind(current_sequence as i32)
239 .bind(current_snapshot as i32)
240 .bind(&aggregate_payload)
241 .execute(&mut *tx)
242 .await?;
243 tx.commit().await?;
244 Ok(())
245 }
246
247 pub(crate) async fn update<A: Aggregate>(
248 &self,
249 aggregate: Value,
250 aggregate_id: String,
251 current_snapshot: usize,
252 events: &[SerializedEvent],
253 ) -> Result<(), PostgresAggregateError> {
254 let mut tx: Transaction<'_, Postgres> = sqlx::Acquire::begin(&self.pool).await?;
255 let current_sequence = self
256 .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
257 .await?;
258
259 let aggregate_payload = serde_json::to_value(&aggregate)?;
260 let result = sqlx::query(self.query_factory.update_snapshot())
261 .bind(A::TYPE)
262 .bind(aggregate_id.as_str())
263 .bind(current_sequence as i32)
264 .bind(current_snapshot as i32)
265 .bind((current_snapshot - 1) as i32)
266 .bind(&aggregate_payload)
267 .execute(&mut *tx)
268 .await?;
269 tx.commit().await?;
270 match result.rows_affected() {
271 1 => Ok(()),
272 _ => Err(PostgresAggregateError::OptimisticLock),
273 }
274 }
275
276 fn deser_event(row: &PgRow) -> SerializedEvent {
277 let aggregate_type: String = row.get("aggregate_type");
278 let aggregate_id: String = row.get("aggregate_id");
279 let sequence = {
280 let s: i64 = row.get("sequence");
281 s as usize
282 };
283 let event_type: String = row.get("event_type");
284 let event_version: String = row.get("event_version");
285 let payload: Value = row.get("payload");
286 let metadata: Value = row.get("metadata");
287 SerializedEvent::new(
288 aggregate_id,
289 sequence,
290 aggregate_type,
291 event_type,
292 event_version,
293 payload,
294 metadata,
295 )
296 }
297
298 fn deser_snapshot(row: &PgRow) -> SerializedSnapshot {
299 let aggregate_id = row.get("aggregate_id");
300 let s: i64 = row.get("last_sequence");
301 let current_sequence = s as usize;
302 let s: i64 = row.get("current_snapshot");
303 let current_snapshot = s as usize;
304 let aggregate: Value = row.get("payload");
305 SerializedSnapshot {
306 aggregate_id,
307 aggregate,
308 current_sequence,
309 current_snapshot,
310 }
311 }
312
313 pub(crate) async fn persist_events<A: Aggregate>(
314 &self,
315 inser_event_query: &str,
316 tx: &mut Transaction<'_, Postgres>,
317 events: &[SerializedEvent],
318 ) -> Result<usize, PostgresAggregateError> {
319 let mut current_sequence: usize = 0;
320 for event in events {
321 current_sequence = event.sequence;
322 let event_type = &event.event_type;
323 let event_version = &event.event_version;
324 let payload = serde_json::to_value(&event.payload)?;
325 let metadata = serde_json::to_value(&event.metadata)?;
326 sqlx::query(inser_event_query)
327 .bind(A::TYPE)
328 .bind(event.aggregate_id.as_str())
329 .bind(event.sequence as i32)
330 .bind(event_type)
331 .bind(event_version)
332 .bind(&payload)
333 .bind(&metadata)
334 .execute(&mut **tx)
335 .await?;
336 }
337 Ok(current_sequence)
338 }
339}
340
341#[cfg(test)]
342mod test {
343 use cqrs_es::persist::PersistedEventRepository;
344
345 use crate::error::PostgresAggregateError;
346 use crate::testing::tests::{
347 snapshot_context, test_event_envelope, Created, SomethingElse, TestAggregate, TestEvent,
348 Tested, TEST_CONNECTION_STRING,
349 };
350 use crate::{default_postgress_pool, PostgresEventRepository};
351
352 #[tokio::test]
353 async fn event_repositories() {
354 let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
355 let id = uuid::Uuid::new_v4().to_string();
356 let event_repo: PostgresEventRepository =
357 PostgresEventRepository::new(pool.clone()).with_streaming_channel_size(1);
358 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
359 assert!(events.is_empty());
360
361 event_repo
362 .insert_events::<TestAggregate>(&[
363 test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
364 test_event_envelope(
365 &id,
366 2,
367 TestEvent::Tested(Tested {
368 test_name: "a test was run".to_string(),
369 }),
370 ),
371 ])
372 .await
373 .unwrap();
374 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
375 assert_eq!(2, events.len());
376 events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
377
378 let result = event_repo
380 .insert_events::<TestAggregate>(&[
381 test_event_envelope(
382 &id,
383 3,
384 TestEvent::SomethingElse(SomethingElse {
385 description: "this should not persist".to_string(),
386 }),
387 ),
388 test_event_envelope(
389 &id,
390 2,
391 TestEvent::SomethingElse(SomethingElse {
392 description: "bad sequence number".to_string(),
393 }),
394 ),
395 ])
396 .await
397 .unwrap_err();
398 assert!(
399 matches!(result, PostgresAggregateError::OptimisticLock),
400 "invalid error result found during insert: {result}"
401 );
402
403 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
404 assert_eq!(2, events.len());
405
406 verify_replay_stream(&id, event_repo).await;
407 }
408
409 async fn verify_replay_stream(id: &str, event_repo: PostgresEventRepository) {
410 let mut stream = event_repo.stream_events::<TestAggregate>(id).await.unwrap();
411 let mut found_in_stream = 0;
412 while (stream.next::<TestAggregate>(&[]).await).is_some() {
413 found_in_stream += 1;
414 }
415 assert_eq!(found_in_stream, 2);
416
417 let mut stream = event_repo
418 .stream_all_events::<TestAggregate>()
419 .await
420 .unwrap();
421 let mut found_in_stream = 0;
422 while (stream.next::<TestAggregate>(&[]).await).is_some() {
423 found_in_stream += 1;
424 }
425 assert!(found_in_stream >= 2);
426 }
427
428 #[tokio::test]
429 async fn snapshot_repositories() {
430 let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
431 let id = uuid::Uuid::new_v4().to_string();
432 let event_repo: PostgresEventRepository = PostgresEventRepository::new(pool.clone());
433 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
434 assert_eq!(None, snapshot);
435
436 let test_description = "some test snapshot here".to_string();
437 let test_tests = vec!["testA".to_string(), "testB".to_string()];
438 event_repo
439 .insert::<TestAggregate>(
440 serde_json::to_value(TestAggregate {
441 id: id.clone(),
442 description: test_description.clone(),
443 tests: test_tests.clone(),
444 })
445 .unwrap(),
446 id.clone(),
447 1,
448 &[],
449 )
450 .await
451 .unwrap();
452
453 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
454 assert_eq!(
455 Some(snapshot_context(
456 id.clone(),
457 0,
458 1,
459 serde_json::to_value(TestAggregate {
460 id: id.clone(),
461 description: test_description.clone(),
462 tests: test_tests.clone(),
463 })
464 .unwrap()
465 )),
466 snapshot
467 );
468
469 event_repo
471 .update::<TestAggregate>(
472 serde_json::to_value(TestAggregate {
473 id: id.clone(),
474 description: "a test description that should be saved".to_string(),
475 tests: test_tests.clone(),
476 })
477 .unwrap(),
478 id.clone(),
479 2,
480 &[],
481 )
482 .await
483 .unwrap();
484
485 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
486 assert_eq!(
487 Some(snapshot_context(
488 id.clone(),
489 0,
490 2,
491 serde_json::to_value(TestAggregate {
492 id: id.clone(),
493 description: "a test description that should be saved".to_string(),
494 tests: test_tests.clone(),
495 })
496 .unwrap()
497 )),
498 snapshot
499 );
500
501 let result = event_repo
503 .update::<TestAggregate>(
504 serde_json::to_value(TestAggregate {
505 id: id.clone(),
506 description: "a test description that should not be saved".to_string(),
507 tests: test_tests.clone(),
508 })
509 .unwrap(),
510 id.clone(),
511 2,
512 &[],
513 )
514 .await
515 .unwrap_err();
516 assert!(
517 matches!(result, PostgresAggregateError::OptimisticLock),
518 "invalid error result found during insert: {result}"
519 );
520
521 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
522 assert_eq!(
523 Some(snapshot_context(
524 id.clone(),
525 0,
526 2,
527 serde_json::to_value(TestAggregate {
528 id: id.clone(),
529 description: "a test description that should be saved".to_string(),
530 tests: test_tests.clone(),
531 })
532 .unwrap()
533 )),
534 snapshot
535 );
536 }
537}