use std::sync::Arc;
use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
use aws_sdk_lambda::types::OperationUpdate;
use durable_lambda_core::backend::DurableBackend;
use durable_lambda_core::error::DurableError;
use tokio::sync::Mutex;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OperationRecord {
pub name: String,
pub operation_type: String,
}
impl OperationRecord {
pub fn to_type_name(&self) -> String {
format!("{}:{}", self.operation_type, self.name)
}
}
impl std::fmt::Display for OperationRecord {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.operation_type, self.name)
}
}
#[derive(Debug, Clone)]
pub struct CheckpointCall {
pub arn: String,
pub checkpoint_token: String,
pub updates: Vec<OperationUpdate>,
}
pub type CheckpointRecorder = Arc<Mutex<Vec<CheckpointCall>>>;
pub type OperationRecorder = Arc<Mutex<Vec<OperationRecord>>>;
pub type BatchCallCounter = Arc<Mutex<usize>>;
pub struct MockBackend {
calls: CheckpointRecorder,
operations: OperationRecorder,
checkpoint_token: String,
batch_call_count: BatchCallCounter,
}
impl MockBackend {
pub fn new(checkpoint_token: &str) -> (Self, CheckpointRecorder, OperationRecorder) {
let calls = Arc::new(Mutex::new(Vec::new()));
let operations = Arc::new(Mutex::new(Vec::new()));
let backend = Self {
calls: calls.clone(),
operations: operations.clone(),
checkpoint_token: checkpoint_token.to_string(),
batch_call_count: Arc::new(Mutex::new(0)),
};
(backend, calls, operations)
}
pub fn batch_call_counter(&self) -> BatchCallCounter {
self.batch_call_count.clone()
}
}
#[async_trait::async_trait]
impl DurableBackend for MockBackend {
async fn checkpoint(
&self,
arn: &str,
checkpoint_token: &str,
updates: Vec<OperationUpdate>,
_client_token: Option<&str>,
) -> Result<CheckpointDurableExecutionOutput, DurableError> {
for update in &updates {
if update.action() == &aws_sdk_lambda::types::OperationAction::Start {
let op_type = match update.r#type() {
aws_sdk_lambda::types::OperationType::Step => "step",
aws_sdk_lambda::types::OperationType::Wait => "wait",
aws_sdk_lambda::types::OperationType::Callback => "callback",
aws_sdk_lambda::types::OperationType::ChainedInvoke => "invoke",
_ => "unknown",
};
let name = update.name().unwrap_or("").to_string();
self.operations.lock().await.push(OperationRecord {
name,
operation_type: op_type.to_string(),
});
}
}
self.calls.lock().await.push(CheckpointCall {
arn: arn.to_string(),
checkpoint_token: checkpoint_token.to_string(),
updates,
});
Ok(CheckpointDurableExecutionOutput::builder()
.checkpoint_token(&self.checkpoint_token)
.build())
}
async fn batch_checkpoint(
&self,
arn: &str,
checkpoint_token: &str,
updates: Vec<OperationUpdate>,
_client_token: Option<&str>,
) -> Result<
aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput,
DurableError,
> {
*self.batch_call_count.lock().await += 1;
for update in &updates {
if update.action() == &aws_sdk_lambda::types::OperationAction::Start {
let op_type = match update.r#type() {
aws_sdk_lambda::types::OperationType::Step => "step",
aws_sdk_lambda::types::OperationType::Wait => "wait",
aws_sdk_lambda::types::OperationType::Callback => "callback",
aws_sdk_lambda::types::OperationType::ChainedInvoke => "invoke",
_ => "unknown",
};
let name = update.name().unwrap_or("").to_string();
self.operations.lock().await.push(OperationRecord {
name,
operation_type: op_type.to_string(),
});
}
}
self.calls.lock().await.push(CheckpointCall {
arn: arn.to_string(),
checkpoint_token: checkpoint_token.to_string(),
updates,
});
Ok(
aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput::builder()
.checkpoint_token(&self.checkpoint_token)
.build(),
)
}
async fn get_execution_state(
&self,
_arn: &str,
_checkpoint_token: &str,
_next_marker: &str,
_max_items: i32,
) -> Result<GetDurableExecutionStateOutput, DurableError> {
Ok(GetDurableExecutionStateOutput::builder()
.build()
.expect("empty execution state"))
}
}