1use std::sync::Arc;
7
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::concurrency::{BatchResult, CompletionReason, ConcurrentExecutor};
11use crate::config::ChildConfig;
12use crate::config::ParallelConfig;
13use crate::context::{create_operation_span, DurableContext, LogInfo, Logger, OperationIdentifier};
14use crate::error::{DurableError, ErrorObject};
15use crate::handlers::child::child_handler;
16use crate::operation::{OperationType, OperationUpdate};
17use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
18use crate::state::ExecutionState;
19
20pub async fn parallel_handler<T, F, Fut>(
40 branches: Vec<F>,
41 state: &Arc<ExecutionState>,
42 op_id: &OperationIdentifier,
43 parent_ctx: &DurableContext,
44 config: &ParallelConfig,
45 logger: &Arc<dyn Logger>,
46) -> Result<BatchResult<T>, DurableError>
47where
48 T: Serialize + DeserializeOwned + Send + 'static,
49 F: FnOnce(DurableContext) -> Fut + Send + 'static,
50 Fut: std::future::Future<Output = Result<T, DurableError>> + Send + 'static,
51{
52 let span = create_operation_span("parallel", op_id, state.durable_execution_arn());
55 let _guard = span.enter();
56
57 let mut log_info =
58 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
59 if let Some(ref parent_id) = op_id.parent_id {
60 log_info = log_info.with_parent_id(parent_id);
61 }
62
63 logger.debug(
64 &format!(
65 "Starting parallel operation: {} with {} branches",
66 op_id,
67 branches.len()
68 ),
69 &log_info,
70 );
71
72 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
74
75 if checkpoint_result.is_existent() {
76 if let Some(op_type) = checkpoint_result.operation_type() {
78 if op_type != OperationType::Context {
79 span.record("status", "non_deterministic");
80 return Err(DurableError::NonDeterministic {
81 message: format!(
82 "Expected Context operation but found {:?} at operation_id {}",
83 op_type, op_id.operation_id
84 ),
85 operation_id: Some(op_id.operation_id.clone()),
86 });
87 }
88 }
89
90 if checkpoint_result.is_succeeded() {
92 logger.debug(
93 &format!("Replaying succeeded parallel operation: {}", op_id),
94 &log_info,
95 );
96
97 if let Some(result_str) = checkpoint_result.result() {
98 let serdes = JsonSerDes::<BatchResult<T>>::new();
99 let serdes_ctx =
100 SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
101 let result = serdes.deserialize(result_str, &serdes_ctx).map_err(|e| {
102 DurableError::SerDes {
103 message: format!("Failed to deserialize parallel result: {}", e),
104 }
105 })?;
106
107 state.track_replay(&op_id.operation_id).await;
108 span.record("status", "replayed_succeeded");
109 return Ok(result);
110 }
111 }
112
113 if checkpoint_result.is_failed() {
115 logger.debug(
116 &format!("Replaying failed parallel operation: {}", op_id),
117 &log_info,
118 );
119
120 state.track_replay(&op_id.operation_id).await;
121 span.record("status", "replayed_failed");
122
123 if let Some(error) = checkpoint_result.error() {
124 return Err(DurableError::UserCode {
125 message: error.error_message.clone(),
126 error_type: error.error_type.clone(),
127 stack_trace: error.stack_trace.clone(),
128 });
129 } else {
130 return Err(DurableError::execution(
131 "Parallel operation failed with unknown error",
132 ));
133 }
134 }
135
136 if checkpoint_result.is_terminal() {
138 state.track_replay(&op_id.operation_id).await;
139 span.record("status", "replayed_terminal");
140
141 let status = checkpoint_result.status().unwrap();
142 return Err(DurableError::execution(format!(
143 "Parallel operation was {}",
144 status
145 )));
146 }
147 }
148
149 if branches.is_empty() {
151 logger.debug("Parallel operation with no branches", &log_info);
152 let result = BatchResult::empty();
153
154 let serdes = JsonSerDes::<BatchResult<T>>::new();
156 let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
157 let serialized =
158 serdes
159 .serialize(&result, &serdes_ctx)
160 .map_err(|e| DurableError::SerDes {
161 message: format!("Failed to serialize parallel result: {}", e),
162 })?;
163
164 let succeed_update = create_succeed_update(op_id, Some(serialized));
165 state.create_checkpoint(succeed_update, true).await?;
166
167 span.record("status", "succeeded_empty");
168 return Ok(result);
169 }
170
171 let parallel_ctx = parent_ctx.create_child_context(&op_id.operation_id);
173
174 let start_update = create_start_update(op_id);
177 state.create_checkpoint(start_update, true).await?;
178
179 let total_branches = branches.len();
181 let executor = ConcurrentExecutor::new(
182 total_branches,
183 config.max_concurrency,
184 config.completion_config.clone(),
185 );
186
187 let tasks: Vec<_> = branches
189 .into_iter()
190 .enumerate()
191 .map(|(index, branch)| {
192 let parallel_ctx = parallel_ctx.clone();
193 let state = state.clone();
194 let logger = logger.clone();
195 let op_id = op_id.clone();
196
197 move |_task_idx: usize| {
198 let parallel_ctx = parallel_ctx.clone();
199 let state = state.clone();
200 let logger = logger.clone();
201 let op_id = op_id.clone();
202
203 Box::pin(async move {
204 let child_op_id = OperationIdentifier::new(
206 parallel_ctx.next_operation_id(),
207 Some(op_id.operation_id.clone()),
208 Some(format!("parallel-branch-{}", index)),
209 );
210
211 child_handler(
213 branch,
214 &state,
215 &child_op_id,
216 ¶llel_ctx,
217 &ChildConfig::default(),
218 &logger,
219 )
220 .await
221 })
222 as std::pin::Pin<
223 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send>,
224 >
225 }
226 })
227 .collect();
228
229 let batch_result = executor.execute(tasks).await;
231
232 logger.debug(
233 &format!(
234 "Parallel operation completed: {} succeeded, {} failed",
235 batch_result.success_count(),
236 batch_result.failure_count()
237 ),
238 &log_info,
239 );
240
241 if batch_result.completion_reason != CompletionReason::Suspended {
243 let serdes = JsonSerDes::<BatchResult<T>>::new();
244 let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
245 let serialized =
246 serdes
247 .serialize(&batch_result, &serdes_ctx)
248 .map_err(|e| DurableError::SerDes {
249 message: format!("Failed to serialize parallel result: {}", e),
250 })?;
251
252 let succeed_update = create_succeed_update(op_id, Some(serialized));
253 state.create_checkpoint(succeed_update, true).await?;
254
255 state.mark_parent_done(&op_id.operation_id).await;
257 span.record("status", "succeeded");
258 } else {
259 span.record("status", "suspended");
260 }
261
262 Ok(batch_result)
263}
264
265fn create_start_update(op_id: &OperationIdentifier) -> OperationUpdate {
267 op_id.apply_to(OperationUpdate::start(
268 &op_id.operation_id,
269 OperationType::Context,
270 ))
271}
272
273fn create_succeed_update(op_id: &OperationIdentifier, result: Option<String>) -> OperationUpdate {
275 op_id.apply_to(OperationUpdate::succeed(
276 &op_id.operation_id,
277 OperationType::Context,
278 result,
279 ))
280}
281
282#[allow(dead_code)]
284fn create_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
285 op_id.apply_to(OperationUpdate::fail(
286 &op_id.operation_id,
287 OperationType::Context,
288 error,
289 ))
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::client::{MockDurableServiceClient, SharedDurableServiceClient};
296 use crate::context::TracingLogger;
297 use crate::lambda::InitialExecutionState;
298
299 fn create_mock_client() -> SharedDurableServiceClient {
300 Arc::new(MockDurableServiceClient::new().with_checkpoint_responses(10))
301 }
302
303 fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
304 Arc::new(ExecutionState::new(
305 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
306 "initial-token",
307 InitialExecutionState::new(),
308 client,
309 ))
310 }
311
312 fn create_test_op_id() -> OperationIdentifier {
313 OperationIdentifier::new(
314 "test-parallel-123",
315 Some("parent-op".to_string()),
316 Some("test-parallel".to_string()),
317 )
318 }
319
320 fn create_test_logger() -> Arc<dyn Logger> {
321 Arc::new(TracingLogger)
322 }
323
324 fn create_test_config() -> ParallelConfig {
325 ParallelConfig::default()
326 }
327
328 fn create_test_parent_ctx(state: Arc<ExecutionState>) -> DurableContext {
329 DurableContext::new(state)
330 }
331
332 #[tokio::test]
333 async fn test_parallel_handler_empty_branches() {
334 let client = create_mock_client();
335 let state = create_test_state(client);
336 let op_id = create_test_op_id();
337 let config = create_test_config();
338 let logger = create_test_logger();
339 let parent_ctx = create_test_parent_ctx(state.clone());
340
341 let branches: Vec<
342 Box<
343 dyn FnOnce(
344 DurableContext,
345 ) -> std::pin::Pin<
346 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
347 > + Send,
348 >,
349 > = vec![];
350 let result =
351 parallel_handler(branches, &state, &op_id, &parent_ctx, &config, &logger).await;
352
353 assert!(result.is_ok());
354 let batch_result = result.unwrap();
355 assert!(batch_result.items.is_empty());
356 assert_eq!(
357 batch_result.completion_reason,
358 CompletionReason::AllCompleted
359 );
360 }
361
362 #[tokio::test]
363 async fn test_parallel_handler_single_branch() {
364 let client = create_mock_client();
365 let state = create_test_state(client);
366 let op_id = create_test_op_id();
367 let config = create_test_config();
368 let logger = create_test_logger();
369 let parent_ctx = create_test_parent_ctx(state.clone());
370
371 let branches: Vec<
372 Box<
373 dyn FnOnce(
374 DurableContext,
375 ) -> std::pin::Pin<
376 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
377 > + Send,
378 >,
379 > = vec![Box::new(|_ctx| {
380 Box::pin(async { Ok(42) })
381 as std::pin::Pin<
382 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
383 >
384 })];
385
386 let result =
387 parallel_handler(branches, &state, &op_id, &parent_ctx, &config, &logger).await;
388
389 assert!(result.is_ok());
390 let batch_result = result.unwrap();
391 assert_eq!(batch_result.total_count(), 1);
392 assert_eq!(batch_result.success_count(), 1);
393 }
394
395 #[tokio::test]
396 async fn test_parallel_handler_multiple_branches() {
397 let client = create_mock_client();
398 let state = create_test_state(client);
399 let op_id = create_test_op_id();
400 let config = create_test_config();
401 let logger = create_test_logger();
402 let parent_ctx = create_test_parent_ctx(state.clone());
403
404 let branches: Vec<
405 Box<
406 dyn FnOnce(
407 DurableContext,
408 ) -> std::pin::Pin<
409 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
410 > + Send,
411 >,
412 > = vec![
413 Box::new(|_ctx| {
414 Box::pin(async { Ok(1) })
415 as std::pin::Pin<
416 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
417 >
418 }),
419 Box::new(|_ctx| {
420 Box::pin(async { Ok(2) })
421 as std::pin::Pin<
422 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
423 >
424 }),
425 Box::new(|_ctx| {
426 Box::pin(async { Ok(3) })
427 as std::pin::Pin<
428 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
429 >
430 }),
431 ];
432
433 let result =
434 parallel_handler(branches, &state, &op_id, &parent_ctx, &config, &logger).await;
435
436 assert!(result.is_ok());
437 let batch_result = result.unwrap();
438 assert_eq!(batch_result.total_count(), 3);
439 assert_eq!(batch_result.success_count(), 3);
440 }
441
442 #[tokio::test]
443 async fn test_parallel_handler_with_concurrency_limit() {
444 let client = create_mock_client();
445 let state = create_test_state(client);
446 let op_id = create_test_op_id();
447 let config = ParallelConfig {
448 max_concurrency: Some(2),
449 ..Default::default()
450 };
451 let logger = create_test_logger();
452 let parent_ctx = create_test_parent_ctx(state.clone());
453
454 let branches: Vec<
455 Box<
456 dyn FnOnce(
457 DurableContext,
458 ) -> std::pin::Pin<
459 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
460 > + Send,
461 >,
462 > = vec![
463 Box::new(|_ctx| {
464 Box::pin(async { Ok(1) })
465 as std::pin::Pin<
466 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
467 >
468 }),
469 Box::new(|_ctx| {
470 Box::pin(async { Ok(2) })
471 as std::pin::Pin<
472 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
473 >
474 }),
475 Box::new(|_ctx| {
476 Box::pin(async { Ok(3) })
477 as std::pin::Pin<
478 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
479 >
480 }),
481 Box::new(|_ctx| {
482 Box::pin(async { Ok(4) })
483 as std::pin::Pin<
484 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
485 >
486 }),
487 ];
488
489 let result =
490 parallel_handler(branches, &state, &op_id, &parent_ctx, &config, &logger).await;
491
492 assert!(result.is_ok());
493 let batch_result = result.unwrap();
494 assert_eq!(batch_result.total_count(), 4);
495 }
496
497 #[test]
498 fn test_create_succeed_update() {
499 let op_id = OperationIdentifier::new(
500 "op-123",
501 Some("parent-456".to_string()),
502 Some("my-parallel".to_string()),
503 );
504 let update = create_succeed_update(&op_id, Some("result".to_string()));
505
506 assert_eq!(update.operation_id, "op-123");
507 assert_eq!(update.operation_type, OperationType::Context);
508 assert_eq!(update.result, Some("result".to_string()));
509 assert_eq!(update.parent_id, Some("parent-456".to_string()));
510 assert_eq!(update.name, Some("my-parallel".to_string()));
511 }
512}