1use std::fmt::Debug;
2use std::marker::PhantomData;
3
4use crate::audit::SagaAuditLog;
5use crate::cloneable::CloneableAny;
6use crate::erased::ErasedStep;
7use crate::error::{CompensationError, SagaError};
8
9pub struct Saga<Input, Output, Ctx, Err> {
15 steps: Vec<Box<dyn ErasedStep<Ctx, Err>>>,
16 _phantom: PhantomData<(Input, Output)>,
17}
18
19impl<Input, Output, Ctx, Err> Saga<Input, Output, Ctx, Err>
20where
21 Input: Clone + Send + 'static,
22 Output: Send + 'static,
23 Err: Debug,
24{
25 pub fn execute(&self, ctx: &Ctx, input: Input) -> Result<Output, SagaError<Err>> {
34 let (result, _audit_log) = self.execute_internal(ctx, input);
35 result
36 }
37
38 pub fn execute_with_audit(
42 &self,
43 ctx: &Ctx,
44 input: Input,
45 ) -> (Result<Output, SagaError<Err>>, SagaAuditLog) {
46 self.execute_internal(ctx, input)
47 }
48
49 pub(crate) fn from_steps(steps: Vec<Box<dyn ErasedStep<Ctx, Err>>>) -> Self {
50 Self {
51 steps,
52 _phantom: PhantomData,
53 }
54 }
55
56 fn execute_internal(
57 &self,
58 ctx: &Ctx,
59 input: Input,
60 ) -> (Result<Output, SagaError<Err>>, SagaAuditLog) {
61 let mut audit_log = SagaAuditLog::new();
62 let mut compensation_stack: Vec<(usize, Box<dyn CloneableAny>)> = Vec::new();
63
64 let mut current_input: Box<dyn CloneableAny> = Box::new(input);
65
66 for (index, step) in self.steps.iter().enumerate() {
67 audit_log.record_start(step.name());
68
69 let input_clone = current_input.clone_box();
70
71 match step.execute_erased(ctx, current_input) {
72 Ok(output) => {
73 let description = step.compensation_description();
74 audit_log.record_success(description);
75 compensation_stack.push((index, input_clone));
76
77 if index == self.steps.len() - 1 {
78 let typed_output = output
79 .into_any()
80 .downcast::<Output>()
81 .expect("type-state builder guarantees final output type");
82 return (Ok(*typed_output), audit_log);
83 }
84
85 current_input = output;
86 }
87 Err(error) => {
88 audit_log.record_failure();
89 let saga_error = self.compensate(
90 ctx,
91 &mut audit_log,
92 compensation_stack,
93 step.name(),
94 error,
95 );
96 return (Err(saga_error), audit_log);
97 }
98 }
99 }
100
101 unreachable!("saga must have at least one step")
102 }
103
104 fn compensate(
105 &self,
106 ctx: &Ctx,
107 audit_log: &mut SagaAuditLog,
108 mut compensation_stack: Vec<(usize, Box<dyn CloneableAny>)>,
109 failed_step: &str,
110 step_error: Err,
111 ) -> SagaError<Err> {
112 let mut compensation_errors = Vec::new();
113
114 while let Some((index, stored_input)) = compensation_stack.pop() {
115 let step = &self.steps[index];
116 let step_name = step.name();
117 let description = step.compensation_description();
118
119 match step.compensate_erased(ctx, stored_input) {
120 Ok(()) => {
121 audit_log.record_compensated(step_name);
122 }
123 Err(error) => {
124 audit_log.record_compensation_failed(step_name);
125 compensation_errors.push(CompensationError {
126 step: step_name.to_string(),
127 description,
128 error,
129 });
130 }
131 }
132 }
133
134 if compensation_errors.is_empty() {
135 SagaError::StepFailed {
136 step: failed_step.to_string(),
137 source: step_error,
138 }
139 } else {
140 SagaError::CompensationFailed {
141 failed_step: failed_step.to_string(),
142 step_error,
143 compensation_errors,
144 }
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use std::cell::RefCell;
152
153 use super::*;
154 use crate::audit::StepStatus;
155 use crate::builder::SagaBuilder;
156 use crate::step::SagaStep;
157
158 struct TestContext {
159 compensation_log: RefCell<Vec<String>>,
160 }
161
162 #[derive(Debug, PartialEq, thiserror::Error)]
163 #[error("{0}")]
164 struct TestError(String);
165
166 struct AddStep {
167 name: &'static str,
168 value: i32,
169 }
170
171 impl SagaStep for AddStep {
172 type Input = i32;
173 type Output = i32;
174 type Context = TestContext;
175 type Error = TestError;
176
177 fn name(&self) -> &'static str {
178 self.name
179 }
180
181 fn execute(
182 &self,
183 _ctx: &Self::Context,
184 input: Self::Input,
185 ) -> Result<Self::Output, Self::Error> {
186 Ok(input + self.value)
187 }
188
189 fn compensate(&self, ctx: &Self::Context, input: Self::Input) -> Result<(), Self::Error> {
190 ctx.compensation_log
191 .borrow_mut()
192 .push(format!("compensate {} with input {}", self.name, input));
193 Ok(())
194 }
195 }
196
197 struct MultiplyStep {
198 factor: i32,
199 }
200
201 impl SagaStep for MultiplyStep {
202 type Input = i32;
203 type Output = i32;
204 type Context = TestContext;
205 type Error = TestError;
206
207 fn name(&self) -> &'static str {
208 "multiply"
209 }
210
211 fn execute(
212 &self,
213 _ctx: &Self::Context,
214 input: Self::Input,
215 ) -> Result<Self::Output, Self::Error> {
216 Ok(input * self.factor)
217 }
218
219 fn compensate(&self, ctx: &Self::Context, input: Self::Input) -> Result<(), Self::Error> {
220 ctx.compensation_log
221 .borrow_mut()
222 .push(format!("compensate multiply with input {input}"));
223 Ok(())
224 }
225 }
226
227 struct FailingStep {
228 error_msg: String,
229 }
230
231 impl SagaStep for FailingStep {
232 type Input = i32;
233 type Output = i32;
234 type Context = TestContext;
235 type Error = TestError;
236
237 fn name(&self) -> &'static str {
238 "failing"
239 }
240
241 fn execute(
242 &self,
243 _ctx: &Self::Context,
244 _input: Self::Input,
245 ) -> Result<Self::Output, Self::Error> {
246 Err(TestError(self.error_msg.clone()))
247 }
248 }
249
250 struct FailingCompensationStep {
251 name: &'static str,
252 }
253
254 impl SagaStep for FailingCompensationStep {
255 type Input = i32;
256 type Output = i32;
257 type Context = TestContext;
258 type Error = TestError;
259
260 fn name(&self) -> &'static str {
261 self.name
262 }
263
264 fn execute(
265 &self,
266 _ctx: &Self::Context,
267 input: Self::Input,
268 ) -> Result<Self::Output, Self::Error> {
269 Ok(input)
270 }
271
272 fn compensate(&self, _ctx: &Self::Context, _input: Self::Input) -> Result<(), Self::Error> {
273 Err(TestError(format!("compensation failed for {}", self.name)))
274 }
275 }
276
277 struct ReadOnlyStep;
278
279 impl SagaStep for ReadOnlyStep {
280 type Input = i32;
281 type Output = i32;
282 type Context = TestContext;
283 type Error = TestError;
284
285 fn name(&self) -> &'static str {
286 "read_only"
287 }
288
289 fn execute(
290 &self,
291 _ctx: &Self::Context,
292 input: Self::Input,
293 ) -> Result<Self::Output, Self::Error> {
294 Ok(input)
295 }
296 }
297
298 struct IntToString;
299
300 impl SagaStep for IntToString {
301 type Input = i32;
302 type Output = String;
303 type Context = TestContext;
304 type Error = TestError;
305
306 fn name(&self) -> &'static str {
307 "int_to_string"
308 }
309
310 fn execute(
311 &self,
312 _ctx: &Self::Context,
313 input: Self::Input,
314 ) -> Result<Self::Output, Self::Error> {
315 Ok(input.to_string())
316 }
317
318 fn compensate(&self, ctx: &Self::Context, input: Self::Input) -> Result<(), Self::Error> {
319 ctx.compensation_log
320 .borrow_mut()
321 .push(format!("compensate int_to_string with input {input}"));
322 Ok(())
323 }
324 }
325
326 struct AppendSuffix {
327 suffix: &'static str,
328 }
329
330 impl SagaStep for AppendSuffix {
331 type Input = String;
332 type Output = String;
333 type Context = TestContext;
334 type Error = TestError;
335
336 fn name(&self) -> &'static str {
337 "append_suffix"
338 }
339
340 fn execute(
341 &self,
342 _ctx: &Self::Context,
343 input: Self::Input,
344 ) -> Result<Self::Output, Self::Error> {
345 Ok(format!("{}{}", input, self.suffix))
346 }
347
348 fn compensate(&self, ctx: &Self::Context, input: Self::Input) -> Result<(), Self::Error> {
349 ctx.compensation_log
350 .borrow_mut()
351 .push(format!("compensate append_suffix with input {input}"));
352 Ok(())
353 }
354 }
355
356 struct FailingStringStep {
357 error_msg: String,
358 }
359
360 impl SagaStep for FailingStringStep {
361 type Input = String;
362 type Output = String;
363 type Context = TestContext;
364 type Error = TestError;
365
366 fn name(&self) -> &'static str {
367 "failing_string"
368 }
369
370 fn execute(
371 &self,
372 _ctx: &Self::Context,
373 _input: Self::Input,
374 ) -> Result<Self::Output, Self::Error> {
375 Err(TestError(self.error_msg.clone()))
376 }
377 }
378
379 #[test]
380 fn multi_step_saga_flows_data_through_steps() -> anyhow::Result<()> {
381 let ctx = TestContext {
382 compensation_log: RefCell::new(Vec::new()),
383 };
384
385 let saga = SagaBuilder::new()
386 .first_step(AddStep {
387 name: "add_10",
388 value: 10,
389 })
390 .then(MultiplyStep { factor: 3 })
391 .then(AddStep {
392 name: "add_5",
393 value: 5,
394 })
395 .build();
396
397 let result = saga.execute(&ctx, 5)?;
398
399 assert_eq!(result, 50);
400 Ok(())
401 }
402
403 #[test]
404 fn compensation_happens_in_lifo_order_with_stored_inputs() {
405 let ctx = TestContext {
406 compensation_log: RefCell::new(Vec::new()),
407 };
408
409 let saga = SagaBuilder::new()
410 .first_step(AddStep {
411 name: "add_10",
412 value: 10,
413 })
414 .then(MultiplyStep { factor: 3 })
415 .then(FailingStep {
416 error_msg: "boom".to_string(),
417 })
418 .build();
419
420 let result = saga.execute(&ctx, 5);
421
422 assert!(result.is_err());
423
424 let comp_log = ctx.compensation_log.borrow();
425 assert_eq!(comp_log.len(), 2);
426 assert_eq!(comp_log[0], "compensate multiply with input 15");
427 assert_eq!(comp_log[1], "compensate add_10 with input 5");
428 }
429
430 #[test]
431 fn read_only_step_uses_default_no_op_compensation() {
432 let ctx = TestContext {
433 compensation_log: RefCell::new(Vec::new()),
434 };
435
436 let saga = SagaBuilder::new()
437 .first_step(ReadOnlyStep)
438 .then(FailingStep {
439 error_msg: "boom".to_string(),
440 })
441 .build();
442
443 let result = saga.execute(&ctx, 42);
444
445 assert!(result.is_err());
446 let comp_log = ctx.compensation_log.borrow();
447 assert!(comp_log.is_empty());
448 }
449
450 #[test]
451 fn first_step_failure_requires_no_compensation() {
452 let ctx = TestContext {
453 compensation_log: RefCell::new(Vec::new()),
454 };
455
456 let saga = SagaBuilder::new()
457 .first_step(FailingStep {
458 error_msg: "immediate failure".to_string(),
459 })
460 .build();
461
462 let result = saga.execute(&ctx, 42);
463
464 assert!(result.is_err());
465 let err = result.expect_err("should be an error");
466 assert!(matches!(err, SagaError::StepFailed { step, .. } if step == "failing"));
467
468 let comp_log = ctx.compensation_log.borrow();
469 assert!(comp_log.is_empty());
470 }
471
472 #[test]
473 fn compensation_failure_returns_compensation_failed_error() {
474 let ctx = TestContext {
475 compensation_log: RefCell::new(Vec::new()),
476 };
477
478 let saga = SagaBuilder::new()
479 .first_step(AddStep {
480 name: "add_10",
481 value: 10,
482 })
483 .then(FailingCompensationStep {
484 name: "will_fail_comp",
485 })
486 .then(FailingStep {
487 error_msg: "trigger compensation".to_string(),
488 })
489 .build();
490
491 let result = saga.execute(&ctx, 5);
492
493 let err = result.expect_err("should be an error");
494 match err {
495 SagaError::CompensationFailed {
496 failed_step,
497 compensation_errors,
498 ..
499 } => {
500 assert_eq!(failed_step, "failing");
501 assert_eq!(compensation_errors.len(), 1);
502 assert_eq!(compensation_errors[0].step, "will_fail_comp");
503 }
504 SagaError::StepFailed { .. } => {
505 panic!("expected CompensationFailed error");
506 }
507 }
508
509 let comp_log = ctx.compensation_log.borrow();
510 assert_eq!(comp_log.len(), 1);
511 assert_eq!(comp_log[0], "compensate add_10 with input 5");
512 }
513
514 #[test]
515 fn execute_with_audit_returns_audit_log() -> anyhow::Result<()> {
516 let ctx = TestContext {
517 compensation_log: RefCell::new(Vec::new()),
518 };
519
520 let saga = SagaBuilder::new()
521 .first_step(AddStep {
522 name: "add_10",
523 value: 10,
524 })
525 .then(MultiplyStep { factor: 2 })
526 .build();
527
528 let (result, audit_log) = saga.execute_with_audit(&ctx, 5);
529
530 assert!(result.is_ok());
531 assert_eq!(result?, 30);
532
533 let records = audit_log.records();
534 assert_eq!(records.len(), 2);
535 assert_eq!(records[0].name, "add_10");
536 assert_eq!(records[0].status, StepStatus::Executed);
537 assert_eq!(records[1].name, "multiply");
538 assert_eq!(records[1].status, StepStatus::Executed);
539
540 Ok(())
541 }
542
543 #[test]
544 fn audit_log_tracks_compensation_status() {
545 let ctx = TestContext {
546 compensation_log: RefCell::new(Vec::new()),
547 };
548
549 let saga = SagaBuilder::new()
550 .first_step(AddStep {
551 name: "add_10",
552 value: 10,
553 })
554 .then(FailingCompensationStep {
555 name: "will_fail_comp",
556 })
557 .then(FailingStep {
558 error_msg: "trigger compensation".to_string(),
559 })
560 .build();
561
562 let (result, audit_log) = saga.execute_with_audit(&ctx, 5);
563
564 assert!(result.is_err());
565
566 let records = audit_log.records();
567 assert_eq!(records.len(), 3);
568 assert_eq!(records[0].name, "add_10");
569 assert_eq!(records[0].status, StepStatus::Compensated);
570 assert_eq!(records[1].name, "will_fail_comp");
571 assert_eq!(records[1].status, StepStatus::CompensationFailed);
572 assert_eq!(records[2].name, "failing");
573 assert_eq!(records[2].status, StepStatus::Failed);
574 }
575
576 #[test]
577 fn typed_data_flow_across_different_types() -> anyhow::Result<()> {
578 let ctx = TestContext {
579 compensation_log: RefCell::new(Vec::new()),
580 };
581
582 let saga = SagaBuilder::new()
583 .first_step(IntToString)
584 .then(AppendSuffix { suffix: "_suffix" })
585 .build();
586
587 let result = saga.execute(&ctx, 42)?;
588
589 assert_eq!(result, "42_suffix");
590 Ok(())
591 }
592
593 #[test]
594 fn compensation_with_different_types_uses_correct_inputs() {
595 let ctx = TestContext {
596 compensation_log: RefCell::new(Vec::new()),
597 };
598
599 let saga = SagaBuilder::new()
600 .first_step(IntToString)
601 .then(AppendSuffix { suffix: "_suffix" })
602 .then(FailingStringStep {
603 error_msg: "boom".to_string(),
604 })
605 .build();
606
607 let result = saga.execute(&ctx, 42);
608
609 assert!(result.is_err());
610
611 let comp_log = ctx.compensation_log.borrow();
612 assert_eq!(comp_log.len(), 2);
613 assert_eq!(comp_log[0], "compensate append_suffix with input 42");
614 assert_eq!(comp_log[1], "compensate int_to_string with input 42");
615 }
616}