1use std::{collections::HashMap, fmt, marker::PhantomData, sync::Arc, time::Duration};
8
9use tokio::{sync::RwLock, time::timeout};
10
11use super::Event;
12
13pub type SagaResult<T> = Result<T, SagaError>;
15
16#[derive(Debug, Clone)]
18pub enum SagaError {
19 StepFailed {
21 step_index: usize,
23 step_name: String,
25 error: String,
27 },
28 CompensationFailed {
30 step_index: usize,
32 error: String,
34 },
35 Timeout {
37 step_index: usize,
39 duration: Duration,
41 },
42 InvalidStep(usize),
44 AlreadyExecuting,
46}
47
48impl fmt::Display for SagaError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 SagaError::StepFailed {
52 step_index,
53 step_name,
54 error,
55 } => {
56 write!(f, "Step {} ({}) failed: {}", step_index, step_name, error)
57 }
58 SagaError::CompensationFailed { step_index, error } => {
59 write!(f, "Compensation for step {} failed: {}", step_index, error)
60 }
61 SagaError::Timeout {
62 step_index,
63 duration,
64 } => {
65 write!(f, "Step {} timed out after {:?}", step_index, duration)
66 }
67 SagaError::InvalidStep(index) => write!(f, "Invalid step index: {}", index),
68 SagaError::AlreadyExecuting => write!(f, "Saga is already executing"),
69 }
70 }
71}
72
73impl std::error::Error for SagaError {}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum SagaStatus {
78 NotStarted,
80 Executing,
82 Completed,
84 Compensated,
86 Failed,
88}
89
90#[derive(Debug, Clone)]
92pub struct SagaMetadata {
93 pub id: String,
95 pub status: SagaStatus,
97 pub steps_executed: usize,
99 pub total_steps: usize,
101 pub updated_at: std::time::SystemTime,
103}
104
105#[async_trait::async_trait]
107pub trait SagaStep<E: Event>: Send + Sync {
108 async fn execute(&self) -> Result<Vec<E>, String>;
110
111 async fn compensate(&self) -> Result<Vec<E>, String>;
113
114 fn name(&self) -> &str;
116
117 fn timeout_duration(&self) -> Duration {
119 Duration::from_secs(30) }
121}
122
123pub struct SagaDefinition<E: Event> {
125 id: String,
127 steps: Vec<Box<dyn SagaStep<E>>>,
129 metadata: SagaMetadata,
131}
132
133impl<E: Event> SagaDefinition<E> {
134 pub fn new(id: impl Into<String>) -> Self {
136 let id = id.into();
137 Self {
138 metadata: SagaMetadata {
139 id: id.clone(),
140 status: SagaStatus::NotStarted,
141 steps_executed: 0,
142 total_steps: 0,
143 updated_at: std::time::SystemTime::now(),
144 },
145 id,
146 steps: Vec::new(),
147 }
148 }
149
150 pub fn add_step<S: SagaStep<E> + 'static>(mut self, step: S) -> Self {
152 self.steps.push(Box::new(step));
153 self.metadata.total_steps = self.steps.len();
154 self
155 }
156
157 pub fn id(&self) -> &str {
159 &self.id
160 }
161
162 pub fn status(&self) -> SagaStatus {
164 self.metadata.status.clone()
165 }
166
167 pub fn metadata(&self) -> &SagaMetadata {
169 &self.metadata
170 }
171}
172
173pub struct SagaOrchestrator<E: Event> {
175 sagas: Arc<RwLock<HashMap<String, SagaMetadata>>>,
177 history: Arc<RwLock<Vec<SagaMetadata>>>,
179 _phantom: PhantomData<E>,
180}
181
182impl<E: Event> SagaOrchestrator<E> {
183 pub fn new() -> Self {
185 Self {
186 sagas: Arc::new(RwLock::new(HashMap::new())),
187 history: Arc::new(RwLock::new(Vec::new())),
188 _phantom: PhantomData,
189 }
190 }
191
192 pub async fn execute(&self, mut saga: SagaDefinition<E>) -> SagaResult<Vec<E>> {
194 {
196 let sagas = self.sagas.read().await;
197 if sagas.contains_key(&saga.id) {
198 return Err(SagaError::AlreadyExecuting);
199 }
200 }
201
202 saga.metadata.status = SagaStatus::Executing;
204 saga.metadata.updated_at = std::time::SystemTime::now();
205 {
206 let mut sagas = self.sagas.write().await;
207 sagas.insert(saga.id.clone(), saga.metadata.clone());
208 }
209
210 let mut all_events = Vec::new();
211 let mut executed_steps = 0;
212
213 for (index, step) in saga.steps.iter().enumerate() {
215 let step_timeout = step.timeout_duration();
217 let result = timeout(step_timeout, step.execute()).await;
218
219 match result {
220 Ok(Ok(events)) => {
221 all_events.extend(events);
223 executed_steps += 1;
224 saga.metadata.steps_executed = executed_steps;
225 saga.metadata.updated_at = std::time::SystemTime::now();
226 }
227 Ok(Err(error)) => {
228 saga.metadata.status = SagaStatus::Failed;
230 let compensation_result = self.compensate_steps(&saga.steps[0..index]).await;
231
232 {
234 let mut sagas = self.sagas.write().await;
235 sagas.remove(&saga.id);
236 }
237
238 {
240 let mut history = self.history.write().await;
241 saga.metadata.status = if compensation_result.is_ok() {
242 SagaStatus::Compensated
243 } else {
244 SagaStatus::Failed
245 };
246 history.push(saga.metadata.clone());
247 }
248
249 return Err(SagaError::StepFailed {
250 step_index: index,
251 step_name: step.name().to_string(),
252 error,
253 });
254 }
255 Err(_) => {
256 saga.metadata.status = SagaStatus::Failed;
258 let _ = self.compensate_steps(&saga.steps[0..index]).await;
259
260 {
261 let mut sagas = self.sagas.write().await;
262 sagas.remove(&saga.id);
263 }
264
265 return Err(SagaError::Timeout {
266 step_index: index,
267 duration: step_timeout,
268 });
269 }
270 }
271 }
272
273 saga.metadata.status = SagaStatus::Completed;
275 saga.metadata.updated_at = std::time::SystemTime::now();
276
277 {
279 let mut sagas = self.sagas.write().await;
280 sagas.remove(&saga.id);
281 }
282 {
283 let mut history = self.history.write().await;
284 history.push(saga.metadata);
285 }
286
287 Ok(all_events)
288 }
289
290 async fn compensate_steps(&self, steps: &[Box<dyn SagaStep<E>>]) -> Result<(), String> {
292 for step in steps.iter().rev() {
294 step.compensate().await?;
295 }
296 Ok(())
297 }
298
299 pub async fn get_saga(&self, id: &str) -> Option<SagaMetadata> {
301 let sagas = self.sagas.read().await;
302 sagas.get(id).cloned()
303 }
304
305 pub async fn get_running_sagas(&self) -> Vec<SagaMetadata> {
307 let sagas = self.sagas.read().await;
308 sagas.values().cloned().collect()
309 }
310
311 pub async fn get_history(&self) -> Vec<SagaMetadata> {
313 let history = self.history.read().await;
314 history.clone()
315 }
316
317 pub async fn running_count(&self) -> usize {
319 self.sagas.read().await.len()
320 }
321
322 pub async fn history_count(&self) -> usize {
324 self.history.read().await.len()
325 }
326}
327
328impl<E: Event> Default for SagaOrchestrator<E> {
329 fn default() -> Self {
330 Self::new()
331 }
332}
333
334impl<E: Event> Clone for SagaOrchestrator<E> {
335 fn clone(&self) -> Self {
336 Self {
337 sagas: Arc::clone(&self.sagas),
338 history: Arc::clone(&self.history),
339 _phantom: PhantomData,
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
349 enum TestEvent {
350 Debited { account: String, amount: f64 },
351 Credited { account: String, amount: f64 },
352 }
353
354 impl Event for TestEvent {}
355
356 struct DebitStep {
357 account: String,
358 amount: f64,
359 }
360
361 #[async_trait::async_trait]
362 impl SagaStep<TestEvent> for DebitStep {
363 async fn execute(&self) -> Result<Vec<TestEvent>, String> {
364 Ok(vec![TestEvent::Debited {
365 account: self.account.clone(),
366 amount: self.amount,
367 }])
368 }
369
370 async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
371 Ok(vec![TestEvent::Credited {
373 account: self.account.clone(),
374 amount: self.amount,
375 }])
376 }
377
378 fn name(&self) -> &str {
379 "DebitStep"
380 }
381 }
382
383 struct CreditStep {
384 account: String,
385 amount: f64,
386 }
387
388 #[async_trait::async_trait]
389 impl SagaStep<TestEvent> for CreditStep {
390 async fn execute(&self) -> Result<Vec<TestEvent>, String> {
391 Ok(vec![TestEvent::Credited {
392 account: self.account.clone(),
393 amount: self.amount,
394 }])
395 }
396
397 async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
398 Ok(vec![TestEvent::Debited {
400 account: self.account.clone(),
401 amount: self.amount,
402 }])
403 }
404
405 fn name(&self) -> &str {
406 "CreditStep"
407 }
408 }
409
410 #[tokio::test]
411 async fn test_successful_saga() {
412 let orchestrator = SagaOrchestrator::<TestEvent>::new();
413
414 let saga = SagaDefinition::new("transfer-1")
415 .add_step(DebitStep {
416 account: "A".to_string(),
417 amount: 100.0,
418 })
419 .add_step(CreditStep {
420 account: "B".to_string(),
421 amount: 100.0,
422 });
423
424 let events = orchestrator.execute(saga).await.unwrap();
425
426 assert_eq!(events.len(), 2);
427 assert_eq!(orchestrator.running_count().await, 0);
428 assert_eq!(orchestrator.history_count().await, 1);
429 }
430
431 #[tokio::test]
432 async fn test_saga_metadata() {
433 let orchestrator = SagaOrchestrator::<TestEvent>::new();
434
435 let saga = SagaDefinition::new("transfer-2").add_step(DebitStep {
436 account: "A".to_string(),
437 amount: 50.0,
438 });
439
440 assert_eq!(saga.id(), "transfer-2");
441 assert_eq!(saga.status(), SagaStatus::NotStarted);
442 assert_eq!(saga.metadata().total_steps, 1);
443
444 orchestrator.execute(saga).await.unwrap();
445
446 let history = orchestrator.get_history().await;
447 assert_eq!(history.len(), 1);
448 assert_eq!(history[0].status, SagaStatus::Completed);
449 }
450
451 #[tokio::test]
452 async fn test_saga_definition_builder() {
453 let saga = SagaDefinition::<TestEvent>::new("test-saga")
454 .add_step(DebitStep {
455 account: "A".to_string(),
456 amount: 10.0,
457 })
458 .add_step(CreditStep {
459 account: "B".to_string(),
460 amount: 10.0,
461 });
462
463 assert_eq!(saga.metadata().total_steps, 2);
464 assert_eq!(saga.status(), SagaStatus::NotStarted);
465 }
466
467 #[tokio::test]
468 async fn test_multiple_sagas() {
469 let orchestrator = SagaOrchestrator::<TestEvent>::new();
470
471 let saga1 = SagaDefinition::new("transfer-1").add_step(DebitStep {
472 account: "A".to_string(),
473 amount: 100.0,
474 });
475
476 let saga2 = SagaDefinition::new("transfer-2").add_step(DebitStep {
477 account: "B".to_string(),
478 amount: 200.0,
479 });
480
481 orchestrator.execute(saga1).await.unwrap();
482 orchestrator.execute(saga2).await.unwrap();
483
484 assert_eq!(orchestrator.history_count().await, 2);
485 }
486}