use std::sync::Arc;
use crate::context::{create_operation_span, LogInfo, Logger, OperationIdentifier};
use crate::duration::Duration;
use crate::error::DurableError;
use crate::operation::{OperationType, OperationUpdate};
use crate::state::ExecutionState;
const MIN_WAIT_SECONDS: u64 = 1;
pub async fn wait_handler(
duration: Duration,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<(), DurableError> {
let span = create_operation_span("wait", op_id, state.durable_execution_arn());
let _guard = span.enter();
let mut log_info =
LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
if let Some(ref parent_id) = op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
logger.debug(
&format!(
"Starting wait operation: {} for {} seconds",
op_id,
duration.to_seconds()
),
&log_info,
);
let wait_seconds = duration.to_seconds();
if wait_seconds < MIN_WAIT_SECONDS {
span.record("status", "validation_failed");
return Err(DurableError::Validation {
message: format!(
"Wait duration must be at least {} second(s), got {} seconds",
MIN_WAIT_SECONDS, wait_seconds
),
});
}
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
if checkpoint_result.is_existent() {
if let Some(op_type) = checkpoint_result.operation_type() {
if op_type != OperationType::Wait {
span.record("status", "non_deterministic");
return Err(DurableError::NonDeterministic {
message: format!(
"Expected Wait operation but found {:?} at operation_id {}",
op_type, op_id.operation_id
),
operation_id: Some(op_id.operation_id.clone()),
});
}
}
if checkpoint_result.is_succeeded() {
logger.debug(&format!("Wait already completed: {}", op_id), &log_info);
state.track_replay(&op_id.operation_id).await;
span.record("status", "replayed_succeeded");
return Ok(());
}
if checkpoint_result.is_cancelled() {
logger.debug(&format!("Wait was cancelled: {}", op_id), &log_info);
state.track_replay(&op_id.operation_id).await;
span.record("status", "replayed_cancelled");
return Ok(());
}
if checkpoint_result.is_existent() && !checkpoint_result.is_terminal() {
logger.debug(&format!("Wait still in progress: {}", op_id), &log_info);
span.record("status", "suspended");
return Err(DurableError::Suspend {
scheduled_timestamp: None,
});
}
}
let start_update = create_start_update(op_id, wait_seconds);
state.create_checkpoint(start_update, true).await?;
logger.debug(
&format!("Wait started for {} seconds", wait_seconds),
&log_info,
);
span.record("status", "suspended");
Err(DurableError::Suspend {
scheduled_timestamp: None,
})
}
pub async fn wait_cancel_handler(
state: &Arc<ExecutionState>,
operation_id: &str,
logger: &Arc<dyn Logger>,
) -> Result<(), DurableError> {
let log_info = LogInfo::new(state.durable_execution_arn()).with_operation_id(operation_id);
logger.debug(
&format!("Attempting to cancel wait operation: {}", operation_id),
&log_info,
);
let checkpoint_result = state.get_checkpoint_result(operation_id).await;
if !checkpoint_result.is_existent() {
logger.debug(
&format!(
"Wait operation not found, nothing to cancel: {}",
operation_id
),
&log_info,
);
return Ok(());
}
if let Some(op_type) = checkpoint_result.operation_type() {
if op_type != OperationType::Wait {
return Err(DurableError::Validation {
message: format!(
"Cannot cancel operation {}: expected WAIT operation but found {:?}",
operation_id, op_type
),
});
}
}
if checkpoint_result.is_terminal() {
logger.debug(
&format!(
"Wait already completed, nothing to cancel: {}",
operation_id
),
&log_info,
);
return Ok(());
}
let cancel_update = OperationUpdate::cancel(operation_id, OperationType::Wait);
state.create_checkpoint(cancel_update, true).await?;
logger.info(
&format!("Wait operation cancelled: {}", operation_id),
&log_info,
);
Ok(())
}
fn create_start_update(op_id: &OperationIdentifier, wait_seconds: u64) -> OperationUpdate {
op_id.apply_to(OperationUpdate::start_wait(
&op_id.operation_id,
wait_seconds,
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
use crate::context::TracingLogger;
use crate::lambda::InitialExecutionState;
use crate::operation::{Operation, OperationStatus};
fn create_mock_client() -> SharedDurableServiceClient {
Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2"))),
)
}
fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
InitialExecutionState::new(),
client,
))
}
fn create_test_op_id() -> OperationIdentifier {
OperationIdentifier::new("test-wait-123", None, Some("test-wait".to_string()))
}
fn create_test_logger() -> Arc<dyn Logger> {
Arc::new(TracingLogger)
}
#[tokio::test]
async fn test_wait_handler_validation_error() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(Duration::from_seconds(0), &state, &op_id, &logger).await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Validation { message } => {
assert!(message.contains("at least 1 second"));
}
_ => panic!("Expected Validation error"),
}
}
#[tokio::test]
async fn test_wait_handler_suspends_on_new_wait() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Suspend {
scheduled_timestamp: _,
} => {
}
_ => panic!("Expected Suspend error"),
}
}
#[tokio::test]
async fn test_wait_handler_replay_completed() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-123", OperationType::Wait);
op.status = OperationStatus::Succeeded;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_wait_handler_non_deterministic_detection() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-123", OperationType::Step);
op.status = OperationStatus::Succeeded;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::NonDeterministic { operation_id, .. } => {
assert_eq!(operation_id, Some("test-wait-123".to_string()));
}
_ => panic!("Expected NonDeterministic error"),
}
}
#[tokio::test]
async fn test_wait_handler_replay_still_waiting() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-123", OperationType::Wait);
op.status = OperationStatus::Started;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(
Duration::from_seconds(3600), &state,
&op_id,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Suspend {
scheduled_timestamp: _,
} => {
}
_ => panic!("Expected Suspend error"),
}
}
#[test]
fn test_create_start_update() {
let op_id = OperationIdentifier::new(
"op-123",
Some("parent-456".to_string()),
Some("my-wait".to_string()),
);
let update = create_start_update(&op_id, 60);
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.operation_type, OperationType::Wait);
assert!(update.wait_options.is_some());
assert_eq!(update.wait_options.unwrap().wait_seconds, 60);
assert_eq!(update.parent_id, Some("parent-456".to_string()));
assert_eq!(update.name, Some("my-wait".to_string()));
}
#[tokio::test]
async fn test_wait_cancel_handler_cancels_active_wait() {
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
);
let mut op = Operation::new("test-wait-123", OperationType::Wait);
op.status = OperationStatus::Started;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let logger = create_test_logger();
let result = wait_cancel_handler(&state, "test-wait-123", &logger).await;
assert!(result.is_ok());
let checkpoint_result = state.get_checkpoint_result("test-wait-123").await;
assert!(checkpoint_result.is_cancelled());
}
#[tokio::test]
async fn test_wait_cancel_handler_handles_already_completed_wait() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-123", OperationType::Wait);
op.status = OperationStatus::Succeeded;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let logger = create_test_logger();
let result = wait_cancel_handler(&state, "test-wait-123", &logger).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_wait_cancel_handler_handles_nonexistent_wait() {
let client = Arc::new(MockDurableServiceClient::new());
let initial_state = InitialExecutionState::new();
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let logger = create_test_logger();
let result = wait_cancel_handler(&state, "nonexistent-wait", &logger).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_wait_cancel_handler_rejects_non_wait_operation() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-step-123", OperationType::Step);
op.status = OperationStatus::Started;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let logger = create_test_logger();
let result = wait_cancel_handler(&state, "test-step-123", &logger).await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Validation { message } => {
assert!(message.contains("expected WAIT operation"));
}
_ => panic!("Expected Validation error"),
}
}
#[tokio::test]
async fn test_wait_handler_replay_cancelled_wait() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-123", OperationType::Wait);
op.status = OperationStatus::Cancelled;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_wait_handler_checks_status_before_checkpoint() {
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
);
let initial_state = InitialExecutionState::new();
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client.clone(),
));
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
assert!(matches!(result, Err(DurableError::Suspend { .. })));
let checkpoint_result = state.get_checkpoint_result("test-wait-123").await;
assert!(
checkpoint_result.is_existent(),
"Checkpoint should have been created"
);
assert_eq!(
checkpoint_result.operation_type(),
Some(OperationType::Wait)
);
}
#[tokio::test]
async fn test_wait_handler_status_check_detects_existing_operation() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-123", OperationType::Wait);
op.status = OperationStatus::Pending;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
assert!(matches!(result, Err(DurableError::Suspend { .. })));
}
#[tokio::test]
async fn test_wait_handler_immediate_completion_via_checkpoint_response() {
use crate::client::{CheckpointResponse, NewExecutionState};
let mut succeeded_op = Operation::new("test-wait-123", OperationType::Wait);
succeeded_op.status = OperationStatus::Succeeded;
let checkpoint_response = CheckpointResponse {
checkpoint_token: "token-1".to_string(),
new_execution_state: Some(NewExecutionState {
operations: vec![succeeded_op],
next_marker: None,
}),
};
let client = Arc::new(
MockDurableServiceClient::new().with_checkpoint_response(Ok(checkpoint_response)),
);
let initial_state = InitialExecutionState::new();
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let logger = create_test_logger();
let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
assert!(matches!(result, Err(DurableError::Suspend { .. })));
let checkpoint_result = state.get_checkpoint_result("test-wait-123").await;
assert!(
checkpoint_result.is_succeeded(),
"State should reflect succeeded status from checkpoint response"
);
}
}