1use std::sync::Arc;
7
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::concurrency::{BatchResult, CompletionReason, ConcurrentExecutor};
11use crate::config::ChildConfig;
12use crate::config::{ItemBatcher, MapConfig};
13use crate::context::{create_operation_span, DurableContext, LogInfo, Logger, OperationIdentifier};
14use crate::error::{DurableError, ErrorObject};
15use crate::handlers::child::child_handler;
16use crate::operation::{OperationType, OperationUpdate};
17use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
18use crate::state::ExecutionState;
19
20pub async fn map_handler<T, U, F, Fut>(
42 items: Vec<T>,
43 func: F,
44 state: &Arc<ExecutionState>,
45 op_id: &OperationIdentifier,
46 parent_ctx: &DurableContext,
47 config: &MapConfig,
48 logger: &Arc<dyn Logger>,
49) -> Result<BatchResult<U>, DurableError>
50where
51 T: Serialize + DeserializeOwned + Send + Sync + Clone + 'static,
52 U: Serialize + DeserializeOwned + Send + 'static,
53 F: Fn(DurableContext, T, usize) -> Fut + Send + Sync + Clone + 'static,
54 Fut: std::future::Future<Output = Result<U, DurableError>> + Send + 'static,
55{
56 let span = create_operation_span("map", op_id, state.durable_execution_arn());
59 let _guard = span.enter();
60
61 let mut log_info =
62 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
63 if let Some(ref parent_id) = op_id.parent_id {
64 log_info = log_info.with_parent_id(parent_id);
65 }
66
67 logger.debug(
68 &format!(
69 "Starting map operation: {} with {} items",
70 op_id,
71 items.len()
72 ),
73 &log_info,
74 );
75
76 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
78
79 if checkpoint_result.is_existent() {
80 if let Some(op_type) = checkpoint_result.operation_type() {
82 if op_type != OperationType::Context {
83 span.record("status", "non_deterministic");
84 return Err(DurableError::NonDeterministic {
85 message: format!(
86 "Expected Context operation but found {:?} at operation_id {}",
87 op_type, op_id.operation_id
88 ),
89 operation_id: Some(op_id.operation_id.clone()),
90 });
91 }
92 }
93
94 if checkpoint_result.is_succeeded() {
96 logger.debug(
97 &format!("Replaying succeeded map operation: {}", op_id),
98 &log_info,
99 );
100
101 if let Some(result_str) = checkpoint_result.result() {
102 let serdes = JsonSerDes::<BatchResult<U>>::new();
103 let serdes_ctx =
104 SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
105 let result = serdes.deserialize(result_str, &serdes_ctx).map_err(|e| {
106 DurableError::SerDes {
107 message: format!("Failed to deserialize map result: {}", e),
108 }
109 })?;
110
111 state.track_replay(&op_id.operation_id).await;
112 span.record("status", "replayed_succeeded");
113 return Ok(result);
114 }
115 }
116
117 if checkpoint_result.is_failed() {
119 logger.debug(
120 &format!("Replaying failed map operation: {}", op_id),
121 &log_info,
122 );
123
124 state.track_replay(&op_id.operation_id).await;
125 span.record("status", "replayed_failed");
126
127 if let Some(error) = checkpoint_result.error() {
128 return Err(DurableError::UserCode {
129 message: error.error_message.clone(),
130 error_type: error.error_type.clone(),
131 stack_trace: error.stack_trace.clone(),
132 });
133 } else {
134 return Err(DurableError::execution(
135 "Map operation failed with unknown error",
136 ));
137 }
138 }
139
140 if checkpoint_result.is_terminal() {
142 state.track_replay(&op_id.operation_id).await;
143 span.record("status", "replayed_terminal");
144
145 let status = checkpoint_result.status().unwrap();
146 return Err(DurableError::execution(format!(
147 "Map operation was {}",
148 status
149 )));
150 }
151 }
152
153 if items.is_empty() {
155 logger.debug("Map operation with empty collection", &log_info);
156 let result = BatchResult::empty();
157
158 let serdes = JsonSerDes::<BatchResult<U>>::new();
160 let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
161 let serialized =
162 serdes
163 .serialize(&result, &serdes_ctx)
164 .map_err(|e| DurableError::SerDes {
165 message: format!("Failed to serialize map result: {}", e),
166 })?;
167
168 let succeed_update = create_succeed_update(op_id, Some(serialized));
169 state.create_checkpoint(succeed_update, true).await?;
170
171 span.record("status", "succeeded_empty");
172 return Ok(result);
173 }
174
175 let batched_items = if let Some(ref batcher) = config.item_batcher {
177 batch_items(&items, batcher)
178 } else {
179 items
180 .into_iter()
181 .enumerate()
182 .map(|(i, item)| (i, vec![item]))
183 .collect()
184 };
185
186 let start_update = create_start_update(op_id);
189 state.create_checkpoint(start_update, true).await?;
190
191 let map_ctx = parent_ctx.create_child_context(&op_id.operation_id);
193
194 let total_tasks = batched_items.len();
196 let executor = ConcurrentExecutor::new(
197 total_tasks,
198 config.max_concurrency,
199 config.completion_config.clone(),
200 );
201
202 let tasks: Vec<_> = batched_items
204 .into_iter()
205 .map(|(original_index, batch)| {
206 let func = func.clone();
207 let map_ctx = map_ctx.clone();
208 let state = state.clone();
209 let logger = logger.clone();
210 let op_id = op_id.clone();
211
212 move |_task_idx: usize| {
213 Box::pin(async move {
214 let item = batch
217 .into_iter()
218 .next()
219 .ok_or_else(|| DurableError::validation("Empty batch in map operation"))?;
220
221 let child_op_id = OperationIdentifier::new(
223 map_ctx.next_operation_id(),
224 Some(op_id.operation_id.clone()),
225 Some(format!("map-item-{}", original_index)),
226 );
227
228 child_handler(
230 |ctx| {
231 async move { func(ctx, item, original_index).await }
233 },
234 &state,
235 &child_op_id,
236 &map_ctx,
237 &ChildConfig::default(),
238 &logger,
239 )
240 .await
241 })
242 as std::pin::Pin<
243 Box<dyn std::future::Future<Output = Result<U, DurableError>> + Send>,
244 >
245 }
246 })
247 .collect();
248
249 let batch_result = executor.execute(tasks).await;
251
252 logger.debug(
253 &format!(
254 "Map operation completed: {} succeeded, {} failed",
255 batch_result.success_count(),
256 batch_result.failure_count()
257 ),
258 &log_info,
259 );
260
261 let serdes = JsonSerDes::<BatchResult<U>>::new();
263 let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
264
265 if batch_result.completion_reason != CompletionReason::Suspended {
267 let serialized =
268 serdes
269 .serialize(&batch_result, &serdes_ctx)
270 .map_err(|e| DurableError::SerDes {
271 message: format!("Failed to serialize map result: {}", e),
272 })?;
273
274 let succeed_update = create_succeed_update(op_id, Some(serialized));
275 state.create_checkpoint(succeed_update, true).await?;
276
277 state.mark_parent_done(&op_id.operation_id).await;
279 span.record("status", "succeeded");
280 } else {
281 span.record("status", "suspended");
282 }
283
284 Ok(batch_result)
285}
286
287fn batch_items<T: Serialize + Clone>(items: &[T], batcher: &ItemBatcher) -> Vec<(usize, Vec<T>)> {
292 batcher.batch(items)
293}
294
295fn create_start_update(op_id: &OperationIdentifier) -> OperationUpdate {
297 op_id.apply_to(OperationUpdate::start(
298 &op_id.operation_id,
299 OperationType::Context,
300 ))
301}
302
303fn create_succeed_update(op_id: &OperationIdentifier, result: Option<String>) -> OperationUpdate {
305 op_id.apply_to(OperationUpdate::succeed(
306 &op_id.operation_id,
307 OperationType::Context,
308 result,
309 ))
310}
311
312#[allow(dead_code)]
314fn create_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
315 op_id.apply_to(OperationUpdate::fail(
316 &op_id.operation_id,
317 OperationType::Context,
318 error,
319 ))
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
326 use crate::context::TracingLogger;
327 use crate::lambda::InitialExecutionState;
328
329 fn create_mock_client() -> SharedDurableServiceClient {
330 Arc::new(
331 MockDurableServiceClient::new()
332 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
333 .with_checkpoint_response(Ok(CheckpointResponse::new("token-2")))
334 .with_checkpoint_response(Ok(CheckpointResponse::new("token-3")))
335 .with_checkpoint_response(Ok(CheckpointResponse::new("token-4")))
336 .with_checkpoint_response(Ok(CheckpointResponse::new("token-5"))),
337 )
338 }
339
340 fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
341 Arc::new(ExecutionState::new(
342 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
343 "initial-token",
344 InitialExecutionState::new(),
345 client,
346 ))
347 }
348
349 fn create_test_op_id() -> OperationIdentifier {
350 OperationIdentifier::new(
351 "test-map-123",
352 Some("parent-op".to_string()),
353 Some("test-map".to_string()),
354 )
355 }
356
357 fn create_test_logger() -> Arc<dyn Logger> {
358 Arc::new(TracingLogger)
359 }
360
361 fn create_test_config() -> MapConfig {
362 MapConfig::default()
363 }
364
365 fn create_test_parent_ctx(state: Arc<ExecutionState>) -> DurableContext {
366 DurableContext::new(state)
367 }
368
369 #[tokio::test]
370 async fn test_map_handler_empty_collection() {
371 let client = create_mock_client();
372 let state = create_test_state(client);
373 let op_id = create_test_op_id();
374 let config = create_test_config();
375 let logger = create_test_logger();
376 let parent_ctx = create_test_parent_ctx(state.clone());
377
378 let items: Vec<i32> = vec![];
379 let result = map_handler(
380 items,
381 |_ctx, item: i32, _idx| async move { Ok(item * 2) },
382 &state,
383 &op_id,
384 &parent_ctx,
385 &config,
386 &logger,
387 )
388 .await;
389
390 assert!(result.is_ok());
391 let batch_result = result.unwrap();
392 assert!(batch_result.items.is_empty());
393 assert_eq!(
394 batch_result.completion_reason,
395 CompletionReason::AllCompleted
396 );
397 }
398
399 #[tokio::test]
400 async fn test_map_handler_single_item() {
401 let client = create_mock_client();
402 let state = create_test_state(client);
403 let op_id = create_test_op_id();
404 let config = create_test_config();
405 let logger = create_test_logger();
406 let parent_ctx = create_test_parent_ctx(state.clone());
407
408 let items = vec![5];
409 let result = map_handler(
410 items,
411 |_ctx, item: i32, _idx| async move { Ok(item * 2) },
412 &state,
413 &op_id,
414 &parent_ctx,
415 &config,
416 &logger,
417 )
418 .await;
419
420 assert!(result.is_ok());
421 let batch_result = result.unwrap();
422 assert_eq!(batch_result.total_count(), 1);
423 assert_eq!(batch_result.success_count(), 1);
424 assert!(batch_result.all_succeeded());
425 }
426
427 #[tokio::test]
428 async fn test_map_handler_multiple_items() {
429 let client = Arc::new(MockDurableServiceClient::new().with_checkpoint_responses(10));
430 let state = create_test_state(client);
431 let op_id = create_test_op_id();
432 let config = create_test_config();
433 let logger = create_test_logger();
434 let parent_ctx = create_test_parent_ctx(state.clone());
435
436 let items = vec![1, 2, 3];
437 let result = map_handler(
438 items,
439 |_ctx, item: i32, _idx| async move { Ok(item * 10) },
440 &state,
441 &op_id,
442 &parent_ctx,
443 &config,
444 &logger,
445 )
446 .await;
447
448 assert!(result.is_ok());
449 let batch_result = result.unwrap();
450 assert_eq!(batch_result.total_count(), 3);
451 assert_eq!(batch_result.success_count(), 3);
452 }
453
454 #[tokio::test]
455 async fn test_map_handler_with_concurrency_limit() {
456 let client = Arc::new(MockDurableServiceClient::new().with_checkpoint_responses(10));
457 let state = create_test_state(client);
458 let op_id = create_test_op_id();
459 let config = MapConfig {
460 max_concurrency: Some(2),
461 ..Default::default()
462 };
463 let logger = create_test_logger();
464 let parent_ctx = create_test_parent_ctx(state.clone());
465
466 let items = vec![1, 2, 3, 4, 5];
467 let result = map_handler(
468 items,
469 |_ctx, item: i32, _idx| async move { Ok(item) },
470 &state,
471 &op_id,
472 &parent_ctx,
473 &config,
474 &logger,
475 )
476 .await;
477
478 assert!(result.is_ok());
479 let batch_result = result.unwrap();
480 assert_eq!(batch_result.total_count(), 5);
481 }
482
483 #[test]
484 fn test_batch_items_no_batching() {
485 let items = vec![1, 2, 3, 4, 5];
486 let batcher = ItemBatcher {
487 max_items_per_batch: 10,
488 max_bytes_per_batch: 1024,
489 };
490
491 let batches = batch_items(&items, &batcher);
492 assert_eq!(batches.len(), 1);
493 assert_eq!(batches[0].0, 0);
494 assert_eq!(batches[0].1, vec![1, 2, 3, 4, 5]);
495 }
496
497 #[test]
498 fn test_batch_items_with_batching() {
499 let items = vec![1, 2, 3, 4, 5];
500 let batcher = ItemBatcher {
501 max_items_per_batch: 2,
502 max_bytes_per_batch: 1024,
503 };
504
505 let batches = batch_items(&items, &batcher);
506 assert_eq!(batches.len(), 3);
507 assert_eq!(batches[0].1, vec![1, 2]);
508 assert_eq!(batches[1].1, vec![3, 4]);
509 assert_eq!(batches[2].1, vec![5]);
510 }
511
512 #[test]
513 fn test_batch_items_empty() {
514 let items: Vec<i32> = vec![];
515 let batcher = ItemBatcher::default();
516
517 let batches = batch_items(&items, &batcher);
518 assert!(batches.is_empty());
519 }
520
521 #[test]
522 fn test_create_succeed_update() {
523 let op_id = OperationIdentifier::new(
524 "op-123",
525 Some("parent-456".to_string()),
526 Some("my-map".to_string()),
527 );
528 let update = create_succeed_update(&op_id, Some("result".to_string()));
529
530 assert_eq!(update.operation_id, "op-123");
531 assert_eq!(update.operation_type, OperationType::Context);
532 assert_eq!(update.result, Some("result".to_string()));
533 assert_eq!(update.parent_id, Some("parent-456".to_string()));
534 assert_eq!(update.name, Some("my-map".to_string()));
535 }
536}