mod common;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use durable_execution_sdk::client::CheckpointResponse;
use durable_execution_sdk::config::CallbackConfig;
use durable_execution_sdk::context::DurableContext;
use durable_execution_sdk::duration::Duration;
use durable_execution_sdk::error::DurableError;
use durable_execution_sdk::lambda::InitialExecutionState;
use durable_execution_sdk::operation::OperationType;
use durable_execution_sdk::state::ExecutionState;
use proptest::prelude::*;
use common::*;
fn create_callback_mock_client() -> Arc<MockDurableServiceClient> {
Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response_with_callback("token-1", "test-callback-id")
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-3")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-4"))),
)
}
fn create_fresh_state(client: Arc<MockDurableServiceClient>) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
TEST_EXECUTION_ARN,
TEST_CHECKPOINT_TOKEN,
InitialExecutionState::new(),
client,
))
}
#[tokio::test]
async fn test_wait_for_callback_checkpoints_submitter_fresh_execution() {
let submitter_call_count = Arc::new(AtomicU32::new(0));
let submitter_count_clone = submitter_call_count.clone();
let client = create_callback_mock_client();
let state = create_fresh_state(client.clone());
let ctx = DurableContext::new(state);
let submitter = move |_callback_id: String| {
let count = submitter_count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
}
};
let result: Result<String, DurableError> = ctx.wait_for_callback(submitter, None).await;
assert!(result.is_err());
assert!(
matches!(result.as_ref().unwrap_err(), DurableError::Suspend { .. }),
"Should suspend waiting for callback, got: {:?}",
result
);
assert_eq!(
submitter_call_count.load(Ordering::SeqCst),
1,
"Submitter should be called exactly once during fresh execution"
);
let calls = client.get_checkpoint_calls();
assert!(
calls.len() >= 2,
"Should have created checkpoints for callback and child context operations"
);
}
#[tokio::test]
async fn test_wait_for_callback_replay_behavior() {
let client = create_callback_mock_client();
let state = create_fresh_state(client.clone());
let ctx = DurableContext::new(state);
let submitter = |_callback_id: String| async move {
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
};
let _result: Result<String, DurableError> = ctx.wait_for_callback(submitter, None).await;
let calls = client.get_checkpoint_calls();
let has_callback = calls.iter().any(|call| {
call.operations
.iter()
.any(|op| op.operation_type == OperationType::Callback)
});
assert!(has_callback, "Should have callback checkpoint");
let has_context = calls.iter().any(|call| {
call.operations
.iter()
.any(|op| op.operation_type == OperationType::Context)
});
assert!(
has_context,
"Should have context checkpoint for child context"
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_wait_for_callback_error_propagation(
error_message in "[a-zA-Z0-9 _-]{1,100}",
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let error_msg = error_message.clone();
let client = create_callback_mock_client();
let state = create_fresh_state(client);
let ctx = DurableContext::new(state);
let submitter = move |_callback_id: String| {
let msg = error_msg.clone();
async move {
Err::<(), Box<dyn std::error::Error + Send + Sync>>(
msg.into()
)
}
};
let result: Result<String, DurableError> = ctx.wait_for_callback(submitter, None).await;
prop_assert!(result.is_err(), "Should return error when submitter fails");
match result.unwrap_err() {
DurableError::UserCode { message, error_type, .. } => {
prop_assert_eq!(
error_type,
"SubmitterError",
"Error type should be SubmitterError"
);
prop_assert!(
message.contains(&error_message),
"Error message should contain original message. Expected to contain '{}', got '{}'",
error_message,
message
);
}
other => {
prop_assert!(
false,
"Expected UserCode error with SubmitterError type, got {:?}",
other
);
}
}
Ok(())
})?;
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_wait_for_callback_configuration_passthrough(
timeout_hours in 1u64..168u64, heartbeat_minutes in 1u64..60u64, ) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = create_callback_mock_client();
let state = create_fresh_state(client.clone());
let ctx = DurableContext::new(state);
let config = CallbackConfig {
timeout: Duration::from_hours(timeout_hours),
heartbeat_timeout: Duration::from_minutes(heartbeat_minutes),
..Default::default()
};
let submitter = |_callback_id: String| async move {
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
};
let _result: Result<String, DurableError> = ctx.wait_for_callback(submitter, Some(config)).await;
let calls = client.get_checkpoint_calls();
prop_assert!(
!calls.is_empty(),
"Should have made at least one checkpoint call"
);
let callback_call = calls.iter().find(|call| {
call.operations.iter().any(|op| op.operation_type == OperationType::Callback)
});
prop_assert!(
callback_call.is_some(),
"Should have a checkpoint call with Callback operation"
);
let callback_op = callback_call
.unwrap()
.operations
.iter()
.find(|op| op.operation_type == OperationType::Callback)
.unwrap();
prop_assert!(
callback_op.callback_options.is_some(),
"Callback operation should have callback_options"
);
let options = callback_op.callback_options.as_ref().unwrap();
let expected_timeout_seconds = timeout_hours * 3600;
let expected_heartbeat_seconds = heartbeat_minutes * 60;
prop_assert_eq!(
options.timeout_seconds,
Some(expected_timeout_seconds),
"Timeout should be {} seconds (from {} hours)",
expected_timeout_seconds,
timeout_hours
);
prop_assert_eq!(
options.heartbeat_timeout_seconds,
Some(expected_heartbeat_seconds),
"Heartbeat timeout should be {} seconds (from {} minutes)",
expected_heartbeat_seconds,
heartbeat_minutes
);
Ok(())
})?;
}
}
#[tokio::test]
async fn test_wait_for_callback_creates_correct_checkpoint_structure() {
let client = create_callback_mock_client();
let state = create_fresh_state(client.clone());
let ctx = DurableContext::new(state);
let submitter = |_callback_id: String| async move {
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
};
let _result: Result<String, DurableError> = ctx.wait_for_callback(submitter, None).await;
let calls = client.get_checkpoint_calls();
assert!(
calls.len() >= 2,
"Should have at least 2 checkpoint calls, got {}",
calls.len()
);
let first_call = &calls[0];
assert!(
first_call
.operations
.iter()
.any(|op| op.operation_type == OperationType::Callback),
"First checkpoint should include Callback operation"
);
}
#[tokio::test]
async fn test_wait_for_callback_result_return() {
use durable_execution_sdk::operation::{CallbackDetails, OperationStatus};
let expected_result = "approval_granted";
let callback_id = "test-callback-result-id";
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response_with_callback("token-1", callback_id)
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-3")))
.with_checkpoint_response(Ok(CheckpointResponse {
checkpoint_token: "token-4".to_string(),
new_execution_state: Some(durable_execution_sdk::client::NewExecutionState {
operations: vec![{
let mut op = durable_execution_sdk::operation::Operation::new(
"__CALLBACK_PLACEHOLDER__",
OperationType::Callback,
);
op.status = OperationStatus::Succeeded;
op.callback_details = Some(CallbackDetails {
callback_id: Some(callback_id.to_string()),
result: Some(format!("\"{}\"", expected_result)), error: None,
});
op
}],
next_marker: None,
}),
})),
);
let state = create_fresh_state(client.clone());
let ctx = DurableContext::new(state);
let received_callback_id = Arc::new(std::sync::Mutex::new(String::new()));
let received_id_clone = received_callback_id.clone();
let submitter = move |cb_id: String| {
let received = received_id_clone.clone();
async move {
*received.lock().unwrap() = cb_id;
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
}
};
let result: Result<String, DurableError> = ctx.wait_for_callback(submitter, None).await;
assert_eq!(
*received_callback_id.lock().unwrap(),
callback_id,
"Submitter should receive the callback_id"
);
match result {
Ok(value) => {
assert_eq!(
value, expected_result,
"Result should be deserialized correctly"
);
}
Err(DurableError::Suspend { .. }) => {
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
}
#[tokio::test]
async fn test_wait_for_callback_timeout_error() {
use durable_execution_sdk::error::ErrorObject;
use durable_execution_sdk::operation::{CallbackDetails, OperationStatus};
let callback_id = "test-callback-timeout-id";
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response_with_callback("token-1", callback_id)
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-3")))
.with_checkpoint_response(Ok(CheckpointResponse {
checkpoint_token: "token-4".to_string(),
new_execution_state: Some(durable_execution_sdk::client::NewExecutionState {
operations: vec![{
let mut op = durable_execution_sdk::operation::Operation::new(
"__CALLBACK_PLACEHOLDER__",
OperationType::Callback,
);
op.status = OperationStatus::TimedOut;
op.callback_details = Some(CallbackDetails {
callback_id: Some(callback_id.to_string()),
result: None,
error: Some(ErrorObject::new("TimeoutError", "Callback timed out")),
});
op
}],
next_marker: None,
}),
})),
);
let state = create_fresh_state(client.clone());
let ctx = DurableContext::new(state);
let submitter = |_callback_id: String| async move {
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
};
let result: Result<String, DurableError> = ctx.wait_for_callback(submitter, None).await;
match result {
Ok(_) => {
panic!("Expected error or suspend, got success");
}
Err(DurableError::Suspend { .. }) => {
}
Err(DurableError::Callback { .. }) => {
}
Err(e) => {
let _ = e; }
}
}
#[tokio::test]
async fn test_wait_for_callback_default_config() {
let client = create_callback_mock_client();
let state = create_fresh_state(client.clone());
let ctx = DurableContext::new(state);
let submitter = |_callback_id: String| async move {
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
};
let _result: Result<String, DurableError> = ctx.wait_for_callback(submitter, None).await;
let calls = client.get_checkpoint_calls();
let callback_call = calls.iter().find(|call| {
call.operations
.iter()
.any(|op| op.operation_type == OperationType::Callback)
});
assert!(callback_call.is_some(), "Should have callback checkpoint");
let callback_op = callback_call
.unwrap()
.operations
.iter()
.find(|op| op.operation_type == OperationType::Callback)
.unwrap();
assert!(
callback_op.callback_options.is_some(),
"Should have callback options even with default config"
);
}