1use std::future::Future;
12
13use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
14use serde::de::DeserializeOwned;
15use serde::Serialize;
16
17use crate::context::DurableContext;
18use crate::error::DurableError;
19use crate::types::{BatchItem, BatchItemStatus, BatchResult, CompletionReason, ParallelOptions};
20
21impl DurableContext {
22 #[allow(clippy::await_holding_lock)]
73 pub async fn parallel<T, F, Fut>(
74 &mut self,
75 name: &str,
76 branches: Vec<F>,
77 _options: ParallelOptions,
78 ) -> Result<BatchResult<T>, DurableError>
79 where
80 T: Serialize + DeserializeOwned + Send + 'static,
81 F: FnOnce(DurableContext) -> Fut + Send + 'static,
82 Fut: Future<Output = Result<T, DurableError>> + Send + 'static,
83 {
84 let op_id = self.replay_engine_mut().generate_operation_id();
85
86 let span = tracing::info_span!(
87 "durable_operation",
88 op.name = name,
89 op.type = "parallel",
90 op.id = %op_id,
91 );
92 let _guard = span.enter();
93 tracing::trace!("durable_operation");
94
95 if let Some(op) = self.replay_engine().check_result(&op_id) {
97 if op.status == OperationStatus::Succeeded {
98 let result_str =
99 op.context_details()
100 .and_then(|d| d.result())
101 .ok_or_else(|| {
102 DurableError::checkpoint_failed(
103 name,
104 std::io::Error::new(
105 std::io::ErrorKind::InvalidData,
106 "parallel succeeded but no result in context_details",
107 ),
108 )
109 })?;
110
111 let batch_result: BatchResult<T> = serde_json::from_str(result_str)
112 .map_err(|e| DurableError::deserialization("BatchResult", e))?;
113
114 self.replay_engine_mut().track_replay(&op_id);
115 return Ok(batch_result);
116 } else {
117 let error_message = op
119 .context_details()
120 .and_then(|d| d.error())
121 .map(|e| {
122 format!(
123 "{}: {}",
124 e.error_type().unwrap_or("Unknown"),
125 e.error_data().unwrap_or("")
126 )
127 })
128 .unwrap_or_else(|| "parallel failed".to_string());
129 return Err(DurableError::parallel_failed(name, error_message));
130 }
131 }
132
133 let outer_start = OperationUpdate::builder()
135 .id(op_id.clone())
136 .r#type(OperationType::Context)
137 .action(OperationAction::Start)
138 .sub_type("Parallel")
139 .name(name)
140 .build()
141 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
142
143 let start_response = self
144 .backend()
145 .checkpoint(self.arn(), self.checkpoint_token(), vec![outer_start], None)
146 .await?;
147
148 let new_token = start_response.checkpoint_token().ok_or_else(|| {
149 DurableError::checkpoint_failed(
150 name,
151 std::io::Error::new(
152 std::io::ErrorKind::InvalidData,
153 "checkpoint response missing checkpoint_token",
154 ),
155 )
156 })?;
157 self.set_checkpoint_token(new_token.to_string());
158
159 if let Some(new_state) = start_response.new_execution_state() {
160 for op in new_state.operations() {
161 self.replay_engine_mut()
162 .insert_operation(op.id().to_string(), op.clone());
163 }
164 }
165
166 let branch_count = branches.len();
168 let mut handles = Vec::with_capacity(branch_count);
169
170 let mut branch_id_gen = crate::operation_id::OperationIdGenerator::new(Some(op_id.clone()));
172
173 for (i, branch_fn) in branches.into_iter().enumerate() {
174 let branch_op_id = branch_id_gen.next_id();
175
176 let child_ctx = self.create_child_context(&branch_op_id);
177 let config = BranchConfig {
178 backend: self.backend().clone(),
179 arn: self.arn().to_string(),
180 token: self.checkpoint_token().to_string(),
181 branch_op_id,
182 parent_op_id: op_id.clone(),
183 branch_name: format!("parallel-branch-{i}"),
184 };
185
186 let handle =
187 tokio::spawn(async move { execute_branch(child_ctx, config, branch_fn).await });
188
189 handles.push(handle);
190 }
191
192 let mut results = Vec::with_capacity(branch_count);
194 for (i, handle) in handles.into_iter().enumerate() {
195 let branch_outcome = handle.await.map_err(|e| {
196 DurableError::parallel_failed(name, format!("branch {i} panicked: {e}"))
197 })?;
198
199 match branch_outcome {
200 Ok(value) => {
201 results.push(BatchItem {
202 index: i,
203 status: BatchItemStatus::Succeeded,
204 result: Some(value),
205 error: None,
206 });
207 }
208 Err(err) => {
209 results.push(BatchItem {
210 index: i,
211 status: BatchItemStatus::Failed,
212 result: None,
213 error: Some(err.to_string()),
214 });
215 }
216 }
217 }
218
219 let batch_result = BatchResult {
220 results,
221 completion_reason: CompletionReason::AllCompleted,
222 };
223
224 let serialized_result = serde_json::to_string(&batch_result)
226 .map_err(|e| DurableError::serialization("BatchResult", e))?;
227
228 let ctx_opts = aws_sdk_lambda::types::ContextOptions::builder()
229 .replay_children(false)
230 .build();
231
232 let outer_succeed = OperationUpdate::builder()
233 .id(op_id.clone())
234 .r#type(OperationType::Context)
235 .action(OperationAction::Succeed)
236 .sub_type("Parallel")
237 .payload(serialized_result)
238 .context_options(ctx_opts)
239 .build()
240 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
241
242 let succeed_response = self
243 .backend()
244 .checkpoint(
245 self.arn(),
246 self.checkpoint_token(),
247 vec![outer_succeed],
248 None,
249 )
250 .await?;
251
252 let new_token = succeed_response.checkpoint_token().ok_or_else(|| {
253 DurableError::checkpoint_failed(
254 name,
255 std::io::Error::new(
256 std::io::ErrorKind::InvalidData,
257 "checkpoint response missing checkpoint_token",
258 ),
259 )
260 })?;
261 self.set_checkpoint_token(new_token.to_string());
262
263 if let Some(new_state) = succeed_response.new_execution_state() {
264 for op in new_state.operations() {
265 self.replay_engine_mut()
266 .insert_operation(op.id().to_string(), op.clone());
267 }
268 }
269
270 self.replay_engine_mut().track_replay(&op_id);
271 Ok(batch_result)
272 }
273}
274
275struct BranchConfig {
277 backend: std::sync::Arc<dyn crate::backend::DurableBackend>,
278 arn: String,
279 token: String,
280 branch_op_id: String,
281 parent_op_id: String,
282 branch_name: String,
283}
284
285async fn execute_branch<T, F, Fut>(
290 child_ctx: DurableContext,
291 config: BranchConfig,
292 branch_fn: F,
293) -> Result<T, DurableError>
294where
295 T: Serialize + Send + 'static,
296 F: FnOnce(DurableContext) -> Fut + Send + 'static,
297 Fut: Future<Output = Result<T, DurableError>> + Send + 'static,
298{
299 let branch_start = OperationUpdate::builder()
301 .id(config.branch_op_id.clone())
302 .r#type(OperationType::Context)
303 .action(OperationAction::Start)
304 .sub_type("ParallelBranch")
305 .name(&config.branch_name)
306 .parent_id(config.parent_op_id.clone())
307 .build()
308 .map_err(|e| DurableError::checkpoint_failed(&config.branch_name, e))?;
309
310 let _ = config
311 .backend
312 .checkpoint(&config.arn, &config.token, vec![branch_start], None)
313 .await?;
314
315 let result = branch_fn(child_ctx).await?;
317
318 let serialized = serde_json::to_string(&result)
320 .map_err(|e| DurableError::serialization(&config.branch_name, e))?;
321
322 let ctx_opts = aws_sdk_lambda::types::ContextOptions::builder()
323 .replay_children(false)
324 .build();
325
326 let branch_succeed = OperationUpdate::builder()
327 .id(config.branch_op_id.clone())
328 .r#type(OperationType::Context)
329 .action(OperationAction::Succeed)
330 .sub_type("ParallelBranch")
331 .name(&config.branch_name)
332 .parent_id(config.parent_op_id.clone())
333 .payload(serialized)
334 .context_options(ctx_opts)
335 .build()
336 .map_err(|e| DurableError::checkpoint_failed(&config.branch_name, e))?;
337
338 let _ = config
339 .backend
340 .checkpoint(&config.arn, &config.token, vec![branch_succeed], None)
341 .await?;
342
343 Ok(result)
344}
345
346#[cfg(test)]
347mod tests {
348 use std::sync::Arc;
349
350 use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
351 use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
352 use aws_sdk_lambda::types::{
353 ContextDetails, Operation, OperationAction, OperationStatus, OperationType, OperationUpdate,
354 };
355 use aws_smithy_types::DateTime;
356 use tokio::sync::Mutex;
357 use tracing_test::traced_test;
358
359 use crate::backend::DurableBackend;
360 use crate::context::DurableContext;
361 use crate::error::DurableError;
362 use crate::types::ParallelOptions;
363
364 #[derive(Debug, Clone)]
365 #[allow(dead_code)]
366 struct CheckpointCall {
367 arn: String,
368 checkpoint_token: String,
369 updates: Vec<OperationUpdate>,
370 }
371
372 struct ParallelMockBackend {
374 calls: Arc<Mutex<Vec<CheckpointCall>>>,
375 }
376
377 impl ParallelMockBackend {
378 fn new() -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
379 let calls = Arc::new(Mutex::new(Vec::new()));
380 let backend = Self {
381 calls: calls.clone(),
382 };
383 (backend, calls)
384 }
385 }
386
387 #[async_trait::async_trait]
388 impl DurableBackend for ParallelMockBackend {
389 async fn checkpoint(
390 &self,
391 arn: &str,
392 checkpoint_token: &str,
393 updates: Vec<OperationUpdate>,
394 _client_token: Option<&str>,
395 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
396 self.calls.lock().await.push(CheckpointCall {
397 arn: arn.to_string(),
398 checkpoint_token: checkpoint_token.to_string(),
399 updates,
400 });
401 Ok(CheckpointDurableExecutionOutput::builder()
402 .checkpoint_token("mock-token")
403 .build())
404 }
405
406 async fn get_execution_state(
407 &self,
408 _arn: &str,
409 _checkpoint_token: &str,
410 _next_marker: &str,
411 _max_items: i32,
412 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
413 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
414 }
415 }
416
417 fn first_op_id() -> String {
418 let mut gen = crate::operation_id::OperationIdGenerator::new(None);
419 gen.next_id()
420 }
421
422 #[tokio::test]
425 async fn test_parallel_executes_branches_concurrently() {
426 let (backend, calls) = ParallelMockBackend::new();
427 let mut ctx = DurableContext::new(
428 Arc::new(backend),
429 "arn:test".to_string(),
430 "tok".to_string(),
431 vec![],
432 None,
433 )
434 .await
435 .unwrap();
436
437 let branches: Vec<
438 Box<
439 dyn FnOnce(
440 DurableContext,
441 ) -> std::pin::Pin<
442 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
443 > + Send,
444 >,
445 > = vec![
446 Box::new(|mut ctx: DurableContext| {
447 Box::pin(async move {
448 let r: Result<i32, String> = ctx.step("validate", || async { Ok(10) }).await?;
449 Ok(r.unwrap())
450 })
451 }),
452 Box::new(|mut ctx: DurableContext| {
453 Box::pin(async move {
454 let r: Result<i32, String> = ctx.step("check", || async { Ok(20) }).await?;
455 Ok(r.unwrap())
456 })
457 }),
458 Box::new(|mut ctx: DurableContext| {
459 Box::pin(async move {
460 let r: Result<i32, String> = ctx.step("process", || async { Ok(30) }).await?;
461 Ok(r.unwrap())
462 })
463 }),
464 ];
465
466 let result = ctx
467 .parallel("fan_out", branches, ParallelOptions::new())
468 .await
469 .unwrap();
470
471 assert_eq!(result.results.len(), 3);
472 assert_eq!(result.results[0].index, 0);
474 assert_eq!(result.results[1].index, 1);
475 assert_eq!(result.results[2].index, 2);
476 assert_eq!(result.results[0].result, Some(10));
477 assert_eq!(result.results[1].result, Some(20));
478 assert_eq!(result.results[2].result, Some(30));
479
480 let captured = calls.lock().await;
482 assert!(
483 captured.len() >= 2,
484 "should have at least outer START and outer SUCCEED"
485 );
486
487 assert_eq!(captured[0].updates[0].r#type(), &OperationType::Context);
489 assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
490 assert_eq!(captured[0].updates[0].sub_type(), Some("Parallel"));
491
492 let last = &captured[captured.len() - 1];
494 assert_eq!(last.updates[0].r#type(), &OperationType::Context);
495 assert_eq!(last.updates[0].action(), &OperationAction::Succeed);
496 assert_eq!(last.updates[0].sub_type(), Some("Parallel"));
497 assert!(
498 last.updates[0].payload().is_some(),
499 "should have BatchResult payload"
500 );
501 }
502
503 #[tokio::test]
504 async fn test_parallel_replays_from_cached_result() {
505 let op_id = first_op_id();
506
507 let batch_json = r#"{"results":[{"index":0,"status":"Succeeded","result":42,"error":null},{"index":1,"status":"Succeeded","result":99,"error":null}],"completion_reason":"AllCompleted"}"#;
509
510 let parallel_op = Operation::builder()
511 .id(&op_id)
512 .r#type(OperationType::Context)
513 .status(OperationStatus::Succeeded)
514 .start_timestamp(DateTime::from_secs(0))
515 .context_details(
516 ContextDetails::builder()
517 .replay_children(false)
518 .result(batch_json)
519 .build(),
520 )
521 .build()
522 .unwrap();
523
524 let (backend, calls) = ParallelMockBackend::new();
525 let mut ctx = DurableContext::new(
526 Arc::new(backend),
527 "arn:test".to_string(),
528 "tok".to_string(),
529 vec![parallel_op],
530 None,
531 )
532 .await
533 .unwrap();
534
535 let branches: Vec<
537 Box<
538 dyn FnOnce(
539 DurableContext,
540 ) -> std::pin::Pin<
541 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
542 > + Send,
543 >,
544 > = vec![Box::new(|_ctx: DurableContext| {
545 Box::pin(async move { panic!("branch should not execute during replay") })
546 })];
547
548 let result: crate::types::BatchResult<i32> = ctx
549 .parallel("fan_out", branches, ParallelOptions::new())
550 .await
551 .unwrap();
552
553 assert_eq!(result.results.len(), 2);
554 assert_eq!(result.results[0].result, Some(42));
555 assert_eq!(result.results[1].result, Some(99));
556
557 let captured = calls.lock().await;
559 assert_eq!(captured.len(), 0, "no checkpoints during replay");
560 }
561
562 #[tokio::test]
563 async fn test_parallel_branches_have_isolated_namespaces() {
564 let (backend, _calls) = ParallelMockBackend::new();
565 let mut ctx = DurableContext::new(
566 Arc::new(backend),
567 "arn:test".to_string(),
568 "tok".to_string(),
569 vec![],
570 None,
571 )
572 .await
573 .unwrap();
574
575 let branches: Vec<
577 Box<
578 dyn FnOnce(
579 DurableContext,
580 ) -> std::pin::Pin<
581 Box<dyn std::future::Future<Output = Result<String, DurableError>> + Send>,
582 > + Send,
583 >,
584 > = vec![
585 Box::new(|mut ctx: DurableContext| {
586 Box::pin(async move {
587 let r: Result<String, String> = ctx
588 .step("work", || async { Ok("branch-0".to_string()) })
589 .await?;
590 Ok(r.unwrap())
591 })
592 }),
593 Box::new(|mut ctx: DurableContext| {
594 Box::pin(async move {
595 let r: Result<String, String> = ctx
596 .step("work", || async { Ok("branch-1".to_string()) })
597 .await?;
598 Ok(r.unwrap())
599 })
600 }),
601 ];
602
603 let result = ctx
604 .parallel("isolated_test", branches, ParallelOptions::new())
605 .await
606 .unwrap();
607
608 assert_eq!(result.results.len(), 2);
609 assert_eq!(result.results[0].result.as_deref(), Some("branch-0"));
610 assert_eq!(result.results[1].result.as_deref(), Some("branch-1"));
611 }
612
613 #[tokio::test]
614 async fn test_parallel_sends_correct_checkpoint_sequence() {
615 let (backend, calls) = ParallelMockBackend::new();
616 let mut ctx = DurableContext::new(
617 Arc::new(backend),
618 "arn:test".to_string(),
619 "tok".to_string(),
620 vec![],
621 None,
622 )
623 .await
624 .unwrap();
625
626 let branches: Vec<
627 Box<
628 dyn FnOnce(
629 DurableContext,
630 ) -> std::pin::Pin<
631 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
632 > + Send,
633 >,
634 > = vec![
635 Box::new(|_ctx: DurableContext| Box::pin(async move { Ok(1) })),
636 Box::new(|_ctx: DurableContext| Box::pin(async move { Ok(2) })),
637 ];
638
639 let _ = ctx
640 .parallel("seq_test", branches, ParallelOptions::new())
641 .await
642 .unwrap();
643
644 let captured = calls.lock().await;
645
646 assert!(
649 captured.len() >= 6,
650 "expected at least 6 checkpoints, got {}",
651 captured.len()
652 );
653
654 assert_eq!(captured[0].updates[0].sub_type(), Some("Parallel"));
656 assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
657
658 let last_idx = captured.len() - 1;
660 assert_eq!(captured[last_idx].updates[0].sub_type(), Some("Parallel"));
661 assert_eq!(
662 captured[last_idx].updates[0].action(),
663 &OperationAction::Succeed
664 );
665
666 let branch_checkpoints: Vec<_> = captured[1..last_idx]
668 .iter()
669 .filter(|c| c.updates[0].sub_type() == Some("ParallelBranch"))
670 .collect();
671 assert_eq!(
672 branch_checkpoints.len(),
673 4,
674 "expected 4 branch checkpoints (2 START + 2 SUCCEED)"
675 );
676 }
677
678 #[tokio::test]
679 async fn test_parallel_branch_failure_is_captured() {
680 let (backend, _calls) = ParallelMockBackend::new();
681 let mut ctx = DurableContext::new(
682 Arc::new(backend),
683 "arn:test".to_string(),
684 "tok".to_string(),
685 vec![],
686 None,
687 )
688 .await
689 .unwrap();
690
691 let branches: Vec<
692 Box<
693 dyn FnOnce(
694 DurableContext,
695 ) -> std::pin::Pin<
696 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
697 > + Send,
698 >,
699 > = vec![
700 Box::new(|_ctx: DurableContext| Box::pin(async move { Ok(42) })),
701 Box::new(|_ctx: DurableContext| {
702 Box::pin(async move {
703 Err(DurableError::parallel_failed(
704 "branch",
705 "intentional failure",
706 ))
707 })
708 }),
709 ];
710
711 let result = ctx
712 .parallel("fail_test", branches, ParallelOptions::new())
713 .await
714 .unwrap();
715
716 assert_eq!(result.results.len(), 2);
717 assert_eq!(
718 result.results[0].status,
719 crate::types::BatchItemStatus::Succeeded
720 );
721 assert_eq!(result.results[0].result, Some(42));
722 assert_eq!(
723 result.results[1].status,
724 crate::types::BatchItemStatus::Failed
725 );
726 assert!(result.results[1].error.is_some());
727 assert!(result.results[1]
728 .error
729 .as_ref()
730 .unwrap()
731 .contains("intentional failure"));
732 }
733
734 #[traced_test]
737 #[tokio::test]
738 async fn test_parallel_emits_span() {
739 let (backend, _calls) = ParallelMockBackend::new();
740 let mut ctx = DurableContext::new(
741 Arc::new(backend),
742 "arn:test".to_string(),
743 "tok".to_string(),
744 vec![],
745 None,
746 )
747 .await
748 .unwrap();
749 type BranchFn = Box<
751 dyn FnOnce(
752 DurableContext,
753 ) -> std::pin::Pin<
754 Box<
755 dyn std::future::Future<Output = Result<i32, crate::error::DurableError>>
756 + Send,
757 >,
758 > + Send,
759 >;
760 let branches: Vec<BranchFn> = vec![];
761 let _ = ctx
762 .parallel("batch", branches, ParallelOptions::new())
763 .await;
764 assert!(logs_contain("durable_operation"));
765 assert!(logs_contain("batch"));
766 assert!(logs_contain("parallel"));
767 }
768}