1use std::future::Future;
12
13use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
14use serde::de::DeserializeOwned;
15use serde::Serialize;
16
17use crate::context::DurableContext;
18use crate::error::DurableError;
19
20impl DurableContext {
21 #[allow(clippy::await_holding_lock)]
54 pub async fn child_context<T, F, Fut>(&mut self, name: &str, f: F) -> Result<T, DurableError>
55 where
56 T: Serialize + DeserializeOwned + Send,
57 F: FnOnce(DurableContext) -> Fut + Send,
58 Fut: Future<Output = Result<T, DurableError>> + Send,
59 {
60 let op_id = self.replay_engine_mut().generate_operation_id();
61
62 let span = tracing::info_span!(
63 "durable_operation",
64 op.name = name,
65 op.type = "child_context",
66 op.id = %op_id,
67 );
68 let _guard = span.enter();
69 tracing::trace!("durable_operation");
70
71 if let Some(op) = self.replay_engine().check_result(&op_id) {
73 if op.status == OperationStatus::Succeeded {
74 let result_str =
75 op.context_details()
76 .and_then(|d| d.result())
77 .ok_or_else(|| {
78 DurableError::checkpoint_failed(
79 name,
80 std::io::Error::new(
81 std::io::ErrorKind::InvalidData,
82 "child context succeeded but no result in context_details",
83 ),
84 )
85 })?;
86
87 let result: T = serde_json::from_str(result_str)
88 .map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))?;
89
90 self.replay_engine_mut().track_replay(&op_id);
91 return Ok(result);
92 } else {
93 let error_message = op
95 .context_details()
96 .and_then(|d| d.error())
97 .map(|e| {
98 format!(
99 "{}: {}",
100 e.error_type().unwrap_or("Unknown"),
101 e.error_data().unwrap_or("")
102 )
103 })
104 .unwrap_or_else(|| "child context failed".to_string());
105 return Err(DurableError::child_context_failed(name, error_message));
106 }
107 }
108
109 let start_update = OperationUpdate::builder()
111 .id(op_id.clone())
112 .r#type(OperationType::Context)
113 .action(OperationAction::Start)
114 .sub_type("Context")
115 .name(name)
116 .build()
117 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
118
119 let start_response = self
120 .backend()
121 .checkpoint(
122 self.arn(),
123 self.checkpoint_token(),
124 vec![start_update],
125 None,
126 )
127 .await?;
128
129 let new_token = start_response.checkpoint_token().ok_or_else(|| {
130 DurableError::checkpoint_failed(
131 name,
132 std::io::Error::new(
133 std::io::ErrorKind::InvalidData,
134 "checkpoint response missing checkpoint_token",
135 ),
136 )
137 })?;
138 self.set_checkpoint_token(new_token.to_string());
139
140 if let Some(new_state) = start_response.new_execution_state() {
141 for op in new_state.operations() {
142 self.replay_engine_mut()
143 .insert_operation(op.id().to_string(), op.clone());
144 }
145 }
146
147 let child_ctx = self.create_child_context(&op_id);
149
150 let result = f(child_ctx).await?;
152
153 let serialized_result = serde_json::to_string(&result)
155 .map_err(|e| DurableError::serialization(std::any::type_name::<T>(), e))?;
156
157 let ctx_opts = aws_sdk_lambda::types::ContextOptions::builder()
158 .replay_children(false)
159 .build();
160
161 let succeed_update = OperationUpdate::builder()
162 .id(op_id.clone())
163 .r#type(OperationType::Context)
164 .action(OperationAction::Succeed)
165 .sub_type("Context")
166 .payload(serialized_result)
167 .context_options(ctx_opts)
168 .build()
169 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
170
171 let succeed_response = self
172 .backend()
173 .checkpoint(
174 self.arn(),
175 self.checkpoint_token(),
176 vec![succeed_update],
177 None,
178 )
179 .await?;
180
181 let new_token = succeed_response.checkpoint_token().ok_or_else(|| {
182 DurableError::checkpoint_failed(
183 name,
184 std::io::Error::new(
185 std::io::ErrorKind::InvalidData,
186 "checkpoint response missing checkpoint_token",
187 ),
188 )
189 })?;
190 self.set_checkpoint_token(new_token.to_string());
191
192 if let Some(new_state) = succeed_response.new_execution_state() {
193 for op in new_state.operations() {
194 self.replay_engine_mut()
195 .insert_operation(op.id().to_string(), op.clone());
196 }
197 }
198
199 self.replay_engine_mut().track_replay(&op_id);
200 Ok(result)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::sync::Arc;
207
208 use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
209 use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
210 use aws_sdk_lambda::types::{
211 ContextDetails, ErrorObject, Operation, OperationAction, OperationStatus, OperationType,
212 OperationUpdate,
213 };
214 use aws_smithy_types::DateTime;
215 use tokio::sync::Mutex;
216 use tracing_test::traced_test;
217
218 use crate::backend::DurableBackend;
219 use crate::context::DurableContext;
220 use crate::error::DurableError;
221
222 #[derive(Debug, Clone)]
223 #[allow(dead_code)]
224 struct CheckpointCall {
225 arn: String,
226 checkpoint_token: String,
227 updates: Vec<OperationUpdate>,
228 }
229
230 struct ChildContextMockBackend {
232 calls: Arc<Mutex<Vec<CheckpointCall>>>,
233 }
234
235 impl ChildContextMockBackend {
236 fn new() -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
237 let calls = Arc::new(Mutex::new(Vec::new()));
238 let backend = Self {
239 calls: calls.clone(),
240 };
241 (backend, calls)
242 }
243 }
244
245 #[async_trait::async_trait]
246 impl DurableBackend for ChildContextMockBackend {
247 async fn checkpoint(
248 &self,
249 arn: &str,
250 checkpoint_token: &str,
251 updates: Vec<OperationUpdate>,
252 _client_token: Option<&str>,
253 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
254 self.calls.lock().await.push(CheckpointCall {
255 arn: arn.to_string(),
256 checkpoint_token: checkpoint_token.to_string(),
257 updates,
258 });
259 Ok(CheckpointDurableExecutionOutput::builder()
260 .checkpoint_token("mock-token")
261 .build())
262 }
263
264 async fn get_execution_state(
265 &self,
266 _arn: &str,
267 _checkpoint_token: &str,
268 _next_marker: &str,
269 _max_items: i32,
270 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
271 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
272 }
273 }
274
275 fn first_op_id() -> String {
276 let mut gen = crate::operation_id::OperationIdGenerator::new(None);
277 gen.next_id()
278 }
279
280 #[tokio::test]
283 async fn test_child_context_executes_closure() {
284 let (backend, calls) = ChildContextMockBackend::new();
285 let mut ctx = DurableContext::new(
286 Arc::new(backend),
287 "arn:test".to_string(),
288 "tok".to_string(),
289 vec![],
290 None,
291 )
292 .await
293 .unwrap();
294
295 let result: i32 = ctx
296 .child_context("sub_workflow", |mut child_ctx| async move {
297 let r: Result<i32, String> =
298 child_ctx.step("inner_step", || async { Ok(42) }).await?;
299 Ok(r.unwrap())
300 })
301 .await
302 .unwrap();
303
304 assert_eq!(result, 42);
305
306 let captured = calls.lock().await;
308 assert!(
309 captured.len() >= 2,
310 "should have at least Context/START and Context/SUCCEED, got {}",
311 captured.len()
312 );
313
314 assert_eq!(captured[0].updates[0].r#type(), &OperationType::Context);
316 assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
317 assert_eq!(captured[0].updates[0].sub_type(), Some("Context"));
318
319 let last = &captured[captured.len() - 1];
321 assert_eq!(last.updates[0].r#type(), &OperationType::Context);
322 assert_eq!(last.updates[0].action(), &OperationAction::Succeed);
323 assert_eq!(last.updates[0].sub_type(), Some("Context"));
324 assert!(
325 last.updates[0].payload().is_some(),
326 "should have serialized result payload"
327 );
328 }
329
330 #[tokio::test]
331 async fn test_child_context_replays_from_cached_result() {
332 let op_id = first_op_id();
333
334 let child_op = Operation::builder()
336 .id(&op_id)
337 .r#type(OperationType::Context)
338 .status(OperationStatus::Succeeded)
339 .start_timestamp(DateTime::from_secs(0))
340 .context_details(
341 ContextDetails::builder()
342 .replay_children(false)
343 .result("42")
344 .build(),
345 )
346 .build()
347 .unwrap();
348
349 let (backend, calls) = ChildContextMockBackend::new();
350 let mut ctx = DurableContext::new(
351 Arc::new(backend),
352 "arn:test".to_string(),
353 "tok".to_string(),
354 vec![child_op],
355 None,
356 )
357 .await
358 .unwrap();
359
360 let result: i32 = ctx
362 .child_context("sub_workflow", |_child_ctx| async move {
363 panic!("closure should not execute during replay")
364 })
365 .await
366 .unwrap();
367
368 assert_eq!(result, 42);
369
370 let captured = calls.lock().await;
372 assert_eq!(captured.len(), 0, "no checkpoints during replay");
373 }
374
375 #[tokio::test]
376 async fn test_child_context_has_isolated_namespace() {
377 let (backend, _calls) = ChildContextMockBackend::new();
378 let mut ctx = DurableContext::new(
379 Arc::new(backend),
380 "arn:test".to_string(),
381 "tok".to_string(),
382 vec![],
383 None,
384 )
385 .await
386 .unwrap();
387
388 let parent_result: Result<String, String> = ctx
390 .step("work", || async { Ok("parent".to_string()) })
391 .await
392 .unwrap();
393 assert_eq!(parent_result.unwrap(), "parent");
394
395 let child_result: String = ctx
397 .child_context("sub_workflow", |mut child_ctx| async move {
398 let r: Result<String, String> = child_ctx
399 .step("work", || async { Ok("child".to_string()) })
400 .await?;
401 Ok(r.unwrap())
402 })
403 .await
404 .unwrap();
405
406 assert_eq!(child_result, "child");
407 }
408
409 #[tokio::test]
410 async fn test_child_context_sends_correct_checkpoint_sequence() {
411 let (backend, calls) = ChildContextMockBackend::new();
412 let mut ctx = DurableContext::new(
413 Arc::new(backend),
414 "arn:test".to_string(),
415 "tok".to_string(),
416 vec![],
417 None,
418 )
419 .await
420 .unwrap();
421
422 let _result: i32 = ctx
423 .child_context("seq_test", |_child_ctx| async move { Ok(99) })
424 .await
425 .unwrap();
426
427 let captured = calls.lock().await;
428
429 assert_eq!(
431 captured.len(),
432 2,
433 "expected exactly 2 checkpoints (START + SUCCEED), got {}",
434 captured.len()
435 );
436
437 assert_eq!(captured[0].updates[0].r#type(), &OperationType::Context);
439 assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
440 assert_eq!(captured[0].updates[0].sub_type(), Some("Context"));
441 assert_eq!(captured[0].updates[0].name(), Some("seq_test"));
442
443 assert_eq!(captured[1].updates[0].r#type(), &OperationType::Context);
445 assert_eq!(captured[1].updates[0].action(), &OperationAction::Succeed);
446 assert_eq!(captured[1].updates[0].sub_type(), Some("Context"));
447 assert_eq!(captured[1].updates[0].payload(), Some("99"));
448 }
449
450 #[tokio::test]
451 async fn test_child_context_closure_failure_propagates() {
452 let (backend, _calls) = ChildContextMockBackend::new();
453 let mut ctx = DurableContext::new(
454 Arc::new(backend),
455 "arn:test".to_string(),
456 "tok".to_string(),
457 vec![],
458 None,
459 )
460 .await
461 .unwrap();
462
463 let result = ctx
464 .child_context("failing_sub", |_child_ctx| async move {
465 Err::<i32, _>(DurableError::child_context_failed(
466 "failing_sub",
467 "intentional failure",
468 ))
469 })
470 .await;
471
472 assert!(result.is_err());
473 let err = result.unwrap_err();
474 let msg = err.to_string();
475 assert!(
476 msg.contains("intentional failure"),
477 "error should contain failure message, got: {msg}"
478 );
479 }
480
481 #[tokio::test]
482 async fn test_child_context_nested() {
483 let (backend, calls) = ChildContextMockBackend::new();
484 let mut ctx = DurableContext::new(
485 Arc::new(backend),
486 "arn:test".to_string(),
487 "tok".to_string(),
488 vec![],
489 None,
490 )
491 .await
492 .unwrap();
493
494 let result: i32 = ctx
495 .child_context("outer", |mut outer_child| async move {
496 let inner_result: i32 = outer_child
497 .child_context("inner", |mut inner_child| async move {
498 let r: Result<i32, String> =
499 inner_child.step("deep_step", || async { Ok(7) }).await?;
500 Ok(r.unwrap())
501 })
502 .await?;
503 Ok(inner_result * 6)
504 })
505 .await
506 .unwrap();
507
508 assert_eq!(result, 42);
509
510 let captured = calls.lock().await;
513 assert!(
514 captured.len() >= 4,
515 "expected at least 4 checkpoints for nested child contexts, got {}",
516 captured.len()
517 );
518
519 assert_eq!(captured[0].updates[0].sub_type(), Some("Context"));
521 assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
522
523 let last = &captured[captured.len() - 1];
525 assert_eq!(last.updates[0].sub_type(), Some("Context"));
526 assert_eq!(last.updates[0].action(), &OperationAction::Succeed);
527 }
528
529 #[tokio::test]
530 async fn test_child_context_replay_failed_status() {
531 let op_id = first_op_id();
532
533 let child_op = Operation::builder()
535 .id(&op_id)
536 .r#type(OperationType::Context)
537 .status(OperationStatus::Failed)
538 .start_timestamp(DateTime::from_secs(0))
539 .context_details(
540 ContextDetails::builder()
541 .replay_children(false)
542 .error(
543 ErrorObject::builder()
544 .error_type("RuntimeError")
545 .error_data("something went wrong")
546 .build(),
547 )
548 .build(),
549 )
550 .build()
551 .unwrap();
552
553 let (backend, calls) = ChildContextMockBackend::new();
554 let mut ctx = DurableContext::new(
555 Arc::new(backend),
556 "arn:test".to_string(),
557 "tok".to_string(),
558 vec![child_op],
559 None,
560 )
561 .await
562 .unwrap();
563
564 let result: Result<i32, DurableError> = ctx
565 .child_context("sub_workflow", |_child_ctx| async move {
566 panic!("closure should not execute during replay of failed context")
567 })
568 .await;
569
570 assert!(result.is_err());
571 let err = result.unwrap_err().to_string();
572 assert!(
573 err.contains("child context failed"),
574 "error should mention child context failed, got: {err}"
575 );
576 assert!(
577 err.contains("RuntimeError"),
578 "error should contain error type, got: {err}"
579 );
580 assert!(
581 err.contains("something went wrong"),
582 "error should contain error data, got: {err}"
583 );
584
585 let captured = calls.lock().await;
587 assert_eq!(captured.len(), 0);
588 }
589
590 #[traced_test]
593 #[tokio::test]
594 async fn test_child_context_emits_span() {
595 let (backend, _calls) = ChildContextMockBackend::new();
596 let mut ctx = DurableContext::new(
597 Arc::new(backend),
598 "arn:test".to_string(),
599 "tok".to_string(),
600 vec![],
601 None,
602 )
603 .await
604 .unwrap();
605 let _ = ctx
606 .child_context("sub", |_child| async move { Ok::<i32, DurableError>(1) })
607 .await;
608 assert!(logs_contain("durable_operation"));
609 assert!(logs_contain("sub"));
610 assert!(logs_contain("child_context"));
611 }
612
613 #[traced_test]
614 #[tokio::test]
615 async fn test_child_context_span_hierarchy() {
616 let (backend, _calls) = ChildContextMockBackend::new();
617 let mut ctx = DurableContext::new(
618 Arc::new(backend),
619 "arn:test".to_string(),
620 "tok".to_string(),
621 vec![],
622 None,
623 )
624 .await
625 .unwrap();
626 let _ = ctx
627 .child_context("parent_flow", |mut child| async move {
628 let _: Result<i32, String> = child.step("inner_step", || async { Ok(42) }).await?;
629 Ok::<_, DurableError>(1)
630 })
631 .await;
632 assert!(logs_contain("child_context"));
633 assert!(logs_contain("parent_flow"));
634 assert!(logs_contain("inner_step"));
635 assert!(logs_contain("step"));
636 }
637}