Skip to main content

durable_lambda_core/operations/
compensation.rs

1//! Saga / compensation pattern operations.
2//!
3//! Implement the saga pattern for distributed transaction rollback:
4//!
5//! 1. [`DurableContext::step_with_compensation`] — execute a forward step and,
6//!    on success, register a type-erased compensation closure that can reverse it.
7//! 2. [`DurableContext::run_compensations`] — execute all registered compensations
8//!    in **reverse registration order** (LIFO — last registered, first executed),
9//!    checkpointing each one with `Context/START + Context/SUCCEED|FAIL` using
10//!    `sub_type = "Compensation"`. All compensations are attempted regardless of
11//!    earlier failures (continue-on-error semantics).
12//!
13//! # Checkpoint Protocol
14//!
15//! Each compensation checkpoint mirrors the child_context pattern:
16//! - `OperationType::Context` + `OperationAction::Start` + `sub_type = "Compensation"`
17//! - `OperationType::Context` + `OperationAction::Succeed` + `sub_type = "Compensation"`
18//! - or `OperationType::Context` + `OperationAction::Fail` on error
19//!
20//! # Replay / Partial Rollback Resume
21//!
22//! During replay, completed compensations (Succeeded or Failed in history) are
23//! skipped — their outcome is read from history. This enables partial rollback
24//! resume: if a Lambda times out mid-compensation, the next invocation replays
25//! the completed ones and continues from the first incomplete one.
26
27use std::future::Future;
28
29use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
30use serde::de::DeserializeOwned;
31use serde::Serialize;
32
33use crate::context::DurableContext;
34use crate::error::DurableError;
35use crate::types::{
36    CompensateFn, CompensationItem, CompensationRecord, CompensationResult, CompensationStatus,
37    StepOptions,
38};
39
40impl DurableContext {
41    /// Execute a forward step and register a compensation closure on success.
42    ///
43    /// Delegates the forward execution to [`step`](Self::step). If the step
44    /// succeeds (returns `Ok(Ok(value))`), the `compensate_fn` closure is
45    /// registered and will be executed by [`run_compensations`](Self::run_compensations).
46    ///
47    /// If the forward step fails (returns `Ok(Err(e))`), no compensation is
48    /// registered — only successful steps have compensations that need undoing.
49    ///
50    /// # Arguments
51    ///
52    /// * `name` — Human-readable name for the forward step operation.
53    /// * `forward_fn` — Closure to execute the forward step.
54    /// * `compensate_fn` — Closure to execute when rolling back; receives the
55    ///   forward step's success value.
56    ///
57    /// # Returns
58    ///
59    /// * `Ok(Ok(T))` — Forward step succeeded; compensation registered.
60    /// * `Ok(Err(E))` — Forward step returned a user error; no compensation registered.
61    /// * `Err(DurableError)` — SDK-level failure (checkpoint, serialization).
62    ///
63    /// # Examples
64    ///
65    /// ```no_run
66    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
67    /// // Book a hotel room and register its cancellation as compensation
68    /// let booking_result: Result<String, String> = ctx.step_with_compensation(
69    ///     "book_hotel",
70    ///     || async { Ok("BOOKING-123".to_string()) },
71    ///     |booking_id| async move {
72    ///         // Cancel the hotel booking
73    ///         println!("Cancelling booking: {booking_id}");
74    ///         Ok(())
75    ///     },
76    /// ).await?;
77    ///
78    /// // Later, roll back all registered compensations
79    /// let comp_result = ctx.run_compensations().await?;
80    /// assert!(comp_result.all_succeeded);
81    /// # Ok(())
82    /// # }
83    /// ```
84    pub async fn step_with_compensation<T, E, F, Fut, G, GFut>(
85        &mut self,
86        name: &str,
87        forward_fn: F,
88        compensate_fn: G,
89    ) -> Result<Result<T, E>, DurableError>
90    where
91        T: Serialize + DeserializeOwned + Send + 'static,
92        E: Serialize + DeserializeOwned + Send + 'static,
93        F: FnOnce() -> Fut + Send + 'static,
94        Fut: Future<Output = Result<T, E>> + Send + 'static,
95        G: FnOnce(T) -> GFut + Send + 'static,
96        GFut: Future<Output = Result<(), DurableError>> + Send + 'static,
97    {
98        let step_result = self.step(name, forward_fn).await?;
99
100        match step_result {
101            Ok(value) => {
102                // Serialize the forward result so we can store it alongside the type-erased fn.
103                let forward_result_json = serde_json::to_value(&value)
104                    .map_err(|e| DurableError::serialization(std::any::type_name::<T>(), e))?;
105
106                // Wrap the typed compensation fn into a type-erased CompensateFn that
107                // deserializes the JSON back to T before calling the original closure.
108                let wrapped: CompensateFn = Box::new(move |json_value: serde_json::Value| {
109                    Box::pin(async move {
110                        let deserialized: T = serde_json::from_value(json_value).map_err(|e| {
111                            DurableError::deserialization(std::any::type_name::<T>(), e)
112                        })?;
113                        compensate_fn(deserialized).await
114                    })
115                });
116
117                self.push_compensation(CompensationRecord {
118                    name: name.to_string(),
119                    forward_result_json,
120                    compensate_fn: wrapped,
121                });
122
123                Ok(Ok(value))
124            }
125            Err(e) => {
126                // Forward step returned a user error — no compensation needed.
127                Ok(Err(e))
128            }
129        }
130    }
131
132    /// Execute a forward step (with options) and register a compensation closure on success.
133    ///
134    /// Like [`step_with_compensation`](Self::step_with_compensation) but accepts
135    /// [`StepOptions`] for configuring retries, backoff, and timeouts on the
136    /// forward step.
137    ///
138    /// # Arguments
139    ///
140    /// * `name` — Human-readable name for the forward step operation.
141    /// * `options` — Step configuration (retries, backoff, timeout).
142    /// * `forward_fn` — Closure to execute the forward step.
143    /// * `compensate_fn` — Closure to execute when rolling back.
144    ///
145    /// # Examples
146    ///
147    /// ```no_run
148    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
149    /// use durable_lambda_core::types::StepOptions;
150    ///
151    /// let result: Result<String, String> = ctx.step_with_compensation_opts(
152    ///     "book_hotel",
153    ///     StepOptions::new().retries(3),
154    ///     || async { Ok("BOOKING-123".to_string()) },
155    ///     |booking_id| async move {
156    ///         println!("Cancelling: {booking_id}");
157    ///         Ok(())
158    ///     },
159    /// ).await?;
160    /// # Ok(())
161    /// # }
162    /// ```
163    pub async fn step_with_compensation_opts<T, E, F, Fut, G, GFut>(
164        &mut self,
165        name: &str,
166        options: StepOptions,
167        forward_fn: F,
168        compensate_fn: G,
169    ) -> Result<Result<T, E>, DurableError>
170    where
171        T: Serialize + DeserializeOwned + Send + 'static,
172        E: Serialize + DeserializeOwned + Send + 'static,
173        F: FnOnce() -> Fut + Send + 'static,
174        Fut: Future<Output = Result<T, E>> + Send + 'static,
175        G: FnOnce(T) -> GFut + Send + 'static,
176        GFut: Future<Output = Result<(), DurableError>> + Send + 'static,
177    {
178        let step_result = self.step_with_options(name, options, forward_fn).await?;
179
180        match step_result {
181            Ok(value) => {
182                let forward_result_json = serde_json::to_value(&value)
183                    .map_err(|e| DurableError::serialization(std::any::type_name::<T>(), e))?;
184
185                let wrapped: CompensateFn = Box::new(move |json_value: serde_json::Value| {
186                    Box::pin(async move {
187                        let deserialized: T = serde_json::from_value(json_value).map_err(|e| {
188                            DurableError::deserialization(std::any::type_name::<T>(), e)
189                        })?;
190                        compensate_fn(deserialized).await
191                    })
192                });
193
194                self.push_compensation(CompensationRecord {
195                    name: name.to_string(),
196                    forward_result_json,
197                    compensate_fn: wrapped,
198                });
199
200                Ok(Ok(value))
201            }
202            Err(e) => Ok(Err(e)),
203        }
204    }
205
206    /// Execute all registered compensations in reverse registration order.
207    ///
208    /// Drains the registered compensations and executes them in LIFO order
209    /// (last registered runs first — stack semantics). Each compensation is
210    /// checkpointed with `Context/START + Context/SUCCEED|FAIL` using
211    /// `sub_type = "Compensation"`.
212    ///
213    /// All compensations are attempted even if earlier ones fail. The returned
214    /// [`CompensationResult`] captures the per-item outcomes.
215    ///
216    /// During replay, completed compensations are skipped — their status is
217    /// read from the execution history to support partial rollback resume.
218    ///
219    /// # Returns
220    ///
221    /// Returns `Ok(CompensationResult)` always (individual failures are captured
222    /// in the result items, not propagated as errors). Returns `Err(DurableError)`
223    /// only on AWS checkpoint failures.
224    ///
225    /// # Examples
226    ///
227    /// ```no_run
228    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
229    /// // After some compensable steps fail:
230    /// let result = ctx.run_compensations().await?;
231    ///
232    /// if !result.all_succeeded {
233    ///     for item in &result.items {
234    ///         if let Some(err) = &item.error {
235    ///             eprintln!("Compensation {} failed: {}", item.name, err);
236    ///         }
237    ///     }
238    /// }
239    /// # Ok(())
240    /// # }
241    /// ```
242    pub async fn run_compensations(&mut self) -> Result<CompensationResult, DurableError> {
243        let mut compensations = self.take_compensations();
244
245        // LIFO: reverse so last registered runs first.
246        compensations.reverse();
247
248        if compensations.is_empty() {
249            return Ok(CompensationResult {
250                items: vec![],
251                all_succeeded: true,
252            });
253        }
254
255        let mut items: Vec<CompensationItem> = Vec::with_capacity(compensations.len());
256
257        for record in compensations {
258            let comp_op_id = self.replay_engine_mut().generate_operation_id();
259            let name = record.name.clone();
260
261            let span = tracing::info_span!(
262                "durable_operation",
263                op.name = %name,
264                op.type = "compensation",
265                op.id = %comp_op_id,
266            );
267            let _guard = span.enter();
268            tracing::trace!("durable_operation");
269
270            // Replay path: check if this compensation already completed.
271            // Extract all needed data BEFORE taking mutable borrow.
272            let replay_outcome = self.replay_engine().check_result(&comp_op_id).map(|op| {
273                let succeeded = op.status == OperationStatus::Succeeded;
274                let error_msg = if !succeeded {
275                    op.context_details()
276                        .and_then(|d| d.error())
277                        .map(|e| {
278                            format!(
279                                "{}: {}",
280                                e.error_type().unwrap_or("Unknown"),
281                                e.error_data().unwrap_or("")
282                            )
283                        })
284                        .or_else(|| Some("compensation failed during replay".to_string()))
285                } else {
286                    None
287                };
288                (succeeded, error_msg)
289            });
290
291            if let Some((succeeded, error_msg)) = replay_outcome {
292                self.replay_engine_mut().track_replay(&comp_op_id);
293                let status = if succeeded {
294                    CompensationStatus::Succeeded
295                } else {
296                    CompensationStatus::Failed
297                };
298                items.push(CompensationItem {
299                    name,
300                    status,
301                    error: error_msg,
302                });
303                continue;
304            }
305
306            // Execute path: send Context/START for this compensation.
307            let start_update = OperationUpdate::builder()
308                .id(comp_op_id.clone())
309                .r#type(OperationType::Context)
310                .action(OperationAction::Start)
311                .sub_type("Compensation")
312                .name(&name)
313                .build()
314                .map_err(|e| DurableError::checkpoint_failed(&name, e))?;
315
316            let start_response = self
317                .backend()
318                .checkpoint(
319                    self.arn(),
320                    self.checkpoint_token(),
321                    vec![start_update],
322                    None,
323                )
324                .await?;
325
326            let new_token = start_response.checkpoint_token().ok_or_else(|| {
327                DurableError::checkpoint_failed(
328                    &name,
329                    std::io::Error::new(
330                        std::io::ErrorKind::InvalidData,
331                        "compensation start checkpoint response missing checkpoint_token",
332                    ),
333                )
334            })?;
335            self.set_checkpoint_token(new_token.to_string());
336
337            if let Some(new_state) = start_response.new_execution_state() {
338                for op in new_state.operations() {
339                    self.replay_engine_mut()
340                        .insert_operation(op.id().to_string(), op.clone());
341                }
342            }
343
344            // Execute the compensation closure inline (no tokio::spawn — strict LIFO order).
345            let comp_result = (record.compensate_fn)(record.forward_result_json).await;
346
347            match comp_result {
348                Ok(()) => {
349                    // Send Context/SUCCEED.
350                    let succeed_update = OperationUpdate::builder()
351                        .id(comp_op_id.clone())
352                        .r#type(OperationType::Context)
353                        .action(OperationAction::Succeed)
354                        .sub_type("Compensation")
355                        .build()
356                        .map_err(|e| DurableError::checkpoint_failed(&name, e))?;
357
358                    let succeed_response = self
359                        .backend()
360                        .checkpoint(
361                            self.arn(),
362                            self.checkpoint_token(),
363                            vec![succeed_update],
364                            None,
365                        )
366                        .await?;
367
368                    let new_token = succeed_response.checkpoint_token().ok_or_else(|| {
369                        DurableError::checkpoint_failed(
370                            &name,
371                            std::io::Error::new(
372                                std::io::ErrorKind::InvalidData,
373                                "compensation succeed checkpoint response missing checkpoint_token",
374                            ),
375                        )
376                    })?;
377                    self.set_checkpoint_token(new_token.to_string());
378
379                    if let Some(new_state) = succeed_response.new_execution_state() {
380                        for op in new_state.operations() {
381                            self.replay_engine_mut()
382                                .insert_operation(op.id().to_string(), op.clone());
383                        }
384                    }
385
386                    self.replay_engine_mut().track_replay(&comp_op_id);
387                    items.push(CompensationItem {
388                        name,
389                        status: CompensationStatus::Succeeded,
390                        error: None,
391                    });
392                }
393                Err(comp_err) => {
394                    let error_msg = comp_err.to_string();
395
396                    // Send Context/FAIL — continue-on-error: do NOT return early.
397                    let fail_update = OperationUpdate::builder()
398                        .id(comp_op_id.clone())
399                        .r#type(OperationType::Context)
400                        .action(OperationAction::Fail)
401                        .sub_type("Compensation")
402                        .build()
403                        .map_err(|e| DurableError::checkpoint_failed(&name, e))?;
404
405                    let fail_response = self
406                        .backend()
407                        .checkpoint(self.arn(), self.checkpoint_token(), vec![fail_update], None)
408                        .await?;
409
410                    let new_token = fail_response.checkpoint_token().ok_or_else(|| {
411                        DurableError::checkpoint_failed(
412                            &name,
413                            std::io::Error::new(
414                                std::io::ErrorKind::InvalidData,
415                                "compensation fail checkpoint response missing checkpoint_token",
416                            ),
417                        )
418                    })?;
419                    self.set_checkpoint_token(new_token.to_string());
420
421                    if let Some(new_state) = fail_response.new_execution_state() {
422                        for op in new_state.operations() {
423                            self.replay_engine_mut()
424                                .insert_operation(op.id().to_string(), op.clone());
425                        }
426                    }
427
428                    self.replay_engine_mut().track_replay(&comp_op_id);
429                    items.push(CompensationItem {
430                        name,
431                        status: CompensationStatus::Failed,
432                        error: Some(error_msg),
433                    });
434                    // Continue to next compensation — do NOT abort.
435                }
436            }
437        }
438
439        let all_succeeded = items
440            .iter()
441            .all(|i| i.status == CompensationStatus::Succeeded);
442
443        Ok(CompensationResult {
444            items,
445            all_succeeded,
446        })
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use std::sync::Arc;
453
454    use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
455    use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
456    use aws_sdk_lambda::types::{
457        Operation, OperationAction, OperationStatus, OperationType, OperationUpdate,
458    };
459    use aws_smithy_types::DateTime;
460    use tokio::sync::Mutex;
461
462    use crate::backend::DurableBackend;
463    use crate::context::DurableContext;
464    use crate::error::DurableError;
465    use crate::types::{CompensationRecord, CompensationStatus};
466
467    #[derive(Debug, Clone)]
468    #[allow(dead_code)]
469    struct CheckpointCall {
470        arn: String,
471        checkpoint_token: String,
472        updates: Vec<OperationUpdate>,
473    }
474
475    /// MockBackend that records all checkpoint calls.
476    struct CompensationMockBackend {
477        calls: Arc<Mutex<Vec<CheckpointCall>>>,
478    }
479
480    impl CompensationMockBackend {
481        fn new() -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
482            let calls = Arc::new(Mutex::new(Vec::new()));
483            let backend = Self {
484                calls: calls.clone(),
485            };
486            (backend, calls)
487        }
488    }
489
490    #[async_trait::async_trait]
491    impl DurableBackend for CompensationMockBackend {
492        async fn checkpoint(
493            &self,
494            arn: &str,
495            checkpoint_token: &str,
496            updates: Vec<OperationUpdate>,
497            _client_token: Option<&str>,
498        ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
499            self.calls.lock().await.push(CheckpointCall {
500                arn: arn.to_string(),
501                checkpoint_token: checkpoint_token.to_string(),
502                updates,
503            });
504            Ok(CheckpointDurableExecutionOutput::builder()
505                .checkpoint_token("mock-token")
506                .build())
507        }
508
509        async fn get_execution_state(
510            &self,
511            _arn: &str,
512            _checkpoint_token: &str,
513            _next_marker: &str,
514            _max_items: i32,
515        ) -> Result<GetDurableExecutionStateOutput, DurableError> {
516            Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
517        }
518    }
519
520    fn first_op_id() -> String {
521        let mut gen = crate::operation_id::OperationIdGenerator::new(None);
522        gen.next_id()
523    }
524
525    fn second_op_id() -> String {
526        let mut gen = crate::operation_id::OperationIdGenerator::new(None);
527        let _ = gen.next_id(); // skip first
528        gen.next_id()
529    }
530
531    async fn make_empty_ctx(backend: CompensationMockBackend) -> DurableContext {
532        DurableContext::new(
533            Arc::new(backend),
534            "arn:test".to_string(),
535            "tok".to_string(),
536            vec![],
537            None,
538        )
539        .await
540        .unwrap()
541    }
542
543    fn make_context_op(id: &str, status: OperationStatus) -> Operation {
544        Operation::builder()
545            .id(id)
546            .r#type(OperationType::Context)
547            .status(status)
548            .start_timestamp(DateTime::from_secs(0))
549            .build()
550            .unwrap()
551    }
552
553    // ─── step_with_compensation tests ────────────────────────────────────
554
555    #[tokio::test]
556    async fn test_step_with_compensation_returns_ok_ok_on_success() {
557        let (backend, _calls) = CompensationMockBackend::new();
558        let mut ctx = make_empty_ctx(backend).await;
559
560        let result: Result<Result<i32, String>, DurableError> = ctx
561            .step_with_compensation(
562                "charge",
563                || async { Ok::<i32, String>(42) },
564                |_value| async move { Ok(()) },
565            )
566            .await;
567
568        let inner = result.unwrap();
569        assert!(inner.is_ok(), "expected Ok(42), got {inner:?}");
570        assert_eq!(inner.unwrap(), 42);
571    }
572
573    #[tokio::test]
574    async fn test_step_with_compensation_returns_ok_err_on_forward_failure() {
575        let (backend, _calls) = CompensationMockBackend::new();
576        let mut ctx = make_empty_ctx(backend).await;
577
578        let result: Result<Result<i32, String>, DurableError> = ctx
579            .step_with_compensation(
580                "charge",
581                || async { Err::<i32, String>("payment declined".to_string()) },
582                |_value| async move { Ok(()) },
583            )
584            .await;
585
586        let inner = result.unwrap();
587        assert!(inner.is_err(), "expected Err, got {inner:?}");
588        assert_eq!(inner.unwrap_err(), "payment declined");
589    }
590
591    #[tokio::test]
592    async fn test_step_with_compensation_registers_compensation_on_success() {
593        let (backend, _calls) = CompensationMockBackend::new();
594        let mut ctx = make_empty_ctx(backend).await;
595
596        assert_eq!(ctx.compensation_count(), 0);
597
598        let _: Result<Result<i32, String>, DurableError> = ctx
599            .step_with_compensation(
600                "charge",
601                || async { Ok::<i32, String>(42) },
602                |_value| async move { Ok(()) },
603            )
604            .await;
605
606        assert_eq!(
607            ctx.compensation_count(),
608            1,
609            "compensation should be registered"
610        );
611    }
612
613    #[tokio::test]
614    async fn test_step_with_compensation_does_not_register_on_forward_failure() {
615        let (backend, _calls) = CompensationMockBackend::new();
616        let mut ctx = make_empty_ctx(backend).await;
617
618        let _: Result<Result<i32, String>, DurableError> = ctx
619            .step_with_compensation(
620                "charge",
621                || async { Err::<i32, String>("declined".to_string()) },
622                |_value| async move { Ok(()) },
623            )
624            .await;
625
626        assert_eq!(
627            ctx.compensation_count(),
628            0,
629            "no compensation should be registered when forward step fails"
630        );
631    }
632
633    // ─── run_compensations tests ─────────────────────────────────────────
634
635    #[tokio::test]
636    async fn test_run_compensations_with_zero_returns_empty_all_succeeded() {
637        let (backend, _calls) = CompensationMockBackend::new();
638        let mut ctx = make_empty_ctx(backend).await;
639
640        let result = ctx.run_compensations().await.unwrap();
641
642        assert!(
643            result.all_succeeded,
644            "empty run should be all_succeeded=true"
645        );
646        assert!(result.items.is_empty(), "items should be empty");
647    }
648
649    #[tokio::test]
650    async fn test_run_compensations_executes_in_reverse_order() {
651        let execution_order: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
652
653        let (backend, _calls) = CompensationMockBackend::new();
654        let mut ctx = make_empty_ctx(backend).await;
655
656        // Register 3 compensations
657        for i in 1..=3_i32 {
658            let order_clone = execution_order.clone();
659            let label = format!("step{i}");
660            let _: Result<Result<i32, String>, DurableError> = ctx
661                .step_with_compensation(
662                    &label.clone(),
663                    move || async move { Ok::<i32, String>(i) },
664                    move |_value| {
665                        let order = order_clone.clone();
666                        let label = label.clone();
667                        async move {
668                            order.lock().await.push(label);
669                            Ok(())
670                        }
671                    },
672                )
673                .await;
674        }
675
676        assert_eq!(ctx.compensation_count(), 3);
677        let result = ctx.run_compensations().await.unwrap();
678        assert!(result.all_succeeded);
679
680        // Registered: step1, step2, step3 → executes: step3, step2, step1
681        let order = execution_order.lock().await;
682        assert_eq!(
683            order.as_slice(),
684            &["step3", "step2", "step1"],
685            "compensations must run in reverse registration order, got: {order:?}"
686        );
687    }
688
689    #[tokio::test]
690    async fn test_run_compensations_sends_context_start_and_succeed() {
691        let (backend, calls) = CompensationMockBackend::new();
692        let mut ctx = make_empty_ctx(backend).await;
693
694        let _: Result<Result<i32, String>, DurableError> = ctx
695            .step_with_compensation(
696                "refund",
697                || async { Ok::<i32, String>(99) },
698                |_value| async move { Ok(()) },
699            )
700            .await;
701
702        // Clear the step checkpoints by getting their count
703        let step_calls_count = calls.lock().await.len();
704
705        let result = ctx.run_compensations().await.unwrap();
706        assert!(result.all_succeeded);
707
708        let all_calls = calls.lock().await;
709        let comp_calls = &all_calls[step_calls_count..]; // only compensation checkpoints
710
711        assert_eq!(
712            comp_calls.len(),
713            2,
714            "expected Context/START + Context/SUCCEED for compensation, got {}",
715            comp_calls.len()
716        );
717
718        // First: Context/START with sub_type "Compensation"
719        assert_eq!(comp_calls[0].updates[0].r#type(), &OperationType::Context);
720        assert_eq!(comp_calls[0].updates[0].action(), &OperationAction::Start);
721        assert_eq!(comp_calls[0].updates[0].sub_type(), Some("Compensation"));
722        assert_eq!(comp_calls[0].updates[0].name(), Some("refund"));
723
724        // Second: Context/SUCCEED with sub_type "Compensation"
725        assert_eq!(comp_calls[1].updates[0].r#type(), &OperationType::Context);
726        assert_eq!(comp_calls[1].updates[0].action(), &OperationAction::Succeed);
727        assert_eq!(comp_calls[1].updates[0].sub_type(), Some("Compensation"));
728    }
729
730    #[tokio::test]
731    async fn test_run_compensations_captures_failure_per_item() {
732        let (backend, _calls) = CompensationMockBackend::new();
733        let mut ctx = make_empty_ctx(backend).await;
734
735        let _: Result<Result<i32, String>, DurableError> = ctx
736            .step_with_compensation(
737                "charge",
738                || async { Ok::<i32, String>(10) },
739                |_value| async move {
740                    Err(DurableError::checkpoint_failed(
741                        "charge",
742                        std::io::Error::new(std::io::ErrorKind::Other, "reversal failed"),
743                    ))
744                },
745            )
746            .await;
747
748        let result = ctx.run_compensations().await.unwrap();
749
750        assert!(
751            !result.all_succeeded,
752            "should not be all_succeeded when a compensation fails"
753        );
754        assert_eq!(result.items.len(), 1);
755        assert_eq!(result.items[0].status, CompensationStatus::Failed);
756        assert!(
757            result.items[0].error.is_some(),
758            "failed compensation should have error message"
759        );
760    }
761
762    #[tokio::test]
763    async fn test_run_compensations_continues_after_failure() {
764        let execution_order: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
765
766        let (backend, _calls) = CompensationMockBackend::new();
767        let mut ctx = make_empty_ctx(backend).await;
768
769        // Register step1 with a FAILING compensation
770        let order1 = execution_order.clone();
771        let _: Result<Result<i32, String>, DurableError> = ctx
772            .step_with_compensation(
773                "step1",
774                || async { Ok::<i32, String>(1) },
775                move |_| {
776                    let order = order1.clone();
777                    async move {
778                        order.lock().await.push("step1".to_string());
779                        Err(DurableError::checkpoint_failed(
780                            "step1",
781                            std::io::Error::new(std::io::ErrorKind::Other, "fail"),
782                        ))
783                    }
784                },
785            )
786            .await;
787
788        // Register step2 with a SUCCEEDING compensation
789        let order2 = execution_order.clone();
790        let _: Result<Result<i32, String>, DurableError> = ctx
791            .step_with_compensation(
792                "step2",
793                || async { Ok::<i32, String>(2) },
794                move |_| {
795                    let order = order2.clone();
796                    async move {
797                        order.lock().await.push("step2".to_string());
798                        Ok(())
799                    }
800                },
801            )
802            .await;
803
804        let result = ctx.run_compensations().await.unwrap();
805
806        // Both should have been attempted: step2 first (LIFO), then step1
807        let order = execution_order.lock().await;
808        assert_eq!(
809            order.as_slice(),
810            &["step2", "step1"],
811            "both compensations must run regardless of step1 failure"
812        );
813
814        assert!(!result.all_succeeded);
815        assert_eq!(result.items.len(), 2);
816        assert_eq!(result.items[0].status, CompensationStatus::Succeeded); // step2 ran first
817        assert_eq!(result.items[1].status, CompensationStatus::Failed); // step1 ran second
818    }
819
820    #[tokio::test]
821    async fn test_run_compensations_all_succeeded_false_when_any_fails() {
822        let (backend, _calls) = CompensationMockBackend::new();
823        let mut ctx = make_empty_ctx(backend).await;
824
825        let _: Result<Result<i32, String>, DurableError> = ctx
826            .step_with_compensation(
827                "step",
828                || async { Ok::<i32, String>(1) },
829                |_| async move {
830                    Err(DurableError::checkpoint_failed(
831                        "step",
832                        std::io::Error::new(std::io::ErrorKind::Other, "fail"),
833                    ))
834                },
835            )
836            .await;
837
838        let result = ctx.run_compensations().await.unwrap();
839        assert!(!result.all_succeeded);
840    }
841
842    #[tokio::test]
843    async fn test_run_compensations_replay_skips_completed() {
844        // Pre-load a context_op at the FIRST op_id (since no other ops have consumed it).
845        // When run_compensations() is called, it generates op_ids starting from 0.
846        // The compensation registered here will get first_op_id() as its op_id.
847        // Since first_op_id() is pre-loaded as Succeeded, the closure must NOT execute.
848
849        let first_op = first_op_id();
850        let comp_op_replay = make_context_op(&first_op, OperationStatus::Succeeded);
851
852        let (backend, calls) = CompensationMockBackend::new();
853        let mut ctx = DurableContext::new(
854            Arc::new(backend),
855            "arn:test".to_string(),
856            "tok".to_string(),
857            vec![comp_op_replay],
858            None,
859        )
860        .await
861        .unwrap();
862
863        let compensation_ran = Arc::new(Mutex::new(false));
864        let ran_clone = compensation_ran.clone();
865
866        let record = CompensationRecord {
867            name: "refund".to_string(),
868            forward_result_json: serde_json::json!(42),
869            compensate_fn: Box::new(move |_| {
870                let flag = ran_clone.clone();
871                Box::pin(async move {
872                    *flag.lock().await = true;
873                    Ok(())
874                })
875            }),
876        };
877        ctx.push_compensation(record);
878
879        let result = ctx.run_compensations().await.unwrap();
880
881        // Compensation should be replayed (skipped from execution)
882        let ran = *compensation_ran.lock().await;
883        assert!(
884            !ran,
885            "compensation closure should NOT execute during replay"
886        );
887
888        // No checkpoint calls during replay
889        let captured = calls.lock().await;
890        assert_eq!(captured.len(), 0, "no checkpoints during replay");
891
892        // Result should reflect the replayed status
893        assert_eq!(result.items.len(), 1);
894        assert_eq!(result.items[0].status, CompensationStatus::Succeeded);
895        assert!(result.all_succeeded);
896    }
897
898    #[tokio::test]
899    async fn test_run_compensations_partial_rollback_resume() {
900        // Simulate partial rollback: 3 compensations registered, first 2 already
901        // completed in history. Only the 3rd should execute.
902        //
903        // Compensation op_ids are generated FIRST (before any step op_ids in this
904        // fresh context). So:
905        // - comp3 (last registered, runs first LIFO) → op_id = first_op_id()
906        // - comp2 → op_id = second_op_id()
907        // - comp1 (first registered, runs last LIFO) → op_id = third_op_id()
908
909        let comp3_op_id = first_op_id();
910        let comp2_op_id = second_op_id();
911
912        // Pre-load history: comp3 and comp2 already completed (Succeeded)
913        let comp3_op = make_context_op(&comp3_op_id, OperationStatus::Succeeded);
914        let comp2_op = make_context_op(&comp2_op_id, OperationStatus::Succeeded);
915
916        let (backend, calls) = CompensationMockBackend::new();
917        let mut ctx = DurableContext::new(
918            Arc::new(backend),
919            "arn:test".to_string(),
920            "tok".to_string(),
921            vec![comp3_op, comp2_op],
922            None,
923        )
924        .await
925        .unwrap();
926
927        let execution_order: Arc<Mutex<Vec<i32>>> = Arc::new(Mutex::new(Vec::new()));
928
929        // Register 3 compensations (step1, step2, step3)
930        for i in [1_i32, 2, 3] {
931            let order = execution_order.clone();
932            let record = CompensationRecord {
933                name: format!("step{i}"),
934                forward_result_json: serde_json::json!(i),
935                compensate_fn: Box::new(move |_| {
936                    let o = order.clone();
937                    Box::pin(async move {
938                        o.lock().await.push(i);
939                        Ok(())
940                    })
941                }),
942            };
943            ctx.push_compensation(record);
944        }
945
946        let result = ctx.run_compensations().await.unwrap();
947
948        // comp3 (i=3) and comp2 (i=2) replayed → only comp1 (i=1) actually executed
949        let order = execution_order.lock().await;
950        assert_eq!(
951            order.as_slice(),
952            &[1],
953            "only comp1 should execute; comp3 and comp2 are already done in history"
954        );
955
956        assert!(result.all_succeeded);
957        assert_eq!(result.items.len(), 3);
958
959        // Check that we only sent checkpoints for comp1 (the one that actually executed)
960        let captured = calls.lock().await;
961        assert_eq!(
962            captured.len(),
963            2,
964            "only 2 checkpoints (START+SUCCEED) for the one unfinished compensation, got {}",
965            captured.len()
966        );
967    }
968}