use std::future::Future;
use std::time::Duration;
use aws_sdk_lambda::types::{
ErrorObject, OperationAction, OperationStatus, OperationType, OperationUpdate,
};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::context::DurableContext;
use crate::error::DurableError;
use crate::types::StepOptions;
impl DurableContext {
pub async fn step<T, E, F, Fut>(
&mut self,
name: &str,
f: F,
) -> Result<Result<T, E>, DurableError>
where
T: Serialize + DeserializeOwned + Send + 'static,
E: Serialize + DeserializeOwned + Send + 'static,
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<T, E>> + Send + 'static,
{
self.step_with_options(name, StepOptions::default(), f)
.await
}
#[allow(clippy::await_holding_lock)]
pub async fn step_with_options<T, E, F, Fut>(
&mut self,
name: &str,
options: StepOptions,
f: F,
) -> Result<Result<T, E>, DurableError>
where
T: Serialize + DeserializeOwned + Send + 'static,
E: Serialize + DeserializeOwned + Send + 'static,
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<T, E>> + Send + 'static,
{
let op_id = self.replay_engine_mut().generate_operation_id();
let span = tracing::info_span!(
"durable_operation",
op.name = name,
op.type = "step",
op.id = %op_id,
);
let _guard = span.enter();
tracing::trace!("durable_operation");
if let Some(operation) = self.replay_engine().check_result(&op_id) {
let result = extract_step_result::<T, E>(operation)?;
self.replay_engine_mut().track_replay(&op_id);
return Ok(result);
}
let is_retry_reexecution =
self.replay_engine()
.operations()
.get(&op_id)
.is_some_and(|op| {
matches!(
op.status,
OperationStatus::Pending
| OperationStatus::Ready
| OperationStatus::Started
)
});
let current_attempt = if is_retry_reexecution {
self.replay_engine()
.operations()
.get(&op_id)
.and_then(|op| op.step_details())
.map(|d| d.attempt())
.unwrap_or(1)
} else {
let start_update = OperationUpdate::builder()
.id(op_id.clone())
.r#type(OperationType::Step)
.action(OperationAction::Start)
.name(name)
.sub_type("Step")
.build()
.map_err(|e| DurableError::checkpoint_failed(name, e))?;
if self.is_batch_mode() {
self.push_pending_update(start_update);
} else {
let start_response = self
.backend()
.checkpoint(
self.arn(),
self.checkpoint_token(),
vec![start_update],
None,
)
.await?;
let new_token = start_response.checkpoint_token().ok_or_else(|| {
DurableError::checkpoint_failed(
name,
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"checkpoint response missing checkpoint_token",
),
)
})?;
self.set_checkpoint_token(new_token.to_string());
if let Some(new_state) = start_response.new_execution_state() {
for op in new_state.operations() {
self.replay_engine_mut()
.insert_operation(op.id().to_string(), op.clone());
}
}
if let Some(operation) = self.replay_engine().check_result(&op_id) {
let result = extract_step_result::<T, E>(operation)?;
self.replay_engine_mut().track_replay(&op_id);
return Ok(result);
}
}
1 };
let name_owned = name.to_string();
let mut handle = tokio::spawn(async move { f().await });
let user_result = if let Some(secs) = options.get_timeout_seconds() {
match tokio::time::timeout(Duration::from_secs(secs), &mut handle).await {
Ok(join_result) => join_result.map_err(|join_err| {
DurableError::checkpoint_failed(
&name_owned,
std::io::Error::other(format!("step closure panicked: {join_err}")),
)
})?,
Err(_elapsed) => {
handle.abort();
return Err(DurableError::step_timeout(&name_owned));
}
}
} else {
handle.await.map_err(|join_err| {
DurableError::checkpoint_failed(
&name_owned,
std::io::Error::other(format!("step closure panicked: {join_err}")),
)
})?
};
match &user_result {
Ok(value) => {
let payload = serde_json::to_string(value)
.map_err(|e| DurableError::serialization(std::any::type_name::<T>(), e))?;
let succeed_update = OperationUpdate::builder()
.id(op_id.clone())
.r#type(OperationType::Step)
.action(OperationAction::Succeed)
.name(name)
.sub_type("Step")
.payload(payload)
.build()
.map_err(|e| DurableError::checkpoint_failed(name, e))?;
if self.is_batch_mode() {
self.push_pending_update(succeed_update);
} else {
let response = self
.backend()
.checkpoint(
self.arn(),
self.checkpoint_token(),
vec![succeed_update],
None,
)
.await?;
let new_token = response.checkpoint_token().ok_or_else(|| {
DurableError::checkpoint_failed(
name,
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"checkpoint response missing checkpoint_token",
),
)
})?;
self.set_checkpoint_token(new_token.to_string());
}
}
Err(error) => {
let max_retries = options.get_retries().unwrap_or(0);
let should_retry = if let Some(pred) = options.get_retry_if() {
pred(error as &dyn std::any::Any)
} else {
true };
if should_retry && (current_attempt as u32) <= max_retries {
let delay = options.get_backoff_seconds().unwrap_or(0);
let aws_step_options = aws_sdk_lambda::types::StepOptions::builder()
.next_attempt_delay_seconds(delay)
.build();
let retry_update = OperationUpdate::builder()
.id(op_id.clone())
.r#type(OperationType::Step)
.action(OperationAction::Retry)
.name(name)
.sub_type("Step")
.step_options(aws_step_options)
.build()
.map_err(|e| DurableError::checkpoint_failed(name, e))?;
if self.is_batch_mode() {
self.push_pending_update(retry_update);
self.flush_batch().await?;
} else {
let response = self
.backend()
.checkpoint(
self.arn(),
self.checkpoint_token(),
vec![retry_update],
None,
)
.await?;
let new_token = response.checkpoint_token().ok_or_else(|| {
DurableError::checkpoint_failed(
name,
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"checkpoint response missing checkpoint_token",
),
)
})?;
self.set_checkpoint_token(new_token.to_string());
}
return Err(DurableError::step_retry_scheduled(name));
}
let error_data = serde_json::to_string(error)
.map_err(|e| DurableError::serialization(std::any::type_name::<E>(), e))?;
let error_object = ErrorObject::builder()
.error_type(std::any::type_name::<E>())
.error_data(error_data)
.build();
let fail_update = OperationUpdate::builder()
.id(op_id.clone())
.r#type(OperationType::Step)
.action(OperationAction::Fail)
.name(name)
.sub_type("Step")
.error(error_object)
.build()
.map_err(|e| DurableError::checkpoint_failed(name, e))?;
if self.is_batch_mode() {
self.push_pending_update(fail_update);
} else {
let response = self
.backend()
.checkpoint(self.arn(), self.checkpoint_token(), vec![fail_update], None)
.await?;
let new_token = response.checkpoint_token().ok_or_else(|| {
DurableError::checkpoint_failed(
name,
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"checkpoint response missing checkpoint_token",
),
)
})?;
self.set_checkpoint_token(new_token.to_string());
}
}
}
Ok(user_result)
}
}
fn extract_step_result<T, E>(
operation: &aws_sdk_lambda::types::Operation,
) -> Result<Result<T, E>, DurableError>
where
T: DeserializeOwned,
E: DeserializeOwned,
{
match &operation.status {
OperationStatus::Succeeded => {
let result_json = operation
.step_details()
.and_then(|d| d.result())
.ok_or_else(|| {
DurableError::checkpoint_failed(
"step",
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"SUCCEEDED operation missing step_details.result",
),
)
})?;
let value: T = serde_json::from_str(result_json)
.map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))?;
Ok(Ok(value))
}
OperationStatus::Failed => {
let error_data = operation
.step_details()
.and_then(|d| d.error())
.and_then(|e| e.error_data())
.ok_or_else(|| {
DurableError::checkpoint_failed(
"step",
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"FAILED operation missing step_details.error.error_data",
),
)
})?;
let error: E = serde_json::from_str(error_data)
.map_err(|e| DurableError::deserialization(std::any::type_name::<E>(), e))?;
Ok(Err(error))
}
other => Err(DurableError::replay_mismatch(
"Succeeded or Failed",
format!("{other:?}"),
0,
)),
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
use aws_sdk_lambda::types::{
ErrorObject, Operation, OperationStatus, OperationType, OperationUpdate, StepDetails,
};
use aws_smithy_types::DateTime;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use tracing_test::traced_test;
use crate::backend::DurableBackend;
use crate::context::DurableContext;
use crate::error::DurableError;
use crate::operation_id::OperationIdGenerator;
use crate::types::StepOptions;
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct CheckpointCall {
arn: String,
checkpoint_token: String,
updates: Vec<OperationUpdate>,
}
struct MockBackend {
calls: Arc<Mutex<Vec<CheckpointCall>>>,
checkpoint_token: String,
}
impl MockBackend {
fn new(checkpoint_token: &str) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
let calls = Arc::new(Mutex::new(Vec::new()));
let backend = Self {
calls: calls.clone(),
checkpoint_token: checkpoint_token.to_string(),
};
(backend, calls)
}
}
#[async_trait::async_trait]
impl DurableBackend for MockBackend {
async fn checkpoint(
&self,
arn: &str,
checkpoint_token: &str,
updates: Vec<OperationUpdate>,
_client_token: Option<&str>,
) -> Result<CheckpointDurableExecutionOutput, DurableError> {
self.calls.lock().await.push(CheckpointCall {
arn: arn.to_string(),
checkpoint_token: checkpoint_token.to_string(),
updates,
});
Ok(CheckpointDurableExecutionOutput::builder()
.checkpoint_token(&self.checkpoint_token)
.build())
}
async fn get_execution_state(
&self,
_arn: &str,
_checkpoint_token: &str,
_next_marker: &str,
_max_items: i32,
) -> Result<GetDurableExecutionStateOutput, DurableError> {
Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
}
}
#[tokio::test]
async fn test_step_executes_closure_in_executing_mode() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let result: Result<i32, String> = ctx.step("my_step", || async { Ok(42) }).await.unwrap();
assert_eq!(result.unwrap(), 42);
let captured = calls.lock().await;
assert_eq!(captured.len(), 2, "expected START + SUCCEED checkpoints");
let start_call = &captured[0];
assert_eq!(start_call.updates.len(), 1);
let start_update = &start_call.updates[0];
assert_eq!(start_update.r#type(), &OperationType::Step);
assert_eq!(
start_update.action(),
&aws_sdk_lambda::types::OperationAction::Start
);
assert_eq!(start_update.name(), Some("my_step"));
let succeed_call = &captured[1];
assert_eq!(succeed_call.updates.len(), 1);
let succeed_update = &succeed_call.updates[0];
assert_eq!(succeed_update.r#type(), &OperationType::Step);
assert_eq!(
succeed_update.action(),
&aws_sdk_lambda::types::OperationAction::Succeed
);
assert_eq!(succeed_update.payload().unwrap(), "42");
assert_eq!(succeed_call.checkpoint_token, "new-token");
}
#[tokio::test]
async fn test_step_returns_cached_result_in_replaying_mode() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut gen = OperationIdGenerator::new(None);
let expected_op_id = gen.next_id();
let cached_op = Operation::builder()
.id(&expected_op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Succeeded)
.start_timestamp(DateTime::from_secs(0))
.step_details(
StepDetails::builder()
.attempt(1)
.result(r#"{"value":42}"#)
.build(),
)
.build()
.unwrap();
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![cached_op],
None,
)
.await
.unwrap();
let closure_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let closure_called_clone = closure_called.clone();
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct MyResult {
value: i32,
}
let result: Result<MyResult, String> = ctx
.step("my_step", move || {
let flag = closure_called_clone.clone();
async move {
flag.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(MyResult { value: 999 })
}
})
.await
.unwrap();
assert_eq!(result.unwrap(), MyResult { value: 42 });
assert!(
!closure_called.load(std::sync::atomic::Ordering::SeqCst),
"closure should NOT have been called during replay"
);
let captured = calls.lock().await;
assert_eq!(captured.len(), 0, "no checkpoint calls during replay");
}
#[tokio::test]
async fn test_step_returns_cached_error_in_replaying_mode() {
let (backend, _calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut gen = OperationIdGenerator::new(None);
let expected_op_id = gen.next_id();
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct MyError {
code: i32,
message: String,
}
let error_data = serde_json::to_string(&MyError {
code: 404,
message: "not found".to_string(),
})
.unwrap();
let cached_op = Operation::builder()
.id(&expected_op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Failed)
.start_timestamp(DateTime::from_secs(0))
.step_details(
StepDetails::builder()
.attempt(1)
.error(
ErrorObject::builder()
.error_type("MyError")
.error_data(&error_data)
.build(),
)
.build(),
)
.build()
.unwrap();
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![cached_op],
None,
)
.await
.unwrap();
let result: Result<String, MyError> = ctx
.step("my_step", || async { Ok("nope".to_string()) })
.await
.unwrap();
let err = result.unwrap_err();
assert_eq!(err.code, 404);
assert_eq!(err.message, "not found");
}
#[tokio::test]
async fn test_step_serialization_roundtrip() {
let (backend, _calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
struct ComplexData {
name: String,
values: Vec<i32>,
nested: NestedData,
optional: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
struct NestedData {
flag: bool,
score: f64,
}
let expected = ComplexData {
name: "test".to_string(),
values: vec![1, 2, 3],
nested: NestedData {
flag: true,
score: 99.5,
},
optional: Some("present".to_string()),
};
let mut gen = OperationIdGenerator::new(None);
let expected_op_id = gen.next_id();
let serialized = serde_json::to_string(&expected).unwrap();
let cached_op = Operation::builder()
.id(&expected_op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Succeeded)
.start_timestamp(DateTime::from_secs(0))
.step_details(
StepDetails::builder()
.attempt(1)
.result(&serialized)
.build(),
)
.build()
.unwrap();
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![cached_op],
None,
)
.await
.unwrap();
let result: Result<ComplexData, String> = ctx
.step("complex_step", || async {
panic!("should not execute during replay")
})
.await
.unwrap();
assert_eq!(result.unwrap(), expected);
}
#[tokio::test]
async fn test_step_sequential_unique_ids() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let _r1: Result<i32, String> = ctx.step("step_1", || async { Ok(1) }).await.unwrap();
let _r2: Result<i32, String> = ctx.step("step_2", || async { Ok(2) }).await.unwrap();
let captured = calls.lock().await;
assert_eq!(captured.len(), 4);
let step1_id = captured[0].updates[0].id().to_string();
let step2_id = captured[2].updates[0].id().to_string();
assert_ne!(
step1_id, step2_id,
"sequential steps must have different operation IDs"
);
assert_eq!(step1_id, captured[1].updates[0].id());
assert_eq!(step2_id, captured[3].updates[0].id());
}
#[tokio::test]
async fn test_step_tracks_replay() {
let (backend, _calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut gen = OperationIdGenerator::new(None);
let expected_op_id = gen.next_id();
let cached_op = Operation::builder()
.id(&expected_op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Succeeded)
.start_timestamp(DateTime::from_secs(0))
.step_details(StepDetails::builder().attempt(1).result("100").build())
.build()
.unwrap();
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![cached_op],
None,
)
.await
.unwrap();
assert!(
ctx.is_replaying(),
"should be replaying before visiting cached ops"
);
let result: Result<i32, String> =
ctx.step("cached_step", || async { Ok(999) }).await.unwrap();
assert_eq!(result.unwrap(), 100);
assert!(
!ctx.is_replaying(),
"should transition to executing after all cached ops replayed"
);
}
#[tokio::test]
async fn test_step_with_options_basic_success() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let result: Result<i32, String> = ctx
.step_with_options("opts_step", StepOptions::default(), || async { Ok(42) })
.await
.unwrap();
assert_eq!(result.unwrap(), 42);
let captured = calls.lock().await;
assert_eq!(captured.len(), 2, "expected START + SUCCEED checkpoints");
let start_update = &captured[0].updates[0];
assert_eq!(start_update.r#type(), &OperationType::Step);
assert_eq!(
start_update.action(),
&aws_sdk_lambda::types::OperationAction::Start
);
assert_eq!(start_update.name(), Some("opts_step"));
let succeed_update = &captured[1].updates[0];
assert_eq!(succeed_update.r#type(), &OperationType::Step);
assert_eq!(
succeed_update.action(),
&aws_sdk_lambda::types::OperationAction::Succeed
);
assert_eq!(succeed_update.payload().unwrap(), "42");
}
#[tokio::test]
async fn test_step_with_options_retry_on_failure() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let options = StepOptions::new().retries(3).backoff_seconds(5);
let result: Result<Result<i32, String>, DurableError> = ctx
.step_with_options("retry_step", options, || async {
Err("transient failure".to_string())
})
.await;
let err = result.unwrap_err();
match err {
DurableError::StepRetryScheduled { .. } => {}
other => panic!("expected StepRetryScheduled, got {other:?}"),
}
let captured = calls.lock().await;
assert_eq!(captured.len(), 2, "expected START + RETRY checkpoints");
let start_update = &captured[0].updates[0];
assert_eq!(
start_update.action(),
&aws_sdk_lambda::types::OperationAction::Start
);
let retry_update = &captured[1].updates[0];
assert_eq!(
retry_update.action(),
&aws_sdk_lambda::types::OperationAction::Retry
);
let step_opts = retry_update
.step_options()
.expect("should have step_options");
assert_eq!(step_opts.next_attempt_delay_seconds(), Some(5));
}
#[tokio::test]
async fn test_step_with_options_retry_exhaustion() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut gen = OperationIdGenerator::new(None);
let expected_op_id = gen.next_id();
let cached_op = Operation::builder()
.id(&expected_op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Pending)
.start_timestamp(DateTime::from_secs(0))
.step_details(StepDetails::builder().attempt(4).build())
.build()
.unwrap();
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![cached_op],
None,
)
.await
.unwrap();
let options = StepOptions::new().retries(3).backoff_seconds(5);
let result: Result<Result<i32, String>, DurableError> = ctx
.step_with_options("exhaust_step", options, || async {
Err("final failure".to_string())
})
.await;
let inner = result.unwrap();
let user_error = inner.unwrap_err();
assert_eq!(user_error, "final failure");
let captured = calls.lock().await;
assert_eq!(captured.len(), 1, "expected only FAIL checkpoint");
let fail_update = &captured[0].updates[0];
assert_eq!(
fail_update.action(),
&aws_sdk_lambda::types::OperationAction::Fail
);
}
#[tokio::test]
async fn test_step_with_options_replay_succeeded_with_retries() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut gen = OperationIdGenerator::new(None);
let expected_op_id = gen.next_id();
let cached_op = Operation::builder()
.id(&expected_op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Succeeded)
.start_timestamp(DateTime::from_secs(0))
.step_details(StepDetails::builder().attempt(3).result("99").build())
.build()
.unwrap();
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![cached_op],
None,
)
.await
.unwrap();
let closure_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let closure_called_clone = closure_called.clone();
let options = StepOptions::new().retries(3);
let result: Result<i32, String> = ctx
.step_with_options("replay_retry_step", options, move || {
let flag = closure_called_clone.clone();
async move {
flag.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(999)
}
})
.await
.unwrap();
assert_eq!(result.unwrap(), 99);
assert!(
!closure_called.load(std::sync::atomic::Ordering::SeqCst),
"closure should NOT have been called during replay"
);
let captured = calls.lock().await;
assert_eq!(captured.len(), 0, "no checkpoint calls during replay");
}
#[tokio::test]
async fn test_step_backward_compatibility() {
let (backend, calls) = MockBackend::new("compat-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let result: Result<String, String> = ctx
.step("compat_step", || async { Ok("hello".to_string()) })
.await
.unwrap();
assert_eq!(result.unwrap(), "hello");
let captured = calls.lock().await;
assert_eq!(captured.len(), 2, "expected START + SUCCEED checkpoints");
let start_update = &captured[0].updates[0];
assert_eq!(
start_update.action(),
&aws_sdk_lambda::types::OperationAction::Start
);
assert_eq!(start_update.name(), Some("compat_step"));
let succeed_update = &captured[1].updates[0];
assert_eq!(
succeed_update.action(),
&aws_sdk_lambda::types::OperationAction::Succeed
);
assert_eq!(succeed_update.payload().unwrap(), r#""hello""#);
}
#[test]
fn test_step_options_builder() {
let default_opts = StepOptions::default();
assert_eq!(default_opts.get_retries(), None);
assert_eq!(default_opts.get_backoff_seconds(), None);
let new_opts = StepOptions::new();
assert_eq!(new_opts.get_retries(), None);
assert_eq!(new_opts.get_backoff_seconds(), None);
let opts = StepOptions::new().retries(5).backoff_seconds(10);
assert_eq!(opts.get_retries(), Some(5));
assert_eq!(opts.get_backoff_seconds(), Some(10));
let opts2 = StepOptions::new().retries(1).retries(3);
assert_eq!(opts2.get_retries(), Some(3));
}
#[tokio::test]
async fn test_step_with_options_typed_error_roundtrip() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
#[derive(Serialize, Deserialize, Debug, PartialEq)]
enum DomainError {
NotFound { resource: String },
PermissionDenied { user: String, action: String },
RateLimited { retry_after_secs: u64 },
}
let mut gen = OperationIdGenerator::new(None);
let expected_op_id = gen.next_id();
let original_error = DomainError::PermissionDenied {
user: "alice".to_string(),
action: "delete".to_string(),
};
let error_data = serde_json::to_string(&original_error).unwrap();
let cached_op = Operation::builder()
.id(&expected_op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Failed)
.start_timestamp(DateTime::from_secs(0))
.step_details(
StepDetails::builder()
.attempt(1)
.error(
ErrorObject::builder()
.error_type("DomainError")
.error_data(&error_data)
.build(),
)
.build(),
)
.build()
.unwrap();
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![cached_op],
None,
)
.await
.unwrap();
let result: Result<String, DomainError> = ctx
.step_with_options("typed_err_step", StepOptions::default(), || async {
Ok("should not run".to_string())
})
.await
.unwrap();
let err = result.unwrap_err();
assert_eq!(
err,
DomainError::PermissionDenied {
user: "alice".to_string(),
action: "delete".to_string(),
}
);
let captured = calls.lock().await;
assert_eq!(captured.len(), 0, "no checkpoint calls during replay");
}
#[tokio::test]
async fn test_step_execute_fail_checkpoint() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let result: Result<i32, String> = ctx
.step("failing_step", || async {
Err("something went wrong".to_string())
})
.await
.unwrap();
assert_eq!(result.unwrap_err(), "something went wrong");
let captured = calls.lock().await;
assert_eq!(captured.len(), 2, "expected START + FAIL checkpoints");
assert_eq!(
captured[0].updates[0].action(),
&aws_sdk_lambda::types::OperationAction::Start
);
assert_eq!(
captured[1].updates[0].action(),
&aws_sdk_lambda::types::OperationAction::Fail
);
}
struct NoneTokenMockBackend;
#[async_trait::async_trait]
impl DurableBackend for NoneTokenMockBackend {
async fn checkpoint(
&self,
_arn: &str,
_checkpoint_token: &str,
_updates: Vec<OperationUpdate>,
_client_token: Option<&str>,
) -> Result<CheckpointDurableExecutionOutput, DurableError> {
Ok(CheckpointDurableExecutionOutput::builder().build())
}
async fn get_execution_state(
&self,
_arn: &str,
_checkpoint_token: &str,
_next_marker: &str,
_max_items: i32,
) -> Result<GetDurableExecutionStateOutput, DurableError> {
Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
}
}
#[tokio::test]
async fn test_step_timeout_aborts_slow_closure() {
let (backend, _calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let options = StepOptions::new().timeout_seconds(1);
let result: Result<Result<i32, String>, DurableError> = ctx
.step_with_options("slow_step", options, || async {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
Ok::<i32, String>(42)
})
.await;
let err = result.unwrap_err();
match err {
DurableError::StepTimeout { operation_name } => {
assert_eq!(operation_name, "slow_step");
}
other => panic!("expected StepTimeout, got {other:?}"),
}
}
#[tokio::test]
async fn test_step_timeout_does_not_fire_when_fast_enough() {
let (backend, _calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let options = StepOptions::new().timeout_seconds(5);
let result: Result<i32, String> = ctx
.step_with_options("fast_step", options, || async { Ok(99) })
.await
.unwrap();
assert_eq!(result.unwrap(), 99);
}
#[tokio::test]
async fn test_retry_if_false_causes_immediate_fail_no_retry_budget_consumed() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let options = StepOptions::new().retries(3).retry_if(|_e: &String| false);
let result: Result<Result<i32, String>, DurableError> = ctx
.step_with_options("no_retry_step", options, || async {
Err("permanent error".to_string())
})
.await;
let inner = result.unwrap();
let user_error = inner.unwrap_err();
assert_eq!(user_error, "permanent error");
let captured = calls.lock().await;
assert_eq!(
captured.len(),
2,
"expected START + FAIL, got {}",
captured.len()
);
assert_eq!(
captured[1].updates[0].action(),
&aws_sdk_lambda::types::OperationAction::Fail,
"second checkpoint should be FAIL not RETRY"
);
}
#[tokio::test]
async fn test_retry_if_true_retries_normally() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let options = StepOptions::new().retries(3).retry_if(|_e: &String| true);
let result: Result<Result<i32, String>, DurableError> = ctx
.step_with_options("retry_true_step", options, || async {
Err("transient error".to_string())
})
.await;
let err = result.unwrap_err();
match err {
DurableError::StepRetryScheduled { .. } => {}
other => panic!("expected StepRetryScheduled, got {other:?}"),
}
let captured = calls.lock().await;
assert_eq!(captured.len(), 2, "expected START + RETRY");
assert_eq!(
captured[1].updates[0].action(),
&aws_sdk_lambda::types::OperationAction::Retry,
);
}
#[tokio::test]
async fn test_no_retry_if_retries_all_errors_backward_compatible() {
let (backend, calls) = MockBackend::new("new-token");
let backend = Arc::new(backend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let options = StepOptions::new().retries(2);
let result: Result<Result<i32, String>, DurableError> = ctx
.step_with_options("compat_retry_step", options, || async {
Err("any error".to_string())
})
.await;
let err = result.unwrap_err();
match err {
DurableError::StepRetryScheduled { .. } => {}
other => panic!("expected StepRetryScheduled, got {other:?}"),
}
let captured = calls.lock().await;
assert_eq!(captured.len(), 2, "expected START + RETRY");
}
#[tokio::test]
async fn checkpoint_none_token_returns_error() {
let backend = Arc::new(NoneTokenMockBackend);
let mut ctx = DurableContext::new(
backend,
"arn:test".to_string(),
"initial-token".to_string(),
vec![], None,
)
.await
.unwrap();
let result: Result<Result<i32, String>, DurableError> =
ctx.step("test_step", || async { Ok(42) }).await;
let err = result
.expect_err("step should fail when checkpoint response has None checkpoint_token");
match &err {
DurableError::CheckpointFailed { operation_name, .. } => {
assert!(
operation_name.contains("test_step"),
"error should reference the operation name, got: {}",
operation_name
);
}
other => panic!("expected DurableError::CheckpointFailed, got: {:?}", other),
}
let err_msg = err.to_string();
assert!(
err_msg.contains("checkpoint response missing checkpoint_token"),
"error message should mention missing checkpoint_token, got: {}",
err_msg
);
}
#[traced_test]
#[tokio::test]
async fn test_step_emits_span() {
let (backend, _calls) = MockBackend::new("tok");
let mut ctx = DurableContext::new(
Arc::new(backend),
"arn:test".to_string(),
"tok".to_string(),
vec![],
None,
)
.await
.unwrap();
let _: Result<i32, String> = ctx.step("validate", || async { Ok(42) }).await.unwrap();
assert!(logs_contain("durable_operation"));
assert!(logs_contain("validate"));
assert!(logs_contain("step"));
}
#[traced_test]
#[tokio::test]
async fn test_span_includes_op_id() {
let (backend, _calls) = MockBackend::new("tok");
let mut ctx = DurableContext::new(
Arc::new(backend),
"arn:test".to_string(),
"tok".to_string(),
vec![],
None,
)
.await
.unwrap();
let _: Result<i32, String> = ctx.step("id_check", || async { Ok(42) }).await.unwrap();
assert!(logs_contain("durable_operation"));
assert!(logs_contain("op.id"));
}
}