Skip to main content

durable_lambda_core/operations/
child_context.rs

1//! Child context operation — isolated subflows.
2//!
3//! Implements FR26-FR28: isolated checkpoint namespace, independent operations,
4//! fully owned child contexts sharing only `Arc<dyn DurableBackend>`.
5//!
6//! The child context operation uses `OperationType::Context` on the wire with
7//! sub_type "Context". Unlike parallel, there is only a single closure that
8//! runs inline (no `tokio::spawn`), and the result is returned directly as `T`
9//! rather than wrapped in `BatchResult`.
10
11use std::future::Future;
12
13use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
14use serde::de::DeserializeOwned;
15use serde::Serialize;
16
17use crate::context::DurableContext;
18use crate::error::DurableError;
19
20impl DurableContext {
21    /// Execute an isolated subflow with its own checkpoint namespace.
22    ///
23    /// The closure receives an owned child [`DurableContext`] whose operations
24    /// are namespaced under this child context's operation ID, preventing
25    /// collisions with the parent or sibling child contexts.
26    ///
27    /// During replay mode, returns the cached result without re-executing
28    /// the closure.
29    ///
30    /// # Arguments
31    ///
32    /// * `name` — Human-readable name for the child context operation
33    /// * `f` — Closure receiving an owned `DurableContext` for the subflow
34    ///
35    /// # Errors
36    ///
37    /// Returns [`DurableError::ChildContextFailed`] if the child context
38    /// is found in a failed state during replay.
39    /// Returns [`DurableError::CheckpointFailed`] if checkpoint API calls fail.
40    ///
41    /// # Examples
42    ///
43    /// ```no_run
44    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
45    /// let result: i32 = ctx.child_context("sub_workflow", |mut child_ctx| async move {
46    ///     let r: Result<i32, String> = child_ctx.step("inner_step", || async { Ok(42) }).await?;
47    ///     Ok(r.unwrap())
48    /// }).await?;
49    /// assert_eq!(result, 42);
50    /// # Ok(())
51    /// # }
52    /// ```
53    #[allow(clippy::await_holding_lock)]
54    pub async fn child_context<T, F, Fut>(&mut self, name: &str, f: F) -> Result<T, DurableError>
55    where
56        T: Serialize + DeserializeOwned + Send,
57        F: FnOnce(DurableContext) -> Fut + Send,
58        Fut: Future<Output = Result<T, DurableError>> + Send,
59    {
60        let op_id = self.replay_engine_mut().generate_operation_id();
61
62        let span = tracing::info_span!(
63            "durable_operation",
64            op.name = name,
65            op.type = "child_context",
66            op.id = %op_id,
67        );
68        let _guard = span.enter();
69        tracing::trace!("durable_operation");
70
71        // Replay path: check for completed outer child context operation.
72        if let Some(op) = self.replay_engine().check_result(&op_id) {
73            if op.status == OperationStatus::Succeeded {
74                let result_str =
75                    op.context_details()
76                        .and_then(|d| d.result())
77                        .ok_or_else(|| {
78                            DurableError::checkpoint_failed(
79                                name,
80                                std::io::Error::new(
81                                    std::io::ErrorKind::InvalidData,
82                                    "child context succeeded but no result in context_details",
83                                ),
84                            )
85                        })?;
86
87                let result: T = serde_json::from_str(result_str)
88                    .map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))?;
89
90                self.replay_engine_mut().track_replay(&op_id);
91                return Ok(result);
92            } else {
93                // Failed/Cancelled/TimedOut/Stopped
94                let error_message = op
95                    .context_details()
96                    .and_then(|d| d.error())
97                    .map(|e| {
98                        format!(
99                            "{}: {}",
100                            e.error_type().unwrap_or("Unknown"),
101                            e.error_data().unwrap_or("")
102                        )
103                    })
104                    .unwrap_or_else(|| "child context failed".to_string());
105                return Err(DurableError::child_context_failed(name, error_message));
106            }
107        }
108
109        // Execute path: send Context/START for the child context block.
110        let start_update = OperationUpdate::builder()
111            .id(op_id.clone())
112            .r#type(OperationType::Context)
113            .action(OperationAction::Start)
114            .sub_type("Context")
115            .name(name)
116            .build()
117            .map_err(|e| DurableError::checkpoint_failed(name, e))?;
118
119        let start_response = self
120            .backend()
121            .checkpoint(
122                self.arn(),
123                self.checkpoint_token(),
124                vec![start_update],
125                None,
126            )
127            .await?;
128
129        let new_token = start_response.checkpoint_token().ok_or_else(|| {
130            DurableError::checkpoint_failed(
131                name,
132                std::io::Error::new(
133                    std::io::ErrorKind::InvalidData,
134                    "checkpoint response missing checkpoint_token",
135                ),
136            )
137        })?;
138        self.set_checkpoint_token(new_token.to_string());
139
140        if let Some(new_state) = start_response.new_execution_state() {
141            for op in new_state.operations() {
142                self.replay_engine_mut()
143                    .insert_operation(op.id().to_string(), op.clone());
144            }
145        }
146
147        // Create child context with isolated namespace.
148        let child_ctx = self.create_child_context(&op_id);
149
150        // Execute closure inline (no tokio::spawn).
151        let result = f(child_ctx).await?;
152
153        // Send Context/SUCCEED with serialized result as payload.
154        let serialized_result = serde_json::to_string(&result)
155            .map_err(|e| DurableError::serialization(std::any::type_name::<T>(), e))?;
156
157        let ctx_opts = aws_sdk_lambda::types::ContextOptions::builder()
158            .replay_children(false)
159            .build();
160
161        let succeed_update = OperationUpdate::builder()
162            .id(op_id.clone())
163            .r#type(OperationType::Context)
164            .action(OperationAction::Succeed)
165            .sub_type("Context")
166            .payload(serialized_result)
167            .context_options(ctx_opts)
168            .build()
169            .map_err(|e| DurableError::checkpoint_failed(name, e))?;
170
171        let succeed_response = self
172            .backend()
173            .checkpoint(
174                self.arn(),
175                self.checkpoint_token(),
176                vec![succeed_update],
177                None,
178            )
179            .await?;
180
181        let new_token = succeed_response.checkpoint_token().ok_or_else(|| {
182            DurableError::checkpoint_failed(
183                name,
184                std::io::Error::new(
185                    std::io::ErrorKind::InvalidData,
186                    "checkpoint response missing checkpoint_token",
187                ),
188            )
189        })?;
190        self.set_checkpoint_token(new_token.to_string());
191
192        if let Some(new_state) = succeed_response.new_execution_state() {
193            for op in new_state.operations() {
194                self.replay_engine_mut()
195                    .insert_operation(op.id().to_string(), op.clone());
196            }
197        }
198
199        self.replay_engine_mut().track_replay(&op_id);
200        Ok(result)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use std::sync::Arc;
207
208    use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
209    use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
210    use aws_sdk_lambda::types::{
211        ContextDetails, ErrorObject, Operation, OperationAction, OperationStatus, OperationType,
212        OperationUpdate,
213    };
214    use aws_smithy_types::DateTime;
215    use tokio::sync::Mutex;
216    use tracing_test::traced_test;
217
218    use crate::backend::DurableBackend;
219    use crate::context::DurableContext;
220    use crate::error::DurableError;
221
222    #[derive(Debug, Clone)]
223    #[allow(dead_code)]
224    struct CheckpointCall {
225        arn: String,
226        checkpoint_token: String,
227        updates: Vec<OperationUpdate>,
228    }
229
230    /// MockBackend that records all checkpoint calls.
231    struct ChildContextMockBackend {
232        calls: Arc<Mutex<Vec<CheckpointCall>>>,
233    }
234
235    impl ChildContextMockBackend {
236        fn new() -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
237            let calls = Arc::new(Mutex::new(Vec::new()));
238            let backend = Self {
239                calls: calls.clone(),
240            };
241            (backend, calls)
242        }
243    }
244
245    #[async_trait::async_trait]
246    impl DurableBackend for ChildContextMockBackend {
247        async fn checkpoint(
248            &self,
249            arn: &str,
250            checkpoint_token: &str,
251            updates: Vec<OperationUpdate>,
252            _client_token: Option<&str>,
253        ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
254            self.calls.lock().await.push(CheckpointCall {
255                arn: arn.to_string(),
256                checkpoint_token: checkpoint_token.to_string(),
257                updates,
258            });
259            Ok(CheckpointDurableExecutionOutput::builder()
260                .checkpoint_token("mock-token")
261                .build())
262        }
263
264        async fn get_execution_state(
265            &self,
266            _arn: &str,
267            _checkpoint_token: &str,
268            _next_marker: &str,
269            _max_items: i32,
270        ) -> Result<GetDurableExecutionStateOutput, DurableError> {
271            Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
272        }
273    }
274
275    fn first_op_id() -> String {
276        let mut gen = crate::operation_id::OperationIdGenerator::new(None);
277        gen.next_id()
278    }
279
280    // ─── child_context tests ────────────────────────────────────────────
281
282    #[tokio::test]
283    async fn test_child_context_executes_closure() {
284        let (backend, calls) = ChildContextMockBackend::new();
285        let mut ctx = DurableContext::new(
286            Arc::new(backend),
287            "arn:test".to_string(),
288            "tok".to_string(),
289            vec![],
290            None,
291        )
292        .await
293        .unwrap();
294
295        let result: i32 = ctx
296            .child_context("sub_workflow", |mut child_ctx| async move {
297                let r: Result<i32, String> =
298                    child_ctx.step("inner_step", || async { Ok(42) }).await?;
299                Ok(r.unwrap())
300            })
301            .await
302            .unwrap();
303
304        assert_eq!(result, 42);
305
306        // Verify checkpoints: Context/START + inner step (START+SUCCEED) + Context/SUCCEED
307        let captured = calls.lock().await;
308        assert!(
309            captured.len() >= 2,
310            "should have at least Context/START and Context/SUCCEED, got {}",
311            captured.len()
312        );
313
314        // First: Context/START with sub_type "Context"
315        assert_eq!(captured[0].updates[0].r#type(), &OperationType::Context);
316        assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
317        assert_eq!(captured[0].updates[0].sub_type(), Some("Context"));
318
319        // Last: Context/SUCCEED with sub_type "Context" and payload
320        let last = &captured[captured.len() - 1];
321        assert_eq!(last.updates[0].r#type(), &OperationType::Context);
322        assert_eq!(last.updates[0].action(), &OperationAction::Succeed);
323        assert_eq!(last.updates[0].sub_type(), Some("Context"));
324        assert!(
325            last.updates[0].payload().is_some(),
326            "should have serialized result payload"
327        );
328    }
329
330    #[tokio::test]
331    async fn test_child_context_replays_from_cached_result() {
332        let op_id = first_op_id();
333
334        // Create a SUCCEEDED child context operation with cached result
335        let child_op = Operation::builder()
336            .id(&op_id)
337            .r#type(OperationType::Context)
338            .status(OperationStatus::Succeeded)
339            .start_timestamp(DateTime::from_secs(0))
340            .context_details(
341                ContextDetails::builder()
342                    .replay_children(false)
343                    .result("42")
344                    .build(),
345            )
346            .build()
347            .unwrap();
348
349        let (backend, calls) = ChildContextMockBackend::new();
350        let mut ctx = DurableContext::new(
351            Arc::new(backend),
352            "arn:test".to_string(),
353            "tok".to_string(),
354            vec![child_op],
355            None,
356        )
357        .await
358        .unwrap();
359
360        // Closure should NOT execute during replay
361        let result: i32 = ctx
362            .child_context("sub_workflow", |_child_ctx| async move {
363                panic!("closure should not execute during replay")
364            })
365            .await
366            .unwrap();
367
368        assert_eq!(result, 42);
369
370        // No checkpoints during replay
371        let captured = calls.lock().await;
372        assert_eq!(captured.len(), 0, "no checkpoints during replay");
373    }
374
375    #[tokio::test]
376    async fn test_child_context_has_isolated_namespace() {
377        let (backend, _calls) = ChildContextMockBackend::new();
378        let mut ctx = DurableContext::new(
379            Arc::new(backend),
380            "arn:test".to_string(),
381            "tok".to_string(),
382            vec![],
383            None,
384        )
385        .await
386        .unwrap();
387
388        // Parent step with name "work"
389        let parent_result: Result<String, String> = ctx
390            .step("work", || async { Ok("parent".to_string()) })
391            .await
392            .unwrap();
393        assert_eq!(parent_result.unwrap(), "parent");
394
395        // Child context with step also named "work" — should NOT collide
396        let child_result: String = ctx
397            .child_context("sub_workflow", |mut child_ctx| async move {
398                let r: Result<String, String> = child_ctx
399                    .step("work", || async { Ok("child".to_string()) })
400                    .await?;
401                Ok(r.unwrap())
402            })
403            .await
404            .unwrap();
405
406        assert_eq!(child_result, "child");
407    }
408
409    #[tokio::test]
410    async fn test_child_context_sends_correct_checkpoint_sequence() {
411        let (backend, calls) = ChildContextMockBackend::new();
412        let mut ctx = DurableContext::new(
413            Arc::new(backend),
414            "arn:test".to_string(),
415            "tok".to_string(),
416            vec![],
417            None,
418        )
419        .await
420        .unwrap();
421
422        let _result: i32 = ctx
423            .child_context("seq_test", |_child_ctx| async move { Ok(99) })
424            .await
425            .unwrap();
426
427        let captured = calls.lock().await;
428
429        // Expected: Context/START + Context/SUCCEED (closure does no durable ops)
430        assert_eq!(
431            captured.len(),
432            2,
433            "expected exactly 2 checkpoints (START + SUCCEED), got {}",
434            captured.len()
435        );
436
437        // First: Context/START with sub_type "Context"
438        assert_eq!(captured[0].updates[0].r#type(), &OperationType::Context);
439        assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
440        assert_eq!(captured[0].updates[0].sub_type(), Some("Context"));
441        assert_eq!(captured[0].updates[0].name(), Some("seq_test"));
442
443        // Second: Context/SUCCEED with sub_type "Context"
444        assert_eq!(captured[1].updates[0].r#type(), &OperationType::Context);
445        assert_eq!(captured[1].updates[0].action(), &OperationAction::Succeed);
446        assert_eq!(captured[1].updates[0].sub_type(), Some("Context"));
447        assert_eq!(captured[1].updates[0].payload(), Some("99"));
448    }
449
450    #[tokio::test]
451    async fn test_child_context_closure_failure_propagates() {
452        let (backend, _calls) = ChildContextMockBackend::new();
453        let mut ctx = DurableContext::new(
454            Arc::new(backend),
455            "arn:test".to_string(),
456            "tok".to_string(),
457            vec![],
458            None,
459        )
460        .await
461        .unwrap();
462
463        let result = ctx
464            .child_context("failing_sub", |_child_ctx| async move {
465                Err::<i32, _>(DurableError::child_context_failed(
466                    "failing_sub",
467                    "intentional failure",
468                ))
469            })
470            .await;
471
472        assert!(result.is_err());
473        let err = result.unwrap_err();
474        let msg = err.to_string();
475        assert!(
476            msg.contains("intentional failure"),
477            "error should contain failure message, got: {msg}"
478        );
479    }
480
481    #[tokio::test]
482    async fn test_child_context_nested() {
483        let (backend, calls) = ChildContextMockBackend::new();
484        let mut ctx = DurableContext::new(
485            Arc::new(backend),
486            "arn:test".to_string(),
487            "tok".to_string(),
488            vec![],
489            None,
490        )
491        .await
492        .unwrap();
493
494        let result: i32 = ctx
495            .child_context("outer", |mut outer_child| async move {
496                let inner_result: i32 = outer_child
497                    .child_context("inner", |mut inner_child| async move {
498                        let r: Result<i32, String> =
499                            inner_child.step("deep_step", || async { Ok(7) }).await?;
500                        Ok(r.unwrap())
501                    })
502                    .await?;
503                Ok(inner_result * 6)
504            })
505            .await
506            .unwrap();
507
508        assert_eq!(result, 42);
509
510        // Verify nested checkpoint structure:
511        // outer START, inner START, step START+SUCCEED, inner SUCCEED, outer SUCCEED
512        let captured = calls.lock().await;
513        assert!(
514            captured.len() >= 4,
515            "expected at least 4 checkpoints for nested child contexts, got {}",
516            captured.len()
517        );
518
519        // First: outer Context/START
520        assert_eq!(captured[0].updates[0].sub_type(), Some("Context"));
521        assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
522
523        // Last: outer Context/SUCCEED
524        let last = &captured[captured.len() - 1];
525        assert_eq!(last.updates[0].sub_type(), Some("Context"));
526        assert_eq!(last.updates[0].action(), &OperationAction::Succeed);
527    }
528
529    #[tokio::test]
530    async fn test_child_context_replay_failed_status() {
531        let op_id = first_op_id();
532
533        // Create a FAILED child context operation
534        let child_op = Operation::builder()
535            .id(&op_id)
536            .r#type(OperationType::Context)
537            .status(OperationStatus::Failed)
538            .start_timestamp(DateTime::from_secs(0))
539            .context_details(
540                ContextDetails::builder()
541                    .replay_children(false)
542                    .error(
543                        ErrorObject::builder()
544                            .error_type("RuntimeError")
545                            .error_data("something went wrong")
546                            .build(),
547                    )
548                    .build(),
549            )
550            .build()
551            .unwrap();
552
553        let (backend, calls) = ChildContextMockBackend::new();
554        let mut ctx = DurableContext::new(
555            Arc::new(backend),
556            "arn:test".to_string(),
557            "tok".to_string(),
558            vec![child_op],
559            None,
560        )
561        .await
562        .unwrap();
563
564        let result: Result<i32, DurableError> = ctx
565            .child_context("sub_workflow", |_child_ctx| async move {
566                panic!("closure should not execute during replay of failed context")
567            })
568            .await;
569
570        assert!(result.is_err());
571        let err = result.unwrap_err().to_string();
572        assert!(
573            err.contains("child context failed"),
574            "error should mention child context failed, got: {err}"
575        );
576        assert!(
577            err.contains("RuntimeError"),
578            "error should contain error type, got: {err}"
579        );
580        assert!(
581            err.contains("something went wrong"),
582            "error should contain error data, got: {err}"
583        );
584
585        // No checkpoints during replay
586        let captured = calls.lock().await;
587        assert_eq!(captured.len(), 0);
588    }
589
590    // ─── span tests (FEAT-17, FEAT-18) ────────────────────────────────────
591
592    #[traced_test]
593    #[tokio::test]
594    async fn test_child_context_emits_span() {
595        let (backend, _calls) = ChildContextMockBackend::new();
596        let mut ctx = DurableContext::new(
597            Arc::new(backend),
598            "arn:test".to_string(),
599            "tok".to_string(),
600            vec![],
601            None,
602        )
603        .await
604        .unwrap();
605        let _ = ctx
606            .child_context("sub", |_child| async move { Ok::<i32, DurableError>(1) })
607            .await;
608        assert!(logs_contain("durable_operation"));
609        assert!(logs_contain("sub"));
610        assert!(logs_contain("child_context"));
611    }
612
613    #[traced_test]
614    #[tokio::test]
615    async fn test_child_context_span_hierarchy() {
616        let (backend, _calls) = ChildContextMockBackend::new();
617        let mut ctx = DurableContext::new(
618            Arc::new(backend),
619            "arn:test".to_string(),
620            "tok".to_string(),
621            vec![],
622            None,
623        )
624        .await
625        .unwrap();
626        let _ = ctx
627            .child_context("parent_flow", |mut child| async move {
628                let _: Result<i32, String> = child.step("inner_step", || async { Ok(42) }).await?;
629                Ok::<_, DurableError>(1)
630            })
631            .await;
632        assert!(logs_contain("child_context"));
633        assert!(logs_contain("parent_flow"));
634        assert!(logs_contain("inner_step"));
635        assert!(logs_contain("step"));
636    }
637}