use std::sync::Arc;
use serde::{de::DeserializeOwned, Serialize};
use crate::context::{
LogInfo, Logger, OperationIdentifier, WaitForConditionConfig, WaitForConditionContext,
};
use crate::error::{DurableError, ErrorObject};
use crate::operation::{OperationType, OperationUpdate};
use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
use crate::state::{CheckpointedResult, ExecutionState};
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
struct WaitForConditionState<S> {
user_state: S,
attempt: usize,
}
pub async fn wait_for_condition_handler<T, S, F, Fut>(
check: F,
config: WaitForConditionConfig<S>,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<T, DurableError>
where
T: Serialize + DeserializeOwned + Send,
S: Serialize + DeserializeOwned + Clone + Send + Sync,
F: Fn(&S, &WaitForConditionContext) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>> + Send,
{
let mut log_info =
LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
if let Some(ref parent_id) = op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
logger.debug(
&format!("Starting wait_for_condition operation: {}", op_id),
&log_info,
);
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
if let Some(result) = handle_replay::<T>(&checkpoint_result, state, op_id, logger).await? {
return Ok(result);
}
let (current_attempt, user_state) = get_current_state::<S>(&checkpoint_result, &config)?;
let check_ctx = WaitForConditionContext {
attempt: current_attempt,
max_attempts: None,
};
logger.debug(
&format!("Checking condition (attempt {})", current_attempt),
&log_info,
);
if current_attempt == 1 && !checkpoint_result.is_existent() {
let start_update = create_start_update(op_id);
state.create_checkpoint(start_update, true).await?;
}
match check(&user_state, &check_ctx).await {
Ok(result) => {
logger.debug(
&format!("Condition met on attempt {}", current_attempt),
&log_info,
);
let serdes = JsonSerDes::<T>::new();
let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
let serialized =
serdes
.serialize(&result, &serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to serialize wait_for_condition result: {}", e),
})?;
let succeed_update = create_succeed_update(op_id, Some(serialized));
state.create_checkpoint(succeed_update, true).await?;
Ok(result)
}
Err(e) => {
logger.debug(
&format!("Condition not met on attempt {}: {}", current_attempt, e),
&log_info,
);
let decision = (config.wait_strategy)(&user_state, current_attempt);
match decision {
crate::config::WaitDecision::Done => {
logger.error(
&format!(
"Max attempts exceeded for wait_for_condition on attempt {}",
current_attempt
),
&log_info,
);
let error = ErrorObject::new(
"MaxAttemptsExceeded",
format!(
"Max attempts exceeded for wait_for_condition. Last error: {}",
e
),
);
let fail_update = create_fail_update(op_id, error);
state.create_checkpoint(fail_update, true).await?;
Err(DurableError::Execution {
message: "Max attempts exceeded for wait_for_condition".to_string(),
termination_reason: crate::error::TerminationReason::ExecutionError,
})
}
crate::config::WaitDecision::Continue { delay } => {
let next_state = WaitForConditionState {
user_state: user_state.clone(),
attempt: current_attempt + 1,
};
let state_serdes = JsonSerDes::<WaitForConditionState<S>>::new();
let serdes_ctx =
SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
let serialized_state = state_serdes
.serialize(&next_state, &serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to serialize wait_for_condition state: {}", e),
})?;
let retry_update = create_retry_update(
op_id,
Some(serialized_state),
Some(delay.to_seconds()),
);
state.create_checkpoint(retry_update, true).await?;
Err(DurableError::Suspend {
scheduled_timestamp: None,
})
}
}
}
}
}
async fn handle_replay<T>(
checkpoint_result: &CheckpointedResult,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<Option<T>, DurableError>
where
T: Serialize + DeserializeOwned,
{
if !checkpoint_result.is_existent() {
return Ok(None);
}
let mut log_info =
LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
if let Some(ref parent_id) = op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
if let Some(op_type) = checkpoint_result.operation_type() {
if op_type != OperationType::Step {
return Err(DurableError::NonDeterministic {
message: format!(
"Expected Step operation but found {:?} at operation_id {}",
op_type, op_id.operation_id
),
operation_id: Some(op_id.operation_id.clone()),
});
}
}
if checkpoint_result.is_succeeded() {
logger.debug(
&format!("Replaying succeeded wait_for_condition: {}", op_id),
&log_info,
);
state.track_replay(&op_id.operation_id).await;
if let Some(result_str) = checkpoint_result.result() {
let serdes = JsonSerDes::<T>::new();
let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
let result =
serdes
.deserialize(result_str, &serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to deserialize checkpointed result: {}", e),
})?;
return Ok(Some(result));
}
}
if checkpoint_result.is_failed() {
logger.debug(
&format!("Replaying failed wait_for_condition: {}", op_id),
&log_info,
);
state.track_replay(&op_id.operation_id).await;
if let Some(error) = checkpoint_result.error() {
return Err(DurableError::UserCode {
message: error.error_message.clone(),
error_type: error.error_type.clone(),
stack_trace: error.stack_trace.clone(),
});
} else {
return Err(DurableError::execution(
"wait_for_condition failed with unknown error",
));
}
}
if checkpoint_result.is_ready() {
logger.debug(
&format!("Resuming READY wait_for_condition: {}", op_id),
&log_info,
);
return Ok(None);
}
if checkpoint_result.is_pending() {
logger.debug(
&format!("Resuming PENDING wait_for_condition: {}", op_id),
&log_info,
);
return Ok(None);
}
Ok(None)
}
fn get_current_state<S>(
checkpoint_result: &CheckpointedResult,
config: &WaitForConditionConfig<S>,
) -> Result<(usize, S), DurableError>
where
S: Serialize + DeserializeOwned + Clone,
{
if let Some(payload) = checkpoint_result.retry_payload() {
let serdes = JsonSerDes::<WaitForConditionState<S>>::new();
let serdes_ctx = SerDesContext::new("", "");
let state: WaitForConditionState<S> =
serdes
.deserialize(payload, &serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to deserialize wait_for_condition state: {}", e),
})?;
return Ok((state.attempt, state.user_state));
}
if let Some(attempt) = checkpoint_result.attempt() {
return Ok((attempt as usize, config.initial_state.clone()));
}
Ok((1, config.initial_state.clone()))
}
fn create_start_update(op_id: &OperationIdentifier) -> OperationUpdate {
let mut update = OperationUpdate::start(&op_id.operation_id, OperationType::Step)
.with_sub_type("wait_for_condition");
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
fn create_succeed_update(op_id: &OperationIdentifier, result: Option<String>) -> OperationUpdate {
let mut update = OperationUpdate::succeed(&op_id.operation_id, OperationType::Step, result)
.with_sub_type("wait_for_condition");
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
#[allow(dead_code)]
fn create_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
let mut update = OperationUpdate::fail(&op_id.operation_id, OperationType::Step, error)
.with_sub_type("wait_for_condition");
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
fn create_retry_update(
op_id: &OperationIdentifier,
payload: Option<String>,
next_attempt_delay_seconds: Option<u64>,
) -> OperationUpdate {
let mut update = OperationUpdate::retry(
&op_id.operation_id,
OperationType::Step,
payload,
next_attempt_delay_seconds,
)
.with_sub_type("wait_for_condition");
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
use crate::context::{OperationIdentifier, TracingLogger};
use crate::duration::Duration;
use crate::lambda::InitialExecutionState;
use crate::operation::{Operation, OperationStatus, StepDetails};
use std::sync::atomic::{AtomicUsize, Ordering};
fn create_mock_client() -> SharedDurableServiceClient {
Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-3"))),
)
}
fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
InitialExecutionState::new(),
client,
))
}
fn create_test_state_with_operations(
client: SharedDurableServiceClient,
operations: Vec<Operation>,
) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
InitialExecutionState::with_operations(operations),
client,
))
}
fn create_test_op_id() -> OperationIdentifier {
OperationIdentifier::new(
"test-wait-cond-123",
None,
Some("test-wait-condition".to_string()),
)
}
fn create_test_logger() -> Arc<dyn Logger> {
Arc::new(TracingLogger)
}
fn create_test_config<S: Clone + Send + Sync + 'static>(
initial_state: S,
) -> WaitForConditionConfig<S> {
WaitForConditionConfig::from_interval(initial_state, Duration::from_seconds(5), Some(3))
}
#[tokio::test]
async fn test_initial_condition_check_executed_on_new_operation() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let check_called = Arc::new(AtomicUsize::new(0));
let check_called_clone = check_called.clone();
let config = create_test_config(0i32);
let result = wait_for_condition_handler(
move |_state: &i32, ctx: &WaitForConditionContext| {
check_called_clone.fetch_add(1, Ordering::SeqCst);
assert_eq!(ctx.attempt, 1, "First attempt should be 1");
async move { Ok::<_, Box<dyn std::error::Error + Send + Sync>>(42) }
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok(), "Should succeed when condition passes");
assert_eq!(result.unwrap(), 42);
assert_eq!(
check_called.load(Ordering::SeqCst),
1,
"Check function should be called once"
);
}
#[tokio::test]
async fn test_initial_check_receives_initial_state() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct TestState {
counter: i32,
name: String,
}
let initial_state = TestState {
counter: 42,
name: "test".to_string(),
};
let config = WaitForConditionConfig::from_interval(
initial_state,
Duration::from_seconds(5),
Some(3),
);
let result =
wait_for_condition_handler(
|state: &TestState, _ctx: &WaitForConditionContext| {
assert_eq!(state.counter, 42);
assert_eq!(state.name, "test");
async move {
Ok::<_, Box<dyn std::error::Error + Send + Sync>>("success".to_string())
}
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
}
#[tokio::test]
async fn test_retry_action_checkpointed_with_state_payload() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = create_test_config(0i32);
let result = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
Err::<String, _>("condition not met".into())
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Suspend { .. } => {
}
other => panic!("Expected Suspend error, got: {:?}", other),
}
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
let op = checkpoint_result
.operation()
.expect("Operation should exist");
assert_eq!(
op.status,
OperationStatus::Pending,
"Status should be Pending after RETRY"
);
}
#[tokio::test]
async fn test_succeed_action_checkpointed_when_condition_passes() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = create_test_config(0i32);
let result = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
Ok::<_, Box<dyn std::error::Error + Send + Sync>>("success_result".to_string())
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success_result");
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
assert!(
checkpoint_result.is_succeeded(),
"Checkpoint should be succeeded"
);
let result_str = checkpoint_result.result().expect("Result should exist");
assert!(result_str.contains("success_result"));
}
#[tokio::test]
async fn test_suspension_when_replaying_pending_status() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-cond-123", OperationType::Step);
op.status = OperationStatus::Pending;
op.step_details = Some(StepDetails {
result: None,
attempt: Some(1),
next_attempt_timestamp: Some(9999999999000), error: None,
payload: Some(r#"{"user_state":0,"attempt":2}"#.to_string()),
});
let state = create_test_state_with_operations(client, vec![op]);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = create_test_config(0i32);
let result = wait_for_condition_handler(
|_state: &i32, ctx: &WaitForConditionContext| {
assert_eq!(ctx.attempt, 2, "Should be on attempt 2");
async move {
Ok::<_, Box<dyn std::error::Error + Send + Sync>>("resumed_result".to_string())
}
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "resumed_result");
}
#[tokio::test]
async fn test_fail_action_checkpointed_when_retries_exhausted() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config =
WaitForConditionConfig::from_interval(0i32, Duration::from_seconds(5), Some(1));
let result = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
Err::<String, _>("condition never met".into())
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Execution { message, .. } => {
assert!(
message.contains("Max attempts"),
"Error should mention max attempts"
);
}
other => panic!("Expected Execution error, got: {:?}", other),
}
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
assert!(checkpoint_result.is_failed(), "Checkpoint should be failed");
}
#[tokio::test]
async fn test_retry_payload_contains_previous_state() {
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
);
#[derive(Clone, serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct TestState {
counter: i32,
}
let previous_state = WaitForConditionState {
user_state: TestState { counter: 42 },
attempt: 2,
};
let payload = serde_json::to_string(&previous_state).unwrap();
let mut op = Operation::new("test-wait-cond-123", OperationType::Step);
op.status = OperationStatus::Pending;
op.step_details = Some(StepDetails {
result: None,
attempt: Some(1),
next_attempt_timestamp: Some(1234567890000),
error: None,
payload: Some(payload),
});
let state = create_test_state_with_operations(client, vec![op]);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = WaitForConditionConfig::from_interval(
TestState { counter: 0 },
Duration::from_seconds(5),
Some(5),
);
let result =
wait_for_condition_handler(
|state: &TestState, ctx: &WaitForConditionContext| {
assert_eq!(state.counter, 42, "State should be from retry payload");
assert_eq!(ctx.attempt, 2, "Attempt should be 2 from payload");
async move {
Ok::<_, Box<dyn std::error::Error + Send + Sync>>("success".to_string())
}
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_condition_receives_initial_state_on_first_attempt() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
#[derive(Clone, serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct TestState {
value: String,
}
let config = WaitForConditionConfig::from_interval(
TestState {
value: "initial".to_string(),
},
Duration::from_seconds(5),
Some(3),
);
let result = wait_for_condition_handler(
|state: &TestState, ctx: &WaitForConditionContext| {
assert_eq!(ctx.attempt, 1, "Should be first attempt");
assert_eq!(state.value, "initial", "Should receive initial state");
async move { Ok::<_, Box<dyn std::error::Error + Send + Sync>>("done".to_string()) }
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_condition_function_error_triggers_retry() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config =
WaitForConditionConfig::from_interval(0i32, Duration::from_seconds(5), Some(3));
let result = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
Err::<String, _>("condition check failed".into())
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::Suspend { .. } => {
}
other => panic!("Expected Suspend error for retry, got: {:?}", other),
}
}
#[tokio::test]
async fn test_replay_succeeded_operation_returns_cached_result() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-cond-123", OperationType::Step);
op.status = OperationStatus::Succeeded;
op.step_details = Some(StepDetails {
result: Some(r#""cached_result""#.to_string()),
attempt: Some(2),
next_attempt_timestamp: None,
error: None,
payload: None,
});
let state = create_test_state_with_operations(client, vec![op]);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = create_test_config(0i32);
let result: Result<String, DurableError> = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
panic!("Function should not be called during replay of succeeded operation")
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "cached_result");
}
#[tokio::test]
async fn test_replay_failed_operation_returns_error() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-cond-123", OperationType::Step);
op.status = OperationStatus::Failed;
op.error = Some(ErrorObject::new(
"MaxAttemptsExceeded",
"Max attempts exceeded",
));
let state = create_test_state_with_operations(client, vec![op]);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = create_test_config(0i32);
let result: Result<String, DurableError> = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
panic!("Function should not be called during replay of failed operation")
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::UserCode { message, .. } => {
assert!(message.contains("Max attempts"));
}
other => panic!("Expected UserCode error, got: {:?}", other),
}
}
#[tokio::test]
async fn test_non_deterministic_detection_wrong_operation_type() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-wait-cond-123", OperationType::Wait);
op.status = OperationStatus::Succeeded;
let state = create_test_state_with_operations(client, vec![op]);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = create_test_config(0i32);
let result = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
Ok::<_, Box<dyn std::error::Error + Send + Sync>>("should not reach".to_string())
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::NonDeterministic { operation_id, .. } => {
assert_eq!(operation_id, Some("test-wait-cond-123".to_string()));
}
other => panic!("Expected NonDeterministic error, got: {:?}", other),
}
}
#[tokio::test]
async fn test_ready_status_continues_execution() {
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
);
let mut op = Operation::new("test-wait-cond-123", OperationType::Step);
op.status = OperationStatus::Ready;
op.step_details = Some(StepDetails {
result: None,
attempt: Some(1),
next_attempt_timestamp: None,
error: None,
payload: None,
});
let state = create_test_state_with_operations(client, vec![op]);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = create_test_config(0i32);
let result = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
Ok::<_, Box<dyn std::error::Error + Send + Sync>>("ready_result".to_string())
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "ready_result");
}
#[tokio::test]
async fn test_wait_for_condition_genuinely_async_closure() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let logger = create_test_logger();
let config = create_test_config(0i32);
let result = wait_for_condition_handler(
|_state: &i32, _ctx: &WaitForConditionContext| async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok::<_, Box<dyn std::error::Error + Send + Sync>>("async_poll_result".to_string())
},
config,
&state,
&op_id,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "async_poll_result");
}
}