use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use atomr_core::actor::ActorSystem;
use tokio::sync::mpsc::UnboundedReceiver;
use crate::saga::state_store::{InMemorySagaStateStore, SagaStateStore};
use crate::topology::Topology;
use crate::PatternError;
pub enum SagaAction<C> {
Send(C),
Schedule(C, Duration),
Compensate(Vec<C>),
Complete,
}
#[async_trait]
pub trait Saga: Send + 'static {
type Event: Send + Clone + 'static;
type Command: Send + 'static;
type State: Default + Send + 'static;
type Error: std::error::Error + Send + 'static;
fn correlation_id(event: &Self::Event) -> Option<String>;
async fn handle(
&mut self,
state: &mut Self::State,
event: Self::Event,
) -> Result<Vec<SagaAction<Self::Command>>, Self::Error>;
fn encode_state(_state: &Self::State) -> Option<Result<Vec<u8>, String>> {
None
}
fn decode_state(_bytes: &[u8]) -> Result<Self::State, String> {
Err("decode_state not implemented".into())
}
}
pub struct SagaPattern<S>(PhantomData<S>);
impl<S: Saga> SagaPattern<S> {
pub fn builder() -> SagaBuilder<S> {
SagaBuilder::default()
}
}
type SagaDispatcher<C> = Arc<dyn Fn(C) -> futures::future::BoxFuture<'static, bool> + Send + Sync>;
pub struct SagaBuilder<S: Saga> {
name: Option<String>,
saga: Option<S>,
events: Option<UnboundedReceiver<S::Event>>,
dispatcher: Option<SagaDispatcher<S::Command>>,
state_store: Option<Arc<dyn SagaStateStore>>,
}
impl<S: Saga> Default for SagaBuilder<S> {
fn default() -> Self {
Self { name: None, saga: None, events: None, dispatcher: None, state_store: None }
}
}
impl<S: Saga> SagaBuilder<S> {
pub fn name(mut self, n: impl Into<String>) -> Self {
self.name = Some(n.into());
self
}
pub fn saga(mut self, s: S) -> Self {
self.saga = Some(s);
self
}
pub fn events(mut self, rx: UnboundedReceiver<S::Event>) -> Self {
self.events = Some(rx);
self
}
pub fn dispatcher<F, Fut>(mut self, f: F) -> Self
where
F: Fn(S::Command) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = bool> + Send + 'static,
{
let f = Arc::new(f);
self.dispatcher = Some(Arc::new(move |cmd| {
let f = f.clone();
Box::pin(async move { f(cmd).await })
}));
self
}
pub fn state_store<T: SagaStateStore>(mut self, store: Arc<T>) -> Self {
self.state_store = Some(store);
self
}
pub fn build(self) -> Result<SagaTopology<S>, PatternError<S::Error>> {
let state_store: Arc<dyn SagaStateStore> =
self.state_store.unwrap_or_else(|| Arc::new(InMemorySagaStateStore::new()));
Ok(SagaTopology {
name: self.name.unwrap_or_else(|| "saga".into()),
saga: self.saga.ok_or(PatternError::NotConfigured("saga"))?,
events: self.events.ok_or(PatternError::NotConfigured("events"))?,
dispatcher: self.dispatcher.ok_or(PatternError::NotConfigured("dispatcher"))?,
state_store,
})
}
}
pub struct SagaTopology<S: Saga> {
name: String,
saga: S,
events: UnboundedReceiver<S::Event>,
dispatcher: SagaDispatcher<S::Command>,
state_store: Arc<dyn SagaStateStore>,
}
pub struct SagaHandles {
pub name: String,
}
#[async_trait]
impl<S: Saga> Topology for SagaTopology<S> {
type Handles = SagaHandles;
async fn materialize(self, _system: &ActorSystem) -> Result<SagaHandles, PatternError<()>> {
let SagaTopology { name, mut saga, mut events, dispatcher, state_store } = self;
let task_name = name.clone();
tokio::spawn(async move {
let mut states: HashMap<String, S::State> = HashMap::new();
if S::encode_state(&S::State::default()).is_some() {
for corr in state_store.keys().await {
if let Some(payload) = state_store.load(&corr).await {
match S::decode_state(&payload) {
Ok(state) => {
states.insert(corr, state);
}
Err(e) => {
tracing::warn!(
saga = %task_name,
error = %e,
"decode saga state failed; dropping"
);
}
}
}
}
}
while let Some(event) = events.recv().await {
let Some(corr) = S::correlation_id(&event) else {
continue;
};
let state = states.entry(corr.clone()).or_default();
match saga.handle(state, event).await {
Ok(actions) => {
if let Some(Ok(payload)) = S::encode_state(state) {
state_store.save(&corr, payload).await;
}
let mut completed = false;
for action in actions {
match action {
SagaAction::Send(c) => {
let _ = (dispatcher)(c).await;
}
SagaAction::Schedule(c, delay) => {
let dispatcher = dispatcher.clone();
tokio::spawn(async move {
tokio::time::sleep(delay).await;
let _ = (dispatcher)(c).await;
});
}
SagaAction::Compensate(cs) => {
for c in cs {
let _ = (dispatcher)(c).await;
}
}
SagaAction::Complete => {
completed = true;
break;
}
}
}
if completed {
states.remove(&corr);
state_store.delete(&corr).await;
}
}
Err(e) => {
tracing::warn!(saga = %task_name, error = %e, "saga handle failed");
}
}
}
});
Ok(SagaHandles { name })
}
}