1use std::future::Future;
12
13use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
14use serde::de::DeserializeOwned;
15use serde::Serialize;
16
17use crate::context::DurableContext;
18use crate::error::DurableError;
19use crate::types::{BatchItem, BatchItemStatus, BatchResult, CompletionReason, MapOptions};
20
21impl DurableContext {
22 #[allow(clippy::await_holding_lock)]
71 pub async fn map<T, I, F, Fut>(
72 &mut self,
73 name: &str,
74 items: Vec<I>,
75 options: MapOptions,
76 f: F,
77 ) -> Result<BatchResult<T>, DurableError>
78 where
79 T: Serialize + DeserializeOwned + Send + 'static,
80 I: Send + 'static,
81 F: FnOnce(I, DurableContext) -> Fut + Send + 'static + Clone,
82 Fut: Future<Output = Result<T, DurableError>> + Send + 'static,
83 {
84 let op_id = self.replay_engine_mut().generate_operation_id();
85
86 let span = tracing::info_span!(
87 "durable_operation",
88 op.name = name,
89 op.type = "map",
90 op.id = %op_id,
91 );
92 let _guard = span.enter();
93 tracing::trace!("durable_operation");
94
95 if let Some(op) = self.replay_engine().check_result(&op_id) {
97 if op.status == OperationStatus::Succeeded {
98 let result_str =
99 op.context_details()
100 .and_then(|d| d.result())
101 .ok_or_else(|| {
102 DurableError::checkpoint_failed(
103 name,
104 std::io::Error::new(
105 std::io::ErrorKind::InvalidData,
106 "map succeeded but no result in context_details",
107 ),
108 )
109 })?;
110
111 let batch_result: BatchResult<T> = serde_json::from_str(result_str)
112 .map_err(|e| DurableError::deserialization("BatchResult", e))?;
113
114 self.replay_engine_mut().track_replay(&op_id);
115 return Ok(batch_result);
116 } else {
117 let error_message = op
118 .context_details()
119 .and_then(|d| d.error())
120 .map(|e| {
121 format!(
122 "{}: {}",
123 e.error_type().unwrap_or("Unknown"),
124 e.error_data().unwrap_or("")
125 )
126 })
127 .unwrap_or_else(|| "map failed".to_string());
128 return Err(DurableError::map_failed(name, error_message));
129 }
130 }
131
132 let outer_start = OperationUpdate::builder()
134 .id(op_id.clone())
135 .r#type(OperationType::Context)
136 .action(OperationAction::Start)
137 .sub_type("Map")
138 .name(name)
139 .build()
140 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
141
142 let start_response = self
143 .backend()
144 .checkpoint(self.arn(), self.checkpoint_token(), vec![outer_start], None)
145 .await?;
146
147 let new_token = start_response.checkpoint_token().ok_or_else(|| {
148 DurableError::checkpoint_failed(
149 name,
150 std::io::Error::new(
151 std::io::ErrorKind::InvalidData,
152 "checkpoint response missing checkpoint_token",
153 ),
154 )
155 })?;
156 self.set_checkpoint_token(new_token.to_string());
157
158 if let Some(new_state) = start_response.new_execution_state() {
159 for op in new_state.operations() {
160 self.replay_engine_mut()
161 .insert_operation(op.id().to_string(), op.clone());
162 }
163 }
164
165 let item_count = items.len();
167 let batch_size = options.get_batch_size().unwrap_or(item_count).max(1);
168 let mut all_results: Vec<(usize, Result<T, DurableError>)> = Vec::with_capacity(item_count);
169
170 let mut item_id_gen = crate::operation_id::OperationIdGenerator::new(Some(op_id.clone()));
172
173 let mut items_iter = items.into_iter().enumerate().peekable();
174
175 while items_iter.peek().is_some() {
176 let batch: Vec<(usize, I)> = items_iter.by_ref().take(batch_size).collect();
177 let mut handles = Vec::with_capacity(batch.len());
178
179 for (index, item) in batch {
180 let item_op_id = item_id_gen.next_id();
181 let child_ctx = self.create_child_context(&item_op_id);
182 let config = ItemConfig {
183 backend: self.backend().clone(),
184 arn: self.arn().to_string(),
185 token: self.checkpoint_token().to_string(),
186 item_op_id,
187 parent_op_id: op_id.clone(),
188 item_name: format!("map-item-{index}"),
189 };
190 let f_clone = f.clone();
191
192 let handle = tokio::spawn(async move {
193 let result = execute_item(child_ctx, config, item, f_clone).await;
194 (index, result)
195 });
196
197 handles.push(handle);
198 }
199
200 for handle in handles {
202 let (index, result) = handle
203 .await
204 .map_err(|e| DurableError::map_failed(name, format!("item panicked: {e}")))?;
205 all_results.push((index, result));
206 }
207 }
208
209 all_results.sort_by_key(|(index, _)| *index);
211
212 let results: Vec<BatchItem<T>> = all_results
214 .into_iter()
215 .map(|(index, result)| match result {
216 Ok(value) => BatchItem {
217 index,
218 status: BatchItemStatus::Succeeded,
219 result: Some(value),
220 error: None,
221 },
222 Err(err) => BatchItem {
223 index,
224 status: BatchItemStatus::Failed,
225 result: None,
226 error: Some(err.to_string()),
227 },
228 })
229 .collect();
230
231 let batch_result = BatchResult {
232 results,
233 completion_reason: CompletionReason::AllCompleted,
234 };
235
236 let serialized_result = serde_json::to_string(&batch_result)
238 .map_err(|e| DurableError::serialization("BatchResult", e))?;
239
240 let ctx_opts = aws_sdk_lambda::types::ContextOptions::builder()
241 .replay_children(false)
242 .build();
243
244 let outer_succeed = OperationUpdate::builder()
245 .id(op_id.clone())
246 .r#type(OperationType::Context)
247 .action(OperationAction::Succeed)
248 .sub_type("Map")
249 .payload(serialized_result)
250 .context_options(ctx_opts)
251 .build()
252 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
253
254 let succeed_response = self
255 .backend()
256 .checkpoint(
257 self.arn(),
258 self.checkpoint_token(),
259 vec![outer_succeed],
260 None,
261 )
262 .await?;
263
264 let new_token = succeed_response.checkpoint_token().ok_or_else(|| {
265 DurableError::checkpoint_failed(
266 name,
267 std::io::Error::new(
268 std::io::ErrorKind::InvalidData,
269 "checkpoint response missing checkpoint_token",
270 ),
271 )
272 })?;
273 self.set_checkpoint_token(new_token.to_string());
274
275 if let Some(new_state) = succeed_response.new_execution_state() {
276 for op in new_state.operations() {
277 self.replay_engine_mut()
278 .insert_operation(op.id().to_string(), op.clone());
279 }
280 }
281
282 self.replay_engine_mut().track_replay(&op_id);
283 Ok(batch_result)
284 }
285}
286
287struct ItemConfig {
289 backend: std::sync::Arc<dyn crate::backend::DurableBackend>,
290 arn: String,
291 token: String,
292 item_op_id: String,
293 parent_op_id: String,
294 item_name: String,
295}
296
297async fn execute_item<T, I, F, Fut>(
302 child_ctx: DurableContext,
303 config: ItemConfig,
304 item: I,
305 f: F,
306) -> Result<T, DurableError>
307where
308 T: Serialize + Send + 'static,
309 I: Send + 'static,
310 F: FnOnce(I, DurableContext) -> Fut + Send + 'static,
311 Fut: Future<Output = Result<T, DurableError>> + Send + 'static,
312{
313 let item_start = OperationUpdate::builder()
315 .id(config.item_op_id.clone())
316 .r#type(OperationType::Context)
317 .action(OperationAction::Start)
318 .sub_type("MapItem")
319 .name(&config.item_name)
320 .parent_id(config.parent_op_id.clone())
321 .build()
322 .map_err(|e| DurableError::checkpoint_failed(&config.item_name, e))?;
323
324 let _ = config
325 .backend
326 .checkpoint(&config.arn, &config.token, vec![item_start], None)
327 .await?;
328
329 let result = f(item, child_ctx).await?;
331
332 let serialized = serde_json::to_string(&result)
334 .map_err(|e| DurableError::serialization(&config.item_name, e))?;
335
336 let ctx_opts = aws_sdk_lambda::types::ContextOptions::builder()
337 .replay_children(false)
338 .build();
339
340 let item_succeed = OperationUpdate::builder()
341 .id(config.item_op_id.clone())
342 .r#type(OperationType::Context)
343 .action(OperationAction::Succeed)
344 .sub_type("MapItem")
345 .name(&config.item_name)
346 .parent_id(config.parent_op_id.clone())
347 .payload(serialized)
348 .context_options(ctx_opts)
349 .build()
350 .map_err(|e| DurableError::checkpoint_failed(&config.item_name, e))?;
351
352 let _ = config
353 .backend
354 .checkpoint(&config.arn, &config.token, vec![item_succeed], None)
355 .await?;
356
357 Ok(result)
358}
359
360#[cfg(test)]
361mod tests {
362 use std::sync::atomic::{AtomicUsize, Ordering};
363 use std::sync::Arc;
364
365 use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
366 use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
367 use aws_sdk_lambda::types::{
368 ContextDetails, Operation, OperationAction, OperationStatus, OperationType, OperationUpdate,
369 };
370 use aws_smithy_types::DateTime;
371 use tokio::sync::Mutex;
372 use tracing_test::traced_test;
373
374 use crate::backend::DurableBackend;
375 use crate::context::DurableContext;
376 use crate::error::DurableError;
377 use crate::types::MapOptions;
378
379 #[derive(Debug, Clone)]
380 #[allow(dead_code)]
381 struct CheckpointCall {
382 arn: String,
383 checkpoint_token: String,
384 updates: Vec<OperationUpdate>,
385 }
386
387 struct MapMockBackend {
389 calls: Arc<Mutex<Vec<CheckpointCall>>>,
390 }
391
392 impl MapMockBackend {
393 fn new() -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
394 let calls = Arc::new(Mutex::new(Vec::new()));
395 let backend = Self {
396 calls: calls.clone(),
397 };
398 (backend, calls)
399 }
400 }
401
402 #[async_trait::async_trait]
403 impl DurableBackend for MapMockBackend {
404 async fn checkpoint(
405 &self,
406 arn: &str,
407 checkpoint_token: &str,
408 updates: Vec<OperationUpdate>,
409 _client_token: Option<&str>,
410 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
411 self.calls.lock().await.push(CheckpointCall {
412 arn: arn.to_string(),
413 checkpoint_token: checkpoint_token.to_string(),
414 updates,
415 });
416 Ok(CheckpointDurableExecutionOutput::builder()
417 .checkpoint_token("mock-token")
418 .build())
419 }
420
421 async fn get_execution_state(
422 &self,
423 _arn: &str,
424 _checkpoint_token: &str,
425 _next_marker: &str,
426 _max_items: i32,
427 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
428 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
429 }
430 }
431
432 fn first_op_id() -> String {
433 let mut gen = crate::operation_id::OperationIdGenerator::new(None);
434 gen.next_id()
435 }
436
437 #[tokio::test]
440 async fn test_map_executes_items_concurrently() {
441 let (backend, calls) = MapMockBackend::new();
442 let mut ctx = DurableContext::new(
443 Arc::new(backend),
444 "arn:test".to_string(),
445 "tok".to_string(),
446 vec![],
447 None,
448 )
449 .await
450 .unwrap();
451
452 let items = vec![10, 20, 30];
453 let result = ctx
454 .map(
455 "process",
456 items,
457 MapOptions::new(),
458 |item: i32, mut child_ctx: DurableContext| async move {
459 let r: Result<i32, String> = child_ctx
460 .step("double", move || async move { Ok(item * 2) })
461 .await?;
462 Ok(r.unwrap())
463 },
464 )
465 .await
466 .unwrap();
467
468 assert_eq!(result.results.len(), 3);
469 assert_eq!(result.results[0].index, 0);
471 assert_eq!(result.results[1].index, 1);
472 assert_eq!(result.results[2].index, 2);
473 assert_eq!(result.results[0].result, Some(20));
474 assert_eq!(result.results[1].result, Some(40));
475 assert_eq!(result.results[2].result, Some(60));
476
477 let captured = calls.lock().await;
479 assert!(
480 captured.len() >= 2,
481 "should have at least outer START and outer SUCCEED"
482 );
483
484 assert_eq!(captured[0].updates[0].r#type(), &OperationType::Context);
486 assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
487 assert_eq!(captured[0].updates[0].sub_type(), Some("Map"));
488
489 let last = &captured[captured.len() - 1];
491 assert_eq!(last.updates[0].r#type(), &OperationType::Context);
492 assert_eq!(last.updates[0].action(), &OperationAction::Succeed);
493 assert_eq!(last.updates[0].sub_type(), Some("Map"));
494 assert!(
495 last.updates[0].payload().is_some(),
496 "should have BatchResult payload"
497 );
498 }
499
500 #[tokio::test]
501 async fn test_map_replays_from_cached_result() {
502 let op_id = first_op_id();
503
504 let batch_json = r#"{"results":[{"index":0,"status":"Succeeded","result":100,"error":null},{"index":1,"status":"Succeeded","result":200,"error":null}],"completion_reason":"AllCompleted"}"#;
505
506 let map_op = Operation::builder()
507 .id(&op_id)
508 .r#type(OperationType::Context)
509 .status(OperationStatus::Succeeded)
510 .start_timestamp(DateTime::from_secs(0))
511 .context_details(
512 ContextDetails::builder()
513 .replay_children(false)
514 .result(batch_json)
515 .build(),
516 )
517 .build()
518 .unwrap();
519
520 let (backend, calls) = MapMockBackend::new();
521 let mut ctx = DurableContext::new(
522 Arc::new(backend),
523 "arn:test".to_string(),
524 "tok".to_string(),
525 vec![map_op],
526 None,
527 )
528 .await
529 .unwrap();
530
531 let result: crate::types::BatchResult<i32> = ctx
533 .map(
534 "process",
535 vec![1],
536 MapOptions::new(),
537 |_item: i32, _ctx: DurableContext| async move {
538 panic!("item should not execute during replay")
539 },
540 )
541 .await
542 .unwrap();
543
544 assert_eq!(result.results.len(), 2);
545 assert_eq!(result.results[0].result, Some(100));
546 assert_eq!(result.results[1].result, Some(200));
547
548 let captured = calls.lock().await;
550 assert_eq!(captured.len(), 0, "no checkpoints during replay");
551 }
552
553 #[tokio::test]
554 async fn test_map_items_have_isolated_namespaces() {
555 let (backend, _calls) = MapMockBackend::new();
556 let mut ctx = DurableContext::new(
557 Arc::new(backend),
558 "arn:test".to_string(),
559 "tok".to_string(),
560 vec![],
561 None,
562 )
563 .await
564 .unwrap();
565
566 let items = vec!["alpha", "beta"];
568 let result = ctx
569 .map(
570 "isolated_test",
571 items,
572 MapOptions::new(),
573 |item: &str, mut child_ctx: DurableContext| async move {
574 let r: Result<String, String> = child_ctx
575 .step("work", move || async move { Ok(format!("result-{item}")) })
576 .await?;
577 Ok(r.unwrap())
578 },
579 )
580 .await
581 .unwrap();
582
583 assert_eq!(result.results.len(), 2);
584 assert_eq!(result.results[0].result.as_deref(), Some("result-alpha"));
585 assert_eq!(result.results[1].result.as_deref(), Some("result-beta"));
586 }
587
588 #[tokio::test]
589 async fn test_map_sends_correct_checkpoint_sequence() {
590 let (backend, calls) = MapMockBackend::new();
591 let mut ctx = DurableContext::new(
592 Arc::new(backend),
593 "arn:test".to_string(),
594 "tok".to_string(),
595 vec![],
596 None,
597 )
598 .await
599 .unwrap();
600
601 let items = vec![1, 2];
602 let _ = ctx
603 .map(
604 "seq_test",
605 items,
606 MapOptions::new(),
607 |_item: i32, _ctx: DurableContext| async move { Ok(0i32) },
608 )
609 .await
610 .unwrap();
611
612 let captured = calls.lock().await;
613
614 assert!(
617 captured.len() >= 6,
618 "expected at least 6 checkpoints, got {}",
619 captured.len()
620 );
621
622 assert_eq!(captured[0].updates[0].sub_type(), Some("Map"));
624 assert_eq!(captured[0].updates[0].action(), &OperationAction::Start);
625
626 let last_idx = captured.len() - 1;
628 assert_eq!(captured[last_idx].updates[0].sub_type(), Some("Map"));
629 assert_eq!(
630 captured[last_idx].updates[0].action(),
631 &OperationAction::Succeed
632 );
633
634 let item_checkpoints: Vec<_> = captured[1..last_idx]
636 .iter()
637 .filter(|c| c.updates[0].sub_type() == Some("MapItem"))
638 .collect();
639 assert_eq!(
640 item_checkpoints.len(),
641 4,
642 "expected 4 item checkpoints (2 START + 2 SUCCEED)"
643 );
644 }
645
646 #[tokio::test]
647 async fn test_map_item_failure_is_captured() {
648 let (backend, _calls) = MapMockBackend::new();
649 let mut ctx = DurableContext::new(
650 Arc::new(backend),
651 "arn:test".to_string(),
652 "tok".to_string(),
653 vec![],
654 None,
655 )
656 .await
657 .unwrap();
658
659 let items = vec![1, 2];
660 let result = ctx
661 .map(
662 "fail_test",
663 items,
664 MapOptions::new(),
665 |item: i32, _ctx: DurableContext| async move {
666 if item == 2 {
667 Err(DurableError::map_failed("item", "intentional failure"))
668 } else {
669 Ok(item * 10)
670 }
671 },
672 )
673 .await
674 .unwrap();
675
676 assert_eq!(result.results.len(), 2);
677 assert_eq!(
678 result.results[0].status,
679 crate::types::BatchItemStatus::Succeeded
680 );
681 assert_eq!(result.results[0].result, Some(10));
682 assert_eq!(
683 result.results[1].status,
684 crate::types::BatchItemStatus::Failed
685 );
686 assert!(result.results[1].error.is_some());
687 assert!(result.results[1]
688 .error
689 .as_ref()
690 .unwrap()
691 .contains("intentional failure"));
692 }
693
694 #[tokio::test]
695 async fn test_map_batching_processes_sequentially() {
696 let (backend, _calls) = MapMockBackend::new();
697 let mut ctx = DurableContext::new(
698 Arc::new(backend),
699 "arn:test".to_string(),
700 "tok".to_string(),
701 vec![],
702 None,
703 )
704 .await
705 .unwrap();
706
707 let execution_order = Arc::new(AtomicUsize::new(0));
710
711 let items = vec![0usize, 1, 2, 3];
712 let order = execution_order.clone();
713 let result = ctx
714 .map(
715 "batch_test",
716 items,
717 MapOptions::new().batch_size(2),
718 move |item: usize, _ctx: DurableContext| {
719 let order = order.clone();
720 async move {
721 let seq = order.fetch_add(1, Ordering::SeqCst);
722 Ok((item, seq))
724 }
725 },
726 )
727 .await
728 .unwrap();
729
730 assert_eq!(result.results.len(), 4);
731
732 let item0 = result.results[0].result.as_ref().unwrap();
735 let item1 = result.results[1].result.as_ref().unwrap();
736 let item2 = result.results[2].result.as_ref().unwrap();
737 let item3 = result.results[3].result.as_ref().unwrap();
738
739 assert!(item0.1 < 2, "batch 1 item should execute before batch 2");
741 assert!(item1.1 < 2, "batch 1 item should execute before batch 2");
742 assert!(item2.1 >= 2, "batch 2 item should execute after batch 1");
743 assert!(item3.1 >= 2, "batch 2 item should execute after batch 1");
744 }
745
746 #[tokio::test]
747 async fn test_map_default_options_all_concurrent() {
748 let (backend, _calls) = MapMockBackend::new();
749 let mut ctx = DurableContext::new(
750 Arc::new(backend),
751 "arn:test".to_string(),
752 "tok".to_string(),
753 vec![],
754 None,
755 )
756 .await
757 .unwrap();
758
759 let items = vec![1, 2, 3, 4, 5];
761 let result = ctx
762 .map(
763 "all_concurrent",
764 items,
765 MapOptions::new(), |item: i32, _ctx: DurableContext| async move { Ok(item) },
767 )
768 .await
769 .unwrap();
770
771 assert_eq!(result.results.len(), 5);
772 for (i, r) in result.results.iter().enumerate() {
773 assert_eq!(r.index, i);
774 assert_eq!(r.result, Some((i + 1) as i32));
775 }
776 }
777
778 #[traced_test]
781 #[tokio::test]
782 async fn test_map_emits_span() {
783 let (backend, _calls) = MapMockBackend::new();
784 let mut ctx = DurableContext::new(
785 Arc::new(backend),
786 "arn:test".to_string(),
787 "tok".to_string(),
788 vec![],
789 None,
790 )
791 .await
792 .unwrap();
793 let _ = ctx
795 .map(
796 "process",
797 Vec::<i32>::new(),
798 MapOptions::new(),
799 |item: i32, _ctx: DurableContext| async move { Ok(item) },
800 )
801 .await;
802 assert!(logs_contain("durable_operation"));
803 assert!(logs_contain("process"));
804 assert!(logs_contain("map"));
805 }
806}