dynamo_es/
event_repository.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use aws_sdk_dynamodb::operation::query::builders::QueryFluentBuilder;
5use aws_sdk_dynamodb::operation::query::QueryOutput;
6use aws_sdk_dynamodb::operation::scan::builders::ScanFluentBuilder;
7use aws_sdk_dynamodb::primitives::Blob;
8use aws_sdk_dynamodb::types::{AttributeValue, Put, TransactWriteItem};
9use aws_sdk_dynamodb::Client;
10use cqrs_es::persist::{
11    PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
12};
13use cqrs_es::Aggregate;
14use serde_json::Value;
15
16use crate::error::DynamoAggregateError;
17use crate::helpers::{att_as_number, att_as_string, att_as_value, commit_transactions};
18
19const DEFAULT_EVENT_TABLE: &str = "Events";
20const DEFAULT_SNAPSHOT_TABLE: &str = "Snapshots";
21
22const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
23
24/// An event repository relying on DynamoDb for persistence.
25pub struct DynamoEventRepository {
26    client: Client,
27    event_table: String,
28    snapshot_table: String,
29    stream_channel_size: usize,
30}
31
32impl DynamoEventRepository {
33    /// Creates a new `DynamoEventRepository` from the provided dynamo client using default
34    /// table names.
35    ///
36    /// ```
37    /// use aws_sdk_dynamodb::Client;
38    /// use dynamo_es::DynamoEventRepository;
39    ///
40    /// fn configure_repo(client: Client) -> DynamoEventRepository {
41    ///     DynamoEventRepository::new(client)
42    /// }
43    /// ```
44    pub fn new(client: Client) -> Self {
45        Self::use_table_names(client, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
46    }
47    /// Configures a `DynamoEventRepository` to use a streaming queue of the provided size.
48    ///
49    /// _Example: configure the repository to stream with a 1000 event buffer._
50    /// ```
51    /// use aws_sdk_dynamodb::Client;
52    /// use dynamo_es::DynamoEventRepository;
53    ///
54    /// fn configure_repo(client: Client) -> DynamoEventRepository {
55    ///     let store = DynamoEventRepository::new(client);
56    ///     store.with_streaming_channel_size(1000)
57    /// }
58    /// ```
59    pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
60        Self {
61            client: self.client,
62            event_table: self.event_table,
63            snapshot_table: self.snapshot_table,
64            stream_channel_size,
65        }
66    }
67    /// Configures a `DynamoEventRepository` to use the provided table names.
68    ///
69    /// _Example: configure the repository to use "my_event_table" and "my_snapshot_table"
70    /// for the event and snapshot table names._
71    /// ```
72    /// use aws_sdk_dynamodb::Client;
73    /// use dynamo_es::DynamoEventRepository;
74    ///
75    /// fn configure_repo(client: Client) -> DynamoEventRepository {
76    ///     let store = DynamoEventRepository::new(client);
77    ///     store.with_tables("my_event_table", "my_snapshot_table")
78    /// }
79    /// ```
80    pub fn with_tables(self, event_table: &str, snapshot_table: &str) -> Self {
81        Self::use_table_names(self.client, event_table, snapshot_table)
82    }
83
84    fn use_table_names(client: Client, event_table: &str, snapshot_table: &str) -> Self {
85        Self {
86            client,
87            event_table: event_table.to_string(),
88            snapshot_table: snapshot_table.to_string(),
89            stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
90        }
91    }
92
93    pub(crate) async fn insert_events(
94        &self,
95        events: &[SerializedEvent],
96    ) -> Result<(), DynamoAggregateError> {
97        if events.is_empty() {
98            return Ok(());
99        }
100        let (transactions, _) = Self::build_event_put_transactions(&self.event_table, events);
101        commit_transactions(&self.client, transactions).await?;
102        Ok(())
103    }
104
105    fn build_event_put_transactions(
106        table_name: &str,
107        events: &[SerializedEvent],
108    ) -> (Vec<TransactWriteItem>, usize) {
109        let mut current_sequence: usize = 0;
110        let mut transactions: Vec<TransactWriteItem> = Vec::default();
111        for event in events {
112            current_sequence = event.sequence;
113            let aggregate_type_and_id =
114                AttributeValue::S(format!("{}:{}", &event.aggregate_type, &event.aggregate_id));
115            let aggregate_type = AttributeValue::S(String::from(&event.aggregate_type));
116            let aggregate_id = AttributeValue::S(String::from(&event.aggregate_id));
117            let sequence = AttributeValue::N(String::from(&event.sequence.to_string()));
118            let event_version = AttributeValue::S(String::from(&event.event_version));
119            let event_type = AttributeValue::S(String::from(&event.event_type));
120            let payload_blob = serde_json::to_vec(&event.payload).unwrap();
121            let payload = AttributeValue::B(Blob::new(payload_blob));
122            let metadata_blob = serde_json::to_vec(&event.metadata).unwrap();
123            let metadata = AttributeValue::B(Blob::new(metadata_blob));
124
125            let put = Put::builder()
126                .table_name(table_name)
127                .item("AggregateTypeAndId", aggregate_type_and_id)
128                .item("AggregateIdSequence", sequence)
129                .item("AggregateType", aggregate_type)
130                .item("AggregateId", aggregate_id)
131                .item("EventVersion", event_version)
132                .item("EventType", event_type)
133                .item("Payload", payload)
134                .item("Metadata", metadata)
135                .condition_expression("attribute_not_exists( AggregateIdSequence )")
136                .build()
137                .unwrap();
138            let write_item = TransactWriteItem::builder().put(put).build();
139            transactions.push(write_item);
140        }
141        (transactions, current_sequence)
142    }
143
144    async fn query_events(
145        &self,
146        aggregate_type: &str,
147        aggregate_id: &str,
148    ) -> Result<Vec<SerializedEvent>, DynamoAggregateError> {
149        let query_output = self
150            .query_table(aggregate_type, aggregate_id, &self.event_table)
151            .await?;
152        let mut result: Vec<SerializedEvent> = Default::default();
153        if let Some(entries) = query_output.items {
154            for entry in entries {
155                result.push(serialized_event(entry)?);
156            }
157        }
158        Ok(result)
159    }
160    async fn query_events_from(
161        &self,
162        aggregate_type: &str,
163        aggregate_id: &str,
164        last_sequence: usize,
165    ) -> Result<Vec<SerializedEvent>, DynamoAggregateError> {
166        let query_output = self
167            .client
168            .query()
169            .table_name(&self.event_table)
170            .key_condition_expression("#agg_type_id = :agg_type_id AND #sequence > :sequence")
171            .expression_attribute_names("#agg_type_id", "AggregateTypeAndId")
172            .expression_attribute_names("#sequence", "AggregateIdSequence")
173            .expression_attribute_values(
174                ":agg_type_id",
175                AttributeValue::S(format!("{}:{}", aggregate_type, aggregate_id)),
176            )
177            .expression_attribute_values(":sequence", AttributeValue::N(last_sequence.to_string()))
178            .send()
179            .await?;
180        let mut result: Vec<SerializedEvent> = Default::default();
181        if let Some(entries) = query_output.items {
182            for entry in entries {
183                result.push(serialized_event(entry)?);
184            }
185        }
186        Ok(result)
187    }
188
189    pub(crate) async fn update_snapshot<A: Aggregate>(
190        &self,
191        aggregate_payload: Value,
192        aggregate_id: String,
193        current_snapshot: usize,
194        events: &[SerializedEvent],
195    ) -> Result<(), DynamoAggregateError> {
196        let expected_snapshot = current_snapshot - 1;
197        let (mut transactions, current_sequence) =
198            Self::build_event_put_transactions(&self.event_table, events);
199        let aggregate_type_and_id =
200            AttributeValue::S(format!("{}:{}", A::aggregate_type(), &aggregate_id));
201        let aggregate_type = AttributeValue::S(A::aggregate_type());
202        let aggregate_id = AttributeValue::S(aggregate_id);
203        let current_sequence = AttributeValue::N(current_sequence.to_string());
204        let current_snapshot = AttributeValue::N(current_snapshot.to_string());
205        let payload_blob = serde_json::to_vec(&aggregate_payload).unwrap();
206        let payload = AttributeValue::B(Blob::new(payload_blob));
207        let expected_snapshot = AttributeValue::N(expected_snapshot.to_string());
208        transactions.push(TransactWriteItem::builder()
209            .put(Put::builder()
210                .table_name(&self.snapshot_table)
211                .item("AggregateTypeAndId", aggregate_type_and_id)
212                .item("AggregateType", aggregate_type)
213                .item("AggregateId", aggregate_id)
214                .item("CurrentSequence", current_sequence)
215                .item("CurrentSnapshot", current_snapshot)
216                .item("Payload", payload)
217                .condition_expression("attribute_not_exists(CurrentSnapshot) OR (CurrentSnapshot  = :current_snapshot)")
218                .expression_attribute_values(":current_snapshot", expected_snapshot)
219                .build()?)
220            .build());
221        commit_transactions(&self.client, transactions).await?;
222        Ok(())
223    }
224
225    async fn query_table(
226        &self,
227        aggregate_type: &str,
228        aggregate_id: &str,
229        table: &str,
230    ) -> Result<QueryOutput, DynamoAggregateError> {
231        let query = self.create_query(table, aggregate_type, aggregate_id).await;
232        Ok(query.send().await?)
233    }
234
235    async fn create_query(
236        &self,
237        table: &str,
238        aggregate_type: &str,
239        aggregate_id: &str,
240    ) -> QueryFluentBuilder {
241        self.client
242            .query()
243            .table_name(table)
244            .consistent_read(true)
245            .key_condition_expression("#agg_type_id = :agg_type_id")
246            .expression_attribute_names("#agg_type_id", "AggregateTypeAndId")
247            .expression_attribute_values(
248                ":agg_type_id",
249                AttributeValue::S(format!("{}:{}", aggregate_type, aggregate_id)),
250            )
251    }
252}
253
254fn serialized_event(
255    entry: HashMap<String, AttributeValue>,
256) -> Result<SerializedEvent, DynamoAggregateError> {
257    let aggregate_id = att_as_string(&entry, "AggregateId")?;
258    let sequence = att_as_number(&entry, "AggregateIdSequence")?;
259    let aggregate_type = att_as_string(&entry, "AggregateType")?;
260    let event_type = att_as_string(&entry, "EventType")?;
261    let event_version = att_as_string(&entry, "EventVersion")?;
262    let payload = att_as_value(&entry, "Payload")?;
263    let metadata = att_as_value(&entry, "Metadata")?;
264    Ok(SerializedEvent {
265        aggregate_id,
266        sequence,
267        aggregate_type,
268        event_type,
269        event_version,
270        payload,
271        metadata,
272    })
273}
274
275#[async_trait]
276impl PersistedEventRepository for DynamoEventRepository {
277    async fn get_events<A: Aggregate>(
278        &self,
279        aggregate_id: &str,
280    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
281        let request = self
282            .query_events(&A::aggregate_type(), aggregate_id)
283            .await?;
284        Ok(request)
285    }
286
287    async fn get_last_events<A: Aggregate>(
288        &self,
289        aggregate_id: &str,
290        number_events: usize,
291    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
292        Ok(self
293            .query_events_from(&A::aggregate_type(), aggregate_id, number_events)
294            .await?)
295    }
296
297    async fn get_snapshot<A: Aggregate>(
298        &self,
299        aggregate_id: &str,
300    ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
301        let query_output = self
302            .query_table(&A::aggregate_type(), aggregate_id, &self.snapshot_table)
303            .await?;
304        let query_items_vec = match query_output.items {
305            None => return Ok(None),
306            Some(items) => items,
307        };
308        if query_items_vec.is_empty() {
309            return Ok(None);
310        }
311        let query_item = query_items_vec.first().unwrap();
312        let aggregate = att_as_value(query_item, "Payload")?;
313        let current_sequence = att_as_number(query_item, "CurrentSequence")?;
314        let current_snapshot = att_as_number(query_item, "CurrentSnapshot")?;
315
316        Ok(Some(SerializedSnapshot {
317            aggregate_id: aggregate_id.to_string(),
318            aggregate,
319            current_sequence,
320            current_snapshot,
321        }))
322    }
323
324    async fn persist<A: Aggregate>(
325        &self,
326        events: &[SerializedEvent],
327        snapshot_update: Option<(String, Value, usize)>,
328    ) -> Result<(), PersistenceError> {
329        match snapshot_update {
330            None => {
331                self.insert_events(events).await?;
332            }
333            Some((aggregate_id, aggregate, current_snapshot)) => {
334                self.update_snapshot::<A>(aggregate, aggregate_id, current_snapshot, events)
335                    .await?;
336            }
337        }
338        Ok(())
339    }
340
341    async fn stream_events<A: Aggregate>(
342        &self,
343        aggregate_id: &str,
344    ) -> Result<ReplayStream, PersistenceError> {
345        let query = self
346            .create_query(&self.event_table, &A::aggregate_type(), aggregate_id)
347            .await
348            .limit(self.stream_channel_size as i32);
349        Ok(stream_events(query, self.stream_channel_size))
350    }
351
352    async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
353        let scan = self
354            .client
355            .scan()
356            .table_name(&self.event_table)
357            .limit(self.stream_channel_size as i32);
358        Ok(stream_all_events(scan, self.stream_channel_size))
359    }
360}
361
362// TODO: combine these two methods
363fn stream_events(base_query: QueryFluentBuilder, channel_size: usize) -> ReplayStream {
364    let (mut feed, stream) = ReplayStream::new(channel_size);
365    tokio::spawn(async move {
366        let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
367        loop {
368            let query = match &last_evaluated_key {
369                None => base_query.clone(),
370                Some(last) => {
371                    let mut query = base_query.clone();
372                    for (key, value) in last {
373                        query = query.exclusive_start_key(key.to_string(), value.to_owned());
374                    }
375                    query
376                }
377            };
378            match query.send().await {
379                Ok(query_output) => {
380                    last_evaluated_key = query_output.last_evaluated_key;
381                    if let Some(entries) = query_output.items {
382                        for entry in entries {
383                            let event = match serialized_event(entry) {
384                                Ok(event) => event,
385                                Err(_) => return,
386                            };
387                            if feed.push(Ok(event)).await.is_err() {
388                                //         TODO: in the unlikely event of a broken channel this error should be reported.
389                                return;
390                            };
391                        }
392                    };
393                }
394                Err(err) => {
395                    let err: DynamoAggregateError = err.into();
396                    if feed.push(Err(err.into())).await.is_err() {};
397                }
398            }
399            if last_evaluated_key.is_none() {
400                return;
401            }
402        }
403    });
404    stream
405}
406fn stream_all_events(base_query: ScanFluentBuilder, channel_size: usize) -> ReplayStream {
407    let (mut feed, stream) = ReplayStream::new(channel_size);
408    tokio::spawn(async move {
409        let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
410        loop {
411            let query = match &last_evaluated_key {
412                None => base_query.clone(),
413                Some(last) => {
414                    let mut query = base_query.clone();
415                    for (key, value) in last {
416                        query = query.exclusive_start_key(key.to_string(), value.to_owned());
417                    }
418                    query
419                }
420            };
421            match query.send().await {
422                Ok(query_output) => {
423                    last_evaluated_key = query_output.last_evaluated_key;
424                    if let Some(entries) = query_output.items {
425                        for entry in entries {
426                            let event = match serialized_event(entry) {
427                                Ok(event) => event,
428                                Err(_) => return,
429                            };
430                            if feed.push(Ok(event)).await.is_err() {
431                                //         TODO: in the unlikely event of a broken channel this error should be reported.
432                                return;
433                            };
434                        }
435                    };
436                }
437                Err(err) => {
438                    let err: DynamoAggregateError = err.into();
439                    if feed.push(Err(err.into())).await.is_err() {};
440                }
441            }
442            if last_evaluated_key.is_none() {
443                return;
444            }
445        }
446    });
447    stream
448}
449
450#[cfg(test)]
451mod test {
452    use cqrs_es::persist::PersistedEventRepository;
453
454    use crate::error::DynamoAggregateError;
455    use crate::testing::tests::{
456        snapshot_context, test_dynamodb_client, test_event_envelope, Created, SomethingElse,
457        TestAggregate, TestEvent, Tested,
458    };
459    use crate::DynamoEventRepository;
460
461    #[tokio::test]
462    async fn event_repositories() {
463        let client = test_dynamodb_client().await;
464        let id = uuid::Uuid::new_v4().to_string();
465        let event_repo = DynamoEventRepository::new(client.clone()).with_streaming_channel_size(1);
466        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
467        assert!(events.is_empty());
468
469        event_repo
470            .insert_events(&[
471                test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
472                test_event_envelope(
473                    &id,
474                    2,
475                    TestEvent::Tested(Tested {
476                        test_name: "a test was run".to_string(),
477                    }),
478                ),
479            ])
480            .await
481            .unwrap();
482        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
483        assert_eq!(2, events.len());
484        events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
485
486        // Optimistic lock error
487        let result = event_repo
488            .insert_events(&[
489                test_event_envelope(
490                    &id,
491                    3,
492                    TestEvent::SomethingElse(SomethingElse {
493                        description: "this should not persist".to_string(),
494                    }),
495                ),
496                test_event_envelope(
497                    &id,
498                    2,
499                    TestEvent::SomethingElse(SomethingElse {
500                        description: "bad sequence number".to_string(),
501                    }),
502                ),
503            ])
504            .await
505            .unwrap_err();
506        match result {
507            DynamoAggregateError::OptimisticLock => {}
508            _ => panic!("invalid error result found during insert: {}", result),
509        };
510
511        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
512        assert_eq!(2, events.len());
513
514        let events = event_repo
515            .get_last_events::<TestAggregate>(&id, 1)
516            .await
517            .unwrap();
518        assert_eq!(1, events.len());
519
520        verify_replay_stream(&id, event_repo).await;
521    }
522
523    async fn verify_replay_stream(id: &str, event_repo: DynamoEventRepository) {
524        let mut stream = event_repo
525            .stream_events::<TestAggregate>(&id)
526            .await
527            .unwrap();
528        let mut found_in_stream = 0;
529        while let Some(_) = stream.next::<TestAggregate>(&None).await {
530            found_in_stream += 1;
531        }
532        assert_eq!(found_in_stream, 2);
533
534        let mut stream = event_repo
535            .stream_all_events::<TestAggregate>()
536            .await
537            .unwrap();
538        let mut found_in_stream = 0;
539        while let Some(_) = stream.next::<TestAggregate>(&None).await {
540            found_in_stream += 1;
541        }
542        assert!(found_in_stream >= 2);
543    }
544
545    #[tokio::test]
546    async fn snapshot_repositories() {
547        let client = test_dynamodb_client().await;
548        let id = uuid::Uuid::new_v4().to_string();
549        let repo = DynamoEventRepository::new(client.clone());
550        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
551        assert_eq!(None, snapshot);
552
553        let test_description = "some test snapshot here".to_string();
554        let test_tests = vec!["testA".to_string(), "testB".to_string()];
555        repo.update_snapshot::<TestAggregate>(
556            serde_json::to_value(TestAggregate {
557                id: id.clone(),
558                description: test_description.clone(),
559                tests: test_tests.clone(),
560            })
561            .unwrap(),
562            id.clone(),
563            1,
564            &vec![],
565        )
566        .await
567        .unwrap();
568
569        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
570        assert_eq!(
571            Some(snapshot_context(
572                id.clone(),
573                0,
574                1,
575                serde_json::to_value(TestAggregate {
576                    id: id.clone(),
577                    description: test_description.clone(),
578                    tests: test_tests.clone(),
579                })
580                .unwrap(),
581            )),
582            snapshot
583        );
584
585        // sequence iterated, does update
586        repo.update_snapshot::<TestAggregate>(
587            serde_json::to_value(TestAggregate {
588                id: id.clone(),
589                description: "a test description that should be saved".to_string(),
590                tests: test_tests.clone(),
591            })
592            .unwrap(),
593            id.clone(),
594            2,
595            &vec![],
596        )
597        .await
598        .unwrap();
599
600        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
601        assert_eq!(
602            Some(snapshot_context(
603                id.clone(),
604                0,
605                2,
606                serde_json::to_value(TestAggregate {
607                    id: id.clone(),
608                    description: "a test description that should be saved".to_string(),
609                    tests: test_tests.clone(),
610                })
611                .unwrap(),
612            )),
613            snapshot
614        );
615
616        // sequence out of order or not iterated, does not update
617        let result = repo
618            .update_snapshot::<TestAggregate>(
619                serde_json::to_value(TestAggregate {
620                    id: id.clone(),
621                    description: "a test description that should not be saved".to_string(),
622                    tests: test_tests.clone(),
623                })
624                .unwrap(),
625                id.clone(),
626                2,
627                &vec![],
628            )
629            .await
630            .unwrap_err();
631        match result {
632            DynamoAggregateError::OptimisticLock => {}
633            _ => panic!("invalid error result found during insert: {}", result),
634        };
635
636        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
637        assert_eq!(
638            Some(snapshot_context(
639                id.clone(),
640                0,
641                2,
642                serde_json::to_value(TestAggregate {
643                    id: id.clone(),
644                    description: "a test description that should be saved".to_string(),
645                    tests: test_tests.clone(),
646                })
647                .unwrap(),
648            )),
649            snapshot
650        );
651    }
652}