use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::context::DurableContext;
use crate::error::DurableError;
impl DurableContext {
#[allow(clippy::await_holding_lock)]
pub async fn invoke<T, P>(
&mut self,
name: &str,
function_name: &str,
payload: &P,
) -> Result<T, DurableError>
where
T: DeserializeOwned,
P: Serialize,
{
let op_id = self.replay_engine_mut().generate_operation_id();
let span = tracing::info_span!(
"durable_operation",
op.name = name,
op.type = "invoke",
op.id = %op_id,
);
let _guard = span.enter();
tracing::trace!("durable_operation");
if let Some(op) = self.replay_engine().check_result(&op_id) {
match &op.status {
OperationStatus::Succeeded => {
let result = Self::deserialize_invoke_result::<T>(op, name)?;
self.replay_engine_mut().track_replay(&op_id);
return Ok(result);
}
_ => {
let error_message = Self::extract_invoke_error(op);
return Err(DurableError::invoke_failed(name, error_message));
}
}
}
if self.replay_engine().get_operation(&op_id).is_some() {
return Err(DurableError::invoke_suspended(name));
}
let serialized_payload = serde_json::to_string(payload)
.map_err(|e| DurableError::serialization(std::any::type_name::<P>(), e))?;
let invoke_opts = aws_sdk_lambda::types::ChainedInvokeOptions::builder()
.function_name(function_name)
.build()
.map_err(|e| DurableError::checkpoint_failed(name, e))?;
let start_update = OperationUpdate::builder()
.id(op_id.clone())
.r#type(OperationType::ChainedInvoke)
.action(OperationAction::Start)
.sub_type("ChainedInvoke")
.name(name)
.payload(serialized_payload)
.chained_invoke_options(invoke_opts)
.build()
.map_err(|e| DurableError::checkpoint_failed(name, e))?;
let start_response = self
.backend()
.checkpoint(
self.arn(),
self.checkpoint_token(),
vec![start_update],
None,
)
.await?;
let new_token = start_response.checkpoint_token().ok_or_else(|| {
DurableError::checkpoint_failed(
name,
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"checkpoint response missing checkpoint_token",
),
)
})?;
self.set_checkpoint_token(new_token.to_string());
if let Some(new_state) = start_response.new_execution_state() {
for op in new_state.operations() {
self.replay_engine_mut()
.insert_operation(op.id().to_string(), op.clone());
}
}
if let Some(op) = self.replay_engine().check_result(&op_id) {
match &op.status {
OperationStatus::Succeeded => {
let result = Self::deserialize_invoke_result::<T>(op, name)?;
self.replay_engine_mut().track_replay(&op_id);
return Ok(result);
}
_ => {
let error_message = Self::extract_invoke_error(op);
return Err(DurableError::invoke_failed(name, error_message));
}
}
}
Err(DurableError::invoke_suspended(name))
}
fn deserialize_invoke_result<T: DeserializeOwned>(
op: &aws_sdk_lambda::types::Operation,
name: &str,
) -> Result<T, DurableError> {
let result_str = op
.chained_invoke_details()
.and_then(|d| d.result())
.or_else(|| op.step_details().and_then(|d| d.result()))
.ok_or_else(|| {
DurableError::checkpoint_failed(
name,
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"invoke succeeded but no result in chained_invoke_details or step_details",
),
)
})?;
serde_json::from_str(result_str)
.map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))
}
fn extract_invoke_error(op: &aws_sdk_lambda::types::Operation) -> String {
op.chained_invoke_details()
.and_then(|d| d.error())
.map(|e| {
format!(
"{}: {}",
e.error_type().unwrap_or("Unknown"),
e.error_data().unwrap_or("")
)
})
.unwrap_or_else(|| "invoke failed".to_string())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
use aws_sdk_lambda::types::{
ChainedInvokeDetails, ErrorObject, Operation, OperationAction, OperationStatus,
OperationType, OperationUpdate,
};
use aws_smithy_types::DateTime;
use tokio::sync::Mutex;
use tracing_test::traced_test;
use crate::backend::DurableBackend;
use crate::context::DurableContext;
use crate::error::DurableError;
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct CheckpointCall {
arn: String,
checkpoint_token: String,
updates: Vec<OperationUpdate>,
}
struct InvokeMockBackend {
calls: Arc<Mutex<Vec<CheckpointCall>>>,
checkpoint_token: String,
response_operation: Option<Operation>,
}
impl InvokeMockBackend {
fn new(
checkpoint_token: &str,
response_op: Option<Operation>,
) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
let calls = Arc::new(Mutex::new(Vec::new()));
let backend = Self {
calls: calls.clone(),
checkpoint_token: checkpoint_token.to_string(),
response_operation: response_op,
};
(backend, calls)
}
}
#[async_trait::async_trait]
impl DurableBackend for InvokeMockBackend {
async fn checkpoint(
&self,
arn: &str,
checkpoint_token: &str,
updates: Vec<OperationUpdate>,
_client_token: Option<&str>,
) -> Result<CheckpointDurableExecutionOutput, DurableError> {
self.calls.lock().await.push(CheckpointCall {
arn: arn.to_string(),
checkpoint_token: checkpoint_token.to_string(),
updates,
});
let mut builder = CheckpointDurableExecutionOutput::builder()
.checkpoint_token(&self.checkpoint_token);
if let Some(ref op) = self.response_operation {
let new_state = aws_sdk_lambda::types::CheckpointUpdatedExecutionState::builder()
.operations(op.clone())
.build();
builder = builder.new_execution_state(new_state);
}
Ok(builder.build())
}
async fn get_execution_state(
&self,
_arn: &str,
_checkpoint_token: &str,
_next_marker: &str,
_max_items: i32,
) -> Result<GetDurableExecutionStateOutput, DurableError> {
Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
}
}
fn first_op_id() -> String {
let mut gen = crate::operation_id::OperationIdGenerator::new(None);
gen.next_id()
}
fn make_invoke_op(
id: &str,
status: OperationStatus,
result: Option<&str>,
error: Option<ErrorObject>,
) -> Operation {
let mut details_builder = ChainedInvokeDetails::builder();
if let Some(r) = result {
details_builder = details_builder.result(r);
}
if let Some(e) = error {
details_builder = details_builder.error(e);
}
Operation::builder()
.id(id)
.r#type(OperationType::ChainedInvoke)
.status(status)
.name("test_invoke")
.start_timestamp(DateTime::from_secs(0))
.chained_invoke_details(details_builder.build())
.build()
.unwrap()
}
#[tokio::test]
async fn test_invoke_sends_start_checkpoint_and_suspends() {
let (backend, calls) = InvokeMockBackend::new("new-token", None);
let mut ctx = DurableContext::new(
Arc::new(backend),
"arn:test".to_string(),
"initial-token".to_string(),
vec![],
None,
)
.await
.unwrap();
let result = ctx
.invoke::<String, _>(
"call_processor",
"target-lambda",
&serde_json::json!({"id": 42}),
)
.await;
let err = result.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("invoke suspended"), "error: {msg}");
assert!(msg.contains("call_processor"), "error: {msg}");
let captured = calls.lock().await;
assert_eq!(captured.len(), 1, "expected exactly 1 checkpoint (START)");
let update = &captured[0].updates[0];
assert_eq!(update.r#type(), &OperationType::ChainedInvoke);
assert_eq!(update.action(), &OperationAction::Start);
assert_eq!(update.name(), Some("call_processor"));
assert_eq!(update.sub_type(), Some("ChainedInvoke"));
let payload = update.payload().expect("should have payload");
assert!(
payload.contains("42"),
"payload should contain id: {payload}"
);
let invoke_opts = update
.chained_invoke_options()
.expect("should have chained_invoke_options");
assert_eq!(invoke_opts.function_name(), "target-lambda");
}
#[tokio::test]
async fn test_invoke_replays_succeeded_result() {
let op_id = first_op_id();
let invoke_op = make_invoke_op(
&op_id,
OperationStatus::Succeeded,
Some(r#"{"status":"processed","amount":100}"#),
None,
);
let (backend, calls) = InvokeMockBackend::new("token", None);
let mut ctx = DurableContext::new(
Arc::new(backend),
"arn:test".to_string(),
"tok".to_string(),
vec![invoke_op],
None,
)
.await
.unwrap();
let result: serde_json::Value = ctx
.invoke("call_processor", "target-lambda", &serde_json::json!({}))
.await
.unwrap();
assert_eq!(result["status"], "processed");
assert_eq!(result["amount"], 100);
let captured = calls.lock().await;
assert_eq!(captured.len(), 0, "no checkpoints during replay");
}
#[tokio::test]
async fn test_invoke_returns_error_on_failed() {
let op_id = first_op_id();
let error_obj = ErrorObject::builder()
.error_type("TargetError")
.error_data("target function crashed")
.build();
let invoke_op = make_invoke_op(&op_id, OperationStatus::Failed, None, Some(error_obj));
let (backend, _) = InvokeMockBackend::new("token", None);
let mut ctx = DurableContext::new(
Arc::new(backend),
"arn:test".to_string(),
"tok".to_string(),
vec![invoke_op],
None,
)
.await
.unwrap();
let err = ctx
.invoke::<String, _>("call_processor", "target-lambda", &serde_json::json!({}))
.await
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("invoke failed"), "error: {msg}");
assert!(msg.contains("TargetError"), "error: {msg}");
assert!(msg.contains("target function crashed"), "error: {msg}");
}
#[tokio::test]
async fn test_invoke_suspends_on_started() {
let op_id = first_op_id();
let invoke_op = make_invoke_op(&op_id, OperationStatus::Started, None, None);
let (backend, _) = InvokeMockBackend::new("token", None);
let mut ctx = DurableContext::new(
Arc::new(backend),
"arn:test".to_string(),
"tok".to_string(),
vec![invoke_op],
None,
)
.await
.unwrap();
let err = ctx
.invoke::<String, _>("call_processor", "target-lambda", &serde_json::json!({}))
.await
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("invoke suspended"), "error: {msg}");
}
#[tokio::test]
async fn test_invoke_double_check_immediate_completion() {
let op_id = first_op_id();
let completed_op = make_invoke_op(
&op_id,
OperationStatus::Succeeded,
Some(r#""instant-result""#),
None,
);
let (backend, calls) = InvokeMockBackend::new("new-token", Some(completed_op));
let mut ctx = DurableContext::new(
Arc::new(backend),
"arn:test".to_string(),
"tok".to_string(),
vec![],
None,
)
.await
.unwrap();
let result: String = ctx
.invoke("call_processor", "target-lambda", &serde_json::json!({}))
.await
.unwrap();
assert_eq!(result, "instant-result");
let captured = calls.lock().await;
assert_eq!(captured.len(), 1, "START checkpoint sent");
}
#[traced_test]
#[tokio::test]
async fn test_invoke_emits_span() {
let (backend, _calls) = InvokeMockBackend::new("tok", None);
let mut ctx = DurableContext::new(
Arc::new(backend),
"arn:test".to_string(),
"tok".to_string(),
vec![],
None,
)
.await
.unwrap();
let _ = ctx
.invoke::<serde_json::Value, _>("target", "my-lambda", &serde_json::json!({}))
.await;
assert!(logs_contain("durable_operation"));
assert!(logs_contain("target"));
assert!(logs_contain("invoke"));
}
}