use std::future::Future;
use std::sync::Arc;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Mutex;
use crate::concurrency::{BatchItem, BatchResult, CompletionReason};
use crate::context::{LogInfo, Logger, OperationIdentifier};
use crate::error::{DurableError, ErrorObject};
use crate::operation::OperationType;
use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
use crate::state::ExecutionState;
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct PromiseCombinatorResult<T> {
pub results: Vec<PromiseOutcome<T>>,
pub winner_index: Option<usize>,
pub completion_reason: PromiseCompletionReason,
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct PromiseOutcome<T> {
pub index: usize,
pub succeeded: bool,
pub result: Option<T>,
pub error: Option<ErrorObject>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, serde::Deserialize)]
pub enum PromiseCompletionReason {
AllSucceeded,
AllSettled,
FirstSettled,
FirstSucceeded,
AllFailed,
OneFailed,
}
impl<T> PromiseOutcome<T> {
pub fn success(index: usize, result: T) -> Self {
Self {
index,
succeeded: true,
result: Some(result),
error: None,
}
}
pub fn failure(index: usize, error: ErrorObject) -> Self {
Self {
index,
succeeded: false,
result: None,
error: Some(error),
}
}
}
pub async fn all_handler<T, Fut>(
futures: Vec<Fut>,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<Vec<T>, DurableError>
where
T: Serialize + DeserializeOwned + Send + Clone + 'static,
Fut: Future<Output = Result<T, DurableError>> + Send + 'static,
{
let log_info = create_log_info(state, op_id);
logger.debug(
&format!("Starting 'all' combinator with {} futures", futures.len()),
&log_info,
);
if let Some(result) = check_replay::<Vec<T>>(state, op_id, logger).await? {
return Ok(result);
}
checkpoint_start(state, op_id).await?;
let total = futures.len();
if total == 0 {
checkpoint_succeed(state, op_id, &Vec::<T>::new()).await?;
return Ok(Vec::new());
}
let (tx, mut rx) = tokio::sync::mpsc::channel::<(usize, Result<T, DurableError>)>(total);
for (index, future) in futures.into_iter().enumerate() {
let tx = tx.clone();
tokio::spawn(async move {
let result = future.await;
let _ = tx.send((index, result)).await;
});
}
drop(tx);
let mut results: Vec<Option<Result<T, DurableError>>> = (0..total).map(|_| None).collect();
let mut received = 0;
while let Some((index, result)) = rx.recv().await {
results[index] = Some(result);
received += 1;
if received == total {
break;
}
}
let mut final_results = Vec::with_capacity(total);
for (index, result) in results.into_iter().enumerate() {
match result {
Some(Ok(value)) => final_results.push(value),
Some(Err(e)) => {
let error_obj = ErrorObject::from(&e);
checkpoint_fail(state, op_id, error_obj).await?;
return Err(DurableError::UserCode {
message: format!("Promise at index {} failed: {}", index, e),
error_type: "AllCombinatorError".to_string(),
stack_trace: None,
});
}
None => {
let error_obj = ErrorObject::new("InternalError", "Promise result missing");
checkpoint_fail(state, op_id, error_obj).await?;
return Err(DurableError::execution("Promise result missing"));
}
}
} checkpoint_succeed(state, op_id, &final_results).await?;
logger.debug("'all' combinator completed successfully", &log_info);
Ok(final_results)
}
pub async fn all_settled_handler<T, Fut>(
futures: Vec<Fut>,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<BatchResult<T>, DurableError>
where
T: Serialize + DeserializeOwned + Send + Clone + 'static,
Fut: Future<Output = Result<T, DurableError>> + Send + 'static,
{
let log_info = create_log_info(state, op_id);
logger.debug(
&format!(
"Starting 'all_settled' combinator with {} futures",
futures.len()
),
&log_info,
);
if let Some(result) = check_replay_batch::<T>(state, op_id, logger).await? {
return Ok(result);
}
checkpoint_start(state, op_id).await?;
let total = futures.len();
if total == 0 {
let result = BatchResult::empty();
checkpoint_succeed_batch(state, op_id, &result).await?;
return Ok(result);
}
let results: Arc<Mutex<Vec<BatchItem<T>>>> =
Arc::new(Mutex::new((0..total).map(BatchItem::pending).collect()));
let mut handles = Vec::with_capacity(total);
for (index, future) in futures.into_iter().enumerate() {
let results = results.clone();
handles.push(tokio::spawn(async move {
let result = future.await;
let mut results_guard = results.lock().await;
match result {
Ok(value) => {
results_guard[index] = BatchItem::succeeded(index, value);
}
Err(e) => {
let error_obj = ErrorObject::from(&e);
results_guard[index] = BatchItem::failed(index, error_obj);
}
}
}));
}
for handle in handles {
let _ = handle.await;
}
let items = Arc::try_unwrap(results)
.map_err(|_| DurableError::execution("Failed to unwrap results"))?
.into_inner();
let batch_result = BatchResult::new(items, CompletionReason::AllCompleted);
checkpoint_succeed_batch(state, op_id, &batch_result).await?;
logger.debug("'all_settled' combinator completed", &log_info);
Ok(batch_result)
}
pub async fn race_handler<T, Fut>(
futures: Vec<Fut>,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<T, DurableError>
where
T: Serialize + DeserializeOwned + Send + Clone + 'static,
Fut: Future<Output = Result<T, DurableError>> + Send + 'static,
{
let log_info = create_log_info(state, op_id);
logger.debug(
&format!("Starting 'race' combinator with {} futures", futures.len()),
&log_info,
);
if let Some(result) = check_replay::<T>(state, op_id, logger).await? {
return Ok(result);
}
checkpoint_start(state, op_id).await?;
let total = futures.len();
if total == 0 {
let error_obj = ErrorObject::new("ValidationError", "race requires at least one future");
checkpoint_fail(state, op_id, error_obj).await?;
return Err(DurableError::validation(
"race requires at least one future",
));
}
let (tx, mut rx) = tokio::sync::mpsc::channel::<(usize, Result<T, DurableError>)>(total);
for (index, future) in futures.into_iter().enumerate() {
let tx = tx.clone();
tokio::spawn(async move {
let result = future.await;
let _ = tx.send((index, result)).await;
});
}
drop(tx);
if let Some((index, result)) = rx.recv().await {
match result {
Ok(value) => {
checkpoint_succeed(state, op_id, &value).await?;
logger.debug(
&format!("'race' combinator won by future at index {}", index),
&log_info,
);
Ok(value)
}
Err(e) => {
let error_obj = ErrorObject::from(&e);
checkpoint_fail(state, op_id, error_obj).await?;
Err(e)
}
}
} else {
let error_obj = ErrorObject::new("InternalError", "No futures completed");
checkpoint_fail(state, op_id, error_obj).await?;
Err(DurableError::execution("No futures completed"))
}
}
pub async fn any_handler<T, Fut>(
futures: Vec<Fut>,
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<T, DurableError>
where
T: Serialize + DeserializeOwned + Send + Clone + 'static,
Fut: Future<Output = Result<T, DurableError>> + Send + 'static,
{
let log_info = create_log_info(state, op_id);
logger.debug(
&format!("Starting 'any' combinator with {} futures", futures.len()),
&log_info,
);
if let Some(result) = check_replay::<T>(state, op_id, logger).await? {
return Ok(result);
}
checkpoint_start(state, op_id).await?;
let total = futures.len();
if total == 0 {
let error_obj = ErrorObject::new("ValidationError", "any requires at least one future");
checkpoint_fail(state, op_id, error_obj).await?;
return Err(DurableError::validation("any requires at least one future"));
}
let (success_tx, mut success_rx) = tokio::sync::mpsc::channel::<(usize, T)>(1);
let errors: Arc<Mutex<Vec<(usize, DurableError)>>> = Arc::new(Mutex::new(Vec::new()));
let completed_count: Arc<std::sync::atomic::AtomicUsize> =
Arc::new(std::sync::atomic::AtomicUsize::new(0));
for (index, future) in futures.into_iter().enumerate() {
let success_tx = success_tx.clone();
let errors = errors.clone();
let completed_count = completed_count.clone();
tokio::spawn(async move {
let result = future.await;
match result {
Ok(value) => {
let _ = success_tx.send((index, value)).await;
}
Err(e) => {
let mut errors_guard = errors.lock().await;
errors_guard.push((index, e));
drop(errors_guard);
let count =
completed_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
if count == total {
drop(success_tx);
}
}
}
});
}
drop(success_tx);
if let Some((index, value)) = success_rx.recv().await {
checkpoint_succeed(state, op_id, &value).await?;
logger.debug(
&format!("'any' combinator succeeded with future at index {}", index),
&log_info,
);
Ok(value)
} else {
let errors_guard = errors.lock().await;
let error_messages: Vec<String> = errors_guard
.iter()
.map(|(i, e)| format!("Future {}: {}", i, e))
.collect();
let combined_message = format!(
"All {} futures failed: {}",
total,
error_messages.join("; ")
);
let error_obj = ErrorObject::new("AnyCombinatorError", &combined_message);
checkpoint_fail(state, op_id, error_obj).await?;
logger.debug("'any' combinator failed - all futures failed", &log_info);
Err(DurableError::UserCode {
message: combined_message,
error_type: "AnyCombinatorError".to_string(),
stack_trace: None,
})
}
}
fn create_log_info(state: &Arc<ExecutionState>, op_id: &OperationIdentifier) -> LogInfo {
let mut log_info =
LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
if let Some(ref parent_id) = op_id.parent_id {
log_info = log_info.with_parent_id(parent_id);
}
log_info
}
async fn check_replay<T>(
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<Option<T>, DurableError>
where
T: Serialize + DeserializeOwned,
{
let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
if !checkpoint_result.is_existent() {
return Ok(None);
}
let log_info = create_log_info(state, op_id);
if let Some(op_type) = checkpoint_result.operation_type() {
if op_type != OperationType::Step {
return Err(DurableError::NonDeterministic {
message: format!(
"Expected Step operation but found {:?} at operation_id {}",
op_type, op_id.operation_id
),
operation_id: Some(op_id.operation_id.clone()),
});
}
}
if checkpoint_result.is_succeeded() {
logger.debug(
&format!("Replaying succeeded promise combinator: {}", op_id),
&log_info,
);
state.track_replay(&op_id.operation_id).await;
if let Some(result_str) = checkpoint_result.result() {
let serdes = JsonSerDes::<T>::new();
let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
let result =
serdes
.deserialize(result_str, &serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to deserialize checkpointed result: {}", e),
})?;
return Ok(Some(result));
}
}
if checkpoint_result.is_failed() {
logger.debug(
&format!("Replaying failed promise combinator: {}", op_id),
&log_info,
);
state.track_replay(&op_id.operation_id).await;
if let Some(error) = checkpoint_result.error() {
return Err(DurableError::UserCode {
message: error.error_message.clone(),
error_type: error.error_type.clone(),
stack_trace: error.stack_trace.clone(),
});
} else {
return Err(DurableError::execution(
"Promise combinator failed with unknown error",
));
}
}
Ok(None)
}
async fn check_replay_batch<T>(
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
logger: &Arc<dyn Logger>,
) -> Result<Option<BatchResult<T>>, DurableError>
where
T: Serialize + DeserializeOwned,
{
check_replay::<BatchResult<T>>(state, op_id, logger).await
}
async fn checkpoint_start(
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
) -> Result<(), DurableError> {
use crate::operation::OperationUpdate;
let mut update = OperationUpdate::start(&op_id.operation_id, OperationType::Step);
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
update = update.with_sub_type("PromiseCombinator");
state.create_checkpoint(update, true).await
}
async fn checkpoint_succeed<T>(
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
result: &T,
) -> Result<(), DurableError>
where
T: Serialize + DeserializeOwned,
{
use crate::operation::OperationUpdate;
let serdes = JsonSerDes::<T>::new();
let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
let serialized = serdes
.serialize(result, &serdes_ctx)
.map_err(|e| DurableError::SerDes {
message: format!("Failed to serialize result: {}", e),
})?;
let mut update =
OperationUpdate::succeed(&op_id.operation_id, OperationType::Step, Some(serialized));
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
state.create_checkpoint(update, true).await
}
async fn checkpoint_succeed_batch<T>(
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
result: &BatchResult<T>,
) -> Result<(), DurableError>
where
T: Serialize + DeserializeOwned,
{
checkpoint_succeed(state, op_id, result).await
}
async fn checkpoint_fail(
state: &Arc<ExecutionState>,
op_id: &OperationIdentifier,
error: ErrorObject,
) -> Result<(), DurableError> {
use crate::operation::OperationUpdate;
let mut update = OperationUpdate::fail(&op_id.operation_id, OperationType::Step, error);
if let Some(ref parent_id) = op_id.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = op_id.name {
update = update.with_name(name);
}
state.create_checkpoint(update, true).await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
use crate::context::TracingLogger;
use crate::lambda::InitialExecutionState;
use std::pin::Pin;
type DurableFuture<T> = Pin<Box<dyn Future<Output = Result<T, DurableError>> + Send>>;
fn create_mock_client() -> SharedDurableServiceClient {
Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-2")))
.with_checkpoint_response(Ok(CheckpointResponse::new("token-3"))),
)
}
fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
InitialExecutionState::new(),
client,
))
}
fn create_test_op_id(name: &str) -> OperationIdentifier {
OperationIdentifier::new(format!("test-op-{}", name), None, Some(name.to_string()))
}
fn create_test_logger() -> Arc<dyn Logger> {
Arc::new(TracingLogger)
}
#[tokio::test]
async fn test_all_handler_success() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("all-success");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Ok::<_, DurableError>(1) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async { Ok(2) }),
Box::pin(async { Ok(3) }),
];
let result = all_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), vec![1, 2, 3]);
}
#[tokio::test]
async fn test_all_handler_failure() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("all-failure");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Ok::<_, DurableError>(1) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async { Err(DurableError::execution("test error")) }),
Box::pin(async { Ok(3) }),
];
let result = all_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_all_handler_empty() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("all-empty");
let logger = create_test_logger();
let futures: Vec<DurableFuture<i32>> = vec![];
let result = all_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[tokio::test]
async fn test_all_settled_handler_mixed() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("all-settled-mixed");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Ok::<_, DurableError>(1) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async { Err(DurableError::execution("test error")) }),
Box::pin(async { Ok(3) }),
];
let result = all_settled_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_ok());
let batch = result.unwrap();
assert_eq!(batch.items.len(), 3);
assert_eq!(batch.success_count(), 2);
assert_eq!(batch.failure_count(), 1);
}
#[tokio::test]
async fn test_all_settled_handler_all_success() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("all-settled-success");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Ok::<_, DurableError>(1) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async { Ok(2) }),
];
let result = all_settled_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_ok());
let batch = result.unwrap();
assert!(batch.all_succeeded());
}
#[tokio::test]
async fn test_race_handler_first_wins() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("race-first");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Ok::<_, DurableError>(1) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Ok(2)
}),
];
let result = race_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 1);
}
#[tokio::test]
async fn test_race_handler_error_wins() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("race-error");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Err::<i32, _>(DurableError::execution("fast error")) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Ok(2)
}),
];
let result = race_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_race_handler_empty() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("race-empty");
let logger = create_test_logger();
let futures: Vec<DurableFuture<i32>> = vec![];
let result = race_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_any_handler_first_success() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("any-first");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Ok::<_, DurableError>(1) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async { Err(DurableError::execution("error")) }),
Box::pin(async { Ok(3) }),
];
let result = any_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_ok());
let value = result.unwrap();
assert!(value == 1 || value == 3);
}
#[tokio::test]
async fn test_any_handler_all_fail() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("any-all-fail");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Err::<i32, _>(DurableError::execution("error 1")) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async { Err(DurableError::execution("error 2")) }),
];
let result = any_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_err());
if let Err(DurableError::UserCode { message, .. }) = result {
assert!(message.contains("All 2 futures failed"));
} else {
panic!("Expected UserCode error");
}
}
#[tokio::test]
async fn test_any_handler_empty() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("any-empty");
let logger = create_test_logger();
let futures: Vec<DurableFuture<i32>> = vec![];
let result = any_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_any_handler_success_after_failures() {
let client = create_mock_client();
let state = create_test_state(client);
let op_id = create_test_op_id("any-success-after-fail");
let logger = create_test_logger();
let futures = vec![
Box::pin(async { Err::<i32, _>(DurableError::execution("error 1")) })
as Pin<Box<dyn Future<Output = Result<i32, DurableError>> + Send>>,
Box::pin(async { Err(DurableError::execution("error 2")) }),
Box::pin(async {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
Ok(42)
}),
];
let result = any_handler(futures, &state, &op_id, &logger).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
use crate::context::TracingLogger;
use crate::lambda::InitialExecutionState;
use proptest::prelude::*;
use std::pin::Pin;
type DurableFuture<T> = Pin<Box<dyn Future<Output = Result<T, DurableError>> + Send>>;
fn create_mock_client_with_responses(count: usize) -> SharedDurableServiceClient {
let mut client = MockDurableServiceClient::new();
for i in 0..count {
client = client
.with_checkpoint_response(Ok(CheckpointResponse::new(format!("token-{}", i))));
}
Arc::new(client)
}
fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"initial-token",
InitialExecutionState::new(),
client,
))
}
fn create_test_op_id(name: &str) -> OperationIdentifier {
OperationIdentifier::new(format!("test-op-{}", name), None, Some(name.to_string()))
}
fn create_test_logger() -> Arc<dyn Logger> {
Arc::new(TracingLogger)
}
mod promise_combinator_tests {
use super::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_all_returns_all_results_on_success(
values in prop::collection::vec(0i32..1000, 1..10),
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = create_mock_client_with_responses(10);
let state = create_test_state(client);
let op_id = create_test_op_id(&format!("all-prop-{}", values.len()));
let logger = create_test_logger();
let expected = values.clone();
let futures: Vec<DurableFuture<i32>> =
values.into_iter()
.map(|v| Box::pin(async move { Ok(v) }) as DurableFuture<i32>)
.collect();
let result = all_handler(futures, &state, &op_id, &logger).await;
prop_assert!(result.is_ok(), "all should succeed when all futures succeed");
prop_assert_eq!(result.unwrap(), expected, "all should return all results in order");
Ok(())
})?;
}
#[test]
fn prop_all_fails_on_any_error(
success_count in 0usize..5,
error_index in 0usize..10,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let total = success_count + 1; let error_index = error_index % total;
let client = create_mock_client_with_responses(10);
let state = create_test_state(client);
let op_id = create_test_op_id(&format!("all-fail-prop-{}", total));
let logger = create_test_logger();
let futures: Vec<DurableFuture<i32>> =
(0..total)
.map(|i| {
if i == error_index {
Box::pin(async move { Err(DurableError::execution("test error")) }) as DurableFuture<i32>
} else {
Box::pin(async move { Ok(i as i32) }) as DurableFuture<i32>
}
})
.collect();
let result = all_handler(futures, &state, &op_id, &logger).await;
prop_assert!(result.is_err(), "all should fail when any future fails");
Ok(())
})?;
}
#[test]
fn prop_all_settled_returns_all_outcomes(
success_indices in prop::collection::vec(0usize..10, 0..10),
total in 1usize..10,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = create_mock_client_with_responses(10);
let state = create_test_state(client);
let op_id = create_test_op_id(&format!("all-settled-prop-{}", total));
let logger = create_test_logger();
let success_set: std::collections::HashSet<usize> =
success_indices.into_iter().filter(|&i| i < total).collect();
let expected_successes = success_set.len();
let expected_failures = total - expected_successes;
let futures: Vec<DurableFuture<i32>> =
(0..total)
.map(|i| {
if success_set.contains(&i) {
Box::pin(async move { Ok(i as i32) }) as DurableFuture<i32>
} else {
Box::pin(async move { Err(DurableError::execution("test error")) }) as DurableFuture<i32>
}
})
.collect();
let result = all_settled_handler(futures, &state, &op_id, &logger).await;
prop_assert!(result.is_ok(), "all_settled should always succeed");
let batch = result.unwrap();
prop_assert_eq!(batch.total_count(), total, "all_settled should return all items");
prop_assert_eq!(batch.success_count(), expected_successes, "success count should match");
prop_assert_eq!(batch.failure_count(), expected_failures, "failure count should match");
Ok(())
})?;
}
#[test]
fn prop_any_succeeds_if_one_succeeds(
failure_count in 0usize..5,
success_value in 0i32..1000,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = create_mock_client_with_responses(10);
let state = create_test_state(client);
let op_id = create_test_op_id(&format!("any-prop-{}", failure_count));
let logger = create_test_logger();
let mut futures: Vec<DurableFuture<i32>> =
(0..failure_count)
.map(|_| Box::pin(async { Err(DurableError::execution("test error")) }) as DurableFuture<i32>)
.collect();
futures.push(Box::pin(async move { Ok(success_value) }));
let result = any_handler(futures, &state, &op_id, &logger).await;
prop_assert!(result.is_ok(), "any should succeed when at least one future succeeds");
prop_assert_eq!(result.unwrap(), success_value, "any should return the successful value");
Ok(())
})?;
}
#[test]
fn prop_any_fails_when_all_fail(
failure_count in 1usize..10,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = create_mock_client_with_responses(10);
let state = create_test_state(client);
let op_id = create_test_op_id(&format!("any-all-fail-prop-{}", failure_count));
let logger = create_test_logger();
let futures: Vec<DurableFuture<i32>> =
(0..failure_count)
.map(|i| Box::pin(async move { Err(DurableError::execution(format!("error {}", i))) }) as DurableFuture<i32>)
.collect();
let result = any_handler(futures, &state, &op_id, &logger).await;
prop_assert!(result.is_err(), "any should fail when all futures fail");
if let Err(DurableError::UserCode { message, .. }) = result {
prop_assert!(message.contains(&format!("All {} futures failed", failure_count)),
"Error message should indicate all futures failed");
}
Ok(())
})?;
}
}
}
}