use std::{collections::HashMap, fmt, marker::PhantomData, sync::Arc, time::Duration};
use tokio::{sync::RwLock, time::timeout};
use super::Event;
pub type SagaResult<T> = Result<T, SagaError>;
#[derive(Debug, Clone)]
pub enum SagaError {
StepFailed {
step_index: usize,
step_name: String,
error: String,
},
CompensationFailed {
step_index: usize,
error: String,
},
Timeout {
step_index: usize,
duration: Duration,
},
InvalidStep(usize),
AlreadyExecuting,
}
impl fmt::Display for SagaError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SagaError::StepFailed {
step_index,
step_name,
error,
} => {
write!(f, "Step {} ({}) failed: {}", step_index, step_name, error)
}
SagaError::CompensationFailed { step_index, error } => {
write!(f, "Compensation for step {} failed: {}", step_index, error)
}
SagaError::Timeout {
step_index,
duration,
} => {
write!(f, "Step {} timed out after {:?}", step_index, duration)
}
SagaError::InvalidStep(index) => write!(f, "Invalid step index: {}", index),
SagaError::AlreadyExecuting => write!(f, "Saga is already executing"),
}
}
}
impl std::error::Error for SagaError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SagaStatus {
NotStarted,
Executing,
Completed,
Compensated,
Failed,
}
#[derive(Debug, Clone)]
pub struct SagaMetadata {
pub id: String,
pub status: SagaStatus,
pub steps_executed: usize,
pub total_steps: usize,
pub updated_at: std::time::SystemTime,
}
#[async_trait::async_trait]
pub trait SagaStep<E: Event>: Send + Sync {
async fn execute(&self) -> Result<Vec<E>, String>;
async fn compensate(&self) -> Result<Vec<E>, String>;
fn name(&self) -> &str;
fn timeout_duration(&self) -> Duration {
Duration::from_secs(30) }
}
pub struct SagaDefinition<E: Event> {
id: String,
steps: Vec<Box<dyn SagaStep<E>>>,
metadata: SagaMetadata,
compensation_strategy: Option<super::saga::CompensationStrategy>,
snapshot_dir: Option<std::path::PathBuf>,
}
impl<E: Event> SagaDefinition<E> {
pub fn new(id: impl Into<String>) -> Self {
let id = id.into();
Self {
metadata: SagaMetadata {
id: id.clone(),
status: SagaStatus::NotStarted,
steps_executed: 0,
total_steps: 0,
updated_at: std::time::SystemTime::now(),
},
id,
steps: Vec::new(),
compensation_strategy: None,
snapshot_dir: None,
}
}
pub fn add_step<S: SagaStep<E> + 'static>(mut self, step: S) -> Self {
self.steps.push(Box::new(step));
self.metadata.total_steps = self.steps.len();
self
}
pub fn id(&self) -> &str {
&self.id
}
pub fn status(&self) -> SagaStatus {
self.metadata.status.clone()
}
pub fn metadata(&self) -> &SagaMetadata {
&self.metadata
}
pub fn with_compensation(mut self, strategy: super::saga::CompensationStrategy) -> Self {
self.compensation_strategy = Some(strategy);
self
}
pub fn with_snapshot_dir(mut self, dir: &std::path::Path) -> Self {
self.snapshot_dir = Some(dir.to_path_buf());
self
}
}
pub struct SagaOrchestrator<E: Event> {
sagas: Arc<RwLock<HashMap<String, SagaMetadata>>>,
history: Arc<RwLock<Vec<SagaMetadata>>>,
_phantom: PhantomData<E>,
}
impl<E: Event> SagaOrchestrator<E> {
pub fn new() -> Self {
Self {
sagas: Arc::new(RwLock::new(HashMap::new())),
history: Arc::new(RwLock::new(Vec::new())),
_phantom: PhantomData,
}
}
pub async fn execute(&self, mut saga: SagaDefinition<E>) -> SagaResult<Vec<E>> {
{
let sagas = self.sagas.read().await;
if sagas.contains_key(&saga.id) {
return Err(SagaError::AlreadyExecuting);
}
}
saga.metadata.status = SagaStatus::Executing;
saga.metadata.updated_at = std::time::SystemTime::now();
{
let mut sagas = self.sagas.write().await;
sagas.insert(saga.id.clone(), saga.metadata.clone());
}
let mut all_events = Vec::new();
let mut executed_steps = 0;
for (index, step) in saga.steps.iter().enumerate() {
let step_timeout = step.timeout_duration();
let result = timeout(step_timeout, step.execute()).await;
match result {
Ok(Ok(events)) => {
all_events.extend(events);
executed_steps += 1;
saga.metadata.steps_executed = executed_steps;
saga.metadata.updated_at = std::time::SystemTime::now();
}
Ok(Err(error)) => {
saga.metadata.status = SagaStatus::Failed;
let compensation_result = self.compensate_steps(&saga.steps[0..index]).await;
{
let mut sagas = self.sagas.write().await;
sagas.remove(&saga.id);
}
{
let mut history = self.history.write().await;
saga.metadata.status = if compensation_result.is_ok() {
SagaStatus::Compensated
} else {
SagaStatus::Failed
};
history.push(saga.metadata.clone());
}
return Err(SagaError::StepFailed {
step_index: index,
step_name: step.name().to_string(),
error,
});
}
Err(_) => {
saga.metadata.status = SagaStatus::Failed;
let _ = self.compensate_steps(&saga.steps[0..index]).await;
{
let mut sagas = self.sagas.write().await;
sagas.remove(&saga.id);
}
return Err(SagaError::Timeout {
step_index: index,
duration: step_timeout,
});
}
}
}
saga.metadata.status = SagaStatus::Completed;
saga.metadata.updated_at = std::time::SystemTime::now();
{
let mut sagas = self.sagas.write().await;
sagas.remove(&saga.id);
}
{
let mut history = self.history.write().await;
history.push(saga.metadata);
}
Ok(all_events)
}
async fn compensate_steps(&self, steps: &[Box<dyn SagaStep<E>>]) -> Result<(), String> {
for step in steps.iter().rev() {
step.compensate().await?;
}
Ok(())
}
pub async fn get_saga(&self, id: &str) -> Option<SagaMetadata> {
let sagas = self.sagas.read().await;
sagas.get(id).cloned()
}
pub async fn get_running_sagas(&self) -> Vec<SagaMetadata> {
let sagas = self.sagas.read().await;
sagas.values().cloned().collect()
}
pub async fn get_history(&self) -> Vec<SagaMetadata> {
let history = self.history.read().await;
history.clone()
}
pub async fn running_count(&self) -> usize {
self.sagas.read().await.len()
}
pub async fn history_count(&self) -> usize {
self.history.read().await.len()
}
}
impl<E: Event> Default for SagaOrchestrator<E> {
fn default() -> Self {
Self::new()
}
}
impl<E: Event> Clone for SagaOrchestrator<E> {
fn clone(&self) -> Self {
Self {
sagas: Arc::clone(&self.sagas),
history: Arc::clone(&self.history),
_phantom: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cqrs::EventTypeName;
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
enum TestEvent {
Debited { account: String, amount: f64 },
Credited { account: String, amount: f64 },
}
impl EventTypeName for TestEvent {}
impl Event for TestEvent {}
struct DebitStep {
account: String,
amount: f64,
}
#[async_trait::async_trait]
impl SagaStep<TestEvent> for DebitStep {
async fn execute(&self) -> Result<Vec<TestEvent>, String> {
Ok(vec![TestEvent::Debited {
account: self.account.clone(),
amount: self.amount,
}])
}
async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
Ok(vec![TestEvent::Credited {
account: self.account.clone(),
amount: self.amount,
}])
}
fn name(&self) -> &str {
"DebitStep"
}
}
struct CreditStep {
account: String,
amount: f64,
}
#[async_trait::async_trait]
impl SagaStep<TestEvent> for CreditStep {
async fn execute(&self) -> Result<Vec<TestEvent>, String> {
Ok(vec![TestEvent::Credited {
account: self.account.clone(),
amount: self.amount,
}])
}
async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
Ok(vec![TestEvent::Debited {
account: self.account.clone(),
amount: self.amount,
}])
}
fn name(&self) -> &str {
"CreditStep"
}
}
#[tokio::test]
async fn test_successful_saga() {
let orchestrator = SagaOrchestrator::<TestEvent>::new();
let saga = SagaDefinition::new("transfer-1")
.add_step(DebitStep {
account: "A".to_string(),
amount: 100.0,
})
.add_step(CreditStep {
account: "B".to_string(),
amount: 100.0,
});
let events = orchestrator.execute(saga).await.unwrap();
assert_eq!(events.len(), 2);
assert_eq!(orchestrator.running_count().await, 0);
assert_eq!(orchestrator.history_count().await, 1);
}
#[tokio::test]
async fn test_saga_metadata() {
let orchestrator = SagaOrchestrator::<TestEvent>::new();
let saga = SagaDefinition::new("transfer-2").add_step(DebitStep {
account: "A".to_string(),
amount: 50.0,
});
assert_eq!(saga.id(), "transfer-2");
assert_eq!(saga.status(), SagaStatus::NotStarted);
assert_eq!(saga.metadata().total_steps, 1);
orchestrator.execute(saga).await.unwrap();
let history = orchestrator.get_history().await;
assert_eq!(history.len(), 1);
assert_eq!(history[0].status, SagaStatus::Completed);
}
#[tokio::test]
async fn test_saga_definition_builder() {
let saga = SagaDefinition::<TestEvent>::new("test-saga")
.add_step(DebitStep {
account: "A".to_string(),
amount: 10.0,
})
.add_step(CreditStep {
account: "B".to_string(),
amount: 10.0,
});
assert_eq!(saga.metadata().total_steps, 2);
assert_eq!(saga.status(), SagaStatus::NotStarted);
}
#[tokio::test]
async fn test_multiple_sagas() {
let orchestrator = SagaOrchestrator::<TestEvent>::new();
let saga1 = SagaDefinition::new("transfer-1").add_step(DebitStep {
account: "A".to_string(),
amount: 100.0,
});
let saga2 = SagaDefinition::new("transfer-2").add_step(DebitStep {
account: "B".to_string(),
amount: 200.0,
});
orchestrator.execute(saga1).await.unwrap();
orchestrator.execute(saga2).await.unwrap();
assert_eq!(orchestrator.history_count().await, 2);
}
}