use futures::TryStreamExt;
use crate::repository::{RepositoryError, RepositoryReader, RepositoryWriter, Snapshot};
use crate::{
aggregate::Aggregate,
event::{DomainEvent, EventStoreEvent},
};
use std::fmt::Debug;
#[derive(Debug, Clone)]
#[must_use]
pub struct Context<T>
where
T: Aggregate,
{
aggregate: T,
version: u64,
uncommitted_events: Vec<EventStoreEvent<T::DomainEvent>>,
uncommitted_side_effects: Vec<T::SideEffect>,
}
impl<T> Context<T>
where
T: Aggregate,
{
pub fn aggregate_id(&self) -> &T::AggregateId {
self.aggregate.aggregate_id()
}
pub fn version(&self) -> u64 {
self.version
}
pub fn snapshot_version(&self) -> u64 {
T::SNAPSHOT_VERSION
}
#[doc(hidden)]
pub fn take_uncommitted_events(&mut self) -> Vec<EventStoreEvent<T::DomainEvent>> {
std::mem::take(&mut self.uncommitted_events)
}
#[doc(hidden)]
pub fn take_uncommitted_side_effects(&mut self) -> Vec<T::SideEffect> {
std::mem::take(&mut self.uncommitted_side_effects)
}
#[doc(hidden)]
pub fn rehydrate_from(
event: &EventStoreEvent<T::DomainEvent>,
) -> Result<Context<T>, T::ApplyError> {
Ok(Context {
version: event.version,
aggregate: T::apply_new(&event.event)?,
uncommitted_events: Vec::default(),
uncommitted_side_effects: Vec::default(),
})
}
#[doc(hidden)]
pub fn apply_rehydrated_event(
mut self,
event: &EventStoreEvent<T::DomainEvent>,
) -> Result<Context<T>, T::ApplyError> {
self.version += 1;
debug_assert!(self.version == event.version);
self.aggregate.apply(&event.event)?;
Ok(self)
}
pub(crate) fn record_new(event: T::DomainEvent) -> Result<Context<T>, T::ApplyError> {
let aggregate = T::apply_new(&event)?;
let event_store_event = EventStoreEvent {
id: event.id().clone(),
version: 0,
event,
};
let uncommitted_side_effects = aggregate
.side_effects(&event_store_event.event)
.unwrap_or_default();
Ok(Context {
version: 0,
aggregate,
uncommitted_events: vec![event_store_event],
uncommitted_side_effects,
})
}
pub fn state(&self) -> &T {
&self.aggregate
}
pub fn record_that(&mut self, event: T::DomainEvent) -> Result<(), T::ApplyError> {
self.aggregate.apply(&event)?;
self.version += 1;
if let Some(side_effects) = self.aggregate.side_effects(&event) {
self.uncommitted_side_effects.extend(side_effects);
}
self.uncommitted_events.push(EventStoreEvent {
id: event.id().clone(),
version: self.version,
event,
});
Ok(())
}
pub async fn save<R>(&mut self, transaction: &mut R) -> Result<(), SaveError<T, R::DbError>>
where
R: RepositoryWriter<T>,
{
let events_to_commit = self.take_uncommitted_events();
if events_to_commit.is_empty() {
return Ok(());
}
let side_effects_to_commit = self.take_uncommitted_side_effects();
let aggregate_id = self.aggregate_id();
let snapshot_version = self.snapshot_version();
let snapshot_to_store = self.state();
let snapshot = Snapshot {
snapshot_version,
aggregate: snapshot_to_store.clone(),
version: self.version(),
};
let inserted_event_ids = transaction
.store_events(aggregate_id, events_to_commit.clone())
.await
.map_err(SaveError::Repository)?;
if inserted_event_ids.len() != events_to_commit.len() {
for event in events_to_commit {
if !inserted_event_ids.contains(&event.id) {
if let Some(saved_event) = transaction
.get_event(aggregate_id, event.id())
.await
.map_err(SaveError::Repository)?
{
if saved_event.event != event.event {
return Err(SaveError::IdempotencyError(
saved_event.event,
event.event,
));
}
} else {
return Err(SaveError::OptimisticConcurrency(
aggregate_id.clone(),
event.version,
));
}
}
}
}
transaction
.store_snapshot(snapshot)
.await
.map_err(SaveError::Repository)?;
transaction
.store_side_effects(side_effects_to_commit)
.await?;
Ok(())
}
pub async fn load<R>(
reader: &mut R,
aggregate_id: &T::AggregateId,
) -> Result<
Context<T>,
RepositoryError<
T::ApplyError,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId,
R::DbError,
>,
>
where
R: RepositoryReader<T>,
{
let snapshot = reader.get_snapshot(aggregate_id).await?;
let (context, version) = snapshot
.map(|s| {
(
Some(Context {
aggregate: s.aggregate,
version: s.version,
uncommitted_events: Vec::new(),
uncommitted_side_effects: Vec::new(),
}),
s.version + 1, )
})
.unwrap_or((None, 0));
let ctx = reader
.stream_from(aggregate_id, version)
.map_err(RepositoryError::Repository)
.try_fold(context, |ctx: Option<Context<T>>, event| async move {
match ctx {
None => Context::rehydrate_from(&event).map(Some),
Some(ctx) => ctx.apply_rehydrated_event(&event).map(Some),
}
.map_err(|e| RepositoryError::Apply(event.id, e))
})
.await?;
ctx.ok_or(RepositoryError::AggregateNotFound)
}
pub async fn regenerate_side_effects<R>(
reader: &mut R,
aggregate_id: &T::AggregateId,
event_id: &<<T as Aggregate>::DomainEvent as DomainEvent>::EventId,
) -> Result<
Option<Vec<T::SideEffect>>,
RepositoryError<
T::ApplyError,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId,
R::DbError,
>,
>
where
R: RepositoryReader<T>,
{
use futures::{StreamExt, TryStreamExt};
let event = reader.get_event(aggregate_id, event_id).await?;
let Some(target_event) = event else {
return Ok(None);
};
let context = reader
.stream_from(aggregate_id, 0)
.take_while(|result| {
futures::future::ready(match result {
Ok(e) => e.version <= target_event.version,
Err(_) => true,
})
})
.map_err(RepositoryError::Repository)
.try_fold(None, |ctx: Option<Context<T>>, event| async move {
match ctx {
None => Context::rehydrate_from(&event).map(Some),
Some(c) => c.apply_rehydrated_event(&event).map(Some),
}
.map_err(|e| RepositoryError::Apply(event.id, e))
})
.await?;
let Some(aggregate) = context else {
return Ok(None);
};
Ok(aggregate.aggregate.side_effects(&target_event.event))
}
}
#[derive(Debug, thiserror::Error)]
pub enum SaveError<T, DE>
where
T: Aggregate,
{
#[error("Idempotency Error. Saved event {0:?} does not equal {1:?}")]
IdempotencyError(T::DomainEvent, T::DomainEvent),
#[error("Event store failed while streaming events: {0}")]
Repository(#[from] DE),
#[error("Optimistic Concurrency Error")]
OptimisticConcurrency(T::AggregateId, u64),
}
pub trait Root<T>
where
T: Aggregate,
{
fn record_new(event: T::DomainEvent) -> Result<Context<T>, T::ApplyError> {
Context::record_new(event)
}
}
impl<T> Root<T> for T where T: Aggregate {}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_fixtures::*;
#[test]
fn test_context_record_new() {
let reset_event = create_reset_event("reset-1", "calc-1");
let context = TestCounter::record_new(reset_event.clone()).unwrap();
assert_eq!(context.state().id, "calc-1");
assert_eq!(context.state().result, 0);
assert_eq!(context.version(), 0);
assert_eq!(context.aggregate_id(), "calc-1");
assert_eq!(context.snapshot_version(), TestCounter::SNAPSHOT_VERSION);
}
#[test]
fn test_context_record_that() {
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
let add_event = create_add_event("add-1", 15);
context.record_that(add_event).unwrap();
assert_eq!(context.state().result, 15);
assert_eq!(context.version(), 1);
let subtract_event = create_subtract_event("sub-1", 5);
context.record_that(subtract_event).unwrap();
assert_eq!(context.state().result, 10);
assert_eq!(context.version(), 2);
}
#[test]
fn test_context_record_that_error() {
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
let multiply_zero_event = create_multiply_event("mul-zero", 0);
let result = context.record_that(multiply_zero_event);
assert!(matches!(result, Err(TestError::DivisionByZero)));
assert_eq!(context.state().result, 0);
assert_eq!(context.version(), 0);
}
#[test]
fn test_context_version_tracking() {
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
assert_eq!(context.version(), 0);
for i in 1..=5 {
let add_event = create_add_event(&format!("add-{i}"), i);
context.record_that(add_event).unwrap();
assert_eq!(context.version(), i as u64);
}
}
#[test]
fn test_context_uncommitted_events() {
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
let add_event = create_add_event("add-1", 10);
context.record_that(add_event).unwrap();
let events = context.take_uncommitted_events();
assert_eq!(events.len(), 2);
let empty_events = context.take_uncommitted_events();
assert_eq!(empty_events.len(), 0);
}
#[test]
fn test_context_side_effects() {
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
let add_event = create_add_event("add-1", 5);
context.record_that(add_event).unwrap();
let side_effects = context.take_uncommitted_side_effects();
assert_eq!(side_effects.len(), 3);
let empty_side_effects = context.take_uncommitted_side_effects();
assert_eq!(empty_side_effects.len(), 0);
}
#[test]
fn test_context_state_access() {
let reset_event = create_reset_event("reset-1", "calc-1");
let context = TestCounter::record_new(reset_event).unwrap();
let state = context.state();
assert_eq!(state.id, "calc-1");
assert_eq!(state.result, 0);
assert_eq!(state.operations_count, 0);
}
#[test]
fn test_context_snapshot_version() {
let reset_event = create_reset_event("reset-1", "calc-1");
let context = TestCounter::record_new(reset_event).unwrap();
assert_eq!(context.snapshot_version(), TestCounter::SNAPSHOT_VERSION);
assert_eq!(context.snapshot_version(), 1);
let version = context.snapshot_version();
assert_eq!(version, 1u64);
assert_eq!(version, TestCounter::SNAPSHOT_VERSION);
}
#[test]
fn test_context_rehydrate_from() {
use crate::event::EventStoreEvent;
let reset_event = create_reset_event("reset-1", "calc-1");
let store_event = EventStoreEvent::new("reset-1".to_string(), 0, reset_event);
let context = Context::<TestCounter>::rehydrate_from(&store_event).unwrap();
assert_eq!(context.state().id, "calc-1");
assert_eq!(context.state().result, 0);
assert_eq!(context.version(), 0);
assert!(context.uncommitted_events.is_empty());
assert!(context.uncommitted_side_effects.is_empty());
}
#[test]
fn test_context_apply_rehydrated_event() {
use crate::event::EventStoreEvent;
let reset_event = create_reset_event("reset-1", "calc-1");
let store_event = EventStoreEvent::new("reset-1".to_string(), 0, reset_event);
let context = Context::<TestCounter>::rehydrate_from(&store_event).unwrap();
let add_event = create_add_event("add-1", 10);
let add_store_event = EventStoreEvent::new("add-1".to_string(), 1, add_event);
let updated_context = context.apply_rehydrated_event(&add_store_event).unwrap();
assert_eq!(updated_context.state().result, 10);
assert_eq!(updated_context.version(), 1);
assert!(updated_context.uncommitted_events.is_empty());
assert!(updated_context.uncommitted_side_effects.is_empty());
}
#[test]
fn test_context_apply_rehydrated_event_error() {
use crate::event::EventStoreEvent;
let reset_event = create_reset_event("reset-1", "calc-1");
let store_event = EventStoreEvent::new("reset-1".to_string(), 0, reset_event);
let context = Context::<TestCounter>::rehydrate_from(&store_event).unwrap();
let invalid_event = create_multiply_event("mul-zero", 0);
let invalid_store_event = EventStoreEvent::new("mul-zero".to_string(), 1, invalid_event);
let result = context.apply_rehydrated_event(&invalid_store_event);
assert!(matches!(result, Err(TestError::DivisionByZero)));
}
#[test]
fn test_context_rehydrate_from_error() {
use crate::event::EventStoreEvent;
let add_event = create_add_event("add-1", 10);
let store_event = EventStoreEvent::new("add-1".to_string(), 0, add_event);
let result = Context::<TestCounter>::rehydrate_from(&store_event);
assert!(matches!(result, Err(TestError::InvalidOperation)));
}
#[test]
fn test_context_record_new_with_side_effects() {
let reset_event = create_reset_event("reset-1", "calc-1");
let context = Context::<TestCounter>::record_new(reset_event).unwrap();
assert_eq!(context.uncommitted_side_effects.len(), 2);
let side_effects = &context.uncommitted_side_effects;
assert!(side_effects.iter().any(|se| matches!(
se,
crate::test_fixtures::TestSideEffect::LogOperation { .. }
)));
assert!(
side_effects
.iter()
.any(|se| matches!(se, crate::test_fixtures::TestSideEffect::NotifyUser { .. }))
);
}
#[test]
fn test_context_record_new_no_side_effects() {
use crate::test_fixtures::*;
#[derive(Clone, Debug, PartialEq, Eq)]
struct SimpleCounts {
id: String,
count: i32,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum SimpleEvent {
Created { event_id: String, id: String },
}
impl crate::event::DomainEvent for SimpleEvent {
type EventId = String;
fn id(&self) -> &Self::EventId {
match self {
SimpleEvent::Created { event_id, .. } => event_id,
}
}
}
impl crate::aggregate::Aggregate for SimpleCounts {
const SNAPSHOT_VERSION: u64 = 1;
type AggregateId = String;
type DomainEvent = SimpleEvent;
type ApplyError = String;
type SideEffect = TestSideEffect;
fn aggregate_id(&self) -> &Self::AggregateId {
&self.id
}
fn apply_new(event: &Self::DomainEvent) -> Result<Self, Self::ApplyError> {
match event {
SimpleEvent::Created { id, .. } => Ok(SimpleCounts {
id: id.clone(),
count: 0,
}),
}
}
fn apply(&mut self, _event: &Self::DomainEvent) -> Result<(), Self::ApplyError> {
Ok(())
}
fn side_effects(&self, _event: &Self::DomainEvent) -> Option<Vec<Self::SideEffect>> {
None }
}
let create_event = SimpleEvent::Created {
event_id: "evt-1".to_string(),
id: "simple-1".to_string(),
};
let context = Context::<SimpleCounts>::record_new(create_event).unwrap();
assert_eq!(context.uncommitted_side_effects.len(), 0);
assert_eq!(context.version(), 0);
assert_eq!(context.state().id, "simple-1");
assert_eq!(context.aggregate_id(), &"simple-1".to_string());
let state_ref = context.state();
assert_eq!(state_ref.count, 0);
}
}