use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use futures::{Stream, stream};
use crate::{
aggregate::{Aggregate, Context, SaveError},
event::{DomainEvent, EventStoreEvent},
repository::{Repository, RepositoryReader, RepositoryWriter, Snapshot},
};
#[derive(Debug, thiserror::Error)]
pub enum InMemoryError {
#[error("Aggregate not found")]
AggregateNotFound,
#[error("Event not found")]
EventNotFound,
#[error("Version conflict: expected {expected}, got {actual}")]
VersionConflict { expected: u64, actual: u64 },
#[error("Event already exists with different content")]
EventExists,
}
#[derive(Debug, Clone)]
struct InMemoryStorage<T: Aggregate>
where
T::AggregateId: Hash + Eq,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq,
T::SideEffect: Debug + Clone,
{
events: HashMap<T::AggregateId, HashMap<u64, EventStoreEvent<T::DomainEvent>>>,
events_by_id:
HashMap<<<T as Aggregate>::DomainEvent as DomainEvent>::EventId, (T::AggregateId, u64)>,
snapshots: HashMap<T::AggregateId, Snapshot<T>>,
side_effects: Vec<T::SideEffect>,
}
impl<T: Aggregate> Default for InMemoryStorage<T>
where
T::AggregateId: Hash + Eq,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq,
T::SideEffect: Debug + Clone,
{
fn default() -> Self {
Self {
events: HashMap::new(),
events_by_id: HashMap::new(),
snapshots: HashMap::new(),
side_effects: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct InMemoryRepository<T: Aggregate>
where
T::AggregateId: Hash + Eq,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq,
T::SideEffect: Debug + Clone,
{
storage: Arc<Mutex<InMemoryStorage<T>>>,
}
impl<T: Aggregate> InMemoryRepository<T>
where
T::AggregateId: Hash + Eq,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq,
T::SideEffect: Debug + Clone,
{
pub fn new() -> Self {
Self {
storage: Arc::new(Mutex::new(InMemoryStorage::default())),
}
}
pub fn get_event(
&self,
aggregate_id: &T::AggregateId,
event_id: &<<T as Aggregate>::DomainEvent as DomainEvent>::EventId,
) -> Option<EventStoreEvent<T::DomainEvent>> {
let storage = self.storage.lock().unwrap();
if let Some((stored_aggregate_id, version)) = storage.events_by_id.get(event_id) {
if stored_aggregate_id == aggregate_id {
if let Some(aggregate_events) = storage.events.get(aggregate_id) {
return aggregate_events.get(version).cloned();
}
}
}
None
}
pub fn get_events_from(
&self,
aggregate_id: &T::AggregateId,
from_version: u64,
) -> Vec<EventStoreEvent<T::DomainEvent>> {
let storage = self.storage.lock().unwrap();
if let Some(aggregate_events) = storage.events.get(aggregate_id) {
let mut events: Vec<_> = aggregate_events
.iter()
.filter(|(version, _)| **version >= from_version)
.map(|(_, event)| event.clone())
.collect();
events.sort_by_key(|event| event.version);
events
} else {
Vec::new()
}
}
pub fn side_effects_count(&self) -> usize {
let storage = self.storage.lock().unwrap();
storage.side_effects.len()
}
pub fn get_all_side_effects(&self) -> Vec<T::SideEffect> {
let storage = self.storage.lock().unwrap();
storage.side_effects.clone()
}
}
impl<T: Aggregate> Default for InMemoryRepository<T>
where
T::AggregateId: Hash + Eq,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq,
T::SideEffect: Debug + Clone,
{
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct InMemoryTransaction<T: Aggregate>
where
T::AggregateId: Hash + Eq,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq,
T::SideEffect: Debug + Clone,
{
repository: InMemoryRepository<T>,
}
impl<T: Aggregate> InMemoryTransaction<T>
where
T::AggregateId: Hash + Eq + Send + Sync,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq + Send + Sync,
T::SideEffect: Debug + Clone + Send + Sync,
T: Send + Sync,
T::DomainEvent: Send + Sync,
T::ApplyError: Send + Sync,
{
pub fn new(repository: InMemoryRepository<T>) -> Self {
Self { repository }
}
pub async fn get(
&mut self,
aggregate_id: &T::AggregateId,
) -> Result<Context<T>, InMemoryError> {
Context::load(self, aggregate_id)
.await
.map_err(|e| match e {
crate::repository::RepositoryError::AggregateNotFound => {
InMemoryError::AggregateNotFound
}
crate::repository::RepositoryError::Apply(_, _) => InMemoryError::AggregateNotFound,
crate::repository::RepositoryError::Repository(db_err) => db_err,
})
}
pub async fn store(&mut self, context: &mut Context<T>) -> Result<(), InMemoryError> {
context.save(self).await.map_err(|e| match e {
SaveError::IdempotencyError(_, _) => InMemoryError::EventExists,
SaveError::OptimisticConcurrency(_, version) => InMemoryError::VersionConflict {
expected: version,
actual: version,
},
SaveError::Repository(db_err) => db_err,
})
}
pub fn commit(self) -> Result<(), InMemoryError> {
Ok(())
}
pub fn rollback(self) -> Result<(), InMemoryError> {
Ok(())
}
}
#[async_trait]
impl<T: Aggregate> RepositoryReader<T> for InMemoryTransaction<T>
where
T::AggregateId: Hash + Eq + Send + Sync,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq + Send + Sync,
T::SideEffect: Debug + Clone + Send + Sync,
T: Send + Sync,
T::DomainEvent: Send + Sync,
{
type DbError = InMemoryError;
fn stream_from(
&mut self,
id: &T::AggregateId,
version: u64,
) -> impl Stream<Item = Result<EventStoreEvent<T::DomainEvent>, Self::DbError>> {
let events = self.repository.get_events_from(id, version);
stream::iter(events.into_iter().map(Ok))
}
async fn get_event(
&mut self,
aggregate_id: &T::AggregateId,
event_id: &<<T as Aggregate>::DomainEvent as DomainEvent>::EventId,
) -> Result<Option<EventStoreEvent<T::DomainEvent>>, Self::DbError> {
Ok(self.repository.get_event(aggregate_id, event_id))
}
async fn get_snapshot(
&mut self,
id: &T::AggregateId,
) -> Result<Option<Snapshot<T>>, Self::DbError> {
let storage = self.repository.storage.lock().unwrap();
if let Some(snapshot) = storage.snapshots.get(id) {
if snapshot.snapshot_version == T::SNAPSHOT_VERSION {
Ok(Some(snapshot.clone()))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
}
#[async_trait]
impl<T: Aggregate> RepositoryWriter<T> for InMemoryTransaction<T>
where
T::AggregateId: Hash + Eq + Send + Sync,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq + Send + Sync,
T::SideEffect: Debug + Clone + Send + Sync,
T: Send + Sync,
T::DomainEvent: Send + Sync,
{
async fn store_events(
&mut self,
id: &T::AggregateId,
events: Vec<EventStoreEvent<T::DomainEvent>>,
) -> Result<Vec<<<T as Aggregate>::DomainEvent as DomainEvent>::EventId>, Self::DbError> {
let mut storage = self.repository.storage.lock().unwrap();
let mut stored_ids = Vec::new();
for event in events {
if let Some((existing_agg_id, _existing_version)) = storage.events_by_id.get(event.id())
{
if existing_agg_id == id {
continue;
} else {
return Err(InMemoryError::EventExists);
}
}
if let Some(aggregate_events) = storage.events.get(id) {
if aggregate_events.contains_key(&event.version) {
return Err(InMemoryError::VersionConflict {
expected: event.version,
actual: event.version,
});
}
}
storage
.events
.entry(id.clone())
.or_default()
.insert(event.version, event.clone());
storage
.events_by_id
.insert(event.id().clone(), (id.clone(), event.version));
stored_ids.push(event.id().clone());
}
Ok(stored_ids)
}
async fn store_snapshot(&mut self, snapshot: Snapshot<T>) -> Result<(), Self::DbError> {
let mut storage = self.repository.storage.lock().unwrap();
storage
.snapshots
.insert(snapshot.aggregate.aggregate_id().clone(), snapshot);
Ok(())
}
async fn store_side_effects(
&mut self,
side_effects: Vec<T::SideEffect>,
) -> Result<(), Self::DbError> {
let mut storage = self.repository.storage.lock().unwrap();
storage.side_effects.extend(side_effects);
Ok(())
}
}
#[async_trait]
impl<T: Aggregate> Repository<T> for InMemoryRepository<T>
where
T::AggregateId: Hash + Eq + Send + Sync,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq + Send + Sync,
T::SideEffect: Debug + Clone + Send + Sync,
T: Send + Sync,
T::DomainEvent: Send + Sync,
T::ApplyError: Send + Sync,
{
type Error = InMemoryError;
async fn load(&self, aggregate_id: &T::AggregateId) -> Result<Context<T>, Self::Error> {
let mut transaction = InMemoryTransaction::new(self.clone());
Context::load(&mut transaction, aggregate_id)
.await
.map_err(|e| match e {
crate::repository::RepositoryError::AggregateNotFound => {
InMemoryError::AggregateNotFound
}
crate::repository::RepositoryError::Apply(_, _) => InMemoryError::AggregateNotFound,
crate::repository::RepositoryError::Repository(db_err) => db_err,
})
}
}
impl<T: Aggregate> InMemoryRepository<T>
where
T::AggregateId: Hash + Eq + Send + Sync,
<<T as Aggregate>::DomainEvent as DomainEvent>::EventId: Hash + Eq + Send + Sync,
T::SideEffect: Debug + Clone + Send + Sync,
T: Send + Sync,
T::DomainEvent: Send + Sync,
T::ApplyError: Send + Sync,
{
pub async fn begin_transaction(&self) -> Result<InMemoryTransaction<T>, InMemoryError> {
Ok(InMemoryTransaction::new(self.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
aggregate::{Context, Root, SaveError},
repository::Repository,
test_fixtures::*,
};
#[tokio::test]
async fn test_repository_load_and_save() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
context.record_that(create_add_event("add-1", 10)).unwrap();
context
.record_that(create_subtract_event("sub-1", 3))
.unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut context).await.unwrap();
transaction.commit().unwrap();
let loaded_context = repository.load(&"calc-1".to_string()).await.unwrap();
assert_eq!(loaded_context.state().result, 7);
assert_eq!(loaded_context.state().operations_count, 2);
assert_eq!(loaded_context.version(), 2);
}
#[tokio::test]
async fn test_repository_aggregate_not_found() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let result = repository.load(&"non-existent".to_string()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_repository_transaction_operations() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut context).await.unwrap();
transaction.commit().unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
let mut loaded_context = transaction.get(&"calc-1".to_string()).await.unwrap();
loaded_context
.record_that(create_add_event("add-1", 20))
.unwrap();
transaction.store(&mut loaded_context).await.unwrap();
transaction.commit().unwrap();
let final_context = repository.load(&"calc-1".to_string()).await.unwrap();
assert_eq!(final_context.state().result, 20);
}
#[tokio::test]
async fn test_repository_idempotency() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context1 = TestCounter::record_new(reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut context1).await.unwrap();
transaction.commit().unwrap();
let same_reset_event = create_reset_event("reset-1", "calc-1");
let mut context2 = TestCounter::record_new(same_reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
let result = transaction.store(&mut context2).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_repository_idempotency_error() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context1 = TestCounter::record_new(reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut context1).await.unwrap();
transaction.commit().unwrap();
let different_reset_event = create_reset_event("reset-1", "different-calc");
let mut context2 = TestCounter::record_new(different_reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
let result = transaction.store(&mut context2).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_snapshot_version_checking() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
context.record_that(create_add_event("add-1", 100)).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut context).await.unwrap();
transaction.commit().unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
let snapshot = transaction
.get_snapshot(&"calc-1".to_string())
.await
.unwrap();
assert!(snapshot.is_some());
let snap = snapshot.unwrap();
assert_eq!(snap.snapshot_version, TestCounter::SNAPSHOT_VERSION);
assert_eq!(snap.version, 1);
assert_eq!(snap.aggregate.result, 100);
}
#[tokio::test]
async fn test_in_memory_repository_basic_operations() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
assert_eq!(repository.side_effects_count(), 0);
assert!(repository.get_all_side_effects().is_empty());
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
context.record_that(create_add_event("add-1", 50)).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut context).await.unwrap();
transaction.commit().unwrap();
assert!(repository.side_effects_count() > 0);
assert!(!repository.get_all_side_effects().is_empty());
}
#[tokio::test]
async fn test_context_save_empty_events() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut context).await.unwrap();
transaction.commit().unwrap();
let mut loaded_context = repository.load(&"calc-1".to_string()).await.unwrap();
assert_eq!(loaded_context.take_uncommitted_events().len(), 0);
let mut transaction = repository.begin_transaction().await.unwrap();
let result = loaded_context.save(&mut transaction).await;
assert!(result.is_ok());
transaction.commit().unwrap();
}
#[tokio::test]
async fn test_context_save_idempotency_error() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context1 = TestCounter::record_new(reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
context1.save(&mut transaction).await.unwrap();
transaction.commit().unwrap();
let mut loaded_context = repository.load(&"calc-1".to_string()).await.unwrap();
loaded_context
.record_that(create_add_event("add-1", 50))
.unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
loaded_context.save(&mut transaction).await.unwrap();
transaction.commit().unwrap();
let mut loaded_context2 = repository.load(&"calc-1".to_string()).await.unwrap();
loaded_context2
.record_that(create_add_event("add-1", 100))
.unwrap(); let mut transaction = repository.begin_transaction().await.unwrap();
let result = loaded_context2.save(&mut transaction).await;
assert!(matches!(result, Err(SaveError::IdempotencyError(_, _))));
}
#[tokio::test]
async fn test_context_save_optimistic_concurrency() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
context.save(&mut transaction).await.unwrap();
transaction.commit().unwrap();
let mut context1 = repository.load(&"calc-1".to_string()).await.unwrap();
let mut context2 = repository.load(&"calc-1".to_string()).await.unwrap();
context1.record_that(create_add_event("add-1", 10)).unwrap();
context2.record_that(create_add_event("add-2", 20)).unwrap();
let mut transaction1 = repository.begin_transaction().await.unwrap();
let result1 = context1.save(&mut transaction1).await;
assert!(result1.is_ok());
transaction1.commit().unwrap();
let mut transaction2 = repository.begin_transaction().await.unwrap();
let result2 = context2.save(&mut transaction2).await;
assert!(matches!(
result2,
Err(SaveError::Repository(InMemoryError::VersionConflict { .. }))
));
}
#[tokio::test]
async fn test_context_load_with_snapshot() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let mut context = TestCounter::record_new(create_reset_event("reset-1", "calc-1")).unwrap();
context.record_that(create_add_event("add-1", 100)).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
context.save(&mut transaction).await.unwrap();
transaction.commit().unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
let mut loaded_context = Context::load(&mut transaction, &"calc-1".to_string())
.await
.unwrap();
assert_eq!(loaded_context.state().result, 100);
assert_eq!(loaded_context.version(), 1);
assert_eq!(loaded_context.take_uncommitted_events().len(), 0);
}
#[tokio::test]
async fn test_regenerate_side_effects_specific_event() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let mut context = TestCounter::record_new(create_reset_event("reset-1", "calc-1")).unwrap();
context.record_that(create_add_event("add-1", 100)).unwrap();
context
.record_that(create_subtract_event("sub-1", 25))
.unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
context.save(&mut transaction).await.unwrap();
transaction.commit().unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
let side_effects = Context::<TestCounter>::regenerate_side_effects(
&mut transaction,
&"calc-1".to_string(),
&"add-1".to_string(),
)
.await
.unwrap();
assert!(side_effects.is_some());
let effects = side_effects.unwrap();
assert_eq!(effects.len(), 1);
assert!(matches!(
effects[0],
TestSideEffect::LogOperation { ref operation, .. } if operation == "Add 100"
));
}
#[tokio::test]
async fn test_regenerate_side_effects_event_not_found() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let mut context = TestCounter::record_new(create_reset_event("reset-1", "calc-1")).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
context.save(&mut transaction).await.unwrap();
transaction.commit().unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
let side_effects = Context::<TestCounter>::regenerate_side_effects(
&mut transaction,
&"calc-1".to_string(),
&"non-existent-event".to_string(),
)
.await
.unwrap();
assert!(side_effects.is_none());
}
#[tokio::test]
async fn test_complete_counter_workflow() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let reset_event = create_reset_event("reset-1", "my-counter");
let mut counter = TestCounter::record_new(reset_event).unwrap();
counter.record_that(create_add_event("add-1", 100)).unwrap();
counter
.record_that(create_subtract_event("sub-1", 25))
.unwrap();
counter
.record_that(create_multiply_event("mul-1", 2))
.unwrap();
counter.record_that(create_add_event("add-2", 50)).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut counter).await.unwrap();
transaction.commit().unwrap();
let loaded_counter = repository.load(&"my-counter".to_string()).await.unwrap();
assert_eq!(loaded_counter.state().result, 200); assert_eq!(loaded_counter.state().operations_count, 4);
assert_eq!(loaded_counter.version(), 4);
let expected_side_effects = 5;
assert_eq!(repository.side_effects_count(), expected_side_effects);
let side_effects = repository.get_all_side_effects();
assert_eq!(side_effects.len(), expected_side_effects);
}
#[tokio::test]
async fn test_error_handling_and_recovery() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::new();
let mut counter = TestCounter::record_new(create_reset_event("reset-1", "calc-1")).unwrap();
counter.record_that(create_add_event("add-1", 10)).unwrap();
let invalid_result = counter.record_that(create_multiply_event("mul-zero", 0));
assert!(matches!(invalid_result, Err(TestError::DivisionByZero)));
assert_eq!(counter.state().result, 10);
assert_eq!(counter.version(), 1);
counter
.record_that(create_multiply_event("mul-valid", 3))
.unwrap();
assert_eq!(counter.state().result, 30);
assert_eq!(counter.version(), 2);
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut counter).await.unwrap();
transaction.commit().unwrap();
let loaded = repository.load(&"calc-1".to_string()).await.unwrap();
assert_eq!(loaded.state().result, 30);
assert_eq!(loaded.version(), 2);
}
#[tokio::test]
async fn test_in_memory_repository_default() {
let repository: InMemoryRepository<TestCounter> = InMemoryRepository::default();
assert_eq!(repository.side_effects_count(), 0);
assert!(repository.get_all_side_effects().is_empty());
let reset_event = create_reset_event("reset-1", "calc-1");
let mut context = TestCounter::record_new(reset_event).unwrap();
let mut transaction = repository.begin_transaction().await.unwrap();
transaction.store(&mut context).await.unwrap();
transaction.commit().unwrap();
let loaded = repository.load(&"calc-1".to_string()).await.unwrap();
assert_eq!(loaded.state().id, "calc-1");
}
}