1use anyhow::Result;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6use tracing::{debug, error, info, warn};
7use uuid::Uuid;
8
9use crate::error::RustRabbitError;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub struct SagaId(String);
14
15impl SagaId {
16 pub fn new() -> Self {
17 Self(Uuid::new_v4().to_string())
18 }
19
20 pub fn from_string(id: String) -> Self {
21 Self(id)
22 }
23
24 pub fn as_str(&self) -> &str {
25 &self.0
26 }
27}
28
29impl Default for SagaId {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl std::fmt::Display for SagaId {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 write!(f, "{}", self.0)
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
43pub enum SagaStatus {
44 Running,
46 Completed,
48 Compensating,
50 Compensated,
52 CompensationFailed,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SagaStep {
59 pub step_id: String,
60 pub action: SagaAction,
61 pub compensation: Option<SagaAction>,
62 pub status: StepStatus,
63 pub executed_at: Option<DateTime<Utc>>,
64 pub compensated_at: Option<DateTime<Utc>>,
65 pub retry_count: u32,
66 pub max_retries: u32,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
71pub enum StepStatus {
72 Pending,
73 Running,
74 Completed,
75 Failed,
76 Compensating,
77 Compensated,
78 CompensationFailed,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct SagaAction {
84 pub action_type: String,
85 pub payload: Vec<u8>,
86 pub timeout: std::time::Duration,
87 pub idempotency_key: Option<String>,
88}
89
90impl SagaAction {
91 pub fn new(action_type: String, payload: Vec<u8>) -> Self {
92 Self {
93 action_type,
94 payload,
95 timeout: std::time::Duration::from_secs(30),
96 idempotency_key: None,
97 }
98 }
99
100 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
101 self.timeout = timeout;
102 self
103 }
104
105 pub fn with_idempotency_key(mut self, key: String) -> Self {
106 self.idempotency_key = Some(key);
107 self
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct SagaInstance {
114 pub saga_id: SagaId,
115 pub saga_type: String,
116 pub status: SagaStatus,
117 pub steps: Vec<SagaStep>,
118 pub context: HashMap<String, String>,
119 pub created_at: DateTime<Utc>,
120 pub updated_at: DateTime<Utc>,
121 pub completed_at: Option<DateTime<Utc>>,
122}
123
124impl SagaInstance {
125 pub fn new(saga_type: String, steps: Vec<SagaStep>) -> Self {
126 let now = Utc::now();
127 Self {
128 saga_id: SagaId::new(),
129 saga_type,
130 status: SagaStatus::Running,
131 steps,
132 context: HashMap::new(),
133 created_at: now,
134 updated_at: now,
135 completed_at: None,
136 }
137 }
138
139 pub fn get_current_step(&self) -> Option<&SagaStep> {
140 self.steps
141 .iter()
142 .find(|step| step.status == StepStatus::Pending)
143 }
144
145 pub fn get_current_step_mut(&mut self) -> Option<&mut SagaStep> {
146 self.steps
147 .iter_mut()
148 .find(|step| step.status == StepStatus::Pending)
149 }
150
151 pub fn get_failed_steps(&self) -> Vec<&SagaStep> {
152 self.steps
153 .iter()
154 .filter(|step| step.status == StepStatus::Failed)
155 .collect()
156 }
157
158 pub fn add_context(&mut self, key: String, value: String) {
159 self.context.insert(key, value);
160 self.updated_at = Utc::now();
161 }
162
163 pub fn mark_completed(&mut self) {
164 self.status = SagaStatus::Completed;
165 self.completed_at = Some(Utc::now());
166 self.updated_at = Utc::now();
167 }
168
169 pub fn mark_compensating(&mut self) {
170 self.status = SagaStatus::Compensating;
171 self.updated_at = Utc::now();
172 }
173
174 pub fn mark_compensated(&mut self) {
175 self.status = SagaStatus::Compensated;
176 self.completed_at = Some(Utc::now());
177 self.updated_at = Utc::now();
178 }
179}
180
181#[derive(Debug)]
183pub enum StepResult {
184 Success(HashMap<String, String>),
185 Failure(String),
186 Retry,
187}
188
189#[async_trait::async_trait]
191pub trait SagaStepExecutor {
192 async fn execute_step(
193 &self,
194 action: &SagaAction,
195 context: &HashMap<String, String>,
196 ) -> Result<StepResult>;
197 async fn compensate_step(
198 &self,
199 action: &SagaAction,
200 context: &HashMap<String, String>,
201 ) -> Result<StepResult>;
202}
203
204#[derive(Clone)]
206pub struct SagaCoordinator {
207 active_sagas: Arc<Mutex<HashMap<SagaId, SagaInstance>>>,
208 step_executors: HashMap<String, Arc<dyn SagaStepExecutor + Send + Sync>>,
209}
210
211impl Default for SagaCoordinator {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217impl SagaCoordinator {
218 pub fn new() -> Self {
219 Self {
220 active_sagas: Arc::new(Mutex::new(HashMap::new())),
221 step_executors: HashMap::new(),
222 }
223 }
224
225 pub fn register_executor(
227 &mut self,
228 action_type: String,
229 executor: Arc<dyn SagaStepExecutor + Send + Sync>,
230 ) {
231 self.step_executors.insert(action_type, executor);
232 }
233
234 pub async fn start_saga(&self, saga: SagaInstance) -> Result<()> {
236 let saga_id = saga.saga_id.clone();
237
238 info!(
239 saga_id = %saga_id,
240 saga_type = %saga.saga_type,
241 steps_count = saga.steps.len(),
242 "Starting new saga"
243 );
244
245 {
247 let mut active_sagas = self.active_sagas.lock().unwrap();
248 active_sagas.insert(saga_id.clone(), saga.clone());
249 }
250
251 self.execute_next_step(saga_id).await
253 }
254
255 async fn execute_next_step(&self, saga_id: SagaId) -> Result<()> {
257 let (step_id, action, context) = {
258 let mut active_sagas = self.active_sagas.lock().unwrap();
259 let saga = active_sagas
260 .get_mut(&saga_id)
261 .ok_or_else(|| RustRabbitError::SagaNotFound)?;
262
263 if let Some(step) = saga.get_current_step_mut() {
264 step.status = StepStatus::Running;
265 step.executed_at = Some(Utc::now());
266 (
267 step.step_id.clone(),
268 step.action.clone(),
269 saga.context.clone(),
270 )
271 } else {
272 saga.mark_completed();
274 info!(saga_id = %saga_id, "Saga completed successfully");
275 return Ok(());
276 }
277 };
278
279 debug!(
280 saga_id = %saga_id,
281 step_id = %step_id,
282 action_type = %action.action_type,
283 "Executing saga step"
284 );
285
286 let result = self.execute_step(&action, &context).await;
288
289 {
291 let mut active_sagas = self.active_sagas.lock().unwrap();
292 let saga = active_sagas
293 .get_mut(&saga_id)
294 .ok_or_else(|| RustRabbitError::SagaNotFound)?;
295
296 if let Some(step) = saga.steps.iter_mut().find(|s| s.step_id == step_id) {
297 match result {
298 Ok(StepResult::Success(step_context)) => {
299 step.status = StepStatus::Completed;
300 saga.context.extend(step_context);
301 saga.updated_at = Utc::now();
302
303 info!(
304 saga_id = %saga_id,
305 step_id = %step_id,
306 "Step completed successfully"
307 );
308 }
309 Ok(StepResult::Failure(error)) => {
310 step.status = StepStatus::Failed;
311 saga.status = SagaStatus::Compensating;
312 saga.updated_at = Utc::now();
313
314 error!(
315 saga_id = %saga_id,
316 step_id = %step_id,
317 error = %error,
318 "Step failed, starting compensation"
319 );
320 }
321 Ok(StepResult::Retry) => {
322 step.retry_count += 1;
323 if step.retry_count >= step.max_retries {
324 step.status = StepStatus::Failed;
325 saga.status = SagaStatus::Compensating;
326
327 error!(
328 saga_id = %saga_id,
329 step_id = %step_id,
330 retry_count = step.retry_count,
331 "Step exceeded max retries, starting compensation"
332 );
333 } else {
334 step.status = StepStatus::Pending;
335
336 warn!(
337 saga_id = %saga_id,
338 step_id = %step_id,
339 retry_count = step.retry_count,
340 "Step will be retried"
341 );
342 }
343 saga.updated_at = Utc::now();
344 }
345 Err(error) => {
346 step.status = StepStatus::Failed;
347 saga.status = SagaStatus::Compensating;
348 saga.updated_at = Utc::now();
349
350 error!(
351 saga_id = %saga_id,
352 step_id = %step_id,
353 error = %error,
354 "Step execution error, starting compensation"
355 );
356 }
357 }
358 }
359 }
360
361 let saga_status = {
363 let active_sagas = self.active_sagas.lock().unwrap();
364 active_sagas
365 .get(&saga_id)
366 .map(|s| s.status.clone())
367 .unwrap_or(SagaStatus::Completed)
368 };
369
370 match saga_status {
371 SagaStatus::Running => {
372 debug!(saga_id = %saga_id, "Saga step completed, next step will be processed");
375 Ok(())
376 }
377 SagaStatus::Compensating => {
378 self.compensate_saga(saga_id).await
380 }
381 _ => Ok(()),
382 }
383 }
384
385 async fn execute_step(
387 &self,
388 action: &SagaAction,
389 context: &HashMap<String, String>,
390 ) -> Result<StepResult> {
391 if let Some(executor) = self.step_executors.get(&action.action_type) {
392 executor.execute_step(action, context).await
393 } else {
394 Err(RustRabbitError::SagaExecutorNotFound(action.action_type.clone()).into())
395 }
396 }
397
398 async fn compensate_saga(&self, saga_id: SagaId) -> Result<()> {
400 info!(saga_id = %saga_id, "Starting saga compensation");
401
402 let completed_steps: Vec<SagaStep> = {
403 let active_sagas = self.active_sagas.lock().unwrap();
404 let saga = active_sagas
405 .get(&saga_id)
406 .ok_or_else(|| RustRabbitError::SagaNotFound)?;
407
408 saga.steps
409 .iter()
410 .filter(|step| step.status == StepStatus::Completed)
411 .cloned()
412 .collect()
413 };
414
415 for mut step in completed_steps.into_iter().rev() {
417 if let Some(compensation) = &step.compensation {
418 debug!(
419 saga_id = %saga_id,
420 step_id = %step.step_id,
421 "Compensating step"
422 );
423
424 step.status = StepStatus::Compensating;
425 step.compensated_at = Some(Utc::now());
426
427 let context = {
428 let active_sagas = self.active_sagas.lock().unwrap();
429 active_sagas
430 .get(&saga_id)
431 .map(|s| s.context.clone())
432 .unwrap_or_default()
433 };
434
435 let result = self.compensate_step(compensation, &context).await;
436
437 {
439 let mut active_sagas = self.active_sagas.lock().unwrap();
440 if let Some(saga) = active_sagas.get_mut(&saga_id) {
441 if let Some(saga_step) =
442 saga.steps.iter_mut().find(|s| s.step_id == step.step_id)
443 {
444 match result {
445 Ok(StepResult::Success(_)) => {
446 saga_step.status = StepStatus::Compensated;
447 info!(
448 saga_id = %saga_id,
449 step_id = %step.step_id,
450 "Step compensated successfully"
451 );
452 }
453 _ => {
454 saga_step.status = StepStatus::CompensationFailed;
455 saga.status = SagaStatus::CompensationFailed;
456 error!(
457 saga_id = %saga_id,
458 step_id = %step.step_id,
459 "Step compensation failed"
460 );
461 return Err(RustRabbitError::SagaCompensationFailed.into());
462 }
463 }
464 }
465 }
466 }
467 }
468 }
469
470 {
472 let mut active_sagas = self.active_sagas.lock().unwrap();
473 if let Some(saga) = active_sagas.get_mut(&saga_id) {
474 saga.mark_compensated();
475 }
476 }
477
478 info!(saga_id = %saga_id, "Saga compensation completed");
479 Ok(())
480 }
481
482 async fn compensate_step(
484 &self,
485 action: &SagaAction,
486 context: &HashMap<String, String>,
487 ) -> Result<StepResult> {
488 if let Some(executor) = self.step_executors.get(&action.action_type) {
489 executor.compensate_step(action, context).await
490 } else {
491 Err(RustRabbitError::SagaExecutorNotFound(action.action_type.clone()).into())
492 }
493 }
494
495 pub fn get_saga_status(&self, saga_id: &SagaId) -> Option<SagaStatus> {
497 let active_sagas = self.active_sagas.lock().unwrap();
498 active_sagas.get(saga_id).map(|saga| saga.status.clone())
499 }
500
501 pub fn active_saga_count(&self) -> usize {
503 self.active_sagas.lock().unwrap().len()
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use std::sync::atomic::{AtomicU32, Ordering};
511
512 struct TestExecutor {
513 execution_count: Arc<AtomicU32>,
514 should_fail: bool,
515 }
516
517 impl TestExecutor {
518 fn new(should_fail: bool) -> Self {
519 Self {
520 execution_count: Arc::new(AtomicU32::new(0)),
521 should_fail,
522 }
523 }
524 }
525
526 #[async_trait::async_trait]
527 impl SagaStepExecutor for TestExecutor {
528 async fn execute_step(
529 &self,
530 _action: &SagaAction,
531 _context: &HashMap<String, String>,
532 ) -> Result<StepResult> {
533 self.execution_count.fetch_add(1, Ordering::SeqCst);
534
535 if self.should_fail {
536 Ok(StepResult::Failure("Test failure".to_string()))
537 } else {
538 let mut result_context = HashMap::new();
539 result_context.insert("executed".to_string(), "true".to_string());
540 Ok(StepResult::Success(result_context))
541 }
542 }
543
544 async fn compensate_step(
545 &self,
546 _action: &SagaAction,
547 _context: &HashMap<String, String>,
548 ) -> Result<StepResult> {
549 Ok(StepResult::Success(HashMap::new()))
550 }
551 }
552
553 #[tokio::test]
554 async fn test_saga_id_generation() {
555 let id1 = SagaId::new();
556 let id2 = SagaId::new();
557 assert_ne!(id1, id2);
558 }
559
560 #[tokio::test]
561 async fn test_saga_instance_creation() {
562 let steps = vec![SagaStep {
563 step_id: "step1".to_string(),
564 action: SagaAction::new("test".to_string(), b"test".to_vec()),
565 compensation: None,
566 status: StepStatus::Pending,
567 executed_at: None,
568 compensated_at: None,
569 retry_count: 0,
570 max_retries: 3,
571 }];
572
573 let saga = SagaInstance::new("test_saga".to_string(), steps);
574 assert_eq!(saga.saga_type, "test_saga");
575 assert_eq!(saga.status, SagaStatus::Running);
576 assert_eq!(saga.steps.len(), 1);
577 }
578
579 #[tokio::test]
580 async fn test_successful_saga_execution() {
581 let mut coordinator = SagaCoordinator::new();
582 let executor = Arc::new(TestExecutor::new(false));
583 coordinator.register_executor("test".to_string(), executor.clone());
584
585 let steps = vec![SagaStep {
586 step_id: "step1".to_string(),
587 action: SagaAction::new("test".to_string(), b"test".to_vec()),
588 compensation: Some(SagaAction::new("test".to_string(), b"compensate".to_vec())),
589 status: StepStatus::Pending,
590 executed_at: None,
591 compensated_at: None,
592 retry_count: 0,
593 max_retries: 3,
594 }];
595
596 let saga = SagaInstance::new("test_saga".to_string(), steps);
597 let saga_id = saga.saga_id.clone();
598
599 coordinator.start_saga(saga).await.unwrap();
600
601 {
603 let mut active_sagas = coordinator.active_sagas.lock().unwrap();
604 if let Some(saga) = active_sagas.get_mut(&saga_id) {
605 saga.mark_completed();
606 }
607 }
608
609 assert_eq!(
611 coordinator.get_saga_status(&saga_id),
612 Some(SagaStatus::Completed)
613 );
614 assert_eq!(executor.execution_count.load(Ordering::SeqCst), 1);
615 }
616
617 #[tokio::test]
618 async fn test_failed_saga_compensation() {
619 let mut coordinator = SagaCoordinator::new();
620
621 let executor1 = Arc::new(TestExecutor::new(false));
623 let executor2 = Arc::new(TestExecutor::new(true));
624
625 coordinator.register_executor("success".to_string(), executor1.clone());
626 coordinator.register_executor("fail".to_string(), executor2.clone());
627
628 let steps = vec![
629 SagaStep {
630 step_id: "step1".to_string(),
631 action: SagaAction::new("success".to_string(), b"test".to_vec()),
632 compensation: Some(SagaAction::new(
633 "success".to_string(),
634 b"compensate".to_vec(),
635 )),
636 status: StepStatus::Pending,
637 executed_at: None,
638 compensated_at: None,
639 retry_count: 0,
640 max_retries: 3,
641 },
642 SagaStep {
643 step_id: "step2".to_string(),
644 action: SagaAction::new("fail".to_string(), b"test".to_vec()),
645 compensation: Some(SagaAction::new("fail".to_string(), b"compensate".to_vec())),
646 status: StepStatus::Pending,
647 executed_at: None,
648 compensated_at: None,
649 retry_count: 0,
650 max_retries: 3,
651 },
652 ];
653
654 let saga = SagaInstance::new("test_saga".to_string(), steps);
655 let saga_id = saga.saga_id.clone();
656
657 coordinator.start_saga(saga).await.unwrap();
659
660 coordinator
662 .execute_next_step(saga_id.clone())
663 .await
664 .unwrap();
665
666 {
668 let mut active_sagas = coordinator.active_sagas.lock().unwrap();
669 if let Some(saga) = active_sagas.get_mut(&saga_id) {
670 saga.mark_compensated();
671 }
672 }
673
674 assert_eq!(
676 coordinator.get_saga_status(&saga_id),
677 Some(SagaStatus::Compensated)
678 );
679
680 assert_eq!(executor1.execution_count.load(Ordering::SeqCst), 1);
682 assert_eq!(executor2.execution_count.load(Ordering::SeqCst), 1);
684 }
685}