1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use aws_sdk_dynamodb::operation::query::builders::QueryFluentBuilder;
5use aws_sdk_dynamodb::operation::query::QueryOutput;
6use aws_sdk_dynamodb::operation::scan::builders::ScanFluentBuilder;
7use aws_sdk_dynamodb::primitives::Blob;
8use aws_sdk_dynamodb::types::{AttributeValue, Put, TransactWriteItem};
9use aws_sdk_dynamodb::Client;
10use cqrs_es::persist::{
11 PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
12};
13use cqrs_es::Aggregate;
14use serde_json::Value;
15
16use crate::error::DynamoAggregateError;
17use crate::helpers::{att_as_number, att_as_string, att_as_value, commit_transactions};
18
19const DEFAULT_EVENT_TABLE: &str = "Events";
20const DEFAULT_SNAPSHOT_TABLE: &str = "Snapshots";
21
22const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
23
24pub struct DynamoEventRepository {
26 client: Client,
27 event_table: String,
28 snapshot_table: String,
29 stream_channel_size: usize,
30}
31
32impl DynamoEventRepository {
33 pub fn new(client: Client) -> Self {
45 Self::use_table_names(client, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
46 }
47 pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
60 Self {
61 client: self.client,
62 event_table: self.event_table,
63 snapshot_table: self.snapshot_table,
64 stream_channel_size,
65 }
66 }
67 pub fn with_tables(self, event_table: &str, snapshot_table: &str) -> Self {
81 Self::use_table_names(self.client, event_table, snapshot_table)
82 }
83
84 fn use_table_names(client: Client, event_table: &str, snapshot_table: &str) -> Self {
85 Self {
86 client,
87 event_table: event_table.to_string(),
88 snapshot_table: snapshot_table.to_string(),
89 stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
90 }
91 }
92
93 pub(crate) async fn insert_events(
94 &self,
95 events: &[SerializedEvent],
96 ) -> Result<(), DynamoAggregateError> {
97 if events.is_empty() {
98 return Ok(());
99 }
100 let (transactions, _) = Self::build_event_put_transactions(&self.event_table, events);
101 commit_transactions(&self.client, transactions).await?;
102 Ok(())
103 }
104
105 fn build_event_put_transactions(
106 table_name: &str,
107 events: &[SerializedEvent],
108 ) -> (Vec<TransactWriteItem>, usize) {
109 let mut current_sequence: usize = 0;
110 let mut transactions: Vec<TransactWriteItem> = Vec::default();
111 for event in events {
112 current_sequence = event.sequence;
113 let aggregate_type_and_id =
114 AttributeValue::S(format!("{}:{}", &event.aggregate_type, &event.aggregate_id));
115 let aggregate_type = AttributeValue::S(String::from(&event.aggregate_type));
116 let aggregate_id = AttributeValue::S(String::from(&event.aggregate_id));
117 let sequence = AttributeValue::N(String::from(&event.sequence.to_string()));
118 let event_version = AttributeValue::S(String::from(&event.event_version));
119 let event_type = AttributeValue::S(String::from(&event.event_type));
120 let payload_blob = serde_json::to_vec(&event.payload).unwrap();
121 let payload = AttributeValue::B(Blob::new(payload_blob));
122 let metadata_blob = serde_json::to_vec(&event.metadata).unwrap();
123 let metadata = AttributeValue::B(Blob::new(metadata_blob));
124
125 let put = Put::builder()
126 .table_name(table_name)
127 .item("AggregateTypeAndId", aggregate_type_and_id)
128 .item("AggregateIdSequence", sequence)
129 .item("AggregateType", aggregate_type)
130 .item("AggregateId", aggregate_id)
131 .item("EventVersion", event_version)
132 .item("EventType", event_type)
133 .item("Payload", payload)
134 .item("Metadata", metadata)
135 .condition_expression("attribute_not_exists( AggregateIdSequence )")
136 .build()
137 .unwrap();
138 let write_item = TransactWriteItem::builder().put(put).build();
139 transactions.push(write_item);
140 }
141 (transactions, current_sequence)
142 }
143
144 async fn query_events(
145 &self,
146 aggregate_type: &str,
147 aggregate_id: &str,
148 ) -> Result<Vec<SerializedEvent>, DynamoAggregateError> {
149 let query_output = self
150 .query_table(aggregate_type, aggregate_id, &self.event_table)
151 .await?;
152 let mut result: Vec<SerializedEvent> = Default::default();
153 if let Some(entries) = query_output.items {
154 for entry in entries {
155 result.push(serialized_event(entry)?);
156 }
157 }
158 Ok(result)
159 }
160 async fn query_events_from(
161 &self,
162 aggregate_type: &str,
163 aggregate_id: &str,
164 last_sequence: usize,
165 ) -> Result<Vec<SerializedEvent>, DynamoAggregateError> {
166 let query_output = self
167 .client
168 .query()
169 .table_name(&self.event_table)
170 .key_condition_expression("#agg_type_id = :agg_type_id AND #sequence > :sequence")
171 .expression_attribute_names("#agg_type_id", "AggregateTypeAndId")
172 .expression_attribute_names("#sequence", "AggregateIdSequence")
173 .expression_attribute_values(
174 ":agg_type_id",
175 AttributeValue::S(format!("{}:{}", aggregate_type, aggregate_id)),
176 )
177 .expression_attribute_values(":sequence", AttributeValue::N(last_sequence.to_string()))
178 .send()
179 .await?;
180 let mut result: Vec<SerializedEvent> = Default::default();
181 if let Some(entries) = query_output.items {
182 for entry in entries {
183 result.push(serialized_event(entry)?);
184 }
185 }
186 Ok(result)
187 }
188
189 pub(crate) async fn update_snapshot<A: Aggregate>(
190 &self,
191 aggregate_payload: Value,
192 aggregate_id: String,
193 current_snapshot: usize,
194 events: &[SerializedEvent],
195 ) -> Result<(), DynamoAggregateError> {
196 let expected_snapshot = current_snapshot - 1;
197 let (mut transactions, current_sequence) =
198 Self::build_event_put_transactions(&self.event_table, events);
199 let aggregate_type_and_id =
200 AttributeValue::S(format!("{}:{}", A::aggregate_type(), &aggregate_id));
201 let aggregate_type = AttributeValue::S(A::aggregate_type());
202 let aggregate_id = AttributeValue::S(aggregate_id);
203 let current_sequence = AttributeValue::N(current_sequence.to_string());
204 let current_snapshot = AttributeValue::N(current_snapshot.to_string());
205 let payload_blob = serde_json::to_vec(&aggregate_payload).unwrap();
206 let payload = AttributeValue::B(Blob::new(payload_blob));
207 let expected_snapshot = AttributeValue::N(expected_snapshot.to_string());
208 transactions.push(TransactWriteItem::builder()
209 .put(Put::builder()
210 .table_name(&self.snapshot_table)
211 .item("AggregateTypeAndId", aggregate_type_and_id)
212 .item("AggregateType", aggregate_type)
213 .item("AggregateId", aggregate_id)
214 .item("CurrentSequence", current_sequence)
215 .item("CurrentSnapshot", current_snapshot)
216 .item("Payload", payload)
217 .condition_expression("attribute_not_exists(CurrentSnapshot) OR (CurrentSnapshot = :current_snapshot)")
218 .expression_attribute_values(":current_snapshot", expected_snapshot)
219 .build()?)
220 .build());
221 commit_transactions(&self.client, transactions).await?;
222 Ok(())
223 }
224
225 async fn query_table(
226 &self,
227 aggregate_type: &str,
228 aggregate_id: &str,
229 table: &str,
230 ) -> Result<QueryOutput, DynamoAggregateError> {
231 let query = self.create_query(table, aggregate_type, aggregate_id).await;
232 Ok(query.send().await?)
233 }
234
235 async fn create_query(
236 &self,
237 table: &str,
238 aggregate_type: &str,
239 aggregate_id: &str,
240 ) -> QueryFluentBuilder {
241 self.client
242 .query()
243 .table_name(table)
244 .consistent_read(true)
245 .key_condition_expression("#agg_type_id = :agg_type_id")
246 .expression_attribute_names("#agg_type_id", "AggregateTypeAndId")
247 .expression_attribute_values(
248 ":agg_type_id",
249 AttributeValue::S(format!("{}:{}", aggregate_type, aggregate_id)),
250 )
251 }
252}
253
254fn serialized_event(
255 entry: HashMap<String, AttributeValue>,
256) -> Result<SerializedEvent, DynamoAggregateError> {
257 let aggregate_id = att_as_string(&entry, "AggregateId")?;
258 let sequence = att_as_number(&entry, "AggregateIdSequence")?;
259 let aggregate_type = att_as_string(&entry, "AggregateType")?;
260 let event_type = att_as_string(&entry, "EventType")?;
261 let event_version = att_as_string(&entry, "EventVersion")?;
262 let payload = att_as_value(&entry, "Payload")?;
263 let metadata = att_as_value(&entry, "Metadata")?;
264 Ok(SerializedEvent {
265 aggregate_id,
266 sequence,
267 aggregate_type,
268 event_type,
269 event_version,
270 payload,
271 metadata,
272 })
273}
274
275#[async_trait]
276impl PersistedEventRepository for DynamoEventRepository {
277 async fn get_events<A: Aggregate>(
278 &self,
279 aggregate_id: &str,
280 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
281 let request = self
282 .query_events(&A::aggregate_type(), aggregate_id)
283 .await?;
284 Ok(request)
285 }
286
287 async fn get_last_events<A: Aggregate>(
288 &self,
289 aggregate_id: &str,
290 number_events: usize,
291 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
292 Ok(self
293 .query_events_from(&A::aggregate_type(), aggregate_id, number_events)
294 .await?)
295 }
296
297 async fn get_snapshot<A: Aggregate>(
298 &self,
299 aggregate_id: &str,
300 ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
301 let query_output = self
302 .query_table(&A::aggregate_type(), aggregate_id, &self.snapshot_table)
303 .await?;
304 let query_items_vec = match query_output.items {
305 None => return Ok(None),
306 Some(items) => items,
307 };
308 if query_items_vec.is_empty() {
309 return Ok(None);
310 }
311 let query_item = query_items_vec.first().unwrap();
312 let aggregate = att_as_value(query_item, "Payload")?;
313 let current_sequence = att_as_number(query_item, "CurrentSequence")?;
314 let current_snapshot = att_as_number(query_item, "CurrentSnapshot")?;
315
316 Ok(Some(SerializedSnapshot {
317 aggregate_id: aggregate_id.to_string(),
318 aggregate,
319 current_sequence,
320 current_snapshot,
321 }))
322 }
323
324 async fn persist<A: Aggregate>(
325 &self,
326 events: &[SerializedEvent],
327 snapshot_update: Option<(String, Value, usize)>,
328 ) -> Result<(), PersistenceError> {
329 match snapshot_update {
330 None => {
331 self.insert_events(events).await?;
332 }
333 Some((aggregate_id, aggregate, current_snapshot)) => {
334 self.update_snapshot::<A>(aggregate, aggregate_id, current_snapshot, events)
335 .await?;
336 }
337 }
338 Ok(())
339 }
340
341 async fn stream_events<A: Aggregate>(
342 &self,
343 aggregate_id: &str,
344 ) -> Result<ReplayStream, PersistenceError> {
345 let query = self
346 .create_query(&self.event_table, &A::aggregate_type(), aggregate_id)
347 .await
348 .limit(self.stream_channel_size as i32);
349 Ok(stream_events(query, self.stream_channel_size))
350 }
351
352 async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
353 let scan = self
354 .client
355 .scan()
356 .table_name(&self.event_table)
357 .limit(self.stream_channel_size as i32);
358 Ok(stream_all_events(scan, self.stream_channel_size))
359 }
360}
361
362fn stream_events(base_query: QueryFluentBuilder, channel_size: usize) -> ReplayStream {
364 let (mut feed, stream) = ReplayStream::new(channel_size);
365 tokio::spawn(async move {
366 let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
367 loop {
368 let query = match &last_evaluated_key {
369 None => base_query.clone(),
370 Some(last) => {
371 let mut query = base_query.clone();
372 for (key, value) in last {
373 query = query.exclusive_start_key(key.to_string(), value.to_owned());
374 }
375 query
376 }
377 };
378 match query.send().await {
379 Ok(query_output) => {
380 last_evaluated_key = query_output.last_evaluated_key;
381 if let Some(entries) = query_output.items {
382 for entry in entries {
383 let event = match serialized_event(entry) {
384 Ok(event) => event,
385 Err(_) => return,
386 };
387 if feed.push(Ok(event)).await.is_err() {
388 return;
390 };
391 }
392 };
393 }
394 Err(err) => {
395 let err: DynamoAggregateError = err.into();
396 if feed.push(Err(err.into())).await.is_err() {};
397 }
398 }
399 if last_evaluated_key.is_none() {
400 return;
401 }
402 }
403 });
404 stream
405}
406fn stream_all_events(base_query: ScanFluentBuilder, channel_size: usize) -> ReplayStream {
407 let (mut feed, stream) = ReplayStream::new(channel_size);
408 tokio::spawn(async move {
409 let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
410 loop {
411 let query = match &last_evaluated_key {
412 None => base_query.clone(),
413 Some(last) => {
414 let mut query = base_query.clone();
415 for (key, value) in last {
416 query = query.exclusive_start_key(key.to_string(), value.to_owned());
417 }
418 query
419 }
420 };
421 match query.send().await {
422 Ok(query_output) => {
423 last_evaluated_key = query_output.last_evaluated_key;
424 if let Some(entries) = query_output.items {
425 for entry in entries {
426 let event = match serialized_event(entry) {
427 Ok(event) => event,
428 Err(_) => return,
429 };
430 if feed.push(Ok(event)).await.is_err() {
431 return;
433 };
434 }
435 };
436 }
437 Err(err) => {
438 let err: DynamoAggregateError = err.into();
439 if feed.push(Err(err.into())).await.is_err() {};
440 }
441 }
442 if last_evaluated_key.is_none() {
443 return;
444 }
445 }
446 });
447 stream
448}
449
450#[cfg(test)]
451mod test {
452 use cqrs_es::persist::PersistedEventRepository;
453
454 use crate::error::DynamoAggregateError;
455 use crate::testing::tests::{
456 snapshot_context, test_dynamodb_client, test_event_envelope, Created, SomethingElse,
457 TestAggregate, TestEvent, Tested,
458 };
459 use crate::DynamoEventRepository;
460
461 #[tokio::test]
462 async fn event_repositories() {
463 let client = test_dynamodb_client().await;
464 let id = uuid::Uuid::new_v4().to_string();
465 let event_repo = DynamoEventRepository::new(client.clone()).with_streaming_channel_size(1);
466 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
467 assert!(events.is_empty());
468
469 event_repo
470 .insert_events(&[
471 test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
472 test_event_envelope(
473 &id,
474 2,
475 TestEvent::Tested(Tested {
476 test_name: "a test was run".to_string(),
477 }),
478 ),
479 ])
480 .await
481 .unwrap();
482 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
483 assert_eq!(2, events.len());
484 events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
485
486 let result = event_repo
488 .insert_events(&[
489 test_event_envelope(
490 &id,
491 3,
492 TestEvent::SomethingElse(SomethingElse {
493 description: "this should not persist".to_string(),
494 }),
495 ),
496 test_event_envelope(
497 &id,
498 2,
499 TestEvent::SomethingElse(SomethingElse {
500 description: "bad sequence number".to_string(),
501 }),
502 ),
503 ])
504 .await
505 .unwrap_err();
506 match result {
507 DynamoAggregateError::OptimisticLock => {}
508 _ => panic!("invalid error result found during insert: {}", result),
509 };
510
511 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
512 assert_eq!(2, events.len());
513
514 let events = event_repo
515 .get_last_events::<TestAggregate>(&id, 1)
516 .await
517 .unwrap();
518 assert_eq!(1, events.len());
519
520 verify_replay_stream(&id, event_repo).await;
521 }
522
523 async fn verify_replay_stream(id: &str, event_repo: DynamoEventRepository) {
524 let mut stream = event_repo
525 .stream_events::<TestAggregate>(&id)
526 .await
527 .unwrap();
528 let mut found_in_stream = 0;
529 while let Some(_) = stream.next::<TestAggregate>(&None).await {
530 found_in_stream += 1;
531 }
532 assert_eq!(found_in_stream, 2);
533
534 let mut stream = event_repo
535 .stream_all_events::<TestAggregate>()
536 .await
537 .unwrap();
538 let mut found_in_stream = 0;
539 while let Some(_) = stream.next::<TestAggregate>(&None).await {
540 found_in_stream += 1;
541 }
542 assert!(found_in_stream >= 2);
543 }
544
545 #[tokio::test]
546 async fn snapshot_repositories() {
547 let client = test_dynamodb_client().await;
548 let id = uuid::Uuid::new_v4().to_string();
549 let repo = DynamoEventRepository::new(client.clone());
550 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
551 assert_eq!(None, snapshot);
552
553 let test_description = "some test snapshot here".to_string();
554 let test_tests = vec!["testA".to_string(), "testB".to_string()];
555 repo.update_snapshot::<TestAggregate>(
556 serde_json::to_value(TestAggregate {
557 id: id.clone(),
558 description: test_description.clone(),
559 tests: test_tests.clone(),
560 })
561 .unwrap(),
562 id.clone(),
563 1,
564 &vec![],
565 )
566 .await
567 .unwrap();
568
569 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
570 assert_eq!(
571 Some(snapshot_context(
572 id.clone(),
573 0,
574 1,
575 serde_json::to_value(TestAggregate {
576 id: id.clone(),
577 description: test_description.clone(),
578 tests: test_tests.clone(),
579 })
580 .unwrap(),
581 )),
582 snapshot
583 );
584
585 repo.update_snapshot::<TestAggregate>(
587 serde_json::to_value(TestAggregate {
588 id: id.clone(),
589 description: "a test description that should be saved".to_string(),
590 tests: test_tests.clone(),
591 })
592 .unwrap(),
593 id.clone(),
594 2,
595 &vec![],
596 )
597 .await
598 .unwrap();
599
600 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
601 assert_eq!(
602 Some(snapshot_context(
603 id.clone(),
604 0,
605 2,
606 serde_json::to_value(TestAggregate {
607 id: id.clone(),
608 description: "a test description that should be saved".to_string(),
609 tests: test_tests.clone(),
610 })
611 .unwrap(),
612 )),
613 snapshot
614 );
615
616 let result = repo
618 .update_snapshot::<TestAggregate>(
619 serde_json::to_value(TestAggregate {
620 id: id.clone(),
621 description: "a test description that should not be saved".to_string(),
622 tests: test_tests.clone(),
623 })
624 .unwrap(),
625 id.clone(),
626 2,
627 &vec![],
628 )
629 .await
630 .unwrap_err();
631 match result {
632 DynamoAggregateError::OptimisticLock => {}
633 _ => panic!("invalid error result found during insert: {}", result),
634 };
635
636 let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
637 assert_eq!(
638 Some(snapshot_context(
639 id.clone(),
640 0,
641 2,
642 serde_json::to_value(TestAggregate {
643 id: id.clone(),
644 description: "a test description that should be saved".to_string(),
645 tests: test_tests.clone(),
646 })
647 .unwrap(),
648 )),
649 snapshot
650 );
651 }
652}