1use std::collections::HashMap;
2
3use aws_sdk_dynamodb::operation::query::builders::QueryFluentBuilder;
4use aws_sdk_dynamodb::operation::query::QueryOutput;
5use aws_sdk_dynamodb::operation::scan::builders::ScanFluentBuilder;
6use aws_sdk_dynamodb::primitives::Blob;
7use aws_sdk_dynamodb::types::{AttributeValue, Put, TransactWriteItem};
8use aws_sdk_dynamodb::Client;
9use cqrs_es::persist::{
10 PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
11};
12use cqrs_es::Aggregate;
13use serde_json::Value;
14
15use crate::error::DynamoAggregateError;
16use crate::helpers::{att_as_number, att_as_string, att_as_value, commit_transactions};
17
18const DEFAULT_EVENT_TABLE: &str = "Events";
19const DEFAULT_SNAPSHOT_TABLE: &str = "Snapshots";
20
21const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
22
23pub struct DynamoEventRepository {
25 client: Client,
26 event_table: String,
27 snapshot_table: String,
28 stream_channel_size: usize,
29}
30
31impl DynamoEventRepository {
32 pub fn new(client: Client) -> Self {
44 Self::use_table_names(client, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
45 }
46 pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
59 Self {
60 client: self.client,
61 event_table: self.event_table,
62 snapshot_table: self.snapshot_table,
63 stream_channel_size,
64 }
65 }
66 pub fn with_tables(self, event_table: &str, snapshot_table: &str) -> Self {
80 Self::use_table_names(self.client, event_table, snapshot_table)
81 }
82
83 fn use_table_names(client: Client, event_table: &str, snapshot_table: &str) -> Self {
84 Self {
85 client,
86 event_table: event_table.to_string(),
87 snapshot_table: snapshot_table.to_string(),
88 stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
89 }
90 }
91
92 pub(crate) async fn insert_events(
93 &self,
94 events: &[SerializedEvent],
95 ) -> Result<(), DynamoAggregateError> {
96 if events.is_empty() {
97 return Ok(());
98 }
99 let (transactions, _) = Self::build_event_put_transactions(&self.event_table, events);
100 commit_transactions(&self.client, transactions).await?;
101 Ok(())
102 }
103
104 fn build_event_put_transactions(
105 table_name: &str,
106 events: &[SerializedEvent],
107 ) -> (Vec<TransactWriteItem>, usize) {
108 let mut current_sequence: usize = 0;
109 let mut transactions: Vec<TransactWriteItem> = Vec::default();
110 for event in events {
111 current_sequence = event.sequence;
112 let aggregate_type_and_id =
113 AttributeValue::S(format!("{}:{}", &event.aggregate_type, &event.aggregate_id));
114 let aggregate_type = AttributeValue::S(String::from(&event.aggregate_type));
115 let aggregate_id = AttributeValue::S(String::from(&event.aggregate_id));
116 let sequence = AttributeValue::N(String::from(&event.sequence.to_string()));
117 let event_version = AttributeValue::S(String::from(&event.event_version));
118 let event_type = AttributeValue::S(String::from(&event.event_type));
119 let payload_blob = serde_json::to_vec(&event.payload).unwrap();
120 let payload = AttributeValue::B(Blob::new(payload_blob));
121 let metadata_blob = serde_json::to_vec(&event.metadata).unwrap();
122 let metadata = AttributeValue::B(Blob::new(metadata_blob));
123
124 let put = Put::builder()
125 .table_name(table_name)
126 .item("AggregateTypeAndId", aggregate_type_and_id)
127 .item("AggregateIdSequence", sequence)
128 .item("AggregateType", aggregate_type)
129 .item("AggregateId", aggregate_id)
130 .item("EventVersion", event_version)
131 .item("EventType", event_type)
132 .item("Payload", payload)
133 .item("Metadata", metadata)
134 .condition_expression("attribute_not_exists( AggregateIdSequence )")
135 .build()
136 .unwrap();
137 let write_item = TransactWriteItem::builder().put(put).build();
138 transactions.push(write_item);
139 }
140 (transactions, current_sequence)
141 }
142
143 async fn query_events(
144 &self,
145 aggregate_type: &str,
146 aggregate_id: &str,
147 ) -> Result<Vec<SerializedEvent>, DynamoAggregateError> {
148 let query_output = self
149 .query_table(aggregate_type, aggregate_id, &self.event_table)
150 .await?;
151 let mut result = Vec::default();
152 for entry in query_output.items.into_iter().flatten() {
153 result.push(serialized_event(entry)?);
154 }
155 Ok(result)
156 }
157 async fn query_events_from(
158 &self,
159 aggregate_type: &str,
160 aggregate_id: &str,
161 last_sequence: usize,
162 ) -> Result<Vec<SerializedEvent>, DynamoAggregateError> {
163 let query_output = self
164 .client
165 .query()
166 .table_name(&self.event_table)
167 .key_condition_expression("#agg_type_id = :agg_type_id AND #sequence > :sequence")
168 .expression_attribute_names("#agg_type_id", "AggregateTypeAndId")
169 .expression_attribute_names("#sequence", "AggregateIdSequence")
170 .expression_attribute_values(
171 ":agg_type_id",
172 AttributeValue::S(format!("{aggregate_type}:{aggregate_id}")),
173 )
174 .expression_attribute_values(":sequence", AttributeValue::N(last_sequence.to_string()))
175 .send()
176 .await?;
177 let mut result = Vec::default();
178 for entry in query_output.items.into_iter().flatten() {
179 result.push(serialized_event(entry)?);
180 }
181 Ok(result)
182 }
183
184 pub(crate) async fn update_snapshot<A: Aggregate>(
185 &self,
186 aggregate_payload: Value,
187 aggregate_id: String,
188 current_snapshot: usize,
189 events: &[SerializedEvent],
190 ) -> Result<(), DynamoAggregateError> {
191 let expected_snapshot = current_snapshot - 1;
192 let (mut transactions, current_sequence) =
193 Self::build_event_put_transactions(&self.event_table, events);
194 let aggregate_type_and_id = AttributeValue::S(format!("{}:{}", A::TYPE, &aggregate_id));
195 let aggregate_type = AttributeValue::S(A::TYPE.to_string());
196 let aggregate_id = AttributeValue::S(aggregate_id);
197 let current_sequence = AttributeValue::N(current_sequence.to_string());
198 let current_snapshot = AttributeValue::N(current_snapshot.to_string());
199 let payload_blob = serde_json::to_vec(&aggregate_payload).unwrap();
200 let payload = AttributeValue::B(Blob::new(payload_blob));
201 let expected_snapshot = AttributeValue::N(expected_snapshot.to_string());
202 transactions.push(TransactWriteItem::builder()
203 .put(Put::builder()
204 .table_name(&self.snapshot_table)
205 .item("AggregateTypeAndId", aggregate_type_and_id)
206 .item("AggregateType", aggregate_type)
207 .item("AggregateId", aggregate_id)
208 .item("CurrentSequence", current_sequence)
209 .item("CurrentSnapshot", current_snapshot)
210 .item("Payload", payload)
211 .condition_expression("attribute_not_exists(CurrentSnapshot) OR (CurrentSnapshot = :current_snapshot)")
212 .expression_attribute_values(":current_snapshot", expected_snapshot)
213 .build()?)
214 .build());
215 commit_transactions(&self.client, transactions).await?;
216 Ok(())
217 }
218
219 async fn query_table(
220 &self,
221 aggregate_type: &str,
222 aggregate_id: &str,
223 table: &str,
224 ) -> Result<QueryOutput, DynamoAggregateError> {
225 let output = self
226 .create_query(table, aggregate_type, aggregate_id)
227 .send()
228 .await?;
229 Ok(output)
230 }
231
232 fn create_query(
233 &self,
234 table: &str,
235 aggregate_type: &str,
236 aggregate_id: &str,
237 ) -> QueryFluentBuilder {
238 self.client
239 .query()
240 .table_name(table)
241 .consistent_read(true)
242 .key_condition_expression("#agg_type_id = :agg_type_id")
243 .expression_attribute_names("#agg_type_id", "AggregateTypeAndId")
244 .expression_attribute_values(
245 ":agg_type_id",
246 AttributeValue::S(format!("{aggregate_type}:{aggregate_id}")),
247 )
248 }
249}
250
251fn serialized_event(
252 entry: HashMap<String, AttributeValue>,
253) -> Result<SerializedEvent, DynamoAggregateError> {
254 let aggregate_id = att_as_string(&entry, "AggregateId")?;
255 let sequence = att_as_number(&entry, "AggregateIdSequence")?;
256 let aggregate_type = att_as_string(&entry, "AggregateType")?;
257 let event_type = att_as_string(&entry, "EventType")?;
258 let event_version = att_as_string(&entry, "EventVersion")?;
259 let payload = att_as_value(&entry, "Payload")?;
260 let metadata = att_as_value(&entry, "Metadata")?;
261 Ok(SerializedEvent {
262 aggregate_id,
263 sequence,
264 aggregate_type,
265 event_type,
266 event_version,
267 payload,
268 metadata,
269 })
270}
271
272impl PersistedEventRepository for DynamoEventRepository {
273 async fn get_events<A: Aggregate>(
274 &self,
275 aggregate_id: &str,
276 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
277 Ok(self.query_events(A::TYPE, aggregate_id).await?)
278 }
279
280 async fn get_last_events<A: Aggregate>(
281 &self,
282 aggregate_id: &str,
283 number_events: usize,
284 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
285 Ok(self
286 .query_events_from(A::TYPE, aggregate_id, number_events)
287 .await?)
288 }
289
290 async fn get_snapshot<A: Aggregate>(
291 &self,
292 aggregate_id: &str,
293 ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
294 let query_output = self
295 .query_table(A::TYPE, aggregate_id, &self.snapshot_table)
296 .await?;
297 let Some(query_items_vec) = query_output.items else {
298 return Ok(None);
299 };
300 if query_items_vec.is_empty() {
301 return Ok(None);
302 }
303 let query_item = query_items_vec.first().unwrap();
304 let aggregate = att_as_value(query_item, "Payload")?;
305 let current_sequence = att_as_number(query_item, "CurrentSequence")?;
306 let current_snapshot = att_as_number(query_item, "CurrentSnapshot")?;
307
308 Ok(Some(SerializedSnapshot {
309 aggregate_id: aggregate_id.to_string(),
310 aggregate,
311 current_sequence,
312 current_snapshot,
313 }))
314 }
315
316 async fn persist<A: Aggregate>(
317 &self,
318 events: &[SerializedEvent],
319 snapshot_update: Option<(String, Value, usize)>,
320 ) -> Result<(), PersistenceError> {
321 match snapshot_update {
322 None => {
323 self.insert_events(events).await?;
324 }
325 Some((aggregate_id, aggregate, current_snapshot)) => {
326 self.update_snapshot::<A>(aggregate, aggregate_id, current_snapshot, events)
327 .await?;
328 }
329 }
330 Ok(())
331 }
332
333 async fn stream_events<A: Aggregate>(
334 &self,
335 aggregate_id: &str,
336 ) -> Result<ReplayStream, PersistenceError> {
337 let query = self
338 .create_query(&self.event_table, A::TYPE, aggregate_id)
339 .limit(self.stream_channel_size as i32);
340 Ok(stream_events(query, self.stream_channel_size))
341 }
342
343 async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
344 let scan = self
345 .client
346 .scan()
347 .table_name(&self.event_table)
348 .limit(self.stream_channel_size as i32);
349 Ok(stream_all_events(scan, self.stream_channel_size))
350 }
351}
352
353fn stream_events(base_query: QueryFluentBuilder, channel_size: usize) -> ReplayStream {
355 let (mut feed, stream) = ReplayStream::new(channel_size);
356 tokio::spawn(async move {
357 let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
358 loop {
359 let query = match &last_evaluated_key {
360 None => base_query.clone(),
361 Some(last) => last.iter().fold(base_query.clone(), |query, (key, value)| {
362 query.exclusive_start_key(key.to_string(), value.to_owned())
363 }),
364 };
365 match query.send().await {
366 Ok(query_output) => {
367 last_evaluated_key = query_output.last_evaluated_key;
368 if let Some(entries) = query_output.items {
369 for entry in entries {
370 let Ok(event) = serialized_event(entry) else {
371 return;
372 };
373 if feed.push(Ok(event)).await.is_err() {
374 return;
376 }
377 }
378 }
379 }
380 Err(err) => {
381 let err: DynamoAggregateError = err.into();
382 if feed.push(Err(err.into())).await.is_err() {}
383 }
384 }
385 if last_evaluated_key.is_none() {
386 return;
387 }
388 }
389 });
390 stream
391}
392fn stream_all_events(base_query: ScanFluentBuilder, channel_size: usize) -> ReplayStream {
393 let (mut feed, stream) = ReplayStream::new(channel_size);
394 tokio::spawn(async move {
395 let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
396 loop {
397 let query = match &last_evaluated_key {
398 None => base_query.clone(),
399 Some(last) => last.iter().fold(base_query.clone(), |query, (key, value)| {
400 query.exclusive_start_key(key.to_string(), value.to_owned())
401 }),
402 };
403 match query.send().await {
404 Ok(query_output) => {
405 last_evaluated_key = query_output.last_evaluated_key;
406 if let Some(entries) = query_output.items {
407 for entry in entries {
408 let Ok(event) = serialized_event(entry) else {
409 return;
410 };
411 if feed.push(Ok(event)).await.is_err() {
412 return;
414 }
415 }
416 }
417 }
418 Err(err) => {
419 let err: DynamoAggregateError = err.into();
420 if feed.push(Err(err.into())).await.is_err() {}
421 }
422 }
423 if last_evaluated_key.is_none() {
424 return;
425 }
426 }
427 });
428 stream
429}
430
431#[cfg(test)]
432mod test {
433 use cqrs_es::persist::PersistedEventRepository;
434
435 use crate::error::DynamoAggregateError;
436 use crate::testing::tests::{
437 snapshot_context, test_dynamodb_client, test_event_envelope, Created, SomethingElse,
438 TestAggregate, TestEvent, Tested,
439 };
440 use crate::DynamoEventRepository;
441
442 #[tokio::test]
443 async fn event_repositories() {
444 let client = test_dynamodb_client().await;
445 let id = uuid::Uuid::new_v4().to_string();
446 let event_repo = DynamoEventRepository::new(client.clone()).with_streaming_channel_size(1);
447 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
448 assert!(events.is_empty());
449
450 event_repo
451 .insert_events(&[
452 test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
453 test_event_envelope(
454 &id,
455 2,
456 TestEvent::Tested(Tested {
457 test_name: "a test was run".to_string(),
458 }),
459 ),
460 ])
461 .await
462 .unwrap();
463 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
464 assert_eq!(2, events.len());
465 events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
466
467 let result = event_repo
469 .insert_events(&[
470 test_event_envelope(
471 &id,
472 3,
473 TestEvent::SomethingElse(SomethingElse {
474 description: "this should not persist".to_string(),
475 }),
476 ),
477 test_event_envelope(
478 &id,
479 2,
480 TestEvent::SomethingElse(SomethingElse {
481 description: "bad sequence number".to_string(),
482 }),
483 ),
484 ])
485 .await
486 .unwrap_err();
487 match result {
488 DynamoAggregateError::OptimisticLock => {}
489 _ => panic!("invalid error result found during insert: {result}"),
490 }
491
492 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
493 assert_eq!(2, events.len());
494
495 let events = event_repo
496 .get_last_events::<TestAggregate>(&id, 1)
497 .await
498 .unwrap();
499 assert_eq!(1, events.len());
500
501 verify_replay_stream(&id, event_repo).await;
502 }
503
504 async fn verify_replay_stream(id: &str, event_repo: DynamoEventRepository) {
505 let mut stream = event_repo.stream_events::<TestAggregate>(id).await.unwrap();
506 let mut found_in_stream = 0;
507 while (stream.next::<TestAggregate>(&[]).await).is_some() {
508 found_in_stream += 1;
509 }
510 assert_eq!(found_in_stream, 2);
511
512 let mut stream = event_repo
513 .stream_all_events::<TestAggregate>()
514 .await
515 .unwrap();
516 let mut found_in_stream = 0;
517 while (stream.next::<TestAggregate>(&[]).await).is_some() {
518 found_in_stream += 1;
519 }
520 assert!(found_in_stream >= 2);
521 }
522
523 #[tokio::test]
524 async fn snapshot_repositories() {
525 let client = test_dynamodb_client().await;
526 let id = uuid::Uuid::new_v4().to_string();
527 let repo = DynamoEventRepository::new(client.clone());
528 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
529 assert_eq!(None, snapshot);
530
531 let test_description = "some test snapshot here".to_string();
532 let test_tests = vec!["testA".to_string(), "testB".to_string()];
533 repo.update_snapshot::<TestAggregate>(
534 serde_json::to_value(TestAggregate {
535 id: id.clone(),
536 description: test_description.clone(),
537 tests: test_tests.clone(),
538 })
539 .unwrap(),
540 id.clone(),
541 1,
542 &[],
543 )
544 .await
545 .unwrap();
546
547 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
548 assert_eq!(
549 Some(snapshot_context(
550 id.clone(),
551 0,
552 1,
553 serde_json::to_value(TestAggregate {
554 id: id.clone(),
555 description: test_description.clone(),
556 tests: test_tests.clone(),
557 })
558 .unwrap(),
559 )),
560 snapshot
561 );
562
563 repo.update_snapshot::<TestAggregate>(
565 serde_json::to_value(TestAggregate {
566 id: id.clone(),
567 description: "a test description that should be saved".to_string(),
568 tests: test_tests.clone(),
569 })
570 .unwrap(),
571 id.clone(),
572 2,
573 &[],
574 )
575 .await
576 .unwrap();
577
578 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
579 assert_eq!(
580 Some(snapshot_context(
581 id.clone(),
582 0,
583 2,
584 serde_json::to_value(TestAggregate {
585 id: id.clone(),
586 description: "a test description that should be saved".to_string(),
587 tests: test_tests.clone(),
588 })
589 .unwrap(),
590 )),
591 snapshot
592 );
593
594 let result = repo
596 .update_snapshot::<TestAggregate>(
597 serde_json::to_value(TestAggregate {
598 id: id.clone(),
599 description: "a test description that should not be saved".to_string(),
600 tests: test_tests.clone(),
601 })
602 .unwrap(),
603 id.clone(),
604 2,
605 &[],
606 )
607 .await
608 .unwrap_err();
609 match result {
610 DynamoAggregateError::OptimisticLock => {}
611 _ => panic!("invalid error result found during insert: {result}"),
612 }
613
614 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
615 assert_eq!(
616 Some(snapshot_context(
617 id.clone(),
618 0,
619 2,
620 serde_json::to_value(TestAggregate {
621 id: id.clone(),
622 description: "a test description that should be saved".to_string(),
623 tests: test_tests.clone(),
624 })
625 .unwrap(),
626 )),
627 snapshot
628 );
629 }
630}