Skip to main content

durable_execution_sdk/handlers/
parallel.rs

1//! Parallel operation handler for the AWS Durable Execution SDK.
2//!
3//! This module implements the parallel handler which executes multiple
4//! independent operations concurrently.
5
6use std::sync::Arc;
7
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::concurrency::{BatchResult, CompletionReason, ConcurrentExecutor};
11use crate::config::ChildConfig;
12use crate::config::ParallelConfig;
13use crate::context::{create_operation_span, DurableContext, LogInfo, Logger, OperationIdentifier};
14use crate::error::{DurableError, ErrorObject};
15use crate::handlers::child::child_handler;
16use crate::operation::{OperationType, OperationUpdate};
17use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
18use crate::state::ExecutionState;
19
20/// Executes multiple operations in parallel.
21///
22/// This handler implements the parallel semantics:
23/// - Creates a child context for each branch
24/// - Uses ConcurrentExecutor for parallel execution
25/// - Returns BatchResult with results for all branches
26///
27/// # Arguments
28///
29/// * `branches` - The list of functions to execute in parallel
30/// * `state` - The execution state for checkpointing
31/// * `op_id` - The operation identifier for the parallel operation
32/// * `parent_ctx` - The parent DurableContext
33/// * `config` - Parallel configuration
34/// * `logger` - Logger for structured logging
35///
36/// # Returns
37///
38/// A `BatchResult` containing results for all branches.
39pub async fn parallel_handler<T, F, Fut>(
40    branches: Vec<F>,
41    state: &Arc<ExecutionState>,
42    op_id: &OperationIdentifier,
43    parent_ctx: &DurableContext,
44    config: &ParallelConfig,
45    logger: &Arc<dyn Logger>,
46) -> Result<BatchResult<T>, DurableError>
47where
48    T: Serialize + DeserializeOwned + Send + 'static,
49    F: FnOnce(DurableContext) -> Fut + Send + 'static,
50    Fut: std::future::Future<Output = Result<T, DurableError>> + Send + 'static,
51{
52    // Create tracing span for this operation
53    // Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 3.6
54    let span = create_operation_span("parallel", op_id, state.durable_execution_arn());
55    let _guard = span.enter();
56
57    let mut log_info =
58        LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
59    if let Some(ref parent_id) = op_id.parent_id {
60        log_info = log_info.with_parent_id(parent_id);
61    }
62
63    logger.debug(
64        &format!(
65            "Starting parallel operation: {} with {} branches",
66            op_id,
67            branches.len()
68        ),
69        &log_info,
70    );
71
72    // Check for existing checkpoint (replay)
73    let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
74
75    if checkpoint_result.is_existent() {
76        // Check for non-deterministic execution
77        if let Some(op_type) = checkpoint_result.operation_type() {
78            if op_type != OperationType::Context {
79                span.record("status", "non_deterministic");
80                return Err(DurableError::NonDeterministic {
81                    message: format!(
82                        "Expected Context operation but found {:?} at operation_id {}",
83                        op_type, op_id.operation_id
84                    ),
85                    operation_id: Some(op_id.operation_id.clone()),
86                });
87            }
88        }
89
90        // Handle succeeded checkpoint
91        if checkpoint_result.is_succeeded() {
92            logger.debug(
93                &format!("Replaying succeeded parallel operation: {}", op_id),
94                &log_info,
95            );
96
97            if let Some(result_str) = checkpoint_result.result() {
98                let serdes = JsonSerDes::<BatchResult<T>>::new();
99                let serdes_ctx =
100                    SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
101                let result = serdes.deserialize(result_str, &serdes_ctx).map_err(|e| {
102                    DurableError::SerDes {
103                        message: format!("Failed to deserialize parallel result: {}", e),
104                    }
105                })?;
106
107                state.track_replay(&op_id.operation_id).await;
108                span.record("status", "replayed_succeeded");
109                return Ok(result);
110            }
111        }
112
113        // Handle failed checkpoint
114        if checkpoint_result.is_failed() {
115            logger.debug(
116                &format!("Replaying failed parallel operation: {}", op_id),
117                &log_info,
118            );
119
120            state.track_replay(&op_id.operation_id).await;
121            span.record("status", "replayed_failed");
122
123            if let Some(error) = checkpoint_result.error() {
124                return Err(DurableError::UserCode {
125                    message: error.error_message.clone(),
126                    error_type: error.error_type.clone(),
127                    stack_trace: error.stack_trace.clone(),
128                });
129            } else {
130                return Err(DurableError::execution(
131                    "Parallel operation failed with unknown error",
132                ));
133            }
134        }
135
136        // Handle other terminal states
137        if checkpoint_result.is_terminal() {
138            state.track_replay(&op_id.operation_id).await;
139            span.record("status", "replayed_terminal");
140
141            let status = checkpoint_result.status().unwrap();
142            return Err(DurableError::execution(format!(
143                "Parallel operation was {}",
144                status
145            )));
146        }
147    }
148
149    // Handle empty branches
150    if branches.is_empty() {
151        logger.debug("Parallel operation with no branches", &log_info);
152        let result = BatchResult::empty();
153
154        // Checkpoint the empty result
155        let serdes = JsonSerDes::<BatchResult<T>>::new();
156        let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
157        let serialized =
158            serdes
159                .serialize(&result, &serdes_ctx)
160                .map_err(|e| DurableError::SerDes {
161                    message: format!("Failed to serialize parallel result: {}", e),
162                })?;
163
164        let succeed_update = create_succeed_update(op_id, Some(serialized));
165        state.create_checkpoint(succeed_update, true).await?;
166
167        span.record("status", "succeeded_empty");
168        return Ok(result);
169    }
170
171    // Create the parallel context (child of parent)
172    let parallel_ctx = parent_ctx.create_child_context(&op_id.operation_id);
173
174    // Checkpoint START for the parallel operation before spawning children
175    // This ensures the parent operation exists when children reference it
176    let start_update = create_start_update(op_id);
177    state.create_checkpoint(start_update, true).await?;
178
179    // Create the executor
180    let total_branches = branches.len();
181    let executor = ConcurrentExecutor::new(
182        total_branches,
183        config.max_concurrency,
184        config.completion_config.clone(),
185    );
186
187    // Build task closures
188    let tasks: Vec<_> = branches
189        .into_iter()
190        .enumerate()
191        .map(|(index, branch)| {
192            let parallel_ctx = parallel_ctx.clone();
193            let state = state.clone();
194            let logger = logger.clone();
195            let op_id = op_id.clone();
196
197            move |_task_idx: usize| {
198                let parallel_ctx = parallel_ctx.clone();
199                let state = state.clone();
200                let logger = logger.clone();
201                let op_id = op_id.clone();
202
203                Box::pin(async move {
204                    // Create child operation ID for this branch
205                    let child_op_id = OperationIdentifier::new(
206                        parallel_ctx.next_operation_id(),
207                        Some(op_id.operation_id.clone()),
208                        Some(format!("parallel-branch-{}", index)),
209                    );
210
211                    // Execute in child context
212                    child_handler(
213                        branch,
214                        &state,
215                        &child_op_id,
216                        &parallel_ctx,
217                        &ChildConfig::default(),
218                        &logger,
219                    )
220                    .await
221                })
222                    as std::pin::Pin<
223                        Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send>,
224                    >
225            }
226        })
227        .collect();
228
229    // Execute all branches
230    let batch_result = executor.execute(tasks).await;
231
232    logger.debug(
233        &format!(
234            "Parallel operation completed: {} succeeded, {} failed",
235            batch_result.success_count(),
236            batch_result.failure_count()
237        ),
238        &log_info,
239    );
240
241    // Checkpoint the result (only if not suspended)
242    if batch_result.completion_reason != CompletionReason::Suspended {
243        let serdes = JsonSerDes::<BatchResult<T>>::new();
244        let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
245        let serialized =
246            serdes
247                .serialize(&batch_result, &serdes_ctx)
248                .map_err(|e| DurableError::SerDes {
249                    message: format!("Failed to serialize parallel result: {}", e),
250                })?;
251
252        let succeed_update = create_succeed_update(op_id, Some(serialized));
253        state.create_checkpoint(succeed_update, true).await?;
254
255        // Mark parent as done
256        state.mark_parent_done(&op_id.operation_id).await;
257        span.record("status", "succeeded");
258    } else {
259        span.record("status", "suspended");
260    }
261
262    Ok(batch_result)
263}
264
265/// Creates a Start operation update for parallel operation.
266fn create_start_update(op_id: &OperationIdentifier) -> OperationUpdate {
267    op_id.apply_to(OperationUpdate::start(
268        &op_id.operation_id,
269        OperationType::Context,
270    ))
271}
272
273/// Creates a Succeed operation update for parallel operation.
274fn create_succeed_update(op_id: &OperationIdentifier, result: Option<String>) -> OperationUpdate {
275    op_id.apply_to(OperationUpdate::succeed(
276        &op_id.operation_id,
277        OperationType::Context,
278        result,
279    ))
280}
281
282/// Creates a Fail operation update for parallel operation.
283#[allow(dead_code)]
284fn create_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
285    op_id.apply_to(OperationUpdate::fail(
286        &op_id.operation_id,
287        OperationType::Context,
288        error,
289    ))
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::client::{MockDurableServiceClient, SharedDurableServiceClient};
296    use crate::context::TracingLogger;
297    use crate::lambda::InitialExecutionState;
298
299    fn create_mock_client() -> SharedDurableServiceClient {
300        Arc::new(MockDurableServiceClient::new().with_checkpoint_responses(10))
301    }
302
303    fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
304        Arc::new(ExecutionState::new(
305            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
306            "initial-token",
307            InitialExecutionState::new(),
308            client,
309        ))
310    }
311
312    fn create_test_op_id() -> OperationIdentifier {
313        OperationIdentifier::new(
314            "test-parallel-123",
315            Some("parent-op".to_string()),
316            Some("test-parallel".to_string()),
317        )
318    }
319
320    fn create_test_logger() -> Arc<dyn Logger> {
321        Arc::new(TracingLogger)
322    }
323
324    fn create_test_config() -> ParallelConfig {
325        ParallelConfig::default()
326    }
327
328    fn create_test_parent_ctx(state: Arc<ExecutionState>) -> DurableContext {
329        DurableContext::new(state)
330    }
331
332    #[tokio::test]
333    async fn test_parallel_handler_empty_branches() {
334        let client = create_mock_client();
335        let state = create_test_state(client);
336        let op_id = create_test_op_id();
337        let config = create_test_config();
338        let logger = create_test_logger();
339        let parent_ctx = create_test_parent_ctx(state.clone());
340
341        let branches: Vec<
342            Box<
343                dyn FnOnce(
344                        DurableContext,
345                    ) -> std::pin::Pin<
346                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
347                    > + Send,
348            >,
349        > = vec![];
350        let result =
351            parallel_handler(branches, &state, &op_id, &parent_ctx, &config, &logger).await;
352
353        assert!(result.is_ok());
354        let batch_result = result.unwrap();
355        assert!(batch_result.items.is_empty());
356        assert_eq!(
357            batch_result.completion_reason,
358            CompletionReason::AllCompleted
359        );
360    }
361
362    #[tokio::test]
363    async fn test_parallel_handler_single_branch() {
364        let client = create_mock_client();
365        let state = create_test_state(client);
366        let op_id = create_test_op_id();
367        let config = create_test_config();
368        let logger = create_test_logger();
369        let parent_ctx = create_test_parent_ctx(state.clone());
370
371        let branches: Vec<
372            Box<
373                dyn FnOnce(
374                        DurableContext,
375                    ) -> std::pin::Pin<
376                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
377                    > + Send,
378            >,
379        > = vec![Box::new(|_ctx| {
380            Box::pin(async { Ok(42) })
381                as std::pin::Pin<
382                    Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
383                >
384        })];
385
386        let result =
387            parallel_handler(branches, &state, &op_id, &parent_ctx, &config, &logger).await;
388
389        assert!(result.is_ok());
390        let batch_result = result.unwrap();
391        assert_eq!(batch_result.total_count(), 1);
392        assert_eq!(batch_result.success_count(), 1);
393    }
394
395    #[tokio::test]
396    async fn test_parallel_handler_multiple_branches() {
397        let client = create_mock_client();
398        let state = create_test_state(client);
399        let op_id = create_test_op_id();
400        let config = create_test_config();
401        let logger = create_test_logger();
402        let parent_ctx = create_test_parent_ctx(state.clone());
403
404        let branches: Vec<
405            Box<
406                dyn FnOnce(
407                        DurableContext,
408                    ) -> std::pin::Pin<
409                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
410                    > + Send,
411            >,
412        > = vec![
413            Box::new(|_ctx| {
414                Box::pin(async { Ok(1) })
415                    as std::pin::Pin<
416                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
417                    >
418            }),
419            Box::new(|_ctx| {
420                Box::pin(async { Ok(2) })
421                    as std::pin::Pin<
422                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
423                    >
424            }),
425            Box::new(|_ctx| {
426                Box::pin(async { Ok(3) })
427                    as std::pin::Pin<
428                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
429                    >
430            }),
431        ];
432
433        let result =
434            parallel_handler(branches, &state, &op_id, &parent_ctx, &config, &logger).await;
435
436        assert!(result.is_ok());
437        let batch_result = result.unwrap();
438        assert_eq!(batch_result.total_count(), 3);
439        assert_eq!(batch_result.success_count(), 3);
440    }
441
442    #[tokio::test]
443    async fn test_parallel_handler_with_concurrency_limit() {
444        let client = create_mock_client();
445        let state = create_test_state(client);
446        let op_id = create_test_op_id();
447        let config = ParallelConfig {
448            max_concurrency: Some(2),
449            ..Default::default()
450        };
451        let logger = create_test_logger();
452        let parent_ctx = create_test_parent_ctx(state.clone());
453
454        let branches: Vec<
455            Box<
456                dyn FnOnce(
457                        DurableContext,
458                    ) -> std::pin::Pin<
459                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
460                    > + Send,
461            >,
462        > = vec![
463            Box::new(|_ctx| {
464                Box::pin(async { Ok(1) })
465                    as std::pin::Pin<
466                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
467                    >
468            }),
469            Box::new(|_ctx| {
470                Box::pin(async { Ok(2) })
471                    as std::pin::Pin<
472                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
473                    >
474            }),
475            Box::new(|_ctx| {
476                Box::pin(async { Ok(3) })
477                    as std::pin::Pin<
478                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
479                    >
480            }),
481            Box::new(|_ctx| {
482                Box::pin(async { Ok(4) })
483                    as std::pin::Pin<
484                        Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
485                    >
486            }),
487        ];
488
489        let result =
490            parallel_handler(branches, &state, &op_id, &parent_ctx, &config, &logger).await;
491
492        assert!(result.is_ok());
493        let batch_result = result.unwrap();
494        assert_eq!(batch_result.total_count(), 4);
495    }
496
497    #[test]
498    fn test_create_succeed_update() {
499        let op_id = OperationIdentifier::new(
500            "op-123",
501            Some("parent-456".to_string()),
502            Some("my-parallel".to_string()),
503        );
504        let update = create_succeed_update(&op_id, Some("result".to_string()));
505
506        assert_eq!(update.operation_id, "op-123");
507        assert_eq!(update.operation_type, OperationType::Context);
508        assert_eq!(update.result, Some("result".to_string()));
509        assert_eq!(update.parent_id, Some("parent-456".to_string()));
510        assert_eq!(update.name, Some("my-parallel".to_string()));
511    }
512}