1use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13
14use crate::context::DurableContext;
15use crate::error::DurableError;
16
17impl DurableContext {
18 #[allow(clippy::await_holding_lock)]
66 pub async fn invoke<T, P>(
67 &mut self,
68 name: &str,
69 function_name: &str,
70 payload: &P,
71 ) -> Result<T, DurableError>
72 where
73 T: DeserializeOwned,
74 P: Serialize,
75 {
76 let op_id = self.replay_engine_mut().generate_operation_id();
77
78 let span = tracing::info_span!(
79 "durable_operation",
80 op.name = name,
81 op.type = "invoke",
82 op.id = %op_id,
83 );
84 let _guard = span.enter();
85 tracing::trace!("durable_operation");
86
87 if let Some(op) = self.replay_engine().check_result(&op_id) {
89 match &op.status {
90 OperationStatus::Succeeded => {
91 let result = Self::deserialize_invoke_result::<T>(op, name)?;
92 self.replay_engine_mut().track_replay(&op_id);
93 return Ok(result);
94 }
95 _ => {
96 let error_message = Self::extract_invoke_error(op);
98 return Err(DurableError::invoke_failed(name, error_message));
99 }
100 }
101 }
102
103 if self.replay_engine().get_operation(&op_id).is_some() {
105 return Err(DurableError::invoke_suspended(name));
106 }
107
108 let serialized_payload = serde_json::to_string(payload)
110 .map_err(|e| DurableError::serialization(std::any::type_name::<P>(), e))?;
111
112 let invoke_opts = aws_sdk_lambda::types::ChainedInvokeOptions::builder()
113 .function_name(function_name)
114 .build()
115 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
116
117 let start_update = OperationUpdate::builder()
118 .id(op_id.clone())
119 .r#type(OperationType::ChainedInvoke)
120 .action(OperationAction::Start)
121 .sub_type("ChainedInvoke")
122 .name(name)
123 .payload(serialized_payload)
124 .chained_invoke_options(invoke_opts)
125 .build()
126 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
127
128 let start_response = self
129 .backend()
130 .checkpoint(
131 self.arn(),
132 self.checkpoint_token(),
133 vec![start_update],
134 None,
135 )
136 .await?;
137
138 let new_token = start_response.checkpoint_token().ok_or_else(|| {
139 DurableError::checkpoint_failed(
140 name,
141 std::io::Error::new(
142 std::io::ErrorKind::InvalidData,
143 "checkpoint response missing checkpoint_token",
144 ),
145 )
146 })?;
147 self.set_checkpoint_token(new_token.to_string());
148
149 if let Some(new_state) = start_response.new_execution_state() {
151 for op in new_state.operations() {
152 self.replay_engine_mut()
153 .insert_operation(op.id().to_string(), op.clone());
154 }
155 }
156
157 if let Some(op) = self.replay_engine().check_result(&op_id) {
159 match &op.status {
160 OperationStatus::Succeeded => {
161 let result = Self::deserialize_invoke_result::<T>(op, name)?;
162 self.replay_engine_mut().track_replay(&op_id);
163 return Ok(result);
164 }
165 _ => {
166 let error_message = Self::extract_invoke_error(op);
167 return Err(DurableError::invoke_failed(name, error_message));
168 }
169 }
170 }
171
172 Err(DurableError::invoke_suspended(name))
174 }
175
176 fn deserialize_invoke_result<T: DeserializeOwned>(
182 op: &aws_sdk_lambda::types::Operation,
183 name: &str,
184 ) -> Result<T, DurableError> {
185 let result_str = op
186 .chained_invoke_details()
187 .and_then(|d| d.result())
188 .or_else(|| op.step_details().and_then(|d| d.result()))
189 .ok_or_else(|| {
190 DurableError::checkpoint_failed(
191 name,
192 std::io::Error::new(
193 std::io::ErrorKind::InvalidData,
194 "invoke succeeded but no result in chained_invoke_details or step_details",
195 ),
196 )
197 })?;
198
199 serde_json::from_str(result_str)
200 .map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))
201 }
202
203 fn extract_invoke_error(op: &aws_sdk_lambda::types::Operation) -> String {
205 op.chained_invoke_details()
206 .and_then(|d| d.error())
207 .map(|e| {
208 format!(
209 "{}: {}",
210 e.error_type().unwrap_or("Unknown"),
211 e.error_data().unwrap_or("")
212 )
213 })
214 .unwrap_or_else(|| "invoke failed".to_string())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use std::sync::Arc;
221
222 use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
223 use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
224 use aws_sdk_lambda::types::{
225 ChainedInvokeDetails, ErrorObject, Operation, OperationAction, OperationStatus,
226 OperationType, OperationUpdate,
227 };
228 use aws_smithy_types::DateTime;
229 use tokio::sync::Mutex;
230 use tracing_test::traced_test;
231
232 use crate::backend::DurableBackend;
233 use crate::context::DurableContext;
234 use crate::error::DurableError;
235
236 #[derive(Debug, Clone)]
237 #[allow(dead_code)]
238 struct CheckpointCall {
239 arn: String,
240 checkpoint_token: String,
241 updates: Vec<OperationUpdate>,
242 }
243
244 struct InvokeMockBackend {
246 calls: Arc<Mutex<Vec<CheckpointCall>>>,
247 checkpoint_token: String,
248 response_operation: Option<Operation>,
249 }
250
251 impl InvokeMockBackend {
252 fn new(
253 checkpoint_token: &str,
254 response_op: Option<Operation>,
255 ) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
256 let calls = Arc::new(Mutex::new(Vec::new()));
257 let backend = Self {
258 calls: calls.clone(),
259 checkpoint_token: checkpoint_token.to_string(),
260 response_operation: response_op,
261 };
262 (backend, calls)
263 }
264 }
265
266 #[async_trait::async_trait]
267 impl DurableBackend for InvokeMockBackend {
268 async fn checkpoint(
269 &self,
270 arn: &str,
271 checkpoint_token: &str,
272 updates: Vec<OperationUpdate>,
273 _client_token: Option<&str>,
274 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
275 self.calls.lock().await.push(CheckpointCall {
276 arn: arn.to_string(),
277 checkpoint_token: checkpoint_token.to_string(),
278 updates,
279 });
280
281 let mut builder = CheckpointDurableExecutionOutput::builder()
282 .checkpoint_token(&self.checkpoint_token);
283
284 if let Some(ref op) = self.response_operation {
285 let new_state = aws_sdk_lambda::types::CheckpointUpdatedExecutionState::builder()
286 .operations(op.clone())
287 .build();
288 builder = builder.new_execution_state(new_state);
289 }
290
291 Ok(builder.build())
292 }
293
294 async fn get_execution_state(
295 &self,
296 _arn: &str,
297 _checkpoint_token: &str,
298 _next_marker: &str,
299 _max_items: i32,
300 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
301 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
302 }
303 }
304
305 fn first_op_id() -> String {
306 let mut gen = crate::operation_id::OperationIdGenerator::new(None);
307 gen.next_id()
308 }
309
310 fn make_invoke_op(
311 id: &str,
312 status: OperationStatus,
313 result: Option<&str>,
314 error: Option<ErrorObject>,
315 ) -> Operation {
316 let mut details_builder = ChainedInvokeDetails::builder();
317 if let Some(r) = result {
318 details_builder = details_builder.result(r);
319 }
320 if let Some(e) = error {
321 details_builder = details_builder.error(e);
322 }
323
324 Operation::builder()
325 .id(id)
326 .r#type(OperationType::ChainedInvoke)
327 .status(status)
328 .name("test_invoke")
329 .start_timestamp(DateTime::from_secs(0))
330 .chained_invoke_details(details_builder.build())
331 .build()
332 .unwrap()
333 }
334
335 #[tokio::test]
338 async fn test_invoke_sends_start_checkpoint_and_suspends() {
339 let (backend, calls) = InvokeMockBackend::new("new-token", None);
341 let mut ctx = DurableContext::new(
342 Arc::new(backend),
343 "arn:test".to_string(),
344 "initial-token".to_string(),
345 vec![],
346 None,
347 )
348 .await
349 .unwrap();
350
351 let result = ctx
352 .invoke::<String, _>(
353 "call_processor",
354 "target-lambda",
355 &serde_json::json!({"id": 42}),
356 )
357 .await;
358
359 let err = result.unwrap_err();
361 let msg = err.to_string();
362 assert!(msg.contains("invoke suspended"), "error: {msg}");
363 assert!(msg.contains("call_processor"), "error: {msg}");
364
365 let captured = calls.lock().await;
367 assert_eq!(captured.len(), 1, "expected exactly 1 checkpoint (START)");
368
369 let update = &captured[0].updates[0];
370 assert_eq!(update.r#type(), &OperationType::ChainedInvoke);
371 assert_eq!(update.action(), &OperationAction::Start);
372 assert_eq!(update.name(), Some("call_processor"));
373 assert_eq!(update.sub_type(), Some("ChainedInvoke"));
374
375 let payload = update.payload().expect("should have payload");
377 assert!(
378 payload.contains("42"),
379 "payload should contain id: {payload}"
380 );
381
382 let invoke_opts = update
384 .chained_invoke_options()
385 .expect("should have chained_invoke_options");
386 assert_eq!(invoke_opts.function_name(), "target-lambda");
387 }
388
389 #[tokio::test]
390 async fn test_invoke_replays_succeeded_result() {
391 let op_id = first_op_id();
392
393 let invoke_op = make_invoke_op(
394 &op_id,
395 OperationStatus::Succeeded,
396 Some(r#"{"status":"processed","amount":100}"#),
397 None,
398 );
399
400 let (backend, calls) = InvokeMockBackend::new("token", None);
401 let mut ctx = DurableContext::new(
402 Arc::new(backend),
403 "arn:test".to_string(),
404 "tok".to_string(),
405 vec![invoke_op],
406 None,
407 )
408 .await
409 .unwrap();
410
411 let result: serde_json::Value = ctx
412 .invoke("call_processor", "target-lambda", &serde_json::json!({}))
413 .await
414 .unwrap();
415
416 assert_eq!(result["status"], "processed");
417 assert_eq!(result["amount"], 100);
418
419 let captured = calls.lock().await;
421 assert_eq!(captured.len(), 0, "no checkpoints during replay");
422 }
423
424 #[tokio::test]
425 async fn test_invoke_returns_error_on_failed() {
426 let op_id = first_op_id();
427
428 let error_obj = ErrorObject::builder()
429 .error_type("TargetError")
430 .error_data("target function crashed")
431 .build();
432
433 let invoke_op = make_invoke_op(&op_id, OperationStatus::Failed, None, Some(error_obj));
434
435 let (backend, _) = InvokeMockBackend::new("token", None);
436 let mut ctx = DurableContext::new(
437 Arc::new(backend),
438 "arn:test".to_string(),
439 "tok".to_string(),
440 vec![invoke_op],
441 None,
442 )
443 .await
444 .unwrap();
445
446 let err = ctx
447 .invoke::<String, _>("call_processor", "target-lambda", &serde_json::json!({}))
448 .await
449 .unwrap_err();
450
451 let msg = err.to_string();
452 assert!(msg.contains("invoke failed"), "error: {msg}");
453 assert!(msg.contains("TargetError"), "error: {msg}");
454 assert!(msg.contains("target function crashed"), "error: {msg}");
455 }
456
457 #[tokio::test]
458 async fn test_invoke_suspends_on_started() {
459 let op_id = first_op_id();
460
461 let invoke_op = make_invoke_op(&op_id, OperationStatus::Started, None, None);
463
464 let (backend, _) = InvokeMockBackend::new("token", None);
465 let mut ctx = DurableContext::new(
466 Arc::new(backend),
467 "arn:test".to_string(),
468 "tok".to_string(),
469 vec![invoke_op],
470 None,
471 )
472 .await
473 .unwrap();
474
475 let err = ctx
476 .invoke::<String, _>("call_processor", "target-lambda", &serde_json::json!({}))
477 .await
478 .unwrap_err();
479
480 let msg = err.to_string();
481 assert!(msg.contains("invoke suspended"), "error: {msg}");
482 }
483
484 #[tokio::test]
485 async fn test_invoke_double_check_immediate_completion() {
486 let op_id = first_op_id();
487
488 let completed_op = make_invoke_op(
490 &op_id,
491 OperationStatus::Succeeded,
492 Some(r#""instant-result""#),
493 None,
494 );
495
496 let (backend, calls) = InvokeMockBackend::new("new-token", Some(completed_op));
497 let mut ctx = DurableContext::new(
498 Arc::new(backend),
499 "arn:test".to_string(),
500 "tok".to_string(),
501 vec![],
502 None,
503 )
504 .await
505 .unwrap();
506
507 let result: String = ctx
509 .invoke("call_processor", "target-lambda", &serde_json::json!({}))
510 .await
511 .unwrap();
512
513 assert_eq!(result, "instant-result");
514
515 let captured = calls.lock().await;
517 assert_eq!(captured.len(), 1, "START checkpoint sent");
518 }
519
520 #[traced_test]
523 #[tokio::test]
524 async fn test_invoke_emits_span() {
525 let (backend, _calls) = InvokeMockBackend::new("tok", None);
526 let mut ctx = DurableContext::new(
527 Arc::new(backend),
528 "arn:test".to_string(),
529 "tok".to_string(),
530 vec![],
531 None,
532 )
533 .await
534 .unwrap();
535 let _ = ctx
537 .invoke::<serde_json::Value, _>("target", "my-lambda", &serde_json::json!({}))
538 .await;
539 assert!(logs_contain("durable_operation"));
540 assert!(logs_contain("target"));
541 assert!(logs_contain("invoke"));
542 }
543}