use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration as StdDuration;
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::time::{timeout, Instant};
use crate::client::SharedDurableServiceClient;
use crate::error::DurableError;
use crate::operation::{OperationAction, OperationType, OperationUpdate};
#[derive(Debug, Clone)]
pub struct CheckpointBatcherConfig {
pub max_batch_size_bytes: usize,
pub max_batch_time_ms: u64,
pub max_batch_operations: usize,
}
impl Default for CheckpointBatcherConfig {
fn default() -> Self {
Self {
max_batch_size_bytes: 750 * 1024, max_batch_time_ms: 1000, max_batch_operations: usize::MAX, }
}
}
#[derive(Debug)]
pub struct CheckpointRequest {
pub operation: OperationUpdate,
pub completion: Option<oneshot::Sender<Result<(), DurableError>>>,
}
impl CheckpointRequest {
pub fn sync(operation: OperationUpdate) -> (Self, oneshot::Receiver<Result<(), DurableError>>) {
let (tx, rx) = oneshot::channel();
(
Self {
operation,
completion: Some(tx),
},
rx,
)
}
pub fn async_request(operation: OperationUpdate) -> Self {
Self {
operation,
completion: None,
}
}
pub fn is_sync(&self) -> bool {
self.completion.is_some()
}
pub fn estimated_size(&self) -> usize {
let base_size = 100; let op = &self.operation;
base_size
+ op.operation_id.len()
+ op.result.as_ref().map(|r| r.len()).unwrap_or(0)
+ op.error
.as_ref()
.map(|e| e.error_message.len() + e.error_type.len())
.unwrap_or(0)
+ op.parent_id.as_ref().map(|p| p.len()).unwrap_or(0)
+ op.name.as_ref().map(|n| n.len()).unwrap_or(0)
}
}
#[derive(Debug)]
pub struct BatchResult {
pub success: bool,
pub new_token: Option<String>,
pub error: Option<DurableError>,
}
pub struct CheckpointBatcher {
config: CheckpointBatcherConfig,
queue_rx: mpsc::Receiver<CheckpointRequest>,
service_client: SharedDurableServiceClient,
durable_execution_arn: String,
checkpoint_token: Arc<RwLock<String>>,
initial_token_consumed: AtomicBool,
}
impl CheckpointBatcher {
pub fn new(
config: CheckpointBatcherConfig,
queue_rx: mpsc::Receiver<CheckpointRequest>,
service_client: SharedDurableServiceClient,
durable_execution_arn: String,
checkpoint_token: Arc<RwLock<String>>,
) -> Self {
Self {
config,
queue_rx,
service_client,
durable_execution_arn,
checkpoint_token,
initial_token_consumed: AtomicBool::new(false),
}
}
pub async fn run(&mut self) {
loop {
let batch = self.collect_batch().await;
if batch.is_empty() {
break;
}
self.process_batch(batch).await;
}
}
async fn collect_batch(&mut self) -> Vec<CheckpointRequest> {
let mut batch = Vec::new();
let mut batch_size = 0usize;
let batch_deadline =
Instant::now() + StdDuration::from_millis(self.config.max_batch_time_ms);
match self.queue_rx.recv().await {
Some(request) => {
batch_size += request.estimated_size();
batch.push(request);
}
None => return batch, }
loop {
if batch.len() >= self.config.max_batch_operations {
break;
}
let now = Instant::now();
if now >= batch_deadline {
break;
}
let remaining = batch_deadline - now;
match timeout(remaining, self.queue_rx.recv()).await {
Ok(Some(request)) => {
let request_size = request.estimated_size();
if batch_size + request_size > self.config.max_batch_size_bytes
&& !batch.is_empty()
{
batch.push(request);
break;
}
batch_size += request_size;
batch.push(request);
}
Ok(None) => break, Err(_) => break, }
}
batch
}
pub fn sort_checkpoint_batch(batch: Vec<CheckpointRequest>) -> Vec<CheckpointRequest> {
if batch.len() <= 1 {
return batch;
}
let context_starts: HashSet<String> = batch
.iter()
.filter(|req| {
req.operation.operation_type == OperationType::Context
&& req.operation.action == OperationAction::Start
})
.map(|req| req.operation.operation_id.clone())
.collect();
let parent_map: HashMap<String, Option<String>> = batch
.iter()
.map(|req| {
(
req.operation.operation_id.clone(),
req.operation.parent_id.clone(),
)
})
.collect();
let is_ancestor = |operation_id: &str, ancestor_id: &str| -> bool {
let mut current = parent_map.get(operation_id).and_then(|p| p.as_ref());
while let Some(parent_id) = current {
if parent_id == ancestor_id {
return true;
}
current = parent_map.get(parent_id).and_then(|p| p.as_ref());
}
false
};
let mut indexed_batch: Vec<(usize, CheckpointRequest)> =
batch.into_iter().enumerate().collect();
indexed_batch.sort_by(|(idx_a, a), (idx_b, b)| {
let a_is_exec_completion = a.operation.operation_type == OperationType::Execution
&& matches!(
a.operation.action,
OperationAction::Succeed | OperationAction::Fail
);
let b_is_exec_completion = b.operation.operation_type == OperationType::Execution
&& matches!(
b.operation.action,
OperationAction::Succeed | OperationAction::Fail
);
if a_is_exec_completion && !b_is_exec_completion {
return std::cmp::Ordering::Greater;
}
if !a_is_exec_completion && b_is_exec_completion {
return std::cmp::Ordering::Less;
}
if b.operation.action == OperationAction::Start
&& context_starts.contains(&b.operation.operation_id)
{
if let Some(ref parent_id) = a.operation.parent_id {
if *parent_id == b.operation.operation_id
|| is_ancestor(&a.operation.operation_id, &b.operation.operation_id)
{
return std::cmp::Ordering::Greater;
}
}
}
if a.operation.action == OperationAction::Start
&& context_starts.contains(&a.operation.operation_id)
{
if let Some(ref parent_id) = b.operation.parent_id {
if *parent_id == a.operation.operation_id
|| is_ancestor(&b.operation.operation_id, &a.operation.operation_id)
{
return std::cmp::Ordering::Less;
}
}
}
if a.operation.operation_id == b.operation.operation_id {
let a_is_start = a.operation.action == OperationAction::Start;
let b_is_start = b.operation.action == OperationAction::Start;
if a_is_start && !b_is_start {
return std::cmp::Ordering::Less;
}
if !a_is_start && b_is_start {
return std::cmp::Ordering::Greater;
}
}
idx_a.cmp(idx_b)
});
indexed_batch.into_iter().map(|(_, req)| req).collect()
}
async fn process_batch(&self, batch: Vec<CheckpointRequest>) {
if batch.is_empty() {
return;
}
let sorted_batch = Self::sort_checkpoint_batch(batch);
let (operations, completions): (Vec<_>, Vec<_>) = sorted_batch
.into_iter()
.map(|req| (req.operation, req.completion))
.unzip();
let token = self.checkpoint_token.read().await.clone();
let result = self.checkpoint_with_retry(&token, operations).await;
match result {
Ok(response) => {
self.initial_token_consumed.store(true, Ordering::Release);
{
let mut token_guard = self.checkpoint_token.write().await;
*token_guard = response.checkpoint_token;
}
for completion in completions.into_iter().flatten() {
let _ = completion.send(Ok(()));
}
}
Err(error) => {
let is_invalid_token = error.is_invalid_checkpoint_token();
for completion in completions.into_iter().flatten() {
let error_msg = if is_invalid_token {
format!(
"Invalid checkpoint token - token may have been consumed. \
Lambda will retry with a fresh token. Original error: {}",
error
)
} else {
error.to_string()
};
let _ = completion.send(Err(DurableError::Checkpoint {
message: error_msg,
is_retriable: error.is_retriable(),
aws_error: None,
}));
}
}
}
}
async fn checkpoint_with_retry(
&self,
token: &str,
operations: Vec<OperationUpdate>,
) -> Result<crate::client::CheckpointResponse, DurableError> {
const MAX_RETRIES: u32 = 5;
const INITIAL_DELAY_MS: u64 = 100;
const MAX_DELAY_MS: u64 = 10_000;
const BACKOFF_MULTIPLIER: u64 = 2;
let mut attempt = 0;
let mut delay_ms = INITIAL_DELAY_MS;
loop {
let result = self
.service_client
.checkpoint(&self.durable_execution_arn, token, operations.clone())
.await;
match result {
Ok(response) => return Ok(response),
Err(error) if error.is_throttling() => {
attempt += 1;
if attempt > MAX_RETRIES {
tracing::warn!(
attempt = attempt,
"Checkpoint throttling: max retries exceeded"
);
return Err(error);
}
let actual_delay = error.get_retry_after_ms().unwrap_or(delay_ms);
tracing::debug!(
attempt = attempt,
delay_ms = actual_delay,
"Checkpoint throttled, retrying with exponential backoff"
);
tokio::time::sleep(StdDuration::from_millis(actual_delay)).await;
delay_ms = (delay_ms * BACKOFF_MULTIPLIER).min(MAX_DELAY_MS);
}
Err(error) => return Err(error),
}
}
}
}
#[derive(Clone)]
pub struct CheckpointSender {
pub tx: mpsc::Sender<CheckpointRequest>,
}
impl CheckpointSender {
pub fn new(tx: mpsc::Sender<CheckpointRequest>) -> Self {
Self { tx }
}
pub async fn checkpoint_sync(&self, operation: OperationUpdate) -> Result<(), DurableError> {
let (request, rx) = CheckpointRequest::sync(operation);
self.tx
.send(request)
.await
.map_err(|_| DurableError::Checkpoint {
message: "Checkpoint queue closed".to_string(),
is_retriable: false,
aws_error: None,
})?;
rx.await.map_err(|_| DurableError::Checkpoint {
message: "Checkpoint completion channel closed".to_string(),
is_retriable: false,
aws_error: None,
})?
}
pub async fn checkpoint_async(&self, operation: OperationUpdate) -> Result<(), DurableError> {
let request = CheckpointRequest::async_request(operation);
self.tx
.send(request)
.await
.map_err(|_| DurableError::Checkpoint {
message: "Checkpoint queue closed".to_string(),
is_retriable: false,
aws_error: None,
})
}
pub async fn checkpoint(
&self,
operation: OperationUpdate,
is_sync: bool,
) -> Result<(), DurableError> {
if is_sync {
self.checkpoint_sync(operation).await
} else {
self.checkpoint_async(operation).await
}
}
}
pub fn create_checkpoint_queue(
buffer_size: usize,
) -> (CheckpointSender, mpsc::Receiver<CheckpointRequest>) {
let (tx, rx) = mpsc::channel(buffer_size);
(CheckpointSender::new(tx), rx)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{CheckpointResponse, MockDurableServiceClient};
use crate::error::ErrorObject;
fn create_test_update(id: &str) -> OperationUpdate {
OperationUpdate::start(id, OperationType::Step)
}
fn create_start_update(id: &str, op_type: OperationType) -> OperationUpdate {
OperationUpdate::start(id, op_type)
}
fn create_succeed_update(id: &str, op_type: OperationType) -> OperationUpdate {
OperationUpdate::succeed(id, op_type, Some("result".to_string()))
}
fn create_fail_update(id: &str, op_type: OperationType) -> OperationUpdate {
OperationUpdate::fail(id, op_type, ErrorObject::new("Error", "message"))
}
fn create_request(update: OperationUpdate) -> CheckpointRequest {
CheckpointRequest::async_request(update)
}
#[test]
fn test_checkpoint_request_sync() {
let update = create_test_update("op-1");
let (request, _rx) = CheckpointRequest::sync(update);
assert!(request.is_sync());
assert_eq!(request.operation.operation_id, "op-1");
}
#[test]
fn test_checkpoint_request_async() {
let update = create_test_update("op-1");
let request = CheckpointRequest::async_request(update);
assert!(!request.is_sync());
assert_eq!(request.operation.operation_id, "op-1");
}
#[test]
fn test_checkpoint_request_estimated_size() {
let update = create_test_update("op-1");
let request = CheckpointRequest::async_request(update);
let size = request.estimated_size();
assert!(size > 0);
assert!(size >= 100);
}
#[test]
fn test_checkpoint_request_estimated_size_with_result() {
let mut update = create_test_update("op-1");
update.result = Some("a".repeat(1000));
let request = CheckpointRequest::async_request(update);
let size = request.estimated_size();
assert!(size >= 1100);
}
#[test]
fn test_create_checkpoint_queue() {
let (sender, _rx) = create_checkpoint_queue(100);
drop(sender);
}
#[tokio::test]
async fn test_checkpoint_sender_sync() {
let (sender, mut rx) = create_checkpoint_queue(10);
let handle = tokio::spawn(async move {
if let Some(request) = rx.recv().await {
assert!(request.is_sync());
if let Some(completion) = request.completion {
let _ = completion.send(Ok(()));
}
}
});
let update = create_test_update("op-1");
let result = sender.checkpoint_sync(update).await;
assert!(result.is_ok());
handle.await.unwrap();
}
#[tokio::test]
async fn test_checkpoint_sender_async() {
let (sender, mut rx) = create_checkpoint_queue(10);
let update = create_test_update("op-1");
let result = sender.checkpoint_async(update).await;
assert!(result.is_ok());
let request = rx.recv().await.unwrap();
assert!(!request.is_sync());
assert_eq!(request.operation.operation_id, "op-1");
}
#[tokio::test]
async fn test_checkpoint_sender_queue_closed() {
let (sender, rx) = create_checkpoint_queue(10);
drop(rx);
let update = create_test_update("op-1");
let result = sender.checkpoint_async(update).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_checkpoint_batcher_processes_batch() {
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("new-token"))),
);
let (sender, rx) = create_checkpoint_queue(10);
let checkpoint_token = Arc::new(RwLock::new("initial-token".to_string()));
let config = CheckpointBatcherConfig {
max_batch_time_ms: 10,
..Default::default()
};
let mut batcher = CheckpointBatcher::new(
config,
rx,
client,
"arn:test".to_string(),
checkpoint_token.clone(),
);
let update = create_test_update("op-1");
let (request, completion_rx) = CheckpointRequest::sync(update);
sender.tx.send(request).await.unwrap();
drop(sender);
batcher.run().await;
let result = completion_rx.await.unwrap();
assert!(result.is_ok());
let token = checkpoint_token.read().await;
assert_eq!(*token, "new-token");
}
#[tokio::test]
async fn test_checkpoint_batcher_handles_error() {
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Err(DurableError::checkpoint_retriable("Test error"))),
);
let (sender, rx) = create_checkpoint_queue(10);
let checkpoint_token = Arc::new(RwLock::new("initial-token".to_string()));
let config = CheckpointBatcherConfig {
max_batch_time_ms: 10,
..Default::default()
};
let mut batcher = CheckpointBatcher::new(
config,
rx,
client,
"arn:test".to_string(),
checkpoint_token.clone(),
);
let update = create_test_update("op-1");
let (request, completion_rx) = CheckpointRequest::sync(update);
sender.tx.send(request).await.unwrap();
drop(sender);
batcher.run().await;
let result = completion_rx.await.unwrap();
assert!(result.is_err());
let token = checkpoint_token.read().await;
assert_eq!(*token, "initial-token");
}
#[tokio::test]
async fn test_checkpoint_batcher_batches_multiple_requests() {
let client = Arc::new(
MockDurableServiceClient::new()
.with_checkpoint_response(Ok(CheckpointResponse::new("new-token"))),
);
let (sender, rx) = create_checkpoint_queue(10);
let checkpoint_token = Arc::new(RwLock::new("initial-token".to_string()));
let config = CheckpointBatcherConfig {
max_batch_time_ms: 50,
max_batch_operations: 3,
..Default::default()
};
let mut batcher = CheckpointBatcher::new(
config,
rx,
client,
"arn:test".to_string(),
checkpoint_token.clone(),
);
for i in 0..3 {
let update = create_test_update(&format!("op-{}", i));
sender
.tx
.send(CheckpointRequest::async_request(update))
.await
.unwrap();
}
drop(sender);
batcher.run().await;
let token = checkpoint_token.read().await;
assert_eq!(*token, "new-token");
}
#[test]
fn test_execution_completion_is_last() {
let batch = vec![
create_request(create_succeed_update("exec-1", OperationType::Execution)),
create_request(create_start_update("step-1", OperationType::Step)),
create_request(create_succeed_update("step-2", OperationType::Step)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 3);
assert_eq!(sorted[2].operation.operation_id, "exec-1");
assert_eq!(sorted[2].operation.action, OperationAction::Succeed);
assert_eq!(sorted[2].operation.operation_type, OperationType::Execution);
}
#[test]
fn test_execution_fail_is_last() {
let batch = vec![
create_request(create_fail_update("exec-1", OperationType::Execution)),
create_request(create_start_update("step-1", OperationType::Step)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 2);
assert_eq!(sorted[1].operation.operation_id, "exec-1");
assert_eq!(sorted[1].operation.action, OperationAction::Fail);
}
#[test]
fn test_child_after_parent_context_start() {
let mut child_update = create_start_update("child-1", OperationType::Step);
child_update.parent_id = Some("parent-ctx".to_string());
let batch = vec![
create_request(child_update),
create_request(create_start_update("parent-ctx", OperationType::Context)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 2);
assert_eq!(sorted[0].operation.operation_id, "parent-ctx");
assert_eq!(sorted[0].operation.action, OperationAction::Start);
assert_eq!(sorted[1].operation.operation_id, "child-1");
}
#[test]
fn test_start_and_completion_same_operation() {
let batch = vec![
create_request(create_succeed_update("step-1", OperationType::Step)),
create_request(create_start_update("step-1", OperationType::Step)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 2);
assert_eq!(sorted[0].operation.operation_id, "step-1");
assert_eq!(sorted[0].operation.action, OperationAction::Start);
assert_eq!(sorted[1].operation.operation_id, "step-1");
assert_eq!(sorted[1].operation.action, OperationAction::Succeed);
}
#[test]
fn test_context_start_and_completion_same_batch() {
let batch = vec![
create_request(create_succeed_update("ctx-1", OperationType::Context)),
create_request(create_start_update("ctx-1", OperationType::Context)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 2);
assert_eq!(sorted[0].operation.operation_id, "ctx-1");
assert_eq!(sorted[0].operation.action, OperationAction::Start);
assert_eq!(sorted[1].operation.operation_id, "ctx-1");
assert_eq!(sorted[1].operation.action, OperationAction::Succeed);
}
#[test]
fn test_step_start_and_fail_same_batch() {
let batch = vec![
create_request(create_fail_update("step-1", OperationType::Step)),
create_request(create_start_update("step-1", OperationType::Step)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 2);
assert_eq!(sorted[0].operation.operation_id, "step-1");
assert_eq!(sorted[0].operation.action, OperationAction::Start);
assert_eq!(sorted[1].operation.operation_id, "step-1");
assert_eq!(sorted[1].operation.action, OperationAction::Fail);
}
#[test]
fn test_complex_ordering_scenario() {
let mut child1 = create_start_update("child-1", OperationType::Step);
child1.parent_id = Some("parent-ctx".to_string());
let mut child2 = create_succeed_update("child-2", OperationType::Step);
child2.parent_id = Some("parent-ctx".to_string());
let batch = vec![
create_request(create_succeed_update("exec-1", OperationType::Execution)),
create_request(child1),
create_request(create_succeed_update("parent-ctx", OperationType::Context)),
create_request(create_start_update("parent-ctx", OperationType::Context)),
create_request(child2),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 5);
let parent_start_pos = sorted
.iter()
.position(|r| {
r.operation.operation_id == "parent-ctx"
&& r.operation.action == OperationAction::Start
})
.unwrap();
let parent_succeed_pos = sorted
.iter()
.position(|r| {
r.operation.operation_id == "parent-ctx"
&& r.operation.action == OperationAction::Succeed
})
.unwrap();
let child1_pos = sorted
.iter()
.position(|r| r.operation.operation_id == "child-1")
.unwrap();
let child2_pos = sorted
.iter()
.position(|r| r.operation.operation_id == "child-2")
.unwrap();
let exec_pos = sorted
.iter()
.position(|r| r.operation.operation_id == "exec-1")
.unwrap();
assert!(parent_start_pos < parent_succeed_pos);
assert!(parent_start_pos < child1_pos);
assert!(parent_start_pos < child2_pos);
assert_eq!(exec_pos, 4);
}
#[test]
fn test_empty_batch() {
let batch: Vec<CheckpointRequest> = vec![];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert!(sorted.is_empty());
}
#[test]
fn test_single_item_batch() {
let batch = vec![create_request(create_start_update(
"step-1",
OperationType::Step,
))];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 1);
assert_eq!(sorted[0].operation.operation_id, "step-1");
}
#[test]
fn test_preserves_original_order() {
let batch = vec![
create_request(create_start_update("step-1", OperationType::Step)),
create_request(create_start_update("step-2", OperationType::Step)),
create_request(create_start_update("step-3", OperationType::Step)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 3);
assert_eq!(sorted[0].operation.operation_id, "step-1");
assert_eq!(sorted[1].operation.operation_id, "step-2");
assert_eq!(sorted[2].operation.operation_id, "step-3");
}
#[test]
fn test_multiple_children_same_parent() {
let mut child1 = create_start_update("child-1", OperationType::Step);
child1.parent_id = Some("parent-ctx".to_string());
let mut child2 = create_start_update("child-2", OperationType::Step);
child2.parent_id = Some("parent-ctx".to_string());
let mut child3 = create_start_update("child-3", OperationType::Step);
child3.parent_id = Some("parent-ctx".to_string());
let batch = vec![
create_request(child1),
create_request(child2),
create_request(create_start_update("parent-ctx", OperationType::Context)),
create_request(child3),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted[0].operation.operation_id, "parent-ctx");
for item in &sorted[1..4] {
assert!(item.operation.parent_id.as_deref() == Some("parent-ctx"));
}
}
#[test]
fn test_nested_contexts() {
let mut parent_ctx = create_start_update("parent-ctx", OperationType::Context);
parent_ctx.parent_id = Some("grandparent-ctx".to_string());
let mut child = create_start_update("child-1", OperationType::Step);
child.parent_id = Some("parent-ctx".to_string());
let batch = vec![
create_request(child),
create_request(parent_ctx),
create_request(create_start_update(
"grandparent-ctx",
OperationType::Context,
)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
let grandparent_pos = sorted
.iter()
.position(|r| r.operation.operation_id == "grandparent-ctx")
.unwrap();
let parent_pos = sorted
.iter()
.position(|r| r.operation.operation_id == "parent-ctx")
.unwrap();
let child_pos = sorted
.iter()
.position(|r| r.operation.operation_id == "child-1")
.unwrap();
assert!(grandparent_pos < parent_pos);
assert!(parent_pos < child_pos);
}
#[test]
fn test_execution_start_not_affected() {
let batch = vec![
create_request(create_start_update("step-1", OperationType::Step)),
create_request(create_start_update("exec-1", OperationType::Execution)),
create_request(create_start_update("step-2", OperationType::Step)),
];
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
assert_eq!(sorted.len(), 3);
assert_eq!(sorted[0].operation.operation_id, "step-1");
assert_eq!(sorted[1].operation.operation_id, "exec-1");
assert_eq!(sorted[2].operation.operation_id, "step-2");
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::client::{CheckpointResponse, DurableServiceClient, GetOperationsResponse};
use async_trait::async_trait;
use proptest::prelude::*;
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
struct CountingMockClient {
checkpoint_count: Arc<AtomicUsize>,
}
impl CountingMockClient {
fn new() -> (Self, Arc<AtomicUsize>) {
let count = Arc::new(AtomicUsize::new(0));
(
Self {
checkpoint_count: count.clone(),
},
count,
)
}
}
#[async_trait]
impl DurableServiceClient for CountingMockClient {
async fn checkpoint(
&self,
_durable_execution_arn: &str,
_checkpoint_token: &str,
_operations: Vec<OperationUpdate>,
) -> Result<CheckpointResponse, DurableError> {
self.checkpoint_count.fetch_add(1, AtomicOrdering::SeqCst);
Ok(CheckpointResponse::new(format!(
"token-{}",
self.checkpoint_count.load(AtomicOrdering::SeqCst)
)))
}
async fn get_operations(
&self,
_durable_execution_arn: &str,
_next_marker: &str,
) -> Result<GetOperationsResponse, DurableError> {
Ok(GetOperationsResponse {
operations: vec![],
next_marker: None,
})
}
}
fn create_test_update_with_size(id: &str, result_size: usize) -> OperationUpdate {
let mut update = OperationUpdate::start(id, OperationType::Step);
if result_size > 0 {
update.result = Some("x".repeat(result_size));
}
update
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_checkpoint_batching_respects_operation_count_limit(
num_requests in 1usize..20,
max_ops_per_batch in 1usize..10,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
let result: Result<(), TestCaseError> = rt.block_on(async {
let (client, call_count) = CountingMockClient::new();
let client = Arc::new(client);
let (sender, rx) = create_checkpoint_queue(100);
let checkpoint_token = Arc::new(RwLock::new("initial-token".to_string()));
let config = CheckpointBatcherConfig {
max_batch_time_ms: 10,
max_batch_operations: max_ops_per_batch,
max_batch_size_bytes: usize::MAX,
};
let mut batcher = CheckpointBatcher::new(
config,
rx,
client,
"arn:test".to_string(),
checkpoint_token.clone(),
);
for i in 0..num_requests {
let update = create_test_update_with_size(&format!("op-{}", i), 0);
sender.tx.send(CheckpointRequest::async_request(update)).await.unwrap();
}
drop(sender);
batcher.run().await;
let expected_max_calls = num_requests.div_ceil(max_ops_per_batch);
let actual_calls = call_count.load(AtomicOrdering::SeqCst);
if actual_calls > expected_max_calls {
return Err(TestCaseError::fail(format!(
"Expected at most {} API calls for {} requests with batch size {}, got {}",
expected_max_calls, num_requests, max_ops_per_batch, actual_calls
)));
}
if actual_calls < 1 {
return Err(TestCaseError::fail(format!(
"Expected at least 1 API call for {} requests, got {}",
num_requests, actual_calls
)));
}
Ok(())
});
result?;
}
#[test]
fn prop_checkpoint_batching_respects_size_limit(
num_requests in 1usize..10,
result_size in 100usize..500,
max_batch_size in 500usize..2000,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
let result: Result<(), TestCaseError> = rt.block_on(async {
let (client, call_count) = CountingMockClient::new();
let client = Arc::new(client);
let (sender, rx) = create_checkpoint_queue(100);
let checkpoint_token = Arc::new(RwLock::new("initial-token".to_string()));
let config = CheckpointBatcherConfig {
max_batch_time_ms: 10,
max_batch_operations: usize::MAX,
max_batch_size_bytes: max_batch_size,
};
let mut batcher = CheckpointBatcher::new(
config,
rx,
client,
"arn:test".to_string(),
checkpoint_token.clone(),
);
for i in 0..num_requests {
let update = create_test_update_with_size(&format!("op-{}", i), result_size);
sender.tx.send(CheckpointRequest::async_request(update)).await.unwrap();
}
drop(sender);
batcher.run().await;
let estimated_request_size = 100 + result_size;
let total_size = num_requests * estimated_request_size;
let expected_max_calls = total_size.div_ceil(max_batch_size);
let actual_calls = call_count.load(AtomicOrdering::SeqCst);
if actual_calls > expected_max_calls.max(1) * 2 {
return Err(TestCaseError::fail(format!(
"Expected at most ~{} API calls for {} requests of size {}, got {}",
expected_max_calls, num_requests, estimated_request_size, actual_calls
)));
}
if actual_calls < 1 {
return Err(TestCaseError::fail(format!(
"Expected at least 1 API call, got {}",
actual_calls
)));
}
Ok(())
});
result?;
}
#[test]
fn prop_all_requests_are_processed(
num_requests in 1usize..20,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
let result: Result<(), TestCaseError> = rt.block_on(async {
let (client, _call_count) = CountingMockClient::new();
let client = Arc::new(client);
let (sender, rx) = create_checkpoint_queue(100);
let checkpoint_token = Arc::new(RwLock::new("initial-token".to_string()));
let config = CheckpointBatcherConfig {
max_batch_time_ms: 10,
max_batch_operations: 5,
..Default::default()
};
let mut batcher = CheckpointBatcher::new(
config,
rx,
client,
"arn:test".to_string(),
checkpoint_token.clone(),
);
let mut receivers = Vec::new();
for i in 0..num_requests {
let update = create_test_update_with_size(&format!("op-{}", i), 0);
let (request, rx) = CheckpointRequest::sync(update);
sender.tx.send(request).await.unwrap();
receivers.push(rx);
}
drop(sender);
batcher.run().await;
let mut success_count = 0;
for rx in receivers {
if let Ok(result) = rx.await {
if result.is_ok() {
success_count += 1;
}
}
}
if success_count != num_requests {
return Err(TestCaseError::fail(format!(
"Expected all {} requests to succeed, got {}",
num_requests, success_count
)));
}
Ok(())
});
result?;
}
}
fn operation_id_strategy() -> impl Strategy<Value = String> {
"[a-z]{1,8}-[0-9]{1,4}".prop_map(|s| s)
}
fn operation_type_strategy() -> impl Strategy<Value = OperationType> {
prop_oneof![
Just(OperationType::Execution),
Just(OperationType::Step),
Just(OperationType::Wait),
Just(OperationType::Callback),
Just(OperationType::Invoke),
Just(OperationType::Context),
]
}
fn operation_action_strategy() -> impl Strategy<Value = OperationAction> {
prop_oneof![
Just(OperationAction::Start),
Just(OperationAction::Succeed),
Just(OperationAction::Fail),
]
}
fn create_checkpoint_request(
id: &str,
op_type: OperationType,
action: OperationAction,
parent_id: Option<String>,
) -> CheckpointRequest {
let mut update = match action {
OperationAction::Start => OperationUpdate::start(id, op_type),
OperationAction::Succeed => {
OperationUpdate::succeed(id, op_type, Some("result".to_string()))
}
OperationAction::Fail => OperationUpdate::fail(
id,
op_type,
crate::error::ErrorObject::new("Error", "message"),
),
_ => OperationUpdate::start(id, op_type),
};
update.parent_id = parent_id;
CheckpointRequest::async_request(update)
}
fn parent_child_batch_strategy() -> impl Strategy<Value = Vec<CheckpointRequest>> {
(
operation_id_strategy(), prop::collection::vec(operation_id_strategy(), 1..5), prop::collection::vec(operation_type_strategy(), 1..5), )
.prop_map(|(parent_id, child_ids, child_types)| {
let mut batch = Vec::new();
for (child_id, child_type) in child_ids.iter().zip(child_types.iter()) {
if *child_type != OperationType::Execution {
batch.push(create_checkpoint_request(
child_id,
*child_type,
OperationAction::Start,
Some(parent_id.clone()),
));
}
}
batch.push(create_checkpoint_request(
&parent_id,
OperationType::Context,
OperationAction::Start,
None,
));
batch
})
}
fn execution_completion_batch_strategy() -> impl Strategy<Value = Vec<CheckpointRequest>> {
(
operation_id_strategy(), prop::collection::vec((operation_id_strategy(), operation_type_strategy()), 1..5), prop::bool::ANY, )
.prop_map(|(exec_id, other_ops, succeed)| {
let mut batch = Vec::new();
let exec_action = if succeed {
OperationAction::Succeed
} else {
OperationAction::Fail
};
batch.push(create_checkpoint_request(
&exec_id,
OperationType::Execution,
exec_action,
None,
));
for (op_id, op_type) in other_ops {
if op_type != OperationType::Execution {
batch.push(create_checkpoint_request(
&op_id,
op_type,
OperationAction::Start,
None,
));
}
}
batch
})
}
fn same_operation_batch_strategy() -> impl Strategy<Value = Vec<CheckpointRequest>> {
(
operation_id_strategy(),
prop_oneof![Just(OperationType::Step), Just(OperationType::Context),],
prop::bool::ANY, )
.prop_map(|(op_id, op_type, succeed)| {
let mut batch = Vec::new();
let completion_action = if succeed {
OperationAction::Succeed
} else {
OperationAction::Fail
};
batch.push(create_checkpoint_request(
&op_id,
op_type,
completion_action,
None,
));
batch.push(create_checkpoint_request(
&op_id,
op_type,
OperationAction::Start,
None,
));
batch
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_parent_context_start_before_children(batch in parent_child_batch_strategy()) {
if batch.is_empty() {
return Ok(());
}
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
let parent_pos = sorted.iter().position(|r| {
r.operation.operation_type == OperationType::Context
&& r.operation.action == OperationAction::Start
&& r.operation.parent_id.is_none()
});
if let Some(parent_idx) = parent_pos {
let parent_id = &sorted[parent_idx].operation.operation_id;
for (idx, req) in sorted.iter().enumerate() {
if let Some(ref req_parent_id) = req.operation.parent_id {
if req_parent_id == parent_id {
prop_assert!(
idx > parent_idx,
"Child operation {} at index {} should come after parent {} at index {}",
req.operation.operation_id,
idx,
parent_id,
parent_idx
);
}
}
}
}
}
#[test]
fn prop_execution_completion_is_last(batch in execution_completion_batch_strategy()) {
if batch.is_empty() {
return Ok(());
}
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
let exec_completion_pos = sorted.iter().position(|r| {
r.operation.operation_type == OperationType::Execution
&& matches!(r.operation.action, OperationAction::Succeed | OperationAction::Fail)
});
if let Some(exec_idx) = exec_completion_pos {
prop_assert_eq!(
exec_idx,
sorted.len() - 1,
"EXECUTION completion should be at index {} (last), but found at index {}",
sorted.len() - 1,
exec_idx
);
}
}
#[test]
fn prop_operation_id_uniqueness_with_start_completion_exception(
batch in same_operation_batch_strategy()
) {
if batch.is_empty() {
return Ok(());
}
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
let mut seen: std::collections::HashMap<String, Vec<OperationAction>> = std::collections::HashMap::new();
for req in &sorted {
seen.entry(req.operation.operation_id.clone())
.or_default()
.push(req.operation.action);
}
for (op_id, actions) in &seen {
if actions.len() > 1 {
let has_start = actions.contains(&OperationAction::Start);
let has_completion = actions.iter().any(|a| {
matches!(a, OperationAction::Succeed | OperationAction::Fail | OperationAction::Retry)
});
prop_assert!(
has_start && has_completion && actions.len() == 2,
"Operation {} appears {} times with actions {:?}. \
Multiple occurrences only allowed for START + completion pair.",
op_id,
actions.len(),
actions
);
}
}
}
#[test]
fn prop_start_before_completion_for_same_operation(
batch in same_operation_batch_strategy()
) {
if batch.is_empty() {
return Ok(());
}
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
let mut start_positions: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
let mut completion_positions: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
for (idx, req) in sorted.iter().enumerate() {
let op_id = &req.operation.operation_id;
match req.operation.action {
OperationAction::Start => {
start_positions.insert(op_id.clone(), idx);
}
OperationAction::Succeed | OperationAction::Fail | OperationAction::Retry => {
completion_positions.insert(op_id.clone(), idx);
}
_ => {}
}
}
for (op_id, start_idx) in &start_positions {
if let Some(completion_idx) = completion_positions.get(op_id) {
prop_assert!(
start_idx < completion_idx,
"For operation {}, START at index {} should come before completion at index {}",
op_id,
start_idx,
completion_idx
);
}
}
}
}
fn random_batch_strategy() -> impl Strategy<Value = Vec<CheckpointRequest>> {
prop::collection::vec(
(
operation_id_strategy(),
operation_type_strategy(),
operation_action_strategy(),
),
1..10,
)
.prop_map(|ops| {
ops.into_iter()
.filter(|(_, op_type, _)| *op_type != OperationType::Execution) .map(|(id, op_type, action)| create_checkpoint_request(&id, op_type, action, None))
.collect()
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_batch_sorting_preserves_all_operations(batch in random_batch_strategy()) {
if batch.is_empty() {
return Ok(());
}
let original_ops: HashSet<(String, String)> = batch.iter()
.map(|r| (r.operation.operation_id.clone(), format!("{:?}", r.operation.action)))
.collect();
let sorted = CheckpointBatcher::sort_checkpoint_batch(batch);
let sorted_ops: HashSet<(String, String)> = sorted.iter()
.map(|r| (r.operation.operation_id.clone(), format!("{:?}", r.operation.action)))
.collect();
prop_assert_eq!(
original_ops.len(),
sorted_ops.len(),
"Sorting changed the number of operations: original {}, sorted {}",
original_ops.len(),
sorted_ops.len()
);
prop_assert_eq!(
original_ops,
sorted_ops,
"Sorting changed the set of operations"
);
}
#[test]
fn prop_sorting_is_idempotent(batch in random_batch_strategy()) {
if batch.is_empty() {
return Ok(());
}
let sorted_once = CheckpointBatcher::sort_checkpoint_batch(batch);
let sorted_once_clone: Vec<CheckpointRequest> = sorted_once.iter()
.map(|r| CheckpointRequest::async_request(r.operation.clone()))
.collect();
let sorted_twice = CheckpointBatcher::sort_checkpoint_batch(sorted_once_clone);
prop_assert_eq!(
sorted_once.len(),
sorted_twice.len(),
"Double sorting changed the number of operations"
);
for (idx, (first, second)) in sorted_once.iter().zip(sorted_twice.iter()).enumerate() {
prop_assert_eq!(
&first.operation.operation_id,
&second.operation.operation_id,
"Operation ID mismatch at index {}: {} vs {}",
idx,
&first.operation.operation_id,
&second.operation.operation_id
);
prop_assert_eq!(
first.operation.action,
second.operation.action,
"Action mismatch at index {} for operation {}: {:?} vs {:?}",
idx,
&first.operation.operation_id,
first.operation.action,
second.operation.action
);
}
}
}
}