use std::sync::Arc;
use aws_sdk_lambda::types::{
CallbackDetails, ChainedInvokeDetails, Operation, OperationStatus, OperationType, StepDetails,
};
use durable_lambda_core::context::DurableContext;
use durable_lambda_core::operation_id::OperationIdGenerator;
use crate::mock_backend::{BatchCallCounter, CheckpointRecorder, MockBackend, OperationRecorder};
pub struct MockDurableContext {
id_gen: OperationIdGenerator,
operations: Vec<Operation>,
}
impl MockDurableContext {
pub fn new() -> Self {
Self {
id_gen: OperationIdGenerator::new(None),
operations: Vec::new(),
}
}
pub fn with_step_result(mut self, _name: &str, result_json: &str) -> Self {
let op_id = self.id_gen.next_id();
let op = Operation::builder()
.id(&op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Succeeded)
.start_timestamp(aws_smithy_types::DateTime::from_secs(0))
.step_details(StepDetails::builder().result(result_json).build())
.build()
.unwrap_or_else(|e| panic!("failed to build mock Operation: {e}"));
self.operations.push(op);
self
}
pub fn with_step_error(mut self, _name: &str, error_type: &str, error_json: &str) -> Self {
let op_id = self.id_gen.next_id();
let error_obj = aws_sdk_lambda::types::ErrorObject::builder()
.error_type(error_type)
.error_data(error_json)
.build();
let op = Operation::builder()
.id(&op_id)
.r#type(OperationType::Step)
.status(OperationStatus::Failed)
.start_timestamp(aws_smithy_types::DateTime::from_secs(0))
.step_details(StepDetails::builder().error(error_obj).build())
.build()
.unwrap_or_else(|e| panic!("failed to build mock Operation: {e}"));
self.operations.push(op);
self
}
pub fn with_wait(mut self, _name: &str) -> Self {
let op_id = self.id_gen.next_id();
let op = Operation::builder()
.id(&op_id)
.r#type(OperationType::Wait)
.status(OperationStatus::Succeeded)
.start_timestamp(aws_smithy_types::DateTime::from_secs(0))
.build()
.unwrap_or_else(|e| panic!("failed to build mock Wait Operation: {e}"));
self.operations.push(op);
self
}
pub fn with_callback(mut self, _name: &str, callback_id: &str, result_json: &str) -> Self {
let op_id = self.id_gen.next_id();
let cb_details = CallbackDetails::builder()
.callback_id(callback_id)
.result(result_json)
.build();
let op = Operation::builder()
.id(&op_id)
.r#type(OperationType::Callback)
.status(OperationStatus::Succeeded)
.start_timestamp(aws_smithy_types::DateTime::from_secs(0))
.callback_details(cb_details)
.build()
.unwrap_or_else(|e| panic!("failed to build mock Callback Operation: {e}"));
self.operations.push(op);
self
}
pub fn with_invoke(mut self, _name: &str, result_json: &str) -> Self {
let op_id = self.id_gen.next_id();
let details = ChainedInvokeDetails::builder().result(result_json).build();
let op = Operation::builder()
.id(&op_id)
.r#type(OperationType::ChainedInvoke)
.status(OperationStatus::Succeeded)
.start_timestamp(aws_smithy_types::DateTime::from_secs(0))
.chained_invoke_details(details)
.build()
.unwrap_or_else(|e| panic!("failed to build mock ChainedInvoke Operation: {e}"));
self.operations.push(op);
self
}
pub async fn build(self) -> (DurableContext, CheckpointRecorder, OperationRecorder) {
let (backend, calls, operations) = MockBackend::new("mock-token");
let ctx = DurableContext::new(
Arc::new(backend),
"arn:aws:lambda:us-east-1:000000000000:durable-execution/mock".to_string(),
"mock-checkpoint-token".to_string(),
self.operations,
None,
)
.await
.expect("MockDurableContext::build should not fail");
(ctx, calls, operations)
}
pub async fn build_with_batch_counter(
self,
) -> (
DurableContext,
CheckpointRecorder,
OperationRecorder,
BatchCallCounter,
) {
let (backend, calls, operations) = MockBackend::new("mock-token");
let batch_counter = backend.batch_call_counter();
let ctx = DurableContext::new(
Arc::new(backend),
"arn:aws:lambda:us-east-1:000000000000:durable-execution/mock".to_string(),
"mock-checkpoint-token".to_string(),
self.operations,
None,
)
.await
.expect("MockDurableContext::build_with_batch_counter should not fail");
(ctx, calls, operations, batch_counter)
}
}
impl Default for MockDurableContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
#[tokio::test]
async fn test_mock_context_replays_step_result() {
let (mut ctx, calls, _ops) = MockDurableContext::new()
.with_step_result("validate", r#"42"#)
.build()
.await;
let executed = Arc::new(AtomicBool::new(false));
let executed_clone = executed.clone();
let result: Result<i32, String> = ctx
.step("validate", move || {
let executed = executed_clone.clone();
async move {
executed.store(true, Ordering::SeqCst);
Ok(999) }
})
.await
.unwrap();
assert_eq!(result.unwrap(), 42);
assert!(
!executed.load(Ordering::SeqCst),
"closure should not execute during replay"
);
let captured = calls.lock().await;
assert_eq!(captured.len(), 0);
}
#[tokio::test]
async fn test_mock_context_replays_multiple_steps() {
let (mut ctx, calls, _ops) = MockDurableContext::new()
.with_step_result("step1", r#""hello""#)
.with_step_result("step2", r#""world""#)
.build()
.await;
let r1: Result<String, String> = ctx
.step("step1", || async { panic!("not executed") })
.await
.unwrap();
assert_eq!(r1.unwrap(), "hello");
let r2: Result<String, String> = ctx
.step("step2", || async { panic!("not executed") })
.await
.unwrap();
assert_eq!(r2.unwrap(), "world");
let captured = calls.lock().await;
assert_eq!(captured.len(), 0);
}
#[tokio::test]
async fn test_mock_context_replays_step_error() {
let (mut ctx, _calls, _ops) = MockDurableContext::new()
.with_step_error("charge", "PaymentError", r#""insufficient_funds""#)
.build()
.await;
let result: Result<i32, String> = ctx
.step("charge", || async { panic!("not executed") })
.await
.unwrap();
assert_eq!(result.unwrap_err(), "insufficient_funds");
}
#[tokio::test]
async fn test_mock_context_executing_mode_when_empty() {
let (ctx, _calls, _ops) = MockDurableContext::new().build().await;
assert!(!ctx.is_replaying());
assert_eq!(
ctx.execution_mode(),
durable_lambda_core::types::ExecutionMode::Executing
);
}
#[tokio::test]
async fn test_mock_context_replaying_mode_with_operations() {
let (ctx, _calls, _ops) = MockDurableContext::new()
.with_step_result("step1", r#"1"#)
.build()
.await;
assert!(ctx.is_replaying());
assert_eq!(
ctx.execution_mode(),
durable_lambda_core::types::ExecutionMode::Replaying
);
}
#[tokio::test]
async fn test_mock_context_no_aws_credentials_needed() {
let (mut ctx, _calls, _ops) = MockDurableContext::new()
.with_step_result("test", r#"true"#)
.build()
.await;
let result: Result<bool, String> = ctx
.step("test", || async { panic!("not executed") })
.await
.unwrap();
assert!(result.unwrap());
}
}