Skip to main content

dynamo_es/
event_repository.rs

1use std::collections::HashMap;
2
3use aws_sdk_dynamodb::operation::query::builders::QueryFluentBuilder;
4use aws_sdk_dynamodb::operation::query::QueryOutput;
5use aws_sdk_dynamodb::operation::scan::builders::ScanFluentBuilder;
6use aws_sdk_dynamodb::primitives::Blob;
7use aws_sdk_dynamodb::types::{AttributeValue, Put, TransactWriteItem};
8use aws_sdk_dynamodb::Client;
9use cqrs_es::persist::{
10    PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
11};
12use cqrs_es::Aggregate;
13use serde_json::Value;
14
15use crate::error::DynamoAggregateError;
16use crate::helpers::{att_as_number, att_as_string, att_as_value, commit_transactions};
17
18const DEFAULT_EVENT_TABLE: &str = "Events";
19const DEFAULT_SNAPSHOT_TABLE: &str = "Snapshots";
20
21const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
22
23/// An event repository relying on DynamoDb for persistence.
24pub struct DynamoEventRepository {
25    client: Client,
26    event_table: String,
27    snapshot_table: String,
28    stream_channel_size: usize,
29}
30
31impl DynamoEventRepository {
32    /// Creates a new `DynamoEventRepository` from the provided dynamo client using default
33    /// table names.
34    ///
35    /// ```
36    /// use aws_sdk_dynamodb::Client;
37    /// use dynamo_es::DynamoEventRepository;
38    ///
39    /// fn configure_repo(client: Client) -> DynamoEventRepository {
40    ///     DynamoEventRepository::new(client)
41    /// }
42    /// ```
43    pub fn new(client: Client) -> Self {
44        Self::use_table_names(client, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
45    }
46    /// Configures a `DynamoEventRepository` to use a streaming queue of the provided size.
47    ///
48    /// _Example: configure the repository to stream with a 1000 event buffer._
49    /// ```
50    /// use aws_sdk_dynamodb::Client;
51    /// use dynamo_es::DynamoEventRepository;
52    ///
53    /// fn configure_repo(client: Client) -> DynamoEventRepository {
54    ///     let store = DynamoEventRepository::new(client);
55    ///     store.with_streaming_channel_size(1000)
56    /// }
57    /// ```
58    pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
59        Self {
60            client: self.client,
61            event_table: self.event_table,
62            snapshot_table: self.snapshot_table,
63            stream_channel_size,
64        }
65    }
66    /// Configures a `DynamoEventRepository` to use the provided table names.
67    ///
68    /// _Example: configure the repository to use "my_event_table" and "my_snapshot_table"
69    /// for the event and snapshot table names._
70    /// ```
71    /// use aws_sdk_dynamodb::Client;
72    /// use dynamo_es::DynamoEventRepository;
73    ///
74    /// fn configure_repo(client: Client) -> DynamoEventRepository {
75    ///     let store = DynamoEventRepository::new(client);
76    ///     store.with_tables("my_event_table", "my_snapshot_table")
77    /// }
78    /// ```
79    pub fn with_tables(self, event_table: &str, snapshot_table: &str) -> Self {
80        Self::use_table_names(self.client, event_table, snapshot_table)
81    }
82
83    fn use_table_names(client: Client, event_table: &str, snapshot_table: &str) -> Self {
84        Self {
85            client,
86            event_table: event_table.to_string(),
87            snapshot_table: snapshot_table.to_string(),
88            stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
89        }
90    }
91
92    pub(crate) async fn insert_events(
93        &self,
94        events: &[SerializedEvent],
95    ) -> Result<(), DynamoAggregateError> {
96        if events.is_empty() {
97            return Ok(());
98        }
99        let (transactions, _) = Self::build_event_put_transactions(&self.event_table, events);
100        commit_transactions(&self.client, transactions).await?;
101        Ok(())
102    }
103
104    fn build_event_put_transactions(
105        table_name: &str,
106        events: &[SerializedEvent],
107    ) -> (Vec<TransactWriteItem>, usize) {
108        let mut current_sequence: usize = 0;
109        let mut transactions: Vec<TransactWriteItem> = Vec::default();
110        for event in events {
111            current_sequence = event.sequence;
112            let aggregate_type_and_id =
113                AttributeValue::S(format!("{}:{}", &event.aggregate_type, &event.aggregate_id));
114            let aggregate_type = AttributeValue::S(String::from(&event.aggregate_type));
115            let aggregate_id = AttributeValue::S(String::from(&event.aggregate_id));
116            let sequence = AttributeValue::N(String::from(&event.sequence.to_string()));
117            let event_version = AttributeValue::S(String::from(&event.event_version));
118            let event_type = AttributeValue::S(String::from(&event.event_type));
119            let payload_blob = serde_json::to_vec(&event.payload).unwrap();
120            let payload = AttributeValue::B(Blob::new(payload_blob));
121            let metadata_blob = serde_json::to_vec(&event.metadata).unwrap();
122            let metadata = AttributeValue::B(Blob::new(metadata_blob));
123
124            let put = Put::builder()
125                .table_name(table_name)
126                .item("AggregateTypeAndId", aggregate_type_and_id)
127                .item("AggregateIdSequence", sequence)
128                .item("AggregateType", aggregate_type)
129                .item("AggregateId", aggregate_id)
130                .item("EventVersion", event_version)
131                .item("EventType", event_type)
132                .item("Payload", payload)
133                .item("Metadata", metadata)
134                .condition_expression("attribute_not_exists( AggregateIdSequence )")
135                .build()
136                .unwrap();
137            let write_item = TransactWriteItem::builder().put(put).build();
138            transactions.push(write_item);
139        }
140        (transactions, current_sequence)
141    }
142
143    async fn query_events(
144        &self,
145        aggregate_type: &str,
146        aggregate_id: &str,
147    ) -> Result<Vec<SerializedEvent>, DynamoAggregateError> {
148        let query_output = self
149            .query_table(aggregate_type, aggregate_id, &self.event_table)
150            .await?;
151        let mut result = Vec::default();
152        for entry in query_output.items.into_iter().flatten() {
153            result.push(serialized_event(entry)?);
154        }
155        Ok(result)
156    }
157    async fn query_events_from(
158        &self,
159        aggregate_type: &str,
160        aggregate_id: &str,
161        last_sequence: usize,
162    ) -> Result<Vec<SerializedEvent>, DynamoAggregateError> {
163        let query_output = self
164            .client
165            .query()
166            .table_name(&self.event_table)
167            .key_condition_expression("#agg_type_id = :agg_type_id AND #sequence > :sequence")
168            .expression_attribute_names("#agg_type_id", "AggregateTypeAndId")
169            .expression_attribute_names("#sequence", "AggregateIdSequence")
170            .expression_attribute_values(
171                ":agg_type_id",
172                AttributeValue::S(format!("{aggregate_type}:{aggregate_id}")),
173            )
174            .expression_attribute_values(":sequence", AttributeValue::N(last_sequence.to_string()))
175            .send()
176            .await?;
177        let mut result = Vec::default();
178        for entry in query_output.items.into_iter().flatten() {
179            result.push(serialized_event(entry)?);
180        }
181        Ok(result)
182    }
183
184    pub(crate) async fn update_snapshot<A: Aggregate>(
185        &self,
186        aggregate_payload: Value,
187        aggregate_id: String,
188        current_snapshot: usize,
189        events: &[SerializedEvent],
190    ) -> Result<(), DynamoAggregateError> {
191        let expected_snapshot = current_snapshot - 1;
192        let (mut transactions, current_sequence) =
193            Self::build_event_put_transactions(&self.event_table, events);
194        let aggregate_type_and_id = AttributeValue::S(format!("{}:{}", A::TYPE, &aggregate_id));
195        let aggregate_type = AttributeValue::S(A::TYPE.to_string());
196        let aggregate_id = AttributeValue::S(aggregate_id);
197        let current_sequence = AttributeValue::N(current_sequence.to_string());
198        let current_snapshot = AttributeValue::N(current_snapshot.to_string());
199        let payload_blob = serde_json::to_vec(&aggregate_payload).unwrap();
200        let payload = AttributeValue::B(Blob::new(payload_blob));
201        let expected_snapshot = AttributeValue::N(expected_snapshot.to_string());
202        transactions.push(TransactWriteItem::builder()
203            .put(Put::builder()
204                .table_name(&self.snapshot_table)
205                .item("AggregateTypeAndId", aggregate_type_and_id)
206                .item("AggregateType", aggregate_type)
207                .item("AggregateId", aggregate_id)
208                .item("CurrentSequence", current_sequence)
209                .item("CurrentSnapshot", current_snapshot)
210                .item("Payload", payload)
211                .condition_expression("attribute_not_exists(CurrentSnapshot) OR (CurrentSnapshot  = :current_snapshot)")
212                .expression_attribute_values(":current_snapshot", expected_snapshot)
213                .build()?)
214            .build());
215        commit_transactions(&self.client, transactions).await?;
216        Ok(())
217    }
218
219    async fn query_table(
220        &self,
221        aggregate_type: &str,
222        aggregate_id: &str,
223        table: &str,
224    ) -> Result<QueryOutput, DynamoAggregateError> {
225        let output = self
226            .create_query(table, aggregate_type, aggregate_id)
227            .send()
228            .await?;
229        Ok(output)
230    }
231
232    fn create_query(
233        &self,
234        table: &str,
235        aggregate_type: &str,
236        aggregate_id: &str,
237    ) -> QueryFluentBuilder {
238        self.client
239            .query()
240            .table_name(table)
241            .consistent_read(true)
242            .key_condition_expression("#agg_type_id = :agg_type_id")
243            .expression_attribute_names("#agg_type_id", "AggregateTypeAndId")
244            .expression_attribute_values(
245                ":agg_type_id",
246                AttributeValue::S(format!("{aggregate_type}:{aggregate_id}")),
247            )
248    }
249}
250
251fn serialized_event(
252    entry: HashMap<String, AttributeValue>,
253) -> Result<SerializedEvent, DynamoAggregateError> {
254    let aggregate_id = att_as_string(&entry, "AggregateId")?;
255    let sequence = att_as_number(&entry, "AggregateIdSequence")?;
256    let aggregate_type = att_as_string(&entry, "AggregateType")?;
257    let event_type = att_as_string(&entry, "EventType")?;
258    let event_version = att_as_string(&entry, "EventVersion")?;
259    let payload = att_as_value(&entry, "Payload")?;
260    let metadata = att_as_value(&entry, "Metadata")?;
261    Ok(SerializedEvent {
262        aggregate_id,
263        sequence,
264        aggregate_type,
265        event_type,
266        event_version,
267        payload,
268        metadata,
269    })
270}
271
272impl PersistedEventRepository for DynamoEventRepository {
273    async fn get_events<A: Aggregate>(
274        &self,
275        aggregate_id: &str,
276    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
277        Ok(self.query_events(A::TYPE, aggregate_id).await?)
278    }
279
280    async fn get_last_events<A: Aggregate>(
281        &self,
282        aggregate_id: &str,
283        number_events: usize,
284    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
285        Ok(self
286            .query_events_from(A::TYPE, aggregate_id, number_events)
287            .await?)
288    }
289
290    async fn get_snapshot<A: Aggregate>(
291        &self,
292        aggregate_id: &str,
293    ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
294        let query_output = self
295            .query_table(A::TYPE, aggregate_id, &self.snapshot_table)
296            .await?;
297        let Some(query_items_vec) = query_output.items else {
298            return Ok(None);
299        };
300        if query_items_vec.is_empty() {
301            return Ok(None);
302        }
303        let query_item = query_items_vec.first().unwrap();
304        let aggregate = att_as_value(query_item, "Payload")?;
305        let current_sequence = att_as_number(query_item, "CurrentSequence")?;
306        let current_snapshot = att_as_number(query_item, "CurrentSnapshot")?;
307
308        Ok(Some(SerializedSnapshot {
309            aggregate_id: aggregate_id.to_string(),
310            aggregate,
311            current_sequence,
312            current_snapshot,
313        }))
314    }
315
316    async fn persist<A: Aggregate>(
317        &self,
318        events: &[SerializedEvent],
319        snapshot_update: Option<(String, Value, usize)>,
320    ) -> Result<(), PersistenceError> {
321        match snapshot_update {
322            None => {
323                self.insert_events(events).await?;
324            }
325            Some((aggregate_id, aggregate, current_snapshot)) => {
326                self.update_snapshot::<A>(aggregate, aggregate_id, current_snapshot, events)
327                    .await?;
328            }
329        }
330        Ok(())
331    }
332
333    async fn stream_events<A: Aggregate>(
334        &self,
335        aggregate_id: &str,
336    ) -> Result<ReplayStream, PersistenceError> {
337        let query = self
338            .create_query(&self.event_table, A::TYPE, aggregate_id)
339            .limit(self.stream_channel_size as i32);
340        Ok(stream_events(query, self.stream_channel_size))
341    }
342
343    async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
344        let scan = self
345            .client
346            .scan()
347            .table_name(&self.event_table)
348            .limit(self.stream_channel_size as i32);
349        Ok(stream_all_events(scan, self.stream_channel_size))
350    }
351}
352
353// TODO: combine these two methods
354fn stream_events(base_query: QueryFluentBuilder, channel_size: usize) -> ReplayStream {
355    let (mut feed, stream) = ReplayStream::new(channel_size);
356    tokio::spawn(async move {
357        let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
358        loop {
359            let query = match &last_evaluated_key {
360                None => base_query.clone(),
361                Some(last) => last.iter().fold(base_query.clone(), |query, (key, value)| {
362                    query.exclusive_start_key(key.to_string(), value.to_owned())
363                }),
364            };
365            match query.send().await {
366                Ok(query_output) => {
367                    last_evaluated_key = query_output.last_evaluated_key;
368                    if let Some(entries) = query_output.items {
369                        for entry in entries {
370                            let Ok(event) = serialized_event(entry) else {
371                                return;
372                            };
373                            if feed.push(Ok(event)).await.is_err() {
374                                //         TODO: in the unlikely event of a broken channel this error should be reported.
375                                return;
376                            }
377                        }
378                    }
379                }
380                Err(err) => {
381                    let err: DynamoAggregateError = err.into();
382                    if feed.push(Err(err.into())).await.is_err() {}
383                }
384            }
385            if last_evaluated_key.is_none() {
386                return;
387            }
388        }
389    });
390    stream
391}
392fn stream_all_events(base_query: ScanFluentBuilder, channel_size: usize) -> ReplayStream {
393    let (mut feed, stream) = ReplayStream::new(channel_size);
394    tokio::spawn(async move {
395        let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
396        loop {
397            let query = match &last_evaluated_key {
398                None => base_query.clone(),
399                Some(last) => last.iter().fold(base_query.clone(), |query, (key, value)| {
400                    query.exclusive_start_key(key.to_string(), value.to_owned())
401                }),
402            };
403            match query.send().await {
404                Ok(query_output) => {
405                    last_evaluated_key = query_output.last_evaluated_key;
406                    if let Some(entries) = query_output.items {
407                        for entry in entries {
408                            let Ok(event) = serialized_event(entry) else {
409                                return;
410                            };
411                            if feed.push(Ok(event)).await.is_err() {
412                                //         TODO: in the unlikely event of a broken channel this error should be reported.
413                                return;
414                            }
415                        }
416                    }
417                }
418                Err(err) => {
419                    let err: DynamoAggregateError = err.into();
420                    if feed.push(Err(err.into())).await.is_err() {}
421                }
422            }
423            if last_evaluated_key.is_none() {
424                return;
425            }
426        }
427    });
428    stream
429}
430
431#[cfg(test)]
432mod test {
433    use cqrs_es::persist::PersistedEventRepository;
434
435    use crate::error::DynamoAggregateError;
436    use crate::testing::tests::{
437        snapshot_context, test_dynamodb_client, test_event_envelope, Created, SomethingElse,
438        TestAggregate, TestEvent, Tested,
439    };
440    use crate::DynamoEventRepository;
441
442    #[tokio::test]
443    async fn event_repositories() {
444        let client = test_dynamodb_client().await;
445        let id = uuid::Uuid::new_v4().to_string();
446        let event_repo = DynamoEventRepository::new(client.clone()).with_streaming_channel_size(1);
447        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
448        assert!(events.is_empty());
449
450        event_repo
451            .insert_events(&[
452                test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
453                test_event_envelope(
454                    &id,
455                    2,
456                    TestEvent::Tested(Tested {
457                        test_name: "a test was run".to_string(),
458                    }),
459                ),
460            ])
461            .await
462            .unwrap();
463        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
464        assert_eq!(2, events.len());
465        events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
466
467        // Optimistic lock error
468        let result = event_repo
469            .insert_events(&[
470                test_event_envelope(
471                    &id,
472                    3,
473                    TestEvent::SomethingElse(SomethingElse {
474                        description: "this should not persist".to_string(),
475                    }),
476                ),
477                test_event_envelope(
478                    &id,
479                    2,
480                    TestEvent::SomethingElse(SomethingElse {
481                        description: "bad sequence number".to_string(),
482                    }),
483                ),
484            ])
485            .await
486            .unwrap_err();
487        match result {
488            DynamoAggregateError::OptimisticLock => {}
489            _ => panic!("invalid error result found during insert: {result}"),
490        }
491
492        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
493        assert_eq!(2, events.len());
494
495        let events = event_repo
496            .get_last_events::<TestAggregate>(&id, 1)
497            .await
498            .unwrap();
499        assert_eq!(1, events.len());
500
501        verify_replay_stream(&id, event_repo).await;
502    }
503
504    async fn verify_replay_stream(id: &str, event_repo: DynamoEventRepository) {
505        let mut stream = event_repo.stream_events::<TestAggregate>(id).await.unwrap();
506        let mut found_in_stream = 0;
507        while (stream.next::<TestAggregate>(&[]).await).is_some() {
508            found_in_stream += 1;
509        }
510        assert_eq!(found_in_stream, 2);
511
512        let mut stream = event_repo
513            .stream_all_events::<TestAggregate>()
514            .await
515            .unwrap();
516        let mut found_in_stream = 0;
517        while (stream.next::<TestAggregate>(&[]).await).is_some() {
518            found_in_stream += 1;
519        }
520        assert!(found_in_stream >= 2);
521    }
522
523    #[tokio::test]
524    async fn snapshot_repositories() {
525        let client = test_dynamodb_client().await;
526        let id = uuid::Uuid::new_v4().to_string();
527        let repo = DynamoEventRepository::new(client.clone());
528        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
529        assert_eq!(None, snapshot);
530
531        let test_description = "some test snapshot here".to_string();
532        let test_tests = vec!["testA".to_string(), "testB".to_string()];
533        repo.update_snapshot::<TestAggregate>(
534            serde_json::to_value(TestAggregate {
535                id: id.clone(),
536                description: test_description.clone(),
537                tests: test_tests.clone(),
538            })
539            .unwrap(),
540            id.clone(),
541            1,
542            &[],
543        )
544        .await
545        .unwrap();
546
547        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
548        assert_eq!(
549            Some(snapshot_context(
550                id.clone(),
551                0,
552                1,
553                serde_json::to_value(TestAggregate {
554                    id: id.clone(),
555                    description: test_description.clone(),
556                    tests: test_tests.clone(),
557                })
558                .unwrap(),
559            )),
560            snapshot
561        );
562
563        // sequence iterated, does update
564        repo.update_snapshot::<TestAggregate>(
565            serde_json::to_value(TestAggregate {
566                id: id.clone(),
567                description: "a test description that should be saved".to_string(),
568                tests: test_tests.clone(),
569            })
570            .unwrap(),
571            id.clone(),
572            2,
573            &[],
574        )
575        .await
576        .unwrap();
577
578        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
579        assert_eq!(
580            Some(snapshot_context(
581                id.clone(),
582                0,
583                2,
584                serde_json::to_value(TestAggregate {
585                    id: id.clone(),
586                    description: "a test description that should be saved".to_string(),
587                    tests: test_tests.clone(),
588                })
589                .unwrap(),
590            )),
591            snapshot
592        );
593
594        // sequence out of order or not iterated, does not update
595        let result = repo
596            .update_snapshot::<TestAggregate>(
597                serde_json::to_value(TestAggregate {
598                    id: id.clone(),
599                    description: "a test description that should not be saved".to_string(),
600                    tests: test_tests.clone(),
601                })
602                .unwrap(),
603                id.clone(),
604                2,
605                &[],
606            )
607            .await
608            .unwrap_err();
609        match result {
610            DynamoAggregateError::OptimisticLock => {}
611            _ => panic!("invalid error result found during insert: {result}"),
612        }
613
614        let snapshot = repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
615        assert_eq!(
616            Some(snapshot_context(
617                id.clone(),
618                0,
619                2,
620                serde_json::to_value(TestAggregate {
621                    id: id.clone(),
622                    description: "a test description that should be saved".to_string(),
623                    tests: test_tests.clone(),
624                })
625                .unwrap(),
626            )),
627            snapshot
628        );
629    }
630}