use std::sync::Arc;
use serde::{de::DeserializeOwned, Serialize};
use crate::config::{StepConfig, StepSemantics};
use crate::context::{create_operation_span, LogInfo, Logger, OperationIdentifier};
use crate::error::{DurableError, ErrorObject, StepResult, TerminationReason};
use crate::operation::{OperationType, OperationUpdate};
use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
use crate::state::{CheckpointedResult, ExecutionState};
use crate::traits::DurableValue;
#[derive(Debug, Clone)]
pub struct StepContext {
pub operation_id: String,
pub parent_id: Option<String>,
pub name: Option<String>,
pub durable_execution_arn: String,
pub attempt: u32,
pub retry_payload: Option<String>,
}
impl StepContext {
pub fn new(operation_id: impl Into<String>, durable_execution_arn: impl Into<String>) -> Self {
Self {
operation_id: operation_id.into(),
parent_id: None,
name: None,
durable_execution_arn: durable_execution_arn.into(),
attempt: 0,
retry_payload: None,
}
}
pub fn with_parent_id(mut self, parent_id: impl Into<String>) -> Self {
self.parent_id = Some(parent_id.into());
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_attempt(mut self, attempt: u32) -> Self {
self.attempt = attempt;
self
}
pub fn with_retry_payload(mut self, payload: impl Into<String>) -> Self {
self.retry_payload = Some(payload.into());
self
}
pub fn serdes_context(&self) -> SerDesContext {
SerDesContext::new(&self.operation_id, &self.durable_execution_arn)
}
pub fn get_retry_payload<T>(
&self,
) -> Result<Option<T>, Box<dyn std::error::Error + Send + Sync>>
where
T: serde::de::DeserializeOwned,
{
match &self.retry_payload {
Some(payload) => {
let value: T = serde_json::from_str(payload)?;
Ok(Some(value))
}
None => Ok(None),
}
}
}
pub async fn step_handler<T, F, Fut>(
func: F,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
config: &StepConfig,
logger: &Arc<dyn Logger>,
) -> StepResult<T>
where
T: DurableValue,
F: FnOnce(StepContext) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>> + Send,
{
let span = create_operation_span("step", op_id, state.durable_execution_arn());
let _guard = span.enter();
let mut log_info =
LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
if let Some(ref parent_id) = op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
logger.debug(&format!("Starting step operation: {}", op_id), &log_info);
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
let skip_start_checkpoint = checkpoint_result.is_ready();
let attempt = checkpoint_result.attempt().unwrap_or(0);
let retry_payload = checkpoint_result.retry_payload().map(|s| s.to_string());
if let Some(result) = handle_replay::<T>(&checkpoint_result, state, op_id, logger).await? {
span.record("status", "replayed");
return Ok(result);
}
let mut step_ctx =
StepContext::new(&op_id.operation_id, state.durable_execution_arn()).with_attempt(attempt);
if let Some(ref parent_id) = op_id.parent_id {
step_ctx = step_ctx.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
step_ctx = step_ctx.with_name(name);
}
if let Some(payload) = retry_payload {
step_ctx = step_ctx.with_retry_payload(payload);
}
let serdes = JsonSerDes::<T>::new();
let serdes_ctx = step_ctx.serdes_context();
let params = StepExecParams {
state,
op_id,
step_ctx: &step_ctx,
serdes: &serdes,
serdes_ctx: &serdes_ctx,
config,
logger,
};
let result = match config.step_semantics {
StepSemantics::AtMostOncePerRetry => {
execute_at_most_once(func, ¶ms, skip_start_checkpoint).await
}
StepSemantics::AtLeastOncePerRetry => execute_at_least_once(func, ¶ms).await,
};
match &result {
Ok(_) => span.record("status", "succeeded"),
Err(_) => span.record("status", "failed"),
};
result
}
async fn handle_replay<T>(
checkpoint_result: &CheckpointedResult,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> StepResult<Option<T>>
where
T: Serialize + DeserializeOwned,
{
if !checkpoint_result.is_existent() {
return Ok(None);
}
let mut log_info =
LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
if let Some(ref parent_id) = op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
if let Some(op_type) = checkpoint_result.operation_type() {
if op_type != OperationType::Step {
return Err(DurableError::NonDeterministic {
message: format!(
"Expected Step operation but found {:?} at operation_id {}",
op_type, op_id.operation_id
),
operation_id: Some(op_id.operation_id.clone()),
});
}
}
if checkpoint_result.is_succeeded() {
logger.debug(&format!("Replaying succeeded step: {}", op_id), &log_info);
state.track_replay(&op_id.operation_id).await;
if let Some(result_str) = checkpoint_result.result() {
let serdes = JsonSerDes::<T>::new();
let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
let result =
serdes
.deserialize(result_str, &serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to deserialize checkpointed result: {}", e),
})?;
return Ok(Some(result));
} else {
let serdes = JsonSerDes::<T>::new();
let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
match serdes.deserialize("null", &serdes_ctx) {
Ok(result) => return Ok(Some(result)),
Err(_) => {
return Err(DurableError::SerDes {
message:
"Step succeeded but no result was stored and type requires a value"
.to_string(),
});
}
}
}
}
if checkpoint_result.is_failed() {
logger.debug(&format!("Replaying failed step: {}", op_id), &log_info);
state.track_replay(&op_id.operation_id).await;
if let Some(error) = checkpoint_result.error() {
return Err(DurableError::UserCode {
message: error.error_message.clone(),
error_type: error.error_type.clone(),
stack_trace: error.stack_trace.clone(),
});
} else {
return Err(DurableError::execution("Step failed with unknown error"));
}
}
if checkpoint_result.is_terminal() {
state.track_replay(&op_id.operation_id).await;
let status = checkpoint_result.status().unwrap();
return Err(DurableError::Execution {
message: format!("Step was {}", status),
termination_reason: TerminationReason::StepInterrupted,
});
}
if checkpoint_result.is_ready() {
logger.debug(&format!("Resuming READY step: {}", op_id), &log_info);
return Ok(None);
}
if checkpoint_result.is_pending() {
logger.debug(
&format!("Step is PENDING, waiting for retry: {}", op_id),
&log_info,
);
return Err(DurableError::Suspend {
scheduled_timestamp: None,
});
}
Ok(None)
}
struct StepExecParams<'a, T> {
state: &'a Arc<ExecutionState>,
op_id: &'a OperationIdentifier,
step_ctx: &'a StepContext,
serdes: &'a JsonSerDes<T>,
serdes_ctx: &'a SerDesContext,
config: &'a StepConfig,
logger: &'a Arc<dyn Logger>,
}
async fn execute_at_most_once<T, F, Fut>(
func: F,
params: &StepExecParams<'_, T>,
skip_start_checkpoint: bool,
) -> StepResult<T>
where
T: DurableValue,
F: FnOnce(StepContext) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>> + Send,
{
let mut log_info = LogInfo::new(params.state.durable_execution_arn())
.with_operation_id(¶ms.op_id.operation_id);
if let Some(ref parent_id) = params.op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
if !skip_start_checkpoint {
params
.logger
.debug("Checkpointing step start (AT_MOST_ONCE)", &log_info);
let start_update = create_start_update(params.op_id);
params.state.create_checkpoint(start_update, true).await?;
} else {
params.logger.debug(
"Skipping START checkpoint for READY operation (AT_MOST_ONCE)",
&log_info,
);
}
let result = execute_with_retry(
func,
params.step_ctx.clone(),
params.config,
params.logger,
&log_info,
)
.await;
checkpoint_result(result, params, &log_info).await
}
async fn execute_at_least_once<T, F, Fut>(func: F, params: &StepExecParams<'_, T>) -> StepResult<T>
where
T: DurableValue,
F: FnOnce(StepContext) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>> + Send,
{
let mut log_info = LogInfo::new(params.state.durable_execution_arn())
.with_operation_id(¶ms.op_id.operation_id);
if let Some(ref parent_id) = params.op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
params
.logger
.debug("Executing step (AT_LEAST_ONCE)", &log_info);
let result = execute_with_retry(
func,
params.step_ctx.clone(),
params.config,
params.logger,
&log_info,
)
.await;
checkpoint_result(result, params, &log_info).await
}
async fn checkpoint_result<T>(
result: Result<T, Box<dyn std::error::Error + Send + Sync>>,
params: &StepExecParams<'_, T>,
log_info: &LogInfo,
) -> StepResult<T>
where
T: DurableValue,
{
match result {
Ok(value) => {
let serialized = params
.serdes
.serialize(&value, params.serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to serialize step result: {}", e),
})?;
let succeed_update = create_succeed_update(params.op_id, Some(serialized));
params.state.create_checkpoint(succeed_update, true).await?;
params.logger.debug("Step completed successfully", log_info);
Ok(value)
}
Err(error) => {
let error_obj = ErrorObject::new("UserCodeError", error.to_string());
let fail_update = create_fail_update(params.op_id, error_obj);
params.state.create_checkpoint(fail_update, true).await?;
params
.logger
.error(&format!("Step failed: {}", error), log_info);
Err(DurableError::UserCode {
message: error.to_string(),
error_type: "UserCodeError".to_string(),
stack_trace: None,
})
}
}
}
async fn execute_with_retry<T, F, Fut>(
func: F,
step_ctx: StepContext,
config: &StepConfig,
logger: &Arc<dyn Logger>,
log_info: &LogInfo,
) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
where
T: Send,
F: FnOnce(StepContext) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>> + Send,
{
if config.retry_strategy.is_some() {
logger.debug(
"Retry strategy configured but not yet implemented for consumed closures",
log_info,
);
}
let result = func(step_ctx).await;
if let Err(ref err) = result {
if let Some(ref filter) = config.retryable_error_filter {
let error_msg = err.to_string();
if !filter.is_retryable(&error_msg) {
logger.debug(
&format!(
"Error does not match retryable error filter, skipping retry: {}",
error_msg
),
log_info,
);
return result;
}
logger.debug(
&format!("Error matches retryable error filter: {}", error_msg),
log_info,
);
}
}
result
}
fn create_start_update(op_id: &OperationIdentifier) -> OperationUpdate {
let mut update = OperationUpdate::start(&op_id.operation_id, OperationType::Step);
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
fn create_succeed_update(op_id: &OperationIdentifier, result: Option<String>) -> OperationUpdate {
let mut update = OperationUpdate::succeed(&op_id.operation_id, OperationType::Step, result);
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
fn create_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
let mut update = OperationUpdate::fail(&op_id.operation_id, OperationType::Step, error);
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
#[allow(dead_code)]
fn create_retry_update(
op_id: &OperationIdentifier,
payload: Option<String>,
next_attempt_delay_seconds: Option<u64>,
) -> OperationUpdate {
let mut update = OperationUpdate::retry(
&op_id.operation_id,
OperationType::Step,
payload,
next_attempt_delay_seconds,
);
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
#[allow(dead_code)]
fn create_retry_with_error_update(
op_id: &OperationIdentifier,
error: ErrorObject,
next_attempt_delay_seconds: Option<u64>,
) -> OperationUpdate {
let mut update = OperationUpdate::retry_with_error(
&op_id.operation_id,
OperationType::Step,
error,
next_attempt_delay_seconds,
);
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
use crate::context::TracingLogger;
use crate::lambda::InitialExecutionState;
use crate::operation::{Operation, OperationStatus};
fn create_mock_client() -> SharedDurableServiceClient {
Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2"))),
)
}
fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
InitialExecutionState::new(),
client,
))
}
fn create_test_op_id() -> OperationIdentifier {
OperationIdentifier::new("test-op-123", None, Some("test-step".to_string()))
}
fn create_test_logger() -> Arc<dyn Logger> {
Arc::new(TracingLogger)
}
#[test]
fn test_step_context_new() {
let ctx = StepContext::new("op-123", "arn:test");
assert_eq!(ctx.operation_id, "op-123");
assert_eq!(ctx.durable_execution_arn, "arn:test");
assert!(ctx.parent_id.is_none());
assert!(ctx.name.is_none());
assert_eq!(ctx.attempt, 0);
}
#[test]
fn test_step_context_with_parent_id() {
let ctx = StepContext::new("op-123", "arn:test").with_parent_id("parent-456");
assert_eq!(ctx.parent_id, Some("parent-456".to_string()));
}
#[test]
fn test_step_context_with_name() {
let ctx = StepContext::new("op-123", "arn:test").with_name("my-step");
assert_eq!(ctx.name, Some("my-step".to_string()));
}
#[test]
fn test_step_context_with_attempt() {
let ctx = StepContext::new("op-123", "arn:test").with_attempt(3);
assert_eq!(ctx.attempt, 3);
}
#[test]
fn test_step_context_serdes_context() {
let ctx = StepContext::new("op-123", "arn:test");
let serdes_ctx = ctx.serdes_context();
assert_eq!(serdes_ctx.operation_id, "op-123");
assert_eq!(serdes_ctx.durable_execution_arn, "arn:test");
}
#[test]
fn test_step_context_with_retry_payload() {
let ctx = StepContext::new("op-123", "arn:test").with_retry_payload(r#"{"counter": 5}"#);
assert_eq!(ctx.retry_payload, Some(r#"{"counter": 5}"#.to_string()));
}
#[test]
fn test_step_context_get_retry_payload() {
#[derive(serde::Deserialize, PartialEq, Debug)]
struct State {
counter: i32,
}
let ctx = StepContext::new("op-123", "arn:test").with_retry_payload(r#"{"counter": 5}"#);
let payload: Option<State> = ctx.get_retry_payload().unwrap();
assert!(payload.is_some());
assert_eq!(payload.unwrap().counter, 5);
}
#[test]
fn test_step_context_get_retry_payload_none() {
#[derive(serde::Deserialize)]
#[allow(dead_code)]
struct State {
counter: i32,
}
let ctx = StepContext::new("op-123", "arn:test");
let payload: Option<State> = ctx.get_retry_payload().unwrap();
assert!(payload.is_none());
}
#[test]
fn test_create_retry_update() {
let op_id = OperationIdentifier::new(
"op-123",
Some("parent-456".to_string()),
Some("my-step".to_string()),
);
let update =
create_retry_update(&op_id, Some(r#"{"state": "waiting"}"#.to_string()), Some(5));
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.action, crate::operation::OperationAction::Retry);
assert_eq!(update.operation_type, OperationType::Step);
assert_eq!(update.parent_id, Some("parent-456".to_string()));
assert_eq!(update.name, Some("my-step".to_string()));
assert_eq!(update.result, Some(r#"{"state": "waiting"}"#.to_string()));
assert!(update.step_options.is_some());
assert_eq!(
update
.step_options
.as_ref()
.unwrap()
.next_attempt_delay_seconds,
Some(5)
);
}
#[test]
fn test_create_retry_with_error_update() {
let op_id = OperationIdentifier::new("op-123", None, None);
let error = ErrorObject::new("RetryableError", "Temporary failure");
let update = create_retry_with_error_update(&op_id, error, Some(10));
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.action, crate::operation::OperationAction::Retry);
assert!(update.result.is_none());
assert!(update.error.is_some());
assert_eq!(update.error.as_ref().unwrap().error_type, "RetryableError");
}
#[tokio::test]
async fn test_step_handler_success() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = StepConfig::default();
let logger = create_test_logger();
let result: Result<i32, DurableError> = step_handler(
|_ctx| async move { Ok(42) },
&state,
&op_id,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_step_handler_failure() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = StepConfig::default();
let logger = create_test_logger();
let result: Result<i32, DurableError> = step_handler(
|_ctx| async move { Err("test error".into()) },
&state,
&op_id,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::UserCode { message, .. } => {
assert!(message.contains("test error"));
}
_ => panic!("Expected UserCode error"),
}
}
#[tokio::test]
async fn test_step_handler_replay_success() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-op-123", OperationType::Step);
op.status = OperationStatus::Succeeded;
op.result = Some("42".to_string());
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let config = StepConfig::default();
let logger = create_test_logger();
let result: Result<i32, DurableError> = step_handler(
|_ctx| async move { panic!("Function should not be called during replay") },
&state,
&op_id,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_step_handler_replay_failure() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-op-123", OperationType::Step);
op.status = OperationStatus::Failed;
op.error = Some(ErrorObject::new("TestError", "Previous failure"));
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let config = StepConfig::default();
let logger = create_test_logger();
let result: Result<i32, DurableError> = step_handler(
|_ctx| async move { panic!("Function should not be called during replay") },
&state,
&op_id,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::UserCode { message, .. } => {
assert!(message.contains("Previous failure"));
}
_ => panic!("Expected UserCode error"),
}
}
#[tokio::test]
async fn test_step_handler_non_deterministic_detection() {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-op-123", OperationType::Wait);
op.status = OperationStatus::Succeeded;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = create_test_op_id();
let config = StepConfig::default();
let logger = create_test_logger();
let result: Result<i32, DurableError> = step_handler(
|_ctx| async move { Ok(42) },
&state,
&op_id,
&config,
&logger,
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
DurableError::NonDeterministic { operation_id, .. } => {
assert_eq!(operation_id, Some("test-op-123".to_string()));
}
_ => panic!("Expected NonDeterministic error"),
}
}
#[tokio::test]
async fn test_step_handler_at_most_once_semantics() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = StepConfig {
step_semantics: StepSemantics::AtMostOncePerRetry,
..Default::default()
};
let logger = create_test_logger();
let result: Result<String, DurableError> = step_handler(
|_ctx| async move { Ok("at_most_once_result".to_string()) },
&state,
&op_id,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "at_most_once_result");
}
#[tokio::test]
async fn test_step_handler_at_least_once_semantics() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = StepConfig {
step_semantics: StepSemantics::AtLeastOncePerRetry,
..Default::default()
};
let logger = create_test_logger();
let result: Result<String, DurableError> = step_handler(
|_ctx| async move { Ok("at_least_once_result".to_string()) },
&state,
&op_id,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "at_least_once_result");
}
#[test]
fn test_create_start_update() {
let op_id = OperationIdentifier::new(
"op-123",
Some("parent-456".to_string()),
Some("my-step".to_string()),
);
let update = create_start_update(&op_id);
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.operation_type, OperationType::Step);
assert_eq!(update.parent_id, Some("parent-456".to_string()));
assert_eq!(update.name, Some("my-step".to_string()));
}
#[test]
fn test_create_succeed_update() {
let op_id = OperationIdentifier::new("op-123", None, None);
let update = create_succeed_update(&op_id, Some("result".to_string()));
assert_eq!(update.operation_id, "op-123");
assert_eq!(update.result, Some("result".to_string()));
}
#[test]
fn test_create_fail_update() {
let op_id = OperationIdentifier::new("op-123", None, None);
let error = ErrorObject::new("TestError", "test message");
let update = create_fail_update(&op_id, error);
assert_eq!(update.operation_id, "op-123");
assert!(update.error.is_some());
assert_eq!(update.error.unwrap().error_type, "TestError");
}
#[tokio::test]
async fn test_step_handler_genuinely_async_closure() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id();
let config = StepConfig::default();
let logger = create_test_logger();
let result: Result<String, DurableError> = step_handler(
|_ctx| async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok("async_result".to_string())
},
&state,
&op_id,
&config,
&logger,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "async_result");
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
use crate::context::TracingLogger;
use crate::lambda::InitialExecutionState;
use crate::operation::{Operation, OperationStatus};
use proptest::prelude::*;
mod step_semantics_tests {
use super::*;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
InitialExecutionState::new(),
client,
))
}
fn create_test_logger() -> Arc<dyn Logger> {
Arc::new(TracingLogger)
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_at_most_once_checkpoints_before_execution(
result_value in any::<i32>(),
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let checkpoint_order = Arc::new(AtomicU32::new(0));
let execution_order = Arc::new(AtomicU32::new(0));
let order_counter = Arc::new(AtomicU32::new(0));
let _checkpoint_order_clone = checkpoint_order.clone();
let execution_order_clone = execution_order.clone();
let order_counter_clone = order_counter.clone();
let client = Arc::new(MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2"))));
let state = create_test_state(client);
let op_id = OperationIdentifier::new(
format!("test-op-{}", result_value),
None,
Some("test-step".to_string()),
);
let config = StepConfig {
step_semantics: StepSemantics::AtMostOncePerRetry,
..Default::default()
};
let logger = create_test_logger();
let result: Result<i32, DurableError> = step_handler(
move |_ctx| async move {
let order = order_counter_clone.fetch_add(1, Ordering::SeqCst);
execution_order_clone.store(order, Ordering::SeqCst);
Ok(result_value)
},
&state,
&op_id,
&config,
&logger,
).await;
prop_assert!(result.is_ok(), "Step should succeed");
prop_assert_eq!(result.unwrap(), result_value, "Result should match input");
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
prop_assert!(checkpoint_result.is_succeeded(), "Checkpoint should be succeeded");
Ok(())
})?;
}
#[test]
fn prop_at_least_once_checkpoints_after_execution(
result_value in any::<i32>(),
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = Arc::new(MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))));
let state = create_test_state(client);
let op_id = OperationIdentifier::new(
format!("test-op-{}", result_value),
None,
Some("test-step".to_string()),
);
let config = StepConfig {
step_semantics: StepSemantics::AtLeastOncePerRetry,
..Default::default()
};
let logger = create_test_logger();
let result: Result<i32, DurableError> = step_handler(
move |_ctx| async move { Ok(result_value) },
&state,
&op_id,
&config,
&logger,
).await;
prop_assert!(result.is_ok(), "Step should succeed");
prop_assert_eq!(result.unwrap(), result_value, "Result should match input");
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
prop_assert!(checkpoint_result.is_succeeded(), "Checkpoint should be succeeded");
if let Some(result_str) = checkpoint_result.result() {
let deserialized: i32 = serde_json::from_str(result_str).unwrap();
prop_assert_eq!(deserialized, result_value, "Checkpointed result should match");
}
Ok(())
})?;
}
#[test]
fn prop_at_most_once_checkpoints_error_on_failure(
error_msg in "[a-zA-Z0-9 ]{1,50}",
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = Arc::new(MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2"))));
let state = create_test_state(client);
let op_id = OperationIdentifier::new(
format!("test-op-fail-{}", error_msg.len()),
None,
Some("test-step".to_string()),
);
let config = StepConfig {
step_semantics: StepSemantics::AtMostOncePerRetry,
..Default::default()
};
let logger = create_test_logger();
let error_msg_clone = error_msg.clone();
let result: Result<i32, DurableError> = step_handler(
move |_ctx| async move { Err(error_msg_clone.into()) },
&state,
&op_id,
&config,
&logger,
).await;
prop_assert!(result.is_err(), "Step should fail");
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
prop_assert!(checkpoint_result.is_failed(), "Checkpoint should be failed");
Ok(())
})?;
}
#[test]
fn prop_at_least_once_checkpoints_error_on_failure(
error_msg in "[a-zA-Z0-9 ]{1,50}",
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = Arc::new(MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))));
let state = create_test_state(client);
let op_id = OperationIdentifier::new(
format!("test-op-fail-{}", error_msg.len()),
None,
Some("test-step".to_string()),
);
let config = StepConfig {
step_semantics: StepSemantics::AtLeastOncePerRetry,
..Default::default()
};
let logger = create_test_logger();
let error_msg_clone = error_msg.clone();
let result: Result<i32, DurableError> = step_handler(
move |_ctx| async move { Err(error_msg_clone.into()) },
&state,
&op_id,
&config,
&logger,
).await;
prop_assert!(result.is_err(), "Step should fail");
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
prop_assert!(checkpoint_result.is_failed(), "Checkpoint should be failed");
Ok(())
})?;
}
#[test]
fn prop_replay_returns_checkpointed_result(
result_value in any::<i32>(),
semantics in prop_oneof![
Just(StepSemantics::AtMostOncePerRetry),
Just(StepSemantics::AtLeastOncePerRetry),
],
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = Arc::new(MockDurableServiceClient::new());
let mut op = Operation::new("test-op-replay", OperationType::Step);
op.status = OperationStatus::Succeeded;
op.result = Some(result_value.to_string());
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = OperationIdentifier::new("test-op-replay", None, None);
let config = StepConfig {
step_semantics: semantics,
..Default::default()
};
let logger = create_test_logger();
let was_called = Arc::new(AtomicBool::new(false));
let was_called_clone = was_called.clone();
let result: Result<i32, DurableError> = step_handler(
move |_ctx| async move {
was_called_clone.store(true, Ordering::SeqCst);
Ok(999) },
&state,
&op_id,
&config,
&logger,
).await;
prop_assert!(result.is_ok(), "Replay should succeed");
prop_assert_eq!(result.unwrap(), result_value, "Should return checkpointed value");
prop_assert!(!was_called.load(Ordering::SeqCst), "Function should not be called during replay");
Ok(())
})?;
}
#[test]
fn prop_ready_status_resumes_without_start_checkpoint(
result_value in any::<i32>(),
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = Arc::new(MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))));
let mut op = Operation::new("test-op-ready", OperationType::Step);
op.status = OperationStatus::Ready;
let initial_state = InitialExecutionState::with_operations(vec![op]);
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
initial_state,
client,
));
let op_id = OperationIdentifier::new("test-op-ready", None, None);
let config = StepConfig {
step_semantics: StepSemantics::AtMostOncePerRetry,
..Default::default()
};
let logger = create_test_logger();
let was_called = Arc::new(AtomicBool::new(false));
let was_called_clone = was_called.clone();
let result: Result<i32, DurableError> = step_handler(
move |_ctx| async move {
was_called_clone.store(true, Ordering::SeqCst);
Ok(result_value)
},
&state,
&op_id,
&config,
&logger,
).await;
prop_assert!(result.is_ok(), "Step should succeed");
prop_assert_eq!(result.unwrap(), result_value, "Result should match input");
prop_assert!(was_called.load(Ordering::SeqCst), "Function should be called for READY status");
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
prop_assert!(checkpoint_result.is_succeeded(), "Checkpoint should be succeeded");
Ok(())
})?;
}
}
}
}