1use std::{collections::HashMap, fmt, marker::PhantomData, sync::Arc, time::Duration};
8
9use tokio::{sync::RwLock, time::timeout};
10
11use super::Event;
12
13pub type SagaResult<T> = Result<T, SagaError>;
15
16#[derive(Debug, Clone)]
18pub enum SagaError {
19 StepFailed {
21 step_index: usize,
23 step_name: String,
25 error: String,
27 },
28 CompensationFailed {
30 step_index: usize,
32 error: String,
34 },
35 Timeout {
37 step_index: usize,
39 duration: Duration,
41 },
42 InvalidStep(usize),
44 AlreadyExecuting,
46}
47
48impl fmt::Display for SagaError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 SagaError::StepFailed {
52 step_index,
53 step_name,
54 error,
55 } => {
56 write!(f, "Step {} ({}) failed: {}", step_index, step_name, error)
57 }
58 SagaError::CompensationFailed { step_index, error } => {
59 write!(f, "Compensation for step {} failed: {}", step_index, error)
60 }
61 SagaError::Timeout {
62 step_index,
63 duration,
64 } => {
65 write!(f, "Step {} timed out after {:?}", step_index, duration)
66 }
67 SagaError::InvalidStep(index) => write!(f, "Invalid step index: {}", index),
68 SagaError::AlreadyExecuting => write!(f, "Saga is already executing"),
69 }
70 }
71}
72
73impl std::error::Error for SagaError {}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum SagaStatus {
78 NotStarted,
80 Executing,
82 Completed,
84 Compensated,
86 Failed,
88}
89
90#[derive(Debug, Clone)]
92pub struct SagaMetadata {
93 pub id: String,
95 pub status: SagaStatus,
97 pub steps_executed: usize,
99 pub total_steps: usize,
101 pub updated_at: std::time::SystemTime,
103}
104
105#[async_trait::async_trait]
107pub trait SagaStep<E: Event>: Send + Sync {
108 async fn execute(&self) -> Result<Vec<E>, String>;
110
111 async fn compensate(&self) -> Result<Vec<E>, String>;
113
114 fn name(&self) -> &str;
116
117 fn timeout_duration(&self) -> Duration {
119 Duration::from_secs(30) }
121}
122
123pub struct SagaDefinition<E: Event> {
125 id: String,
127 steps: Vec<Box<dyn SagaStep<E>>>,
129 metadata: SagaMetadata,
131 compensation_strategy: Option<super::saga::CompensationStrategy>,
133 snapshot_dir: Option<std::path::PathBuf>,
135}
136
137impl<E: Event> SagaDefinition<E> {
138 pub fn new(id: impl Into<String>) -> Self {
140 let id = id.into();
141 Self {
142 metadata: SagaMetadata {
143 id: id.clone(),
144 status: SagaStatus::NotStarted,
145 steps_executed: 0,
146 total_steps: 0,
147 updated_at: std::time::SystemTime::now(),
148 },
149 id,
150 steps: Vec::new(),
151 compensation_strategy: None,
152 snapshot_dir: None,
153 }
154 }
155
156 pub fn add_step<S: SagaStep<E> + 'static>(mut self, step: S) -> Self {
158 self.steps.push(Box::new(step));
159 self.metadata.total_steps = self.steps.len();
160 self
161 }
162
163 pub fn id(&self) -> &str {
165 &self.id
166 }
167
168 pub fn status(&self) -> SagaStatus {
170 self.metadata.status.clone()
171 }
172
173 pub fn metadata(&self) -> &SagaMetadata {
175 &self.metadata
176 }
177
178 pub fn with_compensation(mut self, strategy: super::saga::CompensationStrategy) -> Self {
180 self.compensation_strategy = Some(strategy);
181 self
182 }
183
184 pub fn with_snapshot_dir(mut self, dir: &std::path::Path) -> Self {
186 self.snapshot_dir = Some(dir.to_path_buf());
187 self
188 }
189}
190
191pub struct SagaOrchestrator<E: Event> {
193 sagas: Arc<RwLock<HashMap<String, SagaMetadata>>>,
195 history: Arc<RwLock<Vec<SagaMetadata>>>,
197 _phantom: PhantomData<E>,
198}
199
200impl<E: Event> SagaOrchestrator<E> {
201 pub fn new() -> Self {
203 Self {
204 sagas: Arc::new(RwLock::new(HashMap::new())),
205 history: Arc::new(RwLock::new(Vec::new())),
206 _phantom: PhantomData,
207 }
208 }
209
210 pub async fn execute(&self, mut saga: SagaDefinition<E>) -> SagaResult<Vec<E>> {
212 {
214 let sagas = self.sagas.read().await;
215 if sagas.contains_key(&saga.id) {
216 return Err(SagaError::AlreadyExecuting);
217 }
218 }
219
220 saga.metadata.status = SagaStatus::Executing;
222 saga.metadata.updated_at = std::time::SystemTime::now();
223 {
224 let mut sagas = self.sagas.write().await;
225 sagas.insert(saga.id.clone(), saga.metadata.clone());
226 }
227
228 let mut all_events = Vec::new();
229 let mut executed_steps = 0;
230
231 for (index, step) in saga.steps.iter().enumerate() {
233 let step_timeout = step.timeout_duration();
235 let result = timeout(step_timeout, step.execute()).await;
236
237 match result {
238 Ok(Ok(events)) => {
239 all_events.extend(events);
241 executed_steps += 1;
242 saga.metadata.steps_executed = executed_steps;
243 saga.metadata.updated_at = std::time::SystemTime::now();
244 }
245 Ok(Err(error)) => {
246 saga.metadata.status = SagaStatus::Failed;
248 let compensation_result = self.compensate_steps(&saga.steps[0..index]).await;
249
250 {
252 let mut sagas = self.sagas.write().await;
253 sagas.remove(&saga.id);
254 }
255
256 {
258 let mut history = self.history.write().await;
259 saga.metadata.status = if compensation_result.is_ok() {
260 SagaStatus::Compensated
261 } else {
262 SagaStatus::Failed
263 };
264 history.push(saga.metadata.clone());
265 }
266
267 return Err(SagaError::StepFailed {
268 step_index: index,
269 step_name: step.name().to_string(),
270 error,
271 });
272 }
273 Err(_) => {
274 saga.metadata.status = SagaStatus::Failed;
276 let _ = self.compensate_steps(&saga.steps[0..index]).await;
277
278 {
279 let mut sagas = self.sagas.write().await;
280 sagas.remove(&saga.id);
281 }
282
283 return Err(SagaError::Timeout {
284 step_index: index,
285 duration: step_timeout,
286 });
287 }
288 }
289 }
290
291 saga.metadata.status = SagaStatus::Completed;
293 saga.metadata.updated_at = std::time::SystemTime::now();
294
295 {
297 let mut sagas = self.sagas.write().await;
298 sagas.remove(&saga.id);
299 }
300 {
301 let mut history = self.history.write().await;
302 history.push(saga.metadata);
303 }
304
305 Ok(all_events)
306 }
307
308 async fn compensate_steps(&self, steps: &[Box<dyn SagaStep<E>>]) -> Result<(), String> {
310 for step in steps.iter().rev() {
312 step.compensate().await?;
313 }
314 Ok(())
315 }
316
317 pub async fn get_saga(&self, id: &str) -> Option<SagaMetadata> {
319 let sagas = self.sagas.read().await;
320 sagas.get(id).cloned()
321 }
322
323 pub async fn get_running_sagas(&self) -> Vec<SagaMetadata> {
325 let sagas = self.sagas.read().await;
326 sagas.values().cloned().collect()
327 }
328
329 pub async fn get_history(&self) -> Vec<SagaMetadata> {
331 let history = self.history.read().await;
332 history.clone()
333 }
334
335 pub async fn running_count(&self) -> usize {
337 self.sagas.read().await.len()
338 }
339
340 pub async fn history_count(&self) -> usize {
342 self.history.read().await.len()
343 }
344}
345
346impl<E: Event> Default for SagaOrchestrator<E> {
347 fn default() -> Self {
348 Self::new()
349 }
350}
351
352impl<E: Event> Clone for SagaOrchestrator<E> {
353 fn clone(&self) -> Self {
354 Self {
355 sagas: Arc::clone(&self.sagas),
356 history: Arc::clone(&self.history),
357 _phantom: PhantomData,
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::cqrs::EventTypeName;
366
367 #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
368 enum TestEvent {
369 Debited { account: String, amount: f64 },
370 Credited { account: String, amount: f64 },
371 }
372
373 impl EventTypeName for TestEvent {}
374 impl Event for TestEvent {}
375
376 struct DebitStep {
377 account: String,
378 amount: f64,
379 }
380
381 #[async_trait::async_trait]
382 impl SagaStep<TestEvent> for DebitStep {
383 async fn execute(&self) -> Result<Vec<TestEvent>, String> {
384 Ok(vec![TestEvent::Debited {
385 account: self.account.clone(),
386 amount: self.amount,
387 }])
388 }
389
390 async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
391 Ok(vec![TestEvent::Credited {
393 account: self.account.clone(),
394 amount: self.amount,
395 }])
396 }
397
398 fn name(&self) -> &str {
399 "DebitStep"
400 }
401 }
402
403 struct CreditStep {
404 account: String,
405 amount: f64,
406 }
407
408 #[async_trait::async_trait]
409 impl SagaStep<TestEvent> for CreditStep {
410 async fn execute(&self) -> Result<Vec<TestEvent>, String> {
411 Ok(vec![TestEvent::Credited {
412 account: self.account.clone(),
413 amount: self.amount,
414 }])
415 }
416
417 async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
418 Ok(vec![TestEvent::Debited {
420 account: self.account.clone(),
421 amount: self.amount,
422 }])
423 }
424
425 fn name(&self) -> &str {
426 "CreditStep"
427 }
428 }
429
430 #[tokio::test]
431 async fn test_successful_saga() {
432 let orchestrator = SagaOrchestrator::<TestEvent>::new();
433
434 let saga = SagaDefinition::new("transfer-1")
435 .add_step(DebitStep {
436 account: "A".to_string(),
437 amount: 100.0,
438 })
439 .add_step(CreditStep {
440 account: "B".to_string(),
441 amount: 100.0,
442 });
443
444 let events = orchestrator.execute(saga).await.unwrap();
445
446 assert_eq!(events.len(), 2);
447 assert_eq!(orchestrator.running_count().await, 0);
448 assert_eq!(orchestrator.history_count().await, 1);
449 }
450
451 #[tokio::test]
452 async fn test_saga_metadata() {
453 let orchestrator = SagaOrchestrator::<TestEvent>::new();
454
455 let saga = SagaDefinition::new("transfer-2").add_step(DebitStep {
456 account: "A".to_string(),
457 amount: 50.0,
458 });
459
460 assert_eq!(saga.id(), "transfer-2");
461 assert_eq!(saga.status(), SagaStatus::NotStarted);
462 assert_eq!(saga.metadata().total_steps, 1);
463
464 orchestrator.execute(saga).await.unwrap();
465
466 let history = orchestrator.get_history().await;
467 assert_eq!(history.len(), 1);
468 assert_eq!(history[0].status, SagaStatus::Completed);
469 }
470
471 #[tokio::test]
472 async fn test_saga_definition_builder() {
473 let saga = SagaDefinition::<TestEvent>::new("test-saga")
474 .add_step(DebitStep {
475 account: "A".to_string(),
476 amount: 10.0,
477 })
478 .add_step(CreditStep {
479 account: "B".to_string(),
480 amount: 10.0,
481 });
482
483 assert_eq!(saga.metadata().total_steps, 2);
484 assert_eq!(saga.status(), SagaStatus::NotStarted);
485 }
486
487 #[tokio::test]
488 async fn test_multiple_sagas() {
489 let orchestrator = SagaOrchestrator::<TestEvent>::new();
490
491 let saga1 = SagaDefinition::new("transfer-1").add_step(DebitStep {
492 account: "A".to_string(),
493 amount: 100.0,
494 });
495
496 let saga2 = SagaDefinition::new("transfer-2").add_step(DebitStep {
497 account: "B".to_string(),
498 amount: 200.0,
499 });
500
501 orchestrator.execute(saga1).await.unwrap();
502 orchestrator.execute(saga2).await.unwrap();
503
504 assert_eq!(orchestrator.history_count().await, 2);
505 }
506}