Skip to main content

changeset_saga/
saga.rs

1use std::fmt::Debug;
2use std::marker::PhantomData;
3
4use crate::audit::SagaAuditLog;
5use crate::cloneable::CloneableAny;
6use crate::erased::ErasedStep;
7use crate::error::{CompensationError, SagaError};
8
9/// A compiled saga ready for execution.
10///
11/// Sagas execute a sequence of steps, where each step's output becomes the
12/// next step's input. If any step fails, previously completed steps are
13/// compensated in reverse order (LIFO).
14pub struct Saga<Input, Output, Ctx, Err> {
15    steps: Vec<Box<dyn ErasedStep<Ctx, Err>>>,
16    _phantom: PhantomData<(Input, Output)>,
17}
18
19impl<Input, Output, Ctx, Err> Saga<Input, Output, Ctx, Err>
20where
21    Input: Clone + Send + 'static,
22    Output: Send + 'static,
23    Err: Debug,
24{
25    /// Execute the saga, returning the final output on success.
26    ///
27    /// On failure, compensates all previously completed steps in reverse order.
28    ///
29    /// # Errors
30    ///
31    /// Returns `SagaError::StepFailed` if a step fails and all compensations succeed.
32    /// Returns `SagaError::CompensationFailed` if a step fails and some compensations also fail.
33    pub fn execute(&self, ctx: &Ctx, input: Input) -> Result<Output, SagaError<Err>> {
34        let (result, _audit_log) = self.execute_internal(ctx, input);
35        result
36    }
37
38    /// Execute the saga and return both the result and an audit log.
39    ///
40    /// The audit log tracks all step executions and compensations.
41    pub fn execute_with_audit(
42        &self,
43        ctx: &Ctx,
44        input: Input,
45    ) -> (Result<Output, SagaError<Err>>, SagaAuditLog) {
46        self.execute_internal(ctx, input)
47    }
48
49    pub(crate) fn from_steps(steps: Vec<Box<dyn ErasedStep<Ctx, Err>>>) -> Self {
50        Self {
51            steps,
52            _phantom: PhantomData,
53        }
54    }
55
56    fn execute_internal(
57        &self,
58        ctx: &Ctx,
59        input: Input,
60    ) -> (Result<Output, SagaError<Err>>, SagaAuditLog) {
61        let mut audit_log = SagaAuditLog::new();
62        let mut compensation_stack: Vec<(usize, Box<dyn CloneableAny>)> = Vec::new();
63
64        let mut current_input: Box<dyn CloneableAny> = Box::new(input);
65
66        for (index, step) in self.steps.iter().enumerate() {
67            audit_log.record_start(step.name());
68
69            let input_clone = current_input.clone_box();
70
71            match step.execute_erased(ctx, current_input) {
72                Ok(output) => {
73                    let description = step.compensation_description();
74                    audit_log.record_success(description);
75                    compensation_stack.push((index, input_clone));
76
77                    if index == self.steps.len() - 1 {
78                        let typed_output = output
79                            .into_any()
80                            .downcast::<Output>()
81                            .expect("type-state builder guarantees final output type");
82                        return (Ok(*typed_output), audit_log);
83                    }
84
85                    current_input = output;
86                }
87                Err(error) => {
88                    audit_log.record_failure();
89                    let saga_error = self.compensate(
90                        ctx,
91                        &mut audit_log,
92                        compensation_stack,
93                        step.name(),
94                        error,
95                    );
96                    return (Err(saga_error), audit_log);
97                }
98            }
99        }
100
101        unreachable!("saga must have at least one step")
102    }
103
104    fn compensate(
105        &self,
106        ctx: &Ctx,
107        audit_log: &mut SagaAuditLog,
108        mut compensation_stack: Vec<(usize, Box<dyn CloneableAny>)>,
109        failed_step: &str,
110        step_error: Err,
111    ) -> SagaError<Err> {
112        let mut compensation_errors = Vec::new();
113
114        while let Some((index, stored_input)) = compensation_stack.pop() {
115            let step = &self.steps[index];
116            let step_name = step.name();
117            let description = step.compensation_description();
118
119            match step.compensate_erased(ctx, stored_input) {
120                Ok(()) => {
121                    audit_log.record_compensated(step_name);
122                }
123                Err(error) => {
124                    audit_log.record_compensation_failed(step_name);
125                    compensation_errors.push(CompensationError {
126                        step: step_name.to_string(),
127                        description,
128                        error,
129                    });
130                }
131            }
132        }
133
134        if compensation_errors.is_empty() {
135            SagaError::StepFailed {
136                step: failed_step.to_string(),
137                source: step_error,
138            }
139        } else {
140            SagaError::CompensationFailed {
141                failed_step: failed_step.to_string(),
142                step_error,
143                compensation_errors,
144            }
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use std::cell::RefCell;
152
153    use super::*;
154    use crate::audit::StepStatus;
155    use crate::builder::SagaBuilder;
156    use crate::step::SagaStep;
157
158    struct TestContext {
159        compensation_log: RefCell<Vec<String>>,
160    }
161
162    #[derive(Debug, PartialEq, thiserror::Error)]
163    #[error("{0}")]
164    struct TestError(String);
165
166    struct AddStep {
167        name: &'static str,
168        value: i32,
169    }
170
171    impl SagaStep for AddStep {
172        type Input = i32;
173        type Output = i32;
174        type Context = TestContext;
175        type Error = TestError;
176
177        fn name(&self) -> &'static str {
178            self.name
179        }
180
181        fn execute(
182            &self,
183            _ctx: &Self::Context,
184            input: Self::Input,
185        ) -> Result<Self::Output, Self::Error> {
186            Ok(input + self.value)
187        }
188
189        fn compensate(&self, ctx: &Self::Context, input: Self::Input) -> Result<(), Self::Error> {
190            ctx.compensation_log
191                .borrow_mut()
192                .push(format!("compensate {} with input {}", self.name, input));
193            Ok(())
194        }
195    }
196
197    struct MultiplyStep {
198        factor: i32,
199    }
200
201    impl SagaStep for MultiplyStep {
202        type Input = i32;
203        type Output = i32;
204        type Context = TestContext;
205        type Error = TestError;
206
207        fn name(&self) -> &'static str {
208            "multiply"
209        }
210
211        fn execute(
212            &self,
213            _ctx: &Self::Context,
214            input: Self::Input,
215        ) -> Result<Self::Output, Self::Error> {
216            Ok(input * self.factor)
217        }
218
219        fn compensate(&self, ctx: &Self::Context, input: Self::Input) -> Result<(), Self::Error> {
220            ctx.compensation_log
221                .borrow_mut()
222                .push(format!("compensate multiply with input {input}"));
223            Ok(())
224        }
225    }
226
227    struct FailingStep {
228        error_msg: String,
229    }
230
231    impl SagaStep for FailingStep {
232        type Input = i32;
233        type Output = i32;
234        type Context = TestContext;
235        type Error = TestError;
236
237        fn name(&self) -> &'static str {
238            "failing"
239        }
240
241        fn execute(
242            &self,
243            _ctx: &Self::Context,
244            _input: Self::Input,
245        ) -> Result<Self::Output, Self::Error> {
246            Err(TestError(self.error_msg.clone()))
247        }
248    }
249
250    struct FailingCompensationStep {
251        name: &'static str,
252    }
253
254    impl SagaStep for FailingCompensationStep {
255        type Input = i32;
256        type Output = i32;
257        type Context = TestContext;
258        type Error = TestError;
259
260        fn name(&self) -> &'static str {
261            self.name
262        }
263
264        fn execute(
265            &self,
266            _ctx: &Self::Context,
267            input: Self::Input,
268        ) -> Result<Self::Output, Self::Error> {
269            Ok(input)
270        }
271
272        fn compensate(&self, _ctx: &Self::Context, _input: Self::Input) -> Result<(), Self::Error> {
273            Err(TestError(format!("compensation failed for {}", self.name)))
274        }
275    }
276
277    struct ReadOnlyStep;
278
279    impl SagaStep for ReadOnlyStep {
280        type Input = i32;
281        type Output = i32;
282        type Context = TestContext;
283        type Error = TestError;
284
285        fn name(&self) -> &'static str {
286            "read_only"
287        }
288
289        fn execute(
290            &self,
291            _ctx: &Self::Context,
292            input: Self::Input,
293        ) -> Result<Self::Output, Self::Error> {
294            Ok(input)
295        }
296    }
297
298    struct IntToString;
299
300    impl SagaStep for IntToString {
301        type Input = i32;
302        type Output = String;
303        type Context = TestContext;
304        type Error = TestError;
305
306        fn name(&self) -> &'static str {
307            "int_to_string"
308        }
309
310        fn execute(
311            &self,
312            _ctx: &Self::Context,
313            input: Self::Input,
314        ) -> Result<Self::Output, Self::Error> {
315            Ok(input.to_string())
316        }
317
318        fn compensate(&self, ctx: &Self::Context, input: Self::Input) -> Result<(), Self::Error> {
319            ctx.compensation_log
320                .borrow_mut()
321                .push(format!("compensate int_to_string with input {input}"));
322            Ok(())
323        }
324    }
325
326    struct AppendSuffix {
327        suffix: &'static str,
328    }
329
330    impl SagaStep for AppendSuffix {
331        type Input = String;
332        type Output = String;
333        type Context = TestContext;
334        type Error = TestError;
335
336        fn name(&self) -> &'static str {
337            "append_suffix"
338        }
339
340        fn execute(
341            &self,
342            _ctx: &Self::Context,
343            input: Self::Input,
344        ) -> Result<Self::Output, Self::Error> {
345            Ok(format!("{}{}", input, self.suffix))
346        }
347
348        fn compensate(&self, ctx: &Self::Context, input: Self::Input) -> Result<(), Self::Error> {
349            ctx.compensation_log
350                .borrow_mut()
351                .push(format!("compensate append_suffix with input {input}"));
352            Ok(())
353        }
354    }
355
356    struct FailingStringStep {
357        error_msg: String,
358    }
359
360    impl SagaStep for FailingStringStep {
361        type Input = String;
362        type Output = String;
363        type Context = TestContext;
364        type Error = TestError;
365
366        fn name(&self) -> &'static str {
367            "failing_string"
368        }
369
370        fn execute(
371            &self,
372            _ctx: &Self::Context,
373            _input: Self::Input,
374        ) -> Result<Self::Output, Self::Error> {
375            Err(TestError(self.error_msg.clone()))
376        }
377    }
378
379    #[test]
380    fn multi_step_saga_flows_data_through_steps() -> anyhow::Result<()> {
381        let ctx = TestContext {
382            compensation_log: RefCell::new(Vec::new()),
383        };
384
385        let saga = SagaBuilder::new()
386            .first_step(AddStep {
387                name: "add_10",
388                value: 10,
389            })
390            .then(MultiplyStep { factor: 3 })
391            .then(AddStep {
392                name: "add_5",
393                value: 5,
394            })
395            .build();
396
397        let result = saga.execute(&ctx, 5)?;
398
399        assert_eq!(result, 50);
400        Ok(())
401    }
402
403    #[test]
404    fn compensation_happens_in_lifo_order_with_stored_inputs() {
405        let ctx = TestContext {
406            compensation_log: RefCell::new(Vec::new()),
407        };
408
409        let saga = SagaBuilder::new()
410            .first_step(AddStep {
411                name: "add_10",
412                value: 10,
413            })
414            .then(MultiplyStep { factor: 3 })
415            .then(FailingStep {
416                error_msg: "boom".to_string(),
417            })
418            .build();
419
420        let result = saga.execute(&ctx, 5);
421
422        assert!(result.is_err());
423
424        let comp_log = ctx.compensation_log.borrow();
425        assert_eq!(comp_log.len(), 2);
426        assert_eq!(comp_log[0], "compensate multiply with input 15");
427        assert_eq!(comp_log[1], "compensate add_10 with input 5");
428    }
429
430    #[test]
431    fn read_only_step_uses_default_no_op_compensation() {
432        let ctx = TestContext {
433            compensation_log: RefCell::new(Vec::new()),
434        };
435
436        let saga = SagaBuilder::new()
437            .first_step(ReadOnlyStep)
438            .then(FailingStep {
439                error_msg: "boom".to_string(),
440            })
441            .build();
442
443        let result = saga.execute(&ctx, 42);
444
445        assert!(result.is_err());
446        let comp_log = ctx.compensation_log.borrow();
447        assert!(comp_log.is_empty());
448    }
449
450    #[test]
451    fn first_step_failure_requires_no_compensation() {
452        let ctx = TestContext {
453            compensation_log: RefCell::new(Vec::new()),
454        };
455
456        let saga = SagaBuilder::new()
457            .first_step(FailingStep {
458                error_msg: "immediate failure".to_string(),
459            })
460            .build();
461
462        let result = saga.execute(&ctx, 42);
463
464        assert!(result.is_err());
465        let err = result.expect_err("should be an error");
466        assert!(matches!(err, SagaError::StepFailed { step, .. } if step == "failing"));
467
468        let comp_log = ctx.compensation_log.borrow();
469        assert!(comp_log.is_empty());
470    }
471
472    #[test]
473    fn compensation_failure_returns_compensation_failed_error() {
474        let ctx = TestContext {
475            compensation_log: RefCell::new(Vec::new()),
476        };
477
478        let saga = SagaBuilder::new()
479            .first_step(AddStep {
480                name: "add_10",
481                value: 10,
482            })
483            .then(FailingCompensationStep {
484                name: "will_fail_comp",
485            })
486            .then(FailingStep {
487                error_msg: "trigger compensation".to_string(),
488            })
489            .build();
490
491        let result = saga.execute(&ctx, 5);
492
493        let err = result.expect_err("should be an error");
494        match err {
495            SagaError::CompensationFailed {
496                failed_step,
497                compensation_errors,
498                ..
499            } => {
500                assert_eq!(failed_step, "failing");
501                assert_eq!(compensation_errors.len(), 1);
502                assert_eq!(compensation_errors[0].step, "will_fail_comp");
503            }
504            SagaError::StepFailed { .. } => {
505                panic!("expected CompensationFailed error");
506            }
507        }
508
509        let comp_log = ctx.compensation_log.borrow();
510        assert_eq!(comp_log.len(), 1);
511        assert_eq!(comp_log[0], "compensate add_10 with input 5");
512    }
513
514    #[test]
515    fn execute_with_audit_returns_audit_log() -> anyhow::Result<()> {
516        let ctx = TestContext {
517            compensation_log: RefCell::new(Vec::new()),
518        };
519
520        let saga = SagaBuilder::new()
521            .first_step(AddStep {
522                name: "add_10",
523                value: 10,
524            })
525            .then(MultiplyStep { factor: 2 })
526            .build();
527
528        let (result, audit_log) = saga.execute_with_audit(&ctx, 5);
529
530        assert!(result.is_ok());
531        assert_eq!(result?, 30);
532
533        let records = audit_log.records();
534        assert_eq!(records.len(), 2);
535        assert_eq!(records[0].name, "add_10");
536        assert_eq!(records[0].status, StepStatus::Executed);
537        assert_eq!(records[1].name, "multiply");
538        assert_eq!(records[1].status, StepStatus::Executed);
539
540        Ok(())
541    }
542
543    #[test]
544    fn audit_log_tracks_compensation_status() {
545        let ctx = TestContext {
546            compensation_log: RefCell::new(Vec::new()),
547        };
548
549        let saga = SagaBuilder::new()
550            .first_step(AddStep {
551                name: "add_10",
552                value: 10,
553            })
554            .then(FailingCompensationStep {
555                name: "will_fail_comp",
556            })
557            .then(FailingStep {
558                error_msg: "trigger compensation".to_string(),
559            })
560            .build();
561
562        let (result, audit_log) = saga.execute_with_audit(&ctx, 5);
563
564        assert!(result.is_err());
565
566        let records = audit_log.records();
567        assert_eq!(records.len(), 3);
568        assert_eq!(records[0].name, "add_10");
569        assert_eq!(records[0].status, StepStatus::Compensated);
570        assert_eq!(records[1].name, "will_fail_comp");
571        assert_eq!(records[1].status, StepStatus::CompensationFailed);
572        assert_eq!(records[2].name, "failing");
573        assert_eq!(records[2].status, StepStatus::Failed);
574    }
575
576    #[test]
577    fn typed_data_flow_across_different_types() -> anyhow::Result<()> {
578        let ctx = TestContext {
579            compensation_log: RefCell::new(Vec::new()),
580        };
581
582        let saga = SagaBuilder::new()
583            .first_step(IntToString)
584            .then(AppendSuffix { suffix: "_suffix" })
585            .build();
586
587        let result = saga.execute(&ctx, 42)?;
588
589        assert_eq!(result, "42_suffix");
590        Ok(())
591    }
592
593    #[test]
594    fn compensation_with_different_types_uses_correct_inputs() {
595        let ctx = TestContext {
596            compensation_log: RefCell::new(Vec::new()),
597        };
598
599        let saga = SagaBuilder::new()
600            .first_step(IntToString)
601            .then(AppendSuffix { suffix: "_suffix" })
602            .then(FailingStringStep {
603                error_msg: "boom".to_string(),
604            })
605            .build();
606
607        let result = saga.execute(&ctx, 42);
608
609        assert!(result.is_err());
610
611        let comp_log = ctx.compensation_log.borrow();
612        assert_eq!(comp_log.len(), 2);
613        assert_eq!(comp_log[0], "compensate append_suffix with input 42");
614        assert_eq!(comp_log[1], "compensate int_to_string with input 42");
615    }
616}