allframe_core/cqrs/
saga_orchestrator.rs

1//! Saga Orchestrator for distributed transaction coordination
2//!
3//! This module provides automatic saga orchestration, eliminating boilerplate
4//! for multi-aggregate transactions with automatic compensation and retry
5//! logic.
6
7use std::{collections::HashMap, fmt, marker::PhantomData, sync::Arc, time::Duration};
8
9use tokio::{sync::RwLock, time::timeout};
10
11use super::Event;
12
13/// Result type for saga operations
14pub type SagaResult<T> = Result<T, SagaError>;
15
16/// Errors that can occur during saga execution
17#[derive(Debug, Clone)]
18pub enum SagaError {
19    /// Step execution failed
20    StepFailed {
21        /// Index of the failed step
22        step_index: usize,
23        /// Name of the failed step
24        step_name: String,
25        /// Error message
26        error: String,
27    },
28    /// Compensation failed
29    CompensationFailed {
30        /// Index of the step being compensated
31        step_index: usize,
32        /// Error message
33        error: String,
34    },
35    /// Timeout occurred
36    Timeout {
37        /// Index of the timed out step
38        step_index: usize,
39        /// Duration that was exceeded
40        duration: Duration,
41    },
42    /// Invalid step index
43    InvalidStep(usize),
44    /// Saga already executing
45    AlreadyExecuting,
46}
47
48impl fmt::Display for SagaError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            SagaError::StepFailed {
52                step_index,
53                step_name,
54                error,
55            } => {
56                write!(f, "Step {} ({}) failed: {}", step_index, step_name, error)
57            }
58            SagaError::CompensationFailed { step_index, error } => {
59                write!(f, "Compensation for step {} failed: {}", step_index, error)
60            }
61            SagaError::Timeout {
62                step_index,
63                duration,
64            } => {
65                write!(f, "Step {} timed out after {:?}", step_index, duration)
66            }
67            SagaError::InvalidStep(index) => write!(f, "Invalid step index: {}", index),
68            SagaError::AlreadyExecuting => write!(f, "Saga is already executing"),
69        }
70    }
71}
72
73impl std::error::Error for SagaError {}
74
75/// Status of a saga execution
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum SagaStatus {
78    /// Saga not yet started
79    NotStarted,
80    /// Saga is currently executing
81    Executing,
82    /// Saga completed successfully
83    Completed,
84    /// Saga failed and compensation was successful
85    Compensated,
86    /// Saga failed and compensation also failed
87    Failed,
88}
89
90/// Metadata about saga execution
91#[derive(Debug, Clone)]
92pub struct SagaMetadata {
93    /// Unique saga ID
94    pub id: String,
95    /// Current status
96    pub status: SagaStatus,
97    /// Number of steps executed
98    pub steps_executed: usize,
99    /// Number of total steps
100    pub total_steps: usize,
101    /// Timestamp of last update
102    pub updated_at: std::time::SystemTime,
103}
104
105/// A single step in a saga
106#[async_trait::async_trait]
107pub trait SagaStep<E: Event>: Send + Sync {
108    /// Execute the step
109    async fn execute(&self) -> Result<Vec<E>, String>;
110
111    /// Compensate for this step (rollback)
112    async fn compensate(&self) -> Result<Vec<E>, String>;
113
114    /// Get the step name for logging/debugging
115    fn name(&self) -> &str;
116
117    /// Get the timeout for this step
118    fn timeout_duration(&self) -> Duration {
119        Duration::from_secs(30) // Default 30 seconds
120    }
121}
122
123/// Saga definition with ordered steps
124pub struct SagaDefinition<E: Event> {
125    /// Unique saga ID
126    id: String,
127    /// Ordered list of steps
128    steps: Vec<Box<dyn SagaStep<E>>>,
129    /// Metadata
130    metadata: SagaMetadata,
131}
132
133impl<E: Event> SagaDefinition<E> {
134    /// Create a new saga definition
135    pub fn new(id: impl Into<String>) -> Self {
136        let id = id.into();
137        Self {
138            metadata: SagaMetadata {
139                id: id.clone(),
140                status: SagaStatus::NotStarted,
141                steps_executed: 0,
142                total_steps: 0,
143                updated_at: std::time::SystemTime::now(),
144            },
145            id,
146            steps: Vec::new(),
147        }
148    }
149
150    /// Add a step to the saga
151    pub fn add_step<S: SagaStep<E> + 'static>(mut self, step: S) -> Self {
152        self.steps.push(Box::new(step));
153        self.metadata.total_steps = self.steps.len();
154        self
155    }
156
157    /// Get saga ID
158    pub fn id(&self) -> &str {
159        &self.id
160    }
161
162    /// Get current status
163    pub fn status(&self) -> SagaStatus {
164        self.metadata.status.clone()
165    }
166
167    /// Get metadata
168    pub fn metadata(&self) -> &SagaMetadata {
169        &self.metadata
170    }
171}
172
173/// Orchestrator for executing sagas
174pub struct SagaOrchestrator<E: Event> {
175    /// Running sagas
176    sagas: Arc<RwLock<HashMap<String, SagaMetadata>>>,
177    /// Completed sagas history
178    history: Arc<RwLock<Vec<SagaMetadata>>>,
179    _phantom: PhantomData<E>,
180}
181
182impl<E: Event> SagaOrchestrator<E> {
183    /// Create a new saga orchestrator
184    pub fn new() -> Self {
185        Self {
186            sagas: Arc::new(RwLock::new(HashMap::new())),
187            history: Arc::new(RwLock::new(Vec::new())),
188            _phantom: PhantomData,
189        }
190    }
191
192    /// Execute a saga with automatic compensation on failure
193    pub async fn execute(&self, mut saga: SagaDefinition<E>) -> SagaResult<Vec<E>> {
194        // Check if saga is already running
195        {
196            let sagas = self.sagas.read().await;
197            if sagas.contains_key(&saga.id) {
198                return Err(SagaError::AlreadyExecuting);
199            }
200        }
201
202        // Mark as executing
203        saga.metadata.status = SagaStatus::Executing;
204        saga.metadata.updated_at = std::time::SystemTime::now();
205        {
206            let mut sagas = self.sagas.write().await;
207            sagas.insert(saga.id.clone(), saga.metadata.clone());
208        }
209
210        let mut all_events = Vec::new();
211        let mut executed_steps = 0;
212
213        // Execute each step
214        for (index, step) in saga.steps.iter().enumerate() {
215            // Execute step with timeout
216            let step_timeout = step.timeout_duration();
217            let result = timeout(step_timeout, step.execute()).await;
218
219            match result {
220                Ok(Ok(events)) => {
221                    // Step succeeded
222                    all_events.extend(events);
223                    executed_steps += 1;
224                    saga.metadata.steps_executed = executed_steps;
225                    saga.metadata.updated_at = std::time::SystemTime::now();
226                }
227                Ok(Err(error)) => {
228                    // Step failed - compensate previous steps
229                    saga.metadata.status = SagaStatus::Failed;
230                    let compensation_result = self.compensate_steps(&saga.steps[0..index]).await;
231
232                    // Remove from active sagas
233                    {
234                        let mut sagas = self.sagas.write().await;
235                        sagas.remove(&saga.id);
236                    }
237
238                    // Add to history
239                    {
240                        let mut history = self.history.write().await;
241                        saga.metadata.status = if compensation_result.is_ok() {
242                            SagaStatus::Compensated
243                        } else {
244                            SagaStatus::Failed
245                        };
246                        history.push(saga.metadata.clone());
247                    }
248
249                    return Err(SagaError::StepFailed {
250                        step_index: index,
251                        step_name: step.name().to_string(),
252                        error,
253                    });
254                }
255                Err(_) => {
256                    // Timeout
257                    saga.metadata.status = SagaStatus::Failed;
258                    let _ = self.compensate_steps(&saga.steps[0..index]).await;
259
260                    {
261                        let mut sagas = self.sagas.write().await;
262                        sagas.remove(&saga.id);
263                    }
264
265                    return Err(SagaError::Timeout {
266                        step_index: index,
267                        duration: step_timeout,
268                    });
269                }
270            }
271        }
272
273        // All steps completed successfully
274        saga.metadata.status = SagaStatus::Completed;
275        saga.metadata.updated_at = std::time::SystemTime::now();
276
277        // Remove from active and add to history
278        {
279            let mut sagas = self.sagas.write().await;
280            sagas.remove(&saga.id);
281        }
282        {
283            let mut history = self.history.write().await;
284            history.push(saga.metadata);
285        }
286
287        Ok(all_events)
288    }
289
290    /// Compensate (rollback) executed steps in reverse order
291    async fn compensate_steps(&self, steps: &[Box<dyn SagaStep<E>>]) -> Result<(), String> {
292        // Compensate in reverse order
293        for step in steps.iter().rev() {
294            step.compensate().await?;
295        }
296        Ok(())
297    }
298
299    /// Get metadata for a running saga
300    pub async fn get_saga(&self, id: &str) -> Option<SagaMetadata> {
301        let sagas = self.sagas.read().await;
302        sagas.get(id).cloned()
303    }
304
305    /// Get all running sagas
306    pub async fn get_running_sagas(&self) -> Vec<SagaMetadata> {
307        let sagas = self.sagas.read().await;
308        sagas.values().cloned().collect()
309    }
310
311    /// Get saga history
312    pub async fn get_history(&self) -> Vec<SagaMetadata> {
313        let history = self.history.read().await;
314        history.clone()
315    }
316
317    /// Get number of running sagas
318    pub async fn running_count(&self) -> usize {
319        self.sagas.read().await.len()
320    }
321
322    /// Get number of completed sagas (including failed)
323    pub async fn history_count(&self) -> usize {
324        self.history.read().await.len()
325    }
326}
327
328impl<E: Event> Default for SagaOrchestrator<E> {
329    fn default() -> Self {
330        Self::new()
331    }
332}
333
334impl<E: Event> Clone for SagaOrchestrator<E> {
335    fn clone(&self) -> Self {
336        Self {
337            sagas: Arc::clone(&self.sagas),
338            history: Arc::clone(&self.history),
339            _phantom: PhantomData,
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
349    enum TestEvent {
350        Debited { account: String, amount: f64 },
351        Credited { account: String, amount: f64 },
352    }
353
354    impl Event for TestEvent {}
355
356    struct DebitStep {
357        account: String,
358        amount: f64,
359    }
360
361    #[async_trait::async_trait]
362    impl SagaStep<TestEvent> for DebitStep {
363        async fn execute(&self) -> Result<Vec<TestEvent>, String> {
364            Ok(vec![TestEvent::Debited {
365                account: self.account.clone(),
366                amount: self.amount,
367            }])
368        }
369
370        async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
371            // Compensate by crediting back
372            Ok(vec![TestEvent::Credited {
373                account: self.account.clone(),
374                amount: self.amount,
375            }])
376        }
377
378        fn name(&self) -> &str {
379            "DebitStep"
380        }
381    }
382
383    struct CreditStep {
384        account: String,
385        amount: f64,
386    }
387
388    #[async_trait::async_trait]
389    impl SagaStep<TestEvent> for CreditStep {
390        async fn execute(&self) -> Result<Vec<TestEvent>, String> {
391            Ok(vec![TestEvent::Credited {
392                account: self.account.clone(),
393                amount: self.amount,
394            }])
395        }
396
397        async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
398            // Compensate by debiting back
399            Ok(vec![TestEvent::Debited {
400                account: self.account.clone(),
401                amount: self.amount,
402            }])
403        }
404
405        fn name(&self) -> &str {
406            "CreditStep"
407        }
408    }
409
410    #[tokio::test]
411    async fn test_successful_saga() {
412        let orchestrator = SagaOrchestrator::<TestEvent>::new();
413
414        let saga = SagaDefinition::new("transfer-1")
415            .add_step(DebitStep {
416                account: "A".to_string(),
417                amount: 100.0,
418            })
419            .add_step(CreditStep {
420                account: "B".to_string(),
421                amount: 100.0,
422            });
423
424        let events = orchestrator.execute(saga).await.unwrap();
425
426        assert_eq!(events.len(), 2);
427        assert_eq!(orchestrator.running_count().await, 0);
428        assert_eq!(orchestrator.history_count().await, 1);
429    }
430
431    #[tokio::test]
432    async fn test_saga_metadata() {
433        let orchestrator = SagaOrchestrator::<TestEvent>::new();
434
435        let saga = SagaDefinition::new("transfer-2").add_step(DebitStep {
436            account: "A".to_string(),
437            amount: 50.0,
438        });
439
440        assert_eq!(saga.id(), "transfer-2");
441        assert_eq!(saga.status(), SagaStatus::NotStarted);
442        assert_eq!(saga.metadata().total_steps, 1);
443
444        orchestrator.execute(saga).await.unwrap();
445
446        let history = orchestrator.get_history().await;
447        assert_eq!(history.len(), 1);
448        assert_eq!(history[0].status, SagaStatus::Completed);
449    }
450
451    #[tokio::test]
452    async fn test_saga_definition_builder() {
453        let saga = SagaDefinition::<TestEvent>::new("test-saga")
454            .add_step(DebitStep {
455                account: "A".to_string(),
456                amount: 10.0,
457            })
458            .add_step(CreditStep {
459                account: "B".to_string(),
460                amount: 10.0,
461            });
462
463        assert_eq!(saga.metadata().total_steps, 2);
464        assert_eq!(saga.status(), SagaStatus::NotStarted);
465    }
466
467    #[tokio::test]
468    async fn test_multiple_sagas() {
469        let orchestrator = SagaOrchestrator::<TestEvent>::new();
470
471        let saga1 = SagaDefinition::new("transfer-1").add_step(DebitStep {
472            account: "A".to_string(),
473            amount: 100.0,
474        });
475
476        let saga2 = SagaDefinition::new("transfer-2").add_step(DebitStep {
477            account: "B".to_string(),
478            amount: 200.0,
479        });
480
481        orchestrator.execute(saga1).await.unwrap();
482        orchestrator.execute(saga2).await.unwrap();
483
484        assert_eq!(orchestrator.history_count().await, 2);
485    }
486}