Skip to main content

allframe_core/cqrs/
saga_orchestrator.rs

1//! Saga Orchestrator for distributed transaction coordination
2//!
3//! This module provides automatic saga orchestration, eliminating boilerplate
4//! for multi-aggregate transactions with automatic compensation and retry
5//! logic.
6
7use std::{collections::HashMap, fmt, marker::PhantomData, sync::Arc, time::Duration};
8
9use tokio::{sync::RwLock, time::timeout};
10
11use super::Event;
12
13/// Result type for saga operations
14pub type SagaResult<T> = Result<T, SagaError>;
15
16/// Errors that can occur during saga execution
17#[derive(Debug, Clone)]
18pub enum SagaError {
19    /// Step execution failed
20    StepFailed {
21        /// Index of the failed step
22        step_index: usize,
23        /// Name of the failed step
24        step_name: String,
25        /// Error message
26        error: String,
27    },
28    /// Compensation failed
29    CompensationFailed {
30        /// Index of the step being compensated
31        step_index: usize,
32        /// Error message
33        error: String,
34    },
35    /// Timeout occurred
36    Timeout {
37        /// Index of the timed out step
38        step_index: usize,
39        /// Duration that was exceeded
40        duration: Duration,
41    },
42    /// Invalid step index
43    InvalidStep(usize),
44    /// Saga already executing
45    AlreadyExecuting,
46}
47
48impl fmt::Display for SagaError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            SagaError::StepFailed {
52                step_index,
53                step_name,
54                error,
55            } => {
56                write!(f, "Step {} ({}) failed: {}", step_index, step_name, error)
57            }
58            SagaError::CompensationFailed { step_index, error } => {
59                write!(f, "Compensation for step {} failed: {}", step_index, error)
60            }
61            SagaError::Timeout {
62                step_index,
63                duration,
64            } => {
65                write!(f, "Step {} timed out after {:?}", step_index, duration)
66            }
67            SagaError::InvalidStep(index) => write!(f, "Invalid step index: {}", index),
68            SagaError::AlreadyExecuting => write!(f, "Saga is already executing"),
69        }
70    }
71}
72
73impl std::error::Error for SagaError {}
74
75/// Status of a saga execution
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum SagaStatus {
78    /// Saga not yet started
79    NotStarted,
80    /// Saga is currently executing
81    Executing,
82    /// Saga completed successfully
83    Completed,
84    /// Saga failed and compensation was successful
85    Compensated,
86    /// Saga failed and compensation also failed
87    Failed,
88}
89
90/// Metadata about saga execution
91#[derive(Debug, Clone)]
92pub struct SagaMetadata {
93    /// Unique saga ID
94    pub id: String,
95    /// Current status
96    pub status: SagaStatus,
97    /// Number of steps executed
98    pub steps_executed: usize,
99    /// Number of total steps
100    pub total_steps: usize,
101    /// Timestamp of last update
102    pub updated_at: std::time::SystemTime,
103}
104
105/// A single step in a saga
106#[async_trait::async_trait]
107pub trait SagaStep<E: Event>: Send + Sync {
108    /// Execute the step
109    async fn execute(&self) -> Result<Vec<E>, String>;
110
111    /// Compensate for this step (rollback)
112    async fn compensate(&self) -> Result<Vec<E>, String>;
113
114    /// Get the step name for logging/debugging
115    fn name(&self) -> &str;
116
117    /// Get the timeout for this step
118    fn timeout_duration(&self) -> Duration {
119        Duration::from_secs(30) // Default 30 seconds
120    }
121}
122
123/// Saga definition with ordered steps
124pub struct SagaDefinition<E: Event> {
125    /// Unique saga ID
126    id: String,
127    /// Ordered list of steps
128    steps: Vec<Box<dyn SagaStep<E>>>,
129    /// Metadata
130    metadata: SagaMetadata,
131    /// Compensation strategy (optional, for UC-036.7)
132    compensation_strategy: Option<super::saga::CompensationStrategy>,
133    /// Directory for snapshots (optional, for UC-036.7)
134    snapshot_dir: Option<std::path::PathBuf>,
135}
136
137impl<E: Event> SagaDefinition<E> {
138    /// Create a new saga definition
139    pub fn new(id: impl Into<String>) -> Self {
140        let id = id.into();
141        Self {
142            metadata: SagaMetadata {
143                id: id.clone(),
144                status: SagaStatus::NotStarted,
145                steps_executed: 0,
146                total_steps: 0,
147                updated_at: std::time::SystemTime::now(),
148            },
149            id,
150            steps: Vec::new(),
151            compensation_strategy: None,
152            snapshot_dir: None,
153        }
154    }
155
156    /// Add a step to the saga
157    pub fn add_step<S: SagaStep<E> + 'static>(mut self, step: S) -> Self {
158        self.steps.push(Box::new(step));
159        self.metadata.total_steps = self.steps.len();
160        self
161    }
162
163    /// Get saga ID
164    pub fn id(&self) -> &str {
165        &self.id
166    }
167
168    /// Get current status
169    pub fn status(&self) -> SagaStatus {
170        self.metadata.status.clone()
171    }
172
173    /// Get metadata
174    pub fn metadata(&self) -> &SagaMetadata {
175        &self.metadata
176    }
177
178    /// Set the compensation strategy.
179    pub fn with_compensation(mut self, strategy: super::saga::CompensationStrategy) -> Self {
180        self.compensation_strategy = Some(strategy);
181        self
182    }
183
184    /// Set the snapshot directory for file-based compensation.
185    pub fn with_snapshot_dir(mut self, dir: &std::path::Path) -> Self {
186        self.snapshot_dir = Some(dir.to_path_buf());
187        self
188    }
189}
190
191/// Orchestrator for executing sagas
192pub struct SagaOrchestrator<E: Event> {
193    /// Running sagas
194    sagas: Arc<RwLock<HashMap<String, SagaMetadata>>>,
195    /// Completed sagas history
196    history: Arc<RwLock<Vec<SagaMetadata>>>,
197    _phantom: PhantomData<E>,
198}
199
200impl<E: Event> SagaOrchestrator<E> {
201    /// Create a new saga orchestrator
202    pub fn new() -> Self {
203        Self {
204            sagas: Arc::new(RwLock::new(HashMap::new())),
205            history: Arc::new(RwLock::new(Vec::new())),
206            _phantom: PhantomData,
207        }
208    }
209
210    /// Execute a saga with automatic compensation on failure
211    pub async fn execute(&self, mut saga: SagaDefinition<E>) -> SagaResult<Vec<E>> {
212        // Check if saga is already running
213        {
214            let sagas = self.sagas.read().await;
215            if sagas.contains_key(&saga.id) {
216                return Err(SagaError::AlreadyExecuting);
217            }
218        }
219
220        // Mark as executing
221        saga.metadata.status = SagaStatus::Executing;
222        saga.metadata.updated_at = std::time::SystemTime::now();
223        {
224            let mut sagas = self.sagas.write().await;
225            sagas.insert(saga.id.clone(), saga.metadata.clone());
226        }
227
228        let mut all_events = Vec::new();
229        let mut executed_steps = 0;
230
231        // Execute each step
232        for (index, step) in saga.steps.iter().enumerate() {
233            // Execute step with timeout
234            let step_timeout = step.timeout_duration();
235            let result = timeout(step_timeout, step.execute()).await;
236
237            match result {
238                Ok(Ok(events)) => {
239                    // Step succeeded
240                    all_events.extend(events);
241                    executed_steps += 1;
242                    saga.metadata.steps_executed = executed_steps;
243                    saga.metadata.updated_at = std::time::SystemTime::now();
244                }
245                Ok(Err(error)) => {
246                    // Step failed - compensate previous steps
247                    saga.metadata.status = SagaStatus::Failed;
248                    let compensation_result = self.compensate_steps(&saga.steps[0..index]).await;
249
250                    // Remove from active sagas
251                    {
252                        let mut sagas = self.sagas.write().await;
253                        sagas.remove(&saga.id);
254                    }
255
256                    // Add to history
257                    {
258                        let mut history = self.history.write().await;
259                        saga.metadata.status = if compensation_result.is_ok() {
260                            SagaStatus::Compensated
261                        } else {
262                            SagaStatus::Failed
263                        };
264                        history.push(saga.metadata.clone());
265                    }
266
267                    return Err(SagaError::StepFailed {
268                        step_index: index,
269                        step_name: step.name().to_string(),
270                        error,
271                    });
272                }
273                Err(_) => {
274                    // Timeout
275                    saga.metadata.status = SagaStatus::Failed;
276                    let _ = self.compensate_steps(&saga.steps[0..index]).await;
277
278                    {
279                        let mut sagas = self.sagas.write().await;
280                        sagas.remove(&saga.id);
281                    }
282
283                    return Err(SagaError::Timeout {
284                        step_index: index,
285                        duration: step_timeout,
286                    });
287                }
288            }
289        }
290
291        // All steps completed successfully
292        saga.metadata.status = SagaStatus::Completed;
293        saga.metadata.updated_at = std::time::SystemTime::now();
294
295        // Remove from active and add to history
296        {
297            let mut sagas = self.sagas.write().await;
298            sagas.remove(&saga.id);
299        }
300        {
301            let mut history = self.history.write().await;
302            history.push(saga.metadata);
303        }
304
305        Ok(all_events)
306    }
307
308    /// Compensate (rollback) executed steps in reverse order
309    async fn compensate_steps(&self, steps: &[Box<dyn SagaStep<E>>]) -> Result<(), String> {
310        // Compensate in reverse order
311        for step in steps.iter().rev() {
312            step.compensate().await?;
313        }
314        Ok(())
315    }
316
317    /// Get metadata for a running saga
318    pub async fn get_saga(&self, id: &str) -> Option<SagaMetadata> {
319        let sagas = self.sagas.read().await;
320        sagas.get(id).cloned()
321    }
322
323    /// Get all running sagas
324    pub async fn get_running_sagas(&self) -> Vec<SagaMetadata> {
325        let sagas = self.sagas.read().await;
326        sagas.values().cloned().collect()
327    }
328
329    /// Get saga history
330    pub async fn get_history(&self) -> Vec<SagaMetadata> {
331        let history = self.history.read().await;
332        history.clone()
333    }
334
335    /// Get number of running sagas
336    pub async fn running_count(&self) -> usize {
337        self.sagas.read().await.len()
338    }
339
340    /// Get number of completed sagas (including failed)
341    pub async fn history_count(&self) -> usize {
342        self.history.read().await.len()
343    }
344}
345
346impl<E: Event> Default for SagaOrchestrator<E> {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352impl<E: Event> Clone for SagaOrchestrator<E> {
353    fn clone(&self) -> Self {
354        Self {
355            sagas: Arc::clone(&self.sagas),
356            history: Arc::clone(&self.history),
357            _phantom: PhantomData,
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::cqrs::EventTypeName;
366
367    #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
368    enum TestEvent {
369        Debited { account: String, amount: f64 },
370        Credited { account: String, amount: f64 },
371    }
372
373    impl EventTypeName for TestEvent {}
374    impl Event for TestEvent {}
375
376    struct DebitStep {
377        account: String,
378        amount: f64,
379    }
380
381    #[async_trait::async_trait]
382    impl SagaStep<TestEvent> for DebitStep {
383        async fn execute(&self) -> Result<Vec<TestEvent>, String> {
384            Ok(vec![TestEvent::Debited {
385                account: self.account.clone(),
386                amount: self.amount,
387            }])
388        }
389
390        async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
391            // Compensate by crediting back
392            Ok(vec![TestEvent::Credited {
393                account: self.account.clone(),
394                amount: self.amount,
395            }])
396        }
397
398        fn name(&self) -> &str {
399            "DebitStep"
400        }
401    }
402
403    struct CreditStep {
404        account: String,
405        amount: f64,
406    }
407
408    #[async_trait::async_trait]
409    impl SagaStep<TestEvent> for CreditStep {
410        async fn execute(&self) -> Result<Vec<TestEvent>, String> {
411            Ok(vec![TestEvent::Credited {
412                account: self.account.clone(),
413                amount: self.amount,
414            }])
415        }
416
417        async fn compensate(&self) -> Result<Vec<TestEvent>, String> {
418            // Compensate by debiting back
419            Ok(vec![TestEvent::Debited {
420                account: self.account.clone(),
421                amount: self.amount,
422            }])
423        }
424
425        fn name(&self) -> &str {
426            "CreditStep"
427        }
428    }
429
430    #[tokio::test]
431    async fn test_successful_saga() {
432        let orchestrator = SagaOrchestrator::<TestEvent>::new();
433
434        let saga = SagaDefinition::new("transfer-1")
435            .add_step(DebitStep {
436                account: "A".to_string(),
437                amount: 100.0,
438            })
439            .add_step(CreditStep {
440                account: "B".to_string(),
441                amount: 100.0,
442            });
443
444        let events = orchestrator.execute(saga).await.unwrap();
445
446        assert_eq!(events.len(), 2);
447        assert_eq!(orchestrator.running_count().await, 0);
448        assert_eq!(orchestrator.history_count().await, 1);
449    }
450
451    #[tokio::test]
452    async fn test_saga_metadata() {
453        let orchestrator = SagaOrchestrator::<TestEvent>::new();
454
455        let saga = SagaDefinition::new("transfer-2").add_step(DebitStep {
456            account: "A".to_string(),
457            amount: 50.0,
458        });
459
460        assert_eq!(saga.id(), "transfer-2");
461        assert_eq!(saga.status(), SagaStatus::NotStarted);
462        assert_eq!(saga.metadata().total_steps, 1);
463
464        orchestrator.execute(saga).await.unwrap();
465
466        let history = orchestrator.get_history().await;
467        assert_eq!(history.len(), 1);
468        assert_eq!(history[0].status, SagaStatus::Completed);
469    }
470
471    #[tokio::test]
472    async fn test_saga_definition_builder() {
473        let saga = SagaDefinition::<TestEvent>::new("test-saga")
474            .add_step(DebitStep {
475                account: "A".to_string(),
476                amount: 10.0,
477            })
478            .add_step(CreditStep {
479                account: "B".to_string(),
480                amount: 10.0,
481            });
482
483        assert_eq!(saga.metadata().total_steps, 2);
484        assert_eq!(saga.status(), SagaStatus::NotStarted);
485    }
486
487    #[tokio::test]
488    async fn test_multiple_sagas() {
489        let orchestrator = SagaOrchestrator::<TestEvent>::new();
490
491        let saga1 = SagaDefinition::new("transfer-1").add_step(DebitStep {
492            account: "A".to_string(),
493            amount: 100.0,
494        });
495
496        let saga2 = SagaDefinition::new("transfer-2").add_step(DebitStep {
497            account: "B".to_string(),
498            amount: 200.0,
499        });
500
501        orchestrator.execute(saga1).await.unwrap();
502        orchestrator.execute(saga2).await.unwrap();
503
504        assert_eq!(orchestrator.history_count().await, 2);
505    }
506}