use std::sync::Arc;
use serde::{de::DeserializeOwned, Serialize};
use crate::config::ChildConfig;
use crate::context::{DurableContext, LogInfo, Logger, OperationIdentifier};
use crate::error::{DurableError, ErrorObject};
use crate::operation::{OperationType, OperationUpdate};
use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
use crate::state::ExecutionState;
pub async fn child_handler<T, F, Fut>(
func: F,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
parent_ctx: &DurableContext,
config: &ChildConfig,
logger: &Arc<dyn Logger>,
) -> Result<T, DurableError>
where
T: Serialize + DeserializeOwned + Send,
F: FnOnce(DurableContext) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, DurableError>> + Send,
{
let mut log_info =
LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
if let Some(ref parent_id) = op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
logger.debug(&format!("Starting child context: {}", op_id), &log_info);
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
if checkpoint_result.is_existent() {
if let Some(op_type) = checkpoint_result.operation_type() {
if op_type != OperationType::Context {
return Err(DurableError::NonDeterministic {
message: format!(
"Expected Context operation but found {:?} at operation_id {}",
op_type, op_id.operation_id
),
operation_id: Some(op_id.operation_id.clone()),
});
}
}
if checkpoint_result.is_succeeded() {
logger.debug(
&format!("Replaying succeeded child context: {}", op_id),
&log_info,
);
if config.replay_children && state.has_replay_children(&op_id.operation_id).await {
logger.debug(
&format!(
"ReplayChildren enabled, replaying child operations for: {}",
op_id
),
&log_info,
);
let child_ctx = parent_ctx.create_child_context(&op_id.operation_id);
let result = func(child_ctx).await;
state.track_replay(&op_id.operation_id).await;
return result;
}
if let Some(result_str) = checkpoint_result.result() {
let serdes = JsonSerDes::<T>::new();
let serdes_ctx =
SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
let result = serdes.deserialize(result_str, &serdes_ctx).map_err(|e| {
DurableError::SerDes {
message: format!("Failed to deserialize child context result: {}", e),
}
})?;
state.track_replay(&op_id.operation_id).await;
return Ok(result);
}
}
if checkpoint_result.is_failed() {
logger.debug(
&format!("Replaying failed child context: {}", op_id),
&log_info,
);
state.track_replay(&op_id.operation_id).await;
if let Some(error) = checkpoint_result.error() {
return Err(DurableError::UserCode {
message: error.error_message.clone(),
error_type: error.error_type.clone(),
stack_trace: error.stack_trace.clone(),
});
} else {
return Err(DurableError::execution(
"Child context failed with unknown error",
));
}
}
if checkpoint_result.is_terminal() {
state.track_replay(&op_id.operation_id).await;
let status = checkpoint_result.status().unwrap();
return Err(DurableError::execution(format!(
"Child context was {}",
status
)));
}
} else {
let start_update = create_start_update(op_id, config);
state.create_checkpoint(start_update, true).await?;
}
let child_ctx = parent_ctx.create_child_context(&op_id.operation_id);
logger.debug(&format!("Executing child context: {}", op_id), &log_info);
let result = func(child_ctx).await;
let serdes = JsonSerDes::<T>::new();
let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
match result {
Ok(value) => {
let serialized =
serdes
.serialize(&value, &serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to serialize child context result: {}", e),
})?;
const SUMMARY_THRESHOLD: usize = 256 * 1024;
let result_to_store = match &config.summary_generator {
Some(generator) if serialized.len() > SUMMARY_THRESHOLD => {
logger.debug(
&format!(
"Result size ({} bytes) exceeds 256KB threshold, generating summary",
serialized.len()
),
&log_info,
);
generator(&serialized)
}
_ => serialized,
};
let succeed_update = create_succeed_update(op_id, Some(result_to_store));
state.create_checkpoint(succeed_update, true).await?;
state.mark_parent_done(&op_id.operation_id).await;
logger.debug("Child context completed successfully", &log_info);
Ok(value)
}
Err(error) => {
if error.is_suspend() {
return Err(error);
}
let error = match &config.error_mapper {
Some(mapper) => mapper(error),
None => error,
};
let error_obj = ErrorObject::from(&error);
let fail_update = create_fail_update(op_id, error_obj);
state.create_checkpoint(fail_update, true).await?;
state.mark_parent_done(&op_id.operation_id).await;
logger.error(&format!("Child context failed: {}", error), &log_info);
Err(error)
}
}
}
fn create_start_update(op_id: &OperationIdentifier, config: &ChildConfig) -> OperationUpdate {
let mut update = op_id.apply_to(OperationUpdate::start(
&op_id.operation_id,
OperationType::Context,
));
if config.replay_children {
update = update.with_context_options(Some(true));
}
update
}
fn create_succeed_update(op_id: &OperationIdentifier, result: Option<String>) -> OperationUpdate {
op_id.apply_to(OperationUpdate::succeed(
&op_id.operation_id,
OperationType::Context,
result,
))
}
fn create_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
op_id.apply_to(OperationUpdate::fail(
&op_id.operation_id,
OperationType::Context,
error,
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
use crate::context::TracingLogger;
use crate::lambda::InitialExecutionState;
use crate::operation::{Operation, OperationStatus};
fn create_mock_client() -> SharedDurableServiceClient {
Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
)
}
fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
InitialExecutionState::new(),
client,
))
}
fn create_test_op_id() -> OperationIdentifier {
OperationIdentifier::new(
"test-child-123",
Some("parent-op".to_string()),
Some("test-child".to_string()),
)
}
fn create_test_logger() -> Arc<dyn Logger> {
Arc::new(TracingLogger)
}
fn create_test_config() -> ChildConfig {
ChildConfig::default()
}
fn create_test_parent_ctx(state: Arc<ExecutionState>) -> DurableContext {
DurableContext::new(state)
}
#[tokio::test]
async fn test_child_handler_success() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = create_test_config();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { Ok(42) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_child_handler_failure() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = create_test_config();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { Err(DurableError::execution("child error")) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Execution { message, .. } => {
assert!(message.contains("child error"));
}
_ => panic!("Expected Execution error"),
}
}
#[tokio::test]
async fn test_child_handler_replay_success() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-child-123", OperationType::Context);
op.status = OperationStatus::Succeeded;
op.result = Some("42".to_string());
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let config = create_test_config();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { panic!("Function should not be called during replay") },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_child_handler_replay_failure() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-child-123", OperationType::Context);
op.status = OperationStatus::Failed;
op.error = Some(ErrorObject::new("ChildError", "Previous failure"));
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let config = create_test_config();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { panic!("Function should not be called during replay") },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::UserCode { message, .. } => {
assert!(message.contains("Previous failure"));
}
_ => panic!("Expected UserCode error"),
}
}
#[tokio::test]
async fn test_child_handler_non_deterministic_detection() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-child-123", OperationType::Step);
op.status = OperationStatus::Succeeded;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let config = create_test_config();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { Ok(42) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::NonDeterministic { operation_id, .. } => {
assert_eq!(operation_id, Some("test-child-123".to_string()));
}
_ => panic!("Expected NonDeterministic error"),
}
}
#[tokio::test]
async fn test_child_handler_suspend_not_checkpointed() {
let client = Arc::new(MockDurableServiceClient::new());
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = create_test_config();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { Err(DurableError::suspend()) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Suspend { .. } => {}
_ => panic!("Expected Suspend error"),
}
}
#[test]
fn test_create_succeed_update() {
let op_id = OperationIdentifier::new(
"op-123",
Some("parent-456".to_string()),
Some("my-child".to_string()),
);
let update = create_succeed_update(&op_id, Some("result".to_string()));
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.operation_type, OperationType::Context);
assert_eq!(update.result, Some("result".to_string()));
assert_eq!(update.parent_id, Some("parent-456".to_string()));
assert_eq!(update.name, Some("my-child".to_string()));
}
#[test]
fn test_create_fail_update() {
let op_id = OperationIdentifier::new("op-123", None, None);
let error = ErrorObject::new("ChildError", "test message");
let update = create_fail_update(&op_id, error);
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.operation_type, OperationType::Context);
assert!(update.error.is_some());
assert_eq!(update.error.unwrap().error_type, "ChildError");
}
#[test]
fn test_create_start_update_without_replay_children() {
let op_id = OperationIdentifier::new(
"op-123",
Some("parent-456".to_string()),
Some("my-child".to_string()),
);
let config = ChildConfig::default();
let update = create_start_update(&op_id, &config);
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.operation_type, OperationType::Context);
assert_eq!(update.parent_id, Some("parent-456".to_string()));
assert_eq!(update.name, Some("my-child".to_string()));
assert!(update.context_options.is_none());
}
#[test]
fn test_create_start_update_with_replay_children() {
let op_id = OperationIdentifier::new(
"op-123",
Some("parent-456".to_string()),
Some("my-child".to_string()),
);
let config = ChildConfig::with_replay_children();
let update = create_start_update(&op_id, &config);
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.operation_type, OperationType::Context);
assert_eq!(update.parent_id, Some("parent-456".to_string()));
assert_eq!(update.name, Some("my-child".to_string()));
assert!(update.context_options.is_some());
assert_eq!(update.context_options.unwrap().replay_children, Some(true));
}
#[tokio::test]
async fn test_child_handler_with_replay_children_config() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = ChildConfig::with_replay_children();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { Ok(42) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_child_handler_error_mapper_applied() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let config = ChildConfig::default().set_error_mapper(Arc::new(|_err| {
DurableError::execution("mapped error message")
}));
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { Err(DurableError::execution("original error")) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Execution { message, .. } => {
assert_eq!(message, "mapped error message");
}
other => panic!("Expected Execution error, got: {:?}", other),
}
}
#[tokio::test]
async fn test_child_handler_error_mapper_skipped_for_suspend() {
let client = Arc::new(MockDurableServiceClient::new());
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let config = ChildConfig::default().set_error_mapper(Arc::new(|_err| {
DurableError::execution("should not be mapped")
}));
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { Err(DurableError::suspend()) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Suspend { .. } => {}
other => panic!("Expected Suspend error, got: {:?}", other),
}
}
#[tokio::test]
async fn test_child_handler_error_mapper_none_preserves_behavior() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = ChildConfig::default(); let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let result: Result<i32, DurableError> = child_handler(
|_ctx| async { Err(DurableError::execution("unchanged error")) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Execution { message, .. } => {
assert_eq!(message, "unchanged error");
}
other => panic!("Expected Execution error, got: {:?}", other),
}
}
#[tokio::test]
async fn test_child_handler_summary_generator_invoked_when_over_256kb() {
use std::sync::atomic::{AtomicBool, Ordering};
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let generator_called = Arc::new(AtomicBool::new(false));
let generator_called_clone = generator_called.clone();
let config =
ChildConfig::default().set_summary_generator(Arc::new(move |_serialized: &str| {
generator_called_clone.store(true, Ordering::SeqCst);
r#"{"summary":"large result"}"#.to_string()
}));
let large_string = "x".repeat(300_000);
let result: Result<String, DurableError> = child_handler(
|_ctx| {
let s = large_string.clone();
async move { Ok(s) }
},
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert!(
generator_called.load(Ordering::SeqCst),
"summary_generator should have been invoked for result > 256KB"
);
}
#[tokio::test]
async fn test_child_handler_summary_generator_not_invoked_when_under_256kb() {
use std::sync::atomic::{AtomicBool, Ordering};
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let generator_called = Arc::new(AtomicBool::new(false));
let generator_called_clone = generator_called.clone();
let config =
ChildConfig::default().set_summary_generator(Arc::new(move |_serialized: &str| {
generator_called_clone.store(true, Ordering::SeqCst);
r#"{"summary":"should not be called"}"#.to_string()
}));
let result: Result<String, DurableError> = child_handler(
|_ctx| async { Ok("small result".to_string()) },
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert!(
!generator_called.load(Ordering::SeqCst),
"summary_generator should NOT have been invoked for result <= 256KB"
);
}
#[tokio::test]
async fn test_child_handler_summary_generator_none_preserves_behavior() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let parent_ctx = create_test_parent_ctx(state.clone());
let config = ChildConfig::default();
let large_string = "y".repeat(300_000);
let result: Result<String, DurableError> = child_handler(
|_ctx| {
let s = large_string.clone();
async move { Ok(s) }
},
&state,
&op_id,
&parent_ctx,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "y".repeat(300_000));
}
}