persist-es 0.2.2

A backing logic for an RDBMS implementation of an event store for cqrs-es.
Documentation
use std::collections::HashMap;
use std::marker::PhantomData;

use async_trait::async_trait;
use cqrs_es::{Aggregate, AggregateError, EventEnvelope, EventStore};

use crate::{PersistedEventRepository, SnapshotStoreAggregateContext};

/// Storage engine using a database backing.
/// This is an snapshot-sourced `EventStore`, meaning it uses the serialized aggregate as the
/// primary source of truth for the state of the aggregate.
///
/// The individual events are also persisted but are used only for updating queries.
///
/// For a event-sourced `EventStore` see [`PersistedEventStore`](struct.PersistedEventStore.html).
///
pub struct PersistedSnapshotStore<R, A>
where
    R: PersistedEventRepository<A>,
    A: Aggregate + Send + Sync,
{
    repo: R,
    _phantom: PhantomData<A>,
}

impl<R, A> PersistedSnapshotStore<R, A>
where
    R: PersistedEventRepository<A>,
    A: Aggregate + Send + Sync,
{
    /// Creates a new `PostgresSnapshotStore` from the provided database connection pool,
    /// an `EventStore` used for configuring a new cqrs framework.
    ///
    /// ```ignore
    /// # use postgres_es::PostgresSnapshotStore;
    /// # use cqrs_es::CqrsFramework;
    /// let store = PostgresSnapshotStore::<MyAggregate>::new(pool);
    /// let cqrs = CqrsFramework::new(store, vec![]);
    /// ```
    pub fn new(snapshot_repo: R) -> Self {
        PersistedSnapshotStore {
            repo: snapshot_repo,
            _phantom: PhantomData,
        }
    }
}

#[async_trait]
impl<R, A> EventStore<A> for PersistedSnapshotStore<R, A>
where
    R: PersistedEventRepository<A>,
    A: Aggregate + Send + Sync,
{
    type AC = SnapshotStoreAggregateContext<A>;

    async fn load(&self, aggregate_id: &str) -> Vec<EventEnvelope<A>> {
        match self.repo.get_events(aggregate_id).await {
            Ok(val) => val,
            Err(_err) => {
                // TODO: improved error handling
                Default::default()
            }
        }
    }
    async fn load_aggregate(&self, aggregate_id: &str) -> SnapshotStoreAggregateContext<A> {
        match self.repo.get_snapshot(aggregate_id).await {
            Ok(snapshot) => match snapshot {
                Some(snapshot) => snapshot,
                None => SnapshotStoreAggregateContext {
                    aggregate_id: aggregate_id.to_string(),
                    aggregate: Default::default(),
                    current_sequence: 0,
                    current_snapshot: 0,
                },
            },
            Err(e) => {
                panic!("{}", e);
            }
        }
    }

    async fn commit(
        &self,
        events: Vec<A::Event>,
        mut context: SnapshotStoreAggregateContext<A>,
        metadata: HashMap<String, String>,
    ) -> Result<Vec<EventEnvelope<A>>, AggregateError> {
        for event in events.clone() {
            context.aggregate.apply(event);
        }
        let aggregate_id = context.aggregate_id.clone();
        let next_snapshot = context.current_snapshot + 1;
        let wrapped_events =
            self.wrap_events(&aggregate_id, context.current_sequence, events, metadata);
        self.repo
            .persist(
                &wrapped_events,
                Some((aggregate_id, context.aggregate, next_snapshot)),
            )
            .await?;
        Ok(wrapped_events)
    }
}

#[cfg(test)]
pub(crate) mod test {
    use std::collections::HashMap;
    use std::sync::Mutex;

    use async_trait::async_trait;
    use cqrs_es::{Aggregate, AggregateError, DomainEvent, EventEnvelope, EventStore};
    use serde::{Deserialize, Serialize};

    use crate::{
        PersistedEventRepository, PersistedSnapshotStore, PersistenceError,
        SnapshotStoreAggregateContext,
    };

    #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
    pub(crate) enum TestEvents {
        Started,
        SomethingWasDone,
    }

    impl DomainEvent for TestEvents {
        fn event_type(&self) -> &'static str {
            match self {
                TestEvents::Started => "Started",
                TestEvents::SomethingWasDone => "SomethingWasDone",
            }
        }
        fn event_version(&self) -> &'static str {
            EVENT_VERSION
        }
    }

    #[derive(Debug, Serialize, Deserialize)]
    pub(crate) enum TestCommands {
        DoSomething,
        BadCommand,
    }

    #[derive(Debug, Default, Serialize, Deserialize, PartialEq)]
    pub(crate) struct TestAggregate {
        pub(crate) something_happened: usize,
    }

    impl Aggregate for TestAggregate {
        type Command = TestCommands;
        type Event = TestEvents;

        fn aggregate_type() -> &'static str {
            "TestAggregate"
        }
        fn handle(&self, command: Self::Command) -> Result<Vec<Self::Event>, AggregateError> {
            match command {
                TestCommands::DoSomething => Ok(vec![TestEvents::SomethingWasDone]),
                TestCommands::BadCommand => Err(AggregateError::new("the expected error message")),
            }
        }
        fn apply(&mut self, event: Self::Event) {
            match event {
                TestEvents::Started => {}
                TestEvents::SomethingWasDone => {
                    self.something_happened += 1;
                }
            }
        }
    }

    pub(crate) struct MockRepo {
        events_result: Mutex<Option<Result<Vec<EventEnvelope<TestAggregate>>, PersistenceError>>>,
        snapshot_result: Mutex<
            Option<Result<Option<SnapshotStoreAggregateContext<TestAggregate>>, PersistenceError>>,
        >,
        persist_check: Mutex<
            Option<
                Box<
                    dyn FnOnce(
                            &[EventEnvelope<TestAggregate>],
                            Option<(String, TestAggregate, usize)>,
                        ) + Send,
                >,
            >,
        >,
    }

    impl MockRepo {
        pub(crate) fn with_events(
            result: Result<Vec<EventEnvelope<TestAggregate>>, PersistenceError>,
        ) -> Self {
            Self {
                events_result: Mutex::new(Some(result)),
                snapshot_result: Mutex::new(None),
                persist_check: Mutex::new(None),
            }
        }
        pub(crate) fn with_snapshot(
            result: Result<Option<SnapshotStoreAggregateContext<TestAggregate>>, PersistenceError>,
        ) -> Self {
            Self {
                events_result: Mutex::new(None),
                snapshot_result: Mutex::new(Some(result)),
                persist_check: Mutex::new(None),
            }
        }
        pub(crate) fn with_commit(
            test_function: Box<
                dyn FnOnce(&[EventEnvelope<TestAggregate>], Option<(String, TestAggregate, usize)>)
                    + Send,
            >,
        ) -> Self {
            Self {
                events_result: Mutex::new(None),
                snapshot_result: Mutex::new(None),
                persist_check: Mutex::new(Some(test_function)),
            }
        }
    }

    #[async_trait]
    impl PersistedEventRepository<TestAggregate> for MockRepo {
        async fn get_events(
            &self,
            _aggregate_id: &str,
        ) -> Result<Vec<EventEnvelope<TestAggregate>>, PersistenceError> {
            self.events_result.lock().unwrap().take().unwrap()
        }
        async fn get_last_events(&self, _aggregate_id: &str, _number_events: usize) -> Result<Vec<EventEnvelope<TestAggregate>>, PersistenceError> {
            todo!()
        }
        async fn get_snapshot(
            &self,
            _aggregate_id: &str,
        ) -> Result<Option<SnapshotStoreAggregateContext<TestAggregate>>, PersistenceError>
        {
            self.snapshot_result.lock().unwrap().take().unwrap()
        }
        async fn persist(
            &self,
            events: &[EventEnvelope<TestAggregate>],
            snapshot_update: Option<(String, TestAggregate, usize)>,
        ) -> Result<(), PersistenceError> {
            let test = self.persist_check.lock().unwrap().take().unwrap();
            test(events, snapshot_update);
            Ok(())
        }
    }

    pub(crate) const TEST_AGGREGATE_ID: &str = "test-aggregate-C";
    pub(crate) const EVENT_VERSION: &'static str = "1.0";

    #[tokio::test]
    async fn load() {
        let repo = MockRepo::with_events(Ok(vec![test_event_envelope(
            1,
            TestEvents::SomethingWasDone,
        )]));
        let store = PersistedSnapshotStore::new(repo);
        let events = store.load(TEST_AGGREGATE_ID).await;
        let event = events.get(0).unwrap();
        assert_eq!(1, event.sequence);
        assert_eq!("SomethingWasDone", event.event_type);
        assert_eq!(EVENT_VERSION, event.event_version);
    }

    #[tokio::test]
    async fn load_error() {
        let repo = MockRepo::with_events(Err(PersistenceError::OptimisticLockError));
        let store = PersistedSnapshotStore::new(repo);
        let events = store.load(TEST_AGGREGATE_ID).await;
        assert_eq!(0, events.len())
    }

    #[tokio::test]
    async fn load_aggregate_new() {
        let repo = MockRepo::with_snapshot(Ok(None));
        let store = PersistedSnapshotStore::new(repo);
        let snapshot_context = store.load_aggregate(TEST_AGGREGATE_ID).await;
        assert_eq!(0, snapshot_context.current_snapshot);
        assert_eq!(0, snapshot_context.current_sequence);
        assert_eq!(TEST_AGGREGATE_ID, snapshot_context.aggregate_id);
        assert_eq!(TestAggregate::default(), snapshot_context.aggregate);
    }

    #[tokio::test]
    async fn load_aggregate_existing() {
        let repo =
            MockRepo::with_snapshot(Ok(Some(SnapshotStoreAggregateContext::<TestAggregate> {
                aggregate_id: TEST_AGGREGATE_ID.to_string(),
                aggregate: TestAggregate {
                    something_happened: 3,
                },
                current_sequence: 3,
                current_snapshot: 2,
            })));
        let store = PersistedSnapshotStore::new(repo);
        let snapshot_context = store.load_aggregate(TEST_AGGREGATE_ID).await;
        assert_eq!(2, snapshot_context.current_snapshot);
        assert_eq!(3, snapshot_context.current_sequence);
        assert_eq!(TEST_AGGREGATE_ID, snapshot_context.aggregate_id);
        assert_eq!(
            TestAggregate {
                something_happened: 3
            },
            snapshot_context.aggregate
        );
    }

    // TODO: better error handling needed, this panic could cause problems with non-severless systems
    #[tokio::test]
    #[should_panic]
    async fn load_aggregate_error() {
        let repo = MockRepo::with_snapshot(Err(PersistenceError::OptimisticLockError));
        let store = PersistedSnapshotStore::new(repo);
        store.load_aggregate(TEST_AGGREGATE_ID).await;
    }

    #[tokio::test]
    async fn commit() {
        let repo = MockRepo::with_commit(Box::new(|events, snapshot_update| {
            assert_eq!(3, events.len());
            let event = events.get(2).unwrap();
            assert_eq!(TEST_AGGREGATE_ID, event.aggregate_id);
            assert_eq!(3, event.sequence);

            let snapshot_update = snapshot_update.unwrap();
            let aggregate_id = snapshot_update.0;
            let aggregate = snapshot_update.1;
            let snapshot_version = snapshot_update.2;
            assert_eq!(TEST_AGGREGATE_ID, aggregate_id.as_str());
            assert_eq!(1, snapshot_version);
            assert_eq!(
                TestAggregate {
                    something_happened: 2
                },
                aggregate
            );
        }));
        let store = PersistedSnapshotStore::new(repo);
        let context = SnapshotStoreAggregateContext {
            aggregate_id: TEST_AGGREGATE_ID.to_string(),
            aggregate: TestAggregate::default(),
            current_sequence: 0,
            current_snapshot: 0,
        };
        let event_envelopes = store
            .commit(
                vec![
                    TestEvents::Started,
                    TestEvents::SomethingWasDone,
                    TestEvents::SomethingWasDone,
                ],
                context,
                HashMap::default(),
            )
            .await
            .unwrap();
        assert_eq!(3, event_envelopes.len());
        let event = event_envelopes.get(0).unwrap();
        assert_eq!(TEST_AGGREGATE_ID, event.aggregate_id);
        assert_eq!(TestEvents::Started, event.payload);
        assert_eq!(TestEvents::SomethingWasDone, event_envelopes.get(2).unwrap().payload);
    }

    pub(crate) fn test_event_envelope(seq: usize, payload: TestEvents) -> EventEnvelope<TestAggregate> {
        EventEnvelope::<TestAggregate>::new(
            TEST_AGGREGATE_ID.to_string(),
            seq,
            TestAggregate::aggregate_type().to_string(),
            payload,
        )
    }
}