Skip to main content

durable_execution_sdk/handlers/
map.rs

1//! Map operation handler for the AWS Durable Execution SDK.
2//!
3//! This module implements the map handler which processes collections
4//! in parallel with configurable concurrency and failure tolerance.
5
6use std::sync::Arc;
7
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::concurrency::{BatchResult, CompletionReason, ConcurrentExecutor};
11use crate::config::ChildConfig;
12use crate::config::{ItemBatcher, MapConfig};
13use crate::context::{create_operation_span, DurableContext, LogInfo, Logger, OperationIdentifier};
14use crate::error::{DurableError, ErrorObject};
15use crate::handlers::child::child_handler;
16use crate::operation::{OperationType, OperationUpdate};
17use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
18use crate::state::ExecutionState;
19
20/// Executes a map operation over a collection with parallel processing.
21///
22/// This handler implements the map semantics:
23/// - Creates a child context for each item
24/// - Uses ConcurrentExecutor for parallel execution
25/// - Supports ItemBatcher for batching
26/// - Returns BatchResult with results for all items
27///
28/// # Arguments
29///
30/// * `items` - The collection of items to process
31/// * `func` - The function to apply to each item
32/// * `state` - The execution state for checkpointing
33/// * `op_id` - The operation identifier for the map operation
34/// * `parent_ctx` - The parent DurableContext
35/// * `config` - Map configuration
36/// * `logger` - Logger for structured logging
37///
38/// # Returns
39///
40/// A `BatchResult` containing results for all items.
41pub async fn map_handler<T, U, F, Fut>(
42    items: Vec<T>,
43    func: F,
44    state: &Arc<ExecutionState>,
45    op_id: &OperationIdentifier,
46    parent_ctx: &DurableContext,
47    config: &MapConfig,
48    logger: &Arc<dyn Logger>,
49) -> Result<BatchResult<U>, DurableError>
50where
51    T: Serialize + DeserializeOwned + Send + Sync + Clone + 'static,
52    U: Serialize + DeserializeOwned + Send + 'static,
53    F: Fn(DurableContext, T, usize) -> Fut + Send + Sync + Clone + 'static,
54    Fut: std::future::Future<Output = Result<U, DurableError>> + Send + 'static,
55{
56    // Create tracing span for this operation
57    // Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 3.6
58    let span = create_operation_span("map", op_id, state.durable_execution_arn());
59    let _guard = span.enter();
60
61    let mut log_info =
62        LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
63    if let Some(ref parent_id) = op_id.parent_id {
64        log_info = log_info.with_parent_id(parent_id);
65    }
66
67    logger.debug(
68        &format!(
69            "Starting map operation: {} with {} items",
70            op_id,
71            items.len()
72        ),
73        &log_info,
74    );
75
76    // Check for existing checkpoint (replay)
77    let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
78
79    if checkpoint_result.is_existent() {
80        // Check for non-deterministic execution
81        if let Some(op_type) = checkpoint_result.operation_type() {
82            if op_type != OperationType::Context {
83                span.record("status", "non_deterministic");
84                return Err(DurableError::NonDeterministic {
85                    message: format!(
86                        "Expected Context operation but found {:?} at operation_id {}",
87                        op_type, op_id.operation_id
88                    ),
89                    operation_id: Some(op_id.operation_id.clone()),
90                });
91            }
92        }
93
94        // Handle succeeded checkpoint
95        if checkpoint_result.is_succeeded() {
96            logger.debug(
97                &format!("Replaying succeeded map operation: {}", op_id),
98                &log_info,
99            );
100
101            if let Some(result_str) = checkpoint_result.result() {
102                let serdes = JsonSerDes::<BatchResult<U>>::new();
103                let serdes_ctx =
104                    SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
105                let result = serdes.deserialize(result_str, &serdes_ctx).map_err(|e| {
106                    DurableError::SerDes {
107                        message: format!("Failed to deserialize map result: {}", e),
108                    }
109                })?;
110
111                state.track_replay(&op_id.operation_id).await;
112                span.record("status", "replayed_succeeded");
113                return Ok(result);
114            }
115        }
116
117        // Handle failed checkpoint
118        if checkpoint_result.is_failed() {
119            logger.debug(
120                &format!("Replaying failed map operation: {}", op_id),
121                &log_info,
122            );
123
124            state.track_replay(&op_id.operation_id).await;
125            span.record("status", "replayed_failed");
126
127            if let Some(error) = checkpoint_result.error() {
128                return Err(DurableError::UserCode {
129                    message: error.error_message.clone(),
130                    error_type: error.error_type.clone(),
131                    stack_trace: error.stack_trace.clone(),
132                });
133            } else {
134                return Err(DurableError::execution(
135                    "Map operation failed with unknown error",
136                ));
137            }
138        }
139
140        // Handle other terminal states
141        if checkpoint_result.is_terminal() {
142            state.track_replay(&op_id.operation_id).await;
143            span.record("status", "replayed_terminal");
144
145            let status = checkpoint_result.status().unwrap();
146            return Err(DurableError::execution(format!(
147                "Map operation was {}",
148                status
149            )));
150        }
151    }
152
153    // Handle empty collection
154    if items.is_empty() {
155        logger.debug("Map operation with empty collection", &log_info);
156        let result = BatchResult::empty();
157
158        // Checkpoint the empty result
159        let serdes = JsonSerDes::<BatchResult<U>>::new();
160        let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
161        let serialized =
162            serdes
163                .serialize(&result, &serdes_ctx)
164                .map_err(|e| DurableError::SerDes {
165                    message: format!("Failed to serialize map result: {}", e),
166                })?;
167
168        let succeed_update = create_succeed_update(op_id, Some(serialized));
169        state.create_checkpoint(succeed_update, true).await?;
170
171        span.record("status", "succeeded_empty");
172        return Ok(result);
173    }
174
175    // Apply batching if configured
176    let batched_items = if let Some(ref batcher) = config.item_batcher {
177        batch_items(&items, batcher)
178    } else {
179        items
180            .into_iter()
181            .enumerate()
182            .map(|(i, item)| (i, vec![item]))
183            .collect()
184    };
185
186    // Checkpoint START for the map operation before spawning children
187    // This ensures the parent operation exists when children reference it
188    let start_update = create_start_update(op_id);
189    state.create_checkpoint(start_update, true).await?;
190
191    // Create the map context (child of parent)
192    let map_ctx = parent_ctx.create_child_context(&op_id.operation_id);
193
194    // Create tasks for concurrent execution
195    let total_tasks = batched_items.len();
196    let executor = ConcurrentExecutor::new(
197        total_tasks,
198        config.max_concurrency,
199        config.completion_config.clone(),
200    );
201
202    // Build task closures
203    let tasks: Vec<_> = batched_items
204        .into_iter()
205        .map(|(original_index, batch)| {
206            let func = func.clone();
207            let map_ctx = map_ctx.clone();
208            let state = state.clone();
209            let logger = logger.clone();
210            let op_id = op_id.clone();
211
212            move |_task_idx: usize| {
213                Box::pin(async move {
214                    // For batched items, process each item in the batch
215                    // For now, we process the first item (single item batches)
216                    let item = batch
217                        .into_iter()
218                        .next()
219                        .ok_or_else(|| DurableError::validation("Empty batch in map operation"))?;
220
221                    // Create child operation ID for this item
222                    let child_op_id = OperationIdentifier::new(
223                        map_ctx.next_operation_id(),
224                        Some(op_id.operation_id.clone()),
225                        Some(format!("map-item-{}", original_index)),
226                    );
227
228                    // Execute in child context
229                    child_handler(
230                        |ctx| {
231                            // item and func are moved into this FnOnce closure, no clone needed
232                            async move { func(ctx, item, original_index).await }
233                        },
234                        &state,
235                        &child_op_id,
236                        &map_ctx,
237                        &ChildConfig::default(),
238                        &logger,
239                    )
240                    .await
241                })
242                    as std::pin::Pin<
243                        Box<dyn std::future::Future<Output = Result<U, DurableError>> + Send>,
244                    >
245            }
246        })
247        .collect();
248
249    // Execute all tasks
250    let batch_result = executor.execute(tasks).await;
251
252    logger.debug(
253        &format!(
254            "Map operation completed: {} succeeded, {} failed",
255            batch_result.success_count(),
256            batch_result.failure_count()
257        ),
258        &log_info,
259    );
260
261    // Checkpoint the result
262    let serdes = JsonSerDes::<BatchResult<U>>::new();
263    let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
264
265    // Only checkpoint if not suspended
266    if batch_result.completion_reason != CompletionReason::Suspended {
267        let serialized =
268            serdes
269                .serialize(&batch_result, &serdes_ctx)
270                .map_err(|e| DurableError::SerDes {
271                    message: format!("Failed to serialize map result: {}", e),
272                })?;
273
274        let succeed_update = create_succeed_update(op_id, Some(serialized));
275        state.create_checkpoint(succeed_update, true).await?;
276
277        // Mark parent as done
278        state.mark_parent_done(&op_id.operation_id).await;
279        span.record("status", "succeeded");
280    } else {
281        span.record("status", "suspended");
282    }
283
284    Ok(batch_result)
285}
286
287/// Batches items according to the ItemBatcher configuration.
288///
289/// This function delegates to `ItemBatcher::batch()` which respects both
290/// item count and byte size limits.
291fn batch_items<T: Serialize + Clone>(items: &[T], batcher: &ItemBatcher) -> Vec<(usize, Vec<T>)> {
292    batcher.batch(items)
293}
294
295/// Creates a Start operation update for map operation.
296fn create_start_update(op_id: &OperationIdentifier) -> OperationUpdate {
297    op_id.apply_to(OperationUpdate::start(
298        &op_id.operation_id,
299        OperationType::Context,
300    ))
301}
302
303/// Creates a Succeed operation update for map operation.
304fn create_succeed_update(op_id: &OperationIdentifier, result: Option<String>) -> OperationUpdate {
305    op_id.apply_to(OperationUpdate::succeed(
306        &op_id.operation_id,
307        OperationType::Context,
308        result,
309    ))
310}
311
312/// Creates a Fail operation update for map operation.
313#[allow(dead_code)]
314fn create_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
315    op_id.apply_to(OperationUpdate::fail(
316        &op_id.operation_id,
317        OperationType::Context,
318        error,
319    ))
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
326    use crate::context::TracingLogger;
327    use crate::lambda::InitialExecutionState;
328
329    fn create_mock_client() -> SharedDurableServiceClient {
330        Arc::new(
331            MockDurableServiceClient::new()
332                .with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
333                .with_checkpoint_response(Ok(CheckpointResponse::new("token-2")))
334                .with_checkpoint_response(Ok(CheckpointResponse::new("token-3")))
335                .with_checkpoint_response(Ok(CheckpointResponse::new("token-4")))
336                .with_checkpoint_response(Ok(CheckpointResponse::new("token-5"))),
337        )
338    }
339
340    fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
341        Arc::new(ExecutionState::new(
342            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
343            "initial-token",
344            InitialExecutionState::new(),
345            client,
346        ))
347    }
348
349    fn create_test_op_id() -> OperationIdentifier {
350        OperationIdentifier::new(
351            "test-map-123",
352            Some("parent-op".to_string()),
353            Some("test-map".to_string()),
354        )
355    }
356
357    fn create_test_logger() -> Arc<dyn Logger> {
358        Arc::new(TracingLogger)
359    }
360
361    fn create_test_config() -> MapConfig {
362        MapConfig::default()
363    }
364
365    fn create_test_parent_ctx(state: Arc<ExecutionState>) -> DurableContext {
366        DurableContext::new(state)
367    }
368
369    #[tokio::test]
370    async fn test_map_handler_empty_collection() {
371        let client = create_mock_client();
372        let state = create_test_state(client);
373        let op_id = create_test_op_id();
374        let config = create_test_config();
375        let logger = create_test_logger();
376        let parent_ctx = create_test_parent_ctx(state.clone());
377
378        let items: Vec<i32> = vec![];
379        let result = map_handler(
380            items,
381            |_ctx, item: i32, _idx| async move { Ok(item * 2) },
382            &state,
383            &op_id,
384            &parent_ctx,
385            &config,
386            &logger,
387        )
388        .await;
389
390        assert!(result.is_ok());
391        let batch_result = result.unwrap();
392        assert!(batch_result.items.is_empty());
393        assert_eq!(
394            batch_result.completion_reason,
395            CompletionReason::AllCompleted
396        );
397    }
398
399    #[tokio::test]
400    async fn test_map_handler_single_item() {
401        let client = create_mock_client();
402        let state = create_test_state(client);
403        let op_id = create_test_op_id();
404        let config = create_test_config();
405        let logger = create_test_logger();
406        let parent_ctx = create_test_parent_ctx(state.clone());
407
408        let items = vec![5];
409        let result = map_handler(
410            items,
411            |_ctx, item: i32, _idx| async move { Ok(item * 2) },
412            &state,
413            &op_id,
414            &parent_ctx,
415            &config,
416            &logger,
417        )
418        .await;
419
420        assert!(result.is_ok());
421        let batch_result = result.unwrap();
422        assert_eq!(batch_result.total_count(), 1);
423        assert_eq!(batch_result.success_count(), 1);
424        assert!(batch_result.all_succeeded());
425    }
426
427    #[tokio::test]
428    async fn test_map_handler_multiple_items() {
429        let client = Arc::new(MockDurableServiceClient::new().with_checkpoint_responses(10));
430        let state = create_test_state(client);
431        let op_id = create_test_op_id();
432        let config = create_test_config();
433        let logger = create_test_logger();
434        let parent_ctx = create_test_parent_ctx(state.clone());
435
436        let items = vec![1, 2, 3];
437        let result = map_handler(
438            items,
439            |_ctx, item: i32, _idx| async move { Ok(item * 10) },
440            &state,
441            &op_id,
442            &parent_ctx,
443            &config,
444            &logger,
445        )
446        .await;
447
448        assert!(result.is_ok());
449        let batch_result = result.unwrap();
450        assert_eq!(batch_result.total_count(), 3);
451        assert_eq!(batch_result.success_count(), 3);
452    }
453
454    #[tokio::test]
455    async fn test_map_handler_with_concurrency_limit() {
456        let client = Arc::new(MockDurableServiceClient::new().with_checkpoint_responses(10));
457        let state = create_test_state(client);
458        let op_id = create_test_op_id();
459        let config = MapConfig {
460            max_concurrency: Some(2),
461            ..Default::default()
462        };
463        let logger = create_test_logger();
464        let parent_ctx = create_test_parent_ctx(state.clone());
465
466        let items = vec![1, 2, 3, 4, 5];
467        let result = map_handler(
468            items,
469            |_ctx, item: i32, _idx| async move { Ok(item) },
470            &state,
471            &op_id,
472            &parent_ctx,
473            &config,
474            &logger,
475        )
476        .await;
477
478        assert!(result.is_ok());
479        let batch_result = result.unwrap();
480        assert_eq!(batch_result.total_count(), 5);
481    }
482
483    #[test]
484    fn test_batch_items_no_batching() {
485        let items = vec![1, 2, 3, 4, 5];
486        let batcher = ItemBatcher {
487            max_items_per_batch: 10,
488            max_bytes_per_batch: 1024,
489        };
490
491        let batches = batch_items(&items, &batcher);
492        assert_eq!(batches.len(), 1);
493        assert_eq!(batches[0].0, 0);
494        assert_eq!(batches[0].1, vec![1, 2, 3, 4, 5]);
495    }
496
497    #[test]
498    fn test_batch_items_with_batching() {
499        let items = vec![1, 2, 3, 4, 5];
500        let batcher = ItemBatcher {
501            max_items_per_batch: 2,
502            max_bytes_per_batch: 1024,
503        };
504
505        let batches = batch_items(&items, &batcher);
506        assert_eq!(batches.len(), 3);
507        assert_eq!(batches[0].1, vec![1, 2]);
508        assert_eq!(batches[1].1, vec![3, 4]);
509        assert_eq!(batches[2].1, vec![5]);
510    }
511
512    #[test]
513    fn test_batch_items_empty() {
514        let items: Vec<i32> = vec![];
515        let batcher = ItemBatcher::default();
516
517        let batches = batch_items(&items, &batcher);
518        assert!(batches.is_empty());
519    }
520
521    #[test]
522    fn test_create_succeed_update() {
523        let op_id = OperationIdentifier::new(
524            "op-123",
525            Some("parent-456".to_string()),
526            Some("my-map".to_string()),
527        );
528        let update = create_succeed_update(&op_id, Some("result".to_string()));
529
530        assert_eq!(update.operation_id, "op-123");
531        assert_eq!(update.operation_type, OperationType::Context);
532        assert_eq!(update.result, Some("result".to_string()));
533        assert_eq!(update.parent_id, Some("parent-456".to_string()));
534        assert_eq!(update.name, Some("my-map".to_string()));
535    }
536}