use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context};
use atomr_persistence::{
AsyncSnapshotter, Journal, PersistentRepr, RecoveryPermitter, SnapshotPolicy, SnapshotStore,
};
use tokio::sync::oneshot;
use crate::extensions::ExtensionSlots;
use crate::{AggregateRoot, Command, DomainEvent, PatternError};
fn push_dedupe<E: Clone>(
ring: &mut std::collections::VecDeque<(String, Vec<E>)>,
key: String,
events: Vec<E>,
cap: usize,
) {
if cap == 0 {
return;
}
if ring.iter().any(|(k, _)| k == &key) {
return;
}
if ring.len() >= cap {
ring.pop_front();
}
ring.push_back((key, events));
}
pub(crate) struct SnapshotConfig {
pub store: Arc<dyn SnapshotStore>,
pub policy: SnapshotPolicy,
pub keep_last: usize,
}
impl Clone for SnapshotConfig {
fn clone(&self) -> Self {
Self { store: self.store.clone(), policy: self.policy, keep_last: self.keep_last }
}
}
impl SnapshotConfig {
fn should_snapshot(&self, seq: u64) -> bool {
AsyncSnapshotter::new(self.store.clone(), self.policy).should_snapshot(seq)
}
async fn save(&self, pid: String, seq: u64, payload: Vec<u8>) {
AsyncSnapshotter::new(self.store.clone(), self.policy)
.with_keep_last(self.keep_last)
.save(pid, seq, payload)
.await
}
}
pub(crate) type CommandReply<A> = oneshot::Sender<
Result<
Vec<<A as atomr_persistence::Eventsourced>::Event>,
PatternError<<A as atomr_persistence::Eventsourced>::Error>,
>,
>;
pub(crate) struct CommandEnvelope<A: AggregateRoot>
where
A::Command: Command<AggregateId = <A as AggregateRoot>::Id>,
A::Event: DomainEvent,
{
pub cmd: A::Command,
pub reply: CommandReply<A>,
}
pub(crate) struct EntityState<A: AggregateRoot> {
pub aggregate: A,
pub state: A::State,
pub seq: u64,
pub recovered: bool,
pub dedupe: std::collections::VecDeque<(String, Vec<A::Event>)>,
}
impl<A: AggregateRoot> EntityState<A> {
pub(crate) fn new(aggregate: A) -> Self {
Self {
aggregate,
state: <A::State as Default>::default(),
seq: 0,
recovered: false,
dedupe: std::collections::VecDeque::new(),
}
}
}
pub(crate) struct CommandGateway<A, J>
where
A: AggregateRoot,
A::Command: Command<AggregateId = <A as AggregateRoot>::Id>,
A::Event: DomainEvent,
J: Journal,
{
pub factory: Arc<dyn Fn(<A as AggregateRoot>::Id) -> A + Send + Sync>,
pub journal: Arc<J>,
pub permits: Arc<RecoveryPermitter>,
pub writer_uuid: String,
pub entities: HashMap<<A as AggregateRoot>::Id, EntityState<A>>,
pub extensions: ExtensionSlots<A::Command, A::Event, A::Error>,
pub snapshot: Option<SnapshotConfig>,
pub dedupe_window: usize,
}
#[async_trait]
impl<A, J> Actor for CommandGateway<A, J>
where
A: AggregateRoot,
A::Command: Command<AggregateId = <A as AggregateRoot>::Id>,
A::Event: DomainEvent,
J: Journal,
{
type Msg = CommandEnvelope<A>;
async fn handle(&mut self, _ctx: &mut Context<Self>, env: Self::Msg) {
let result = self.process(env.cmd).await;
let _ = env.reply.send(result);
}
}
impl<A, J> CommandGateway<A, J>
where
A: AggregateRoot,
A::Command: Command<AggregateId = <A as AggregateRoot>::Id>,
A::Event: DomainEvent,
J: Journal,
{
async fn process(&mut self, cmd: A::Command) -> Result<Vec<A::Event>, PatternError<A::Error>> {
self.extensions.run_interceptors(&cmd)?;
let id = cmd.aggregate_id();
let mut entity =
self.entities.remove(&id).unwrap_or_else(|| EntityState::new((self.factory)(id.clone())));
if self.dedupe_window > 0 {
if let Some(key) = cmd.command_id().map(|s| s.to_string()) {
if let Some(prev) = entity.dedupe.iter().find(|(k, _)| k == &key) {
let cached = Ok(prev.1.clone());
self.entities.insert(id, entity);
return cached;
}
}
}
let result = self.process_entity(&mut entity, cmd).await;
self.entities.insert(id, entity);
result
}
async fn process_entity(
&mut self,
entity: &mut EntityState<A>,
cmd: A::Command,
) -> Result<Vec<A::Event>, PatternError<A::Error>> {
let dedupe_key = if self.dedupe_window > 0 { cmd.command_id().map(|s| s.to_string()) } else { None };
let expected = cmd.expected_version();
if !entity.recovered {
self.recover_entity(entity).await?;
}
if let Some(expected) = expected {
if entity.seq != expected {
let actual = entity.seq;
let err = PatternError::ConcurrencyConflict { expected, actual };
return Err(err);
}
}
let events = entity.aggregate.command_to_events(&entity.state, cmd).map_err(PatternError::Domain)?;
if events.is_empty() {
return Ok(events);
}
let manifest = entity.aggregate.event_manifest().to_string();
let pid = entity.aggregate.persistence_id();
let pre_seq = entity.seq;
let mut reprs = Vec::with_capacity(events.len());
for e in &events {
entity.seq += 1;
let payload = A::encode_event(e).map_err(PatternError::Codec)?;
reprs.push(PersistentRepr {
persistence_id: pid.clone(),
sequence_nr: entity.seq,
payload,
manifest: manifest.clone(),
writer_uuid: self.writer_uuid.clone(),
deleted: false,
tags: e.tags(),
});
}
if let Err(e) = self.journal.write_messages(reprs).await {
entity.seq = pre_seq;
return Err(PatternError::Journal(e));
}
for e in &events {
A::apply_event(&mut entity.state, e);
}
A::check_invariants(&entity.state).map_err(PatternError::Domain)?;
if let Some(sc) = &self.snapshot {
if sc.should_snapshot(entity.seq) {
if let Some(encode_result) = A::encode_state(&entity.state) {
match encode_result {
Ok(payload) => {
sc.save(entity.aggregate.persistence_id(), entity.seq, payload).await;
}
Err(e) => {
tracing::warn!(error = %e, "snapshot encode failed; skipping");
}
}
}
}
}
for e in &events {
self.extensions.notify_listeners(e);
}
for e in &events {
self.extensions.push_event_taps(e);
}
if let Some(key) = dedupe_key {
push_dedupe(&mut entity.dedupe, key, events.clone(), self.dedupe_window);
}
Ok(events)
}
async fn recover_entity(&mut self, entity: &mut EntityState<A>) -> Result<(), PatternError<A::Error>> {
let _permit = self
.permits
.acquire()
.await
.ok_or_else(|| PatternError::Invariant("recovery permit denied".into()))?;
let pid = entity.aggregate.persistence_id();
let snapshot_seq: Option<u64> = if let Some(sc) = &self.snapshot {
match sc.store.load(&pid).await {
Some((meta, payload)) => match A::decode_state(&payload) {
Ok(state) => {
entity.state = state;
Some(meta.sequence_nr)
}
Err(e) => {
tracing::warn!(
pid = %pid,
error = %e,
"snapshot decode failed; falling back to full journal replay"
);
None
}
},
None => None,
}
} else {
None
};
let highest = self.journal.highest_sequence_nr(&pid, 0).await.map_err(PatternError::Journal)?;
let from = snapshot_seq.map(|s| s + 1).unwrap_or(1);
if highest >= from {
let events = self
.journal
.replay_messages(&pid, from, highest, u64::MAX)
.await
.map_err(PatternError::Journal)?;
for e in &events {
let evt = A::decode_event(&e.payload).map_err(PatternError::Codec)?;
A::apply_event(&mut entity.state, &evt);
}
}
entity.seq = highest;
entity.recovered = true;
drop(_permit);
entity.aggregate.recovery_completed(&entity.state, highest).await;
Ok(())
}
}