Skip to main content

durable_lambda_core/operations/
invoke.rs

1//! Invoke operation — durable Lambda-to-Lambda invocation.
2//!
3//! Implement FR17-FR18: invoke target function, checkpoint result for replay.
4//!
5//! The invoke operation sends a **single START checkpoint** with the serialized
6//! payload and target function name. The server invokes the target Lambda
7//! asynchronously and transitions the operation to SUCCEEDED/FAILED when done.
8//! The wire type is `ChainedInvoke` (matching the Python SDK).
9
10use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13
14use crate::context::DurableContext;
15use crate::error::DurableError;
16
17impl DurableContext {
18    /// Durably invoke another Lambda function and return its result.
19    ///
20    /// During execution mode, serializes the payload, sends a START checkpoint
21    /// with the target function name, and returns [`DurableError::InvokeSuspended`]
22    /// to signal the function should exit. The server invokes the target function
23    /// asynchronously and re-invokes this Lambda when complete.
24    ///
25    /// During replay mode, returns the cached result without re-invoking the
26    /// target function.
27    ///
28    /// If the target function completes immediately (detected via the
29    /// double-check pattern), the result is returned directly without
30    /// suspending.
31    ///
32    /// # Arguments
33    ///
34    /// * `name` — Human-readable name for the invoke operation
35    /// * `function_name` — Name or ARN of the target Lambda function
36    /// * `payload` — Input payload to send to the target function
37    ///
38    /// # Errors
39    ///
40    /// Returns [`DurableError::InvokeSuspended`] when the invoke has been
41    /// checkpointed and the target is still executing — the handler must
42    /// propagate this to exit.
43    /// Returns [`DurableError::InvokeFailed`] if the target function failed,
44    /// timed out, or was stopped.
45    /// Returns [`DurableError::Serialization`] if the payload cannot be
46    /// serialized.
47    /// Returns [`DurableError::Deserialization`] if the result cannot be
48    /// deserialized as type `T`.
49    /// Returns [`DurableError::CheckpointFailed`] if the AWS checkpoint API
50    /// call fails.
51    ///
52    /// # Examples
53    ///
54    /// ```no_run
55    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
56    /// let result: String = ctx.invoke(
57    ///     "call_processor",
58    ///     "payment-processor-lambda",
59    ///     &serde_json::json!({"order_id": 123}),
60    /// ).await?;
61    /// println!("Target returned: {result}");
62    /// # Ok(())
63    /// # }
64    /// ```
65    #[allow(clippy::await_holding_lock)]
66    pub async fn invoke<T, P>(
67        &mut self,
68        name: &str,
69        function_name: &str,
70        payload: &P,
71    ) -> Result<T, DurableError>
72    where
73        T: DeserializeOwned,
74        P: Serialize,
75    {
76        let op_id = self.replay_engine_mut().generate_operation_id();
77
78        let span = tracing::info_span!(
79            "durable_operation",
80            op.name = name,
81            op.type = "invoke",
82            op.id = %op_id,
83        );
84        let _guard = span.enter();
85        tracing::trace!("durable_operation");
86
87        // Replay path: check for completed result (SUCCEEDED/FAILED/TIMED_OUT/etc).
88        if let Some(op) = self.replay_engine().check_result(&op_id) {
89            match &op.status {
90                OperationStatus::Succeeded => {
91                    let result = Self::deserialize_invoke_result::<T>(op, name)?;
92                    self.replay_engine_mut().track_replay(&op_id);
93                    return Ok(result);
94                }
95                _ => {
96                    // Failed/Cancelled/TimedOut/Stopped — completed but not successful.
97                    let error_message = Self::extract_invoke_error(op);
98                    return Err(DurableError::invoke_failed(name, error_message));
99                }
100            }
101        }
102
103        // Check for non-completed status (STARTED/PENDING — target still running).
104        if self.replay_engine().get_operation(&op_id).is_some() {
105            return Err(DurableError::invoke_suspended(name));
106        }
107
108        // Execute path — serialize payload and send START checkpoint.
109        let serialized_payload = serde_json::to_string(payload)
110            .map_err(|e| DurableError::serialization(std::any::type_name::<P>(), e))?;
111
112        let invoke_opts = aws_sdk_lambda::types::ChainedInvokeOptions::builder()
113            .function_name(function_name)
114            .build()
115            .map_err(|e| DurableError::checkpoint_failed(name, e))?;
116
117        let start_update = OperationUpdate::builder()
118            .id(op_id.clone())
119            .r#type(OperationType::ChainedInvoke)
120            .action(OperationAction::Start)
121            .sub_type("ChainedInvoke")
122            .name(name)
123            .payload(serialized_payload)
124            .chained_invoke_options(invoke_opts)
125            .build()
126            .map_err(|e| DurableError::checkpoint_failed(name, e))?;
127
128        let start_response = self
129            .backend()
130            .checkpoint(
131                self.arn(),
132                self.checkpoint_token(),
133                vec![start_update],
134                None,
135            )
136            .await?;
137
138        let new_token = start_response.checkpoint_token().ok_or_else(|| {
139            DurableError::checkpoint_failed(
140                name,
141                std::io::Error::new(
142                    std::io::ErrorKind::InvalidData,
143                    "checkpoint response missing checkpoint_token",
144                ),
145            )
146        })?;
147        self.set_checkpoint_token(new_token.to_string());
148
149        // Merge any new execution state from checkpoint response.
150        if let Some(new_state) = start_response.new_execution_state() {
151            for op in new_state.operations() {
152                self.replay_engine_mut()
153                    .insert_operation(op.id().to_string(), op.clone());
154            }
155        }
156
157        // Double-check: detect immediate completion.
158        if let Some(op) = self.replay_engine().check_result(&op_id) {
159            match &op.status {
160                OperationStatus::Succeeded => {
161                    let result = Self::deserialize_invoke_result::<T>(op, name)?;
162                    self.replay_engine_mut().track_replay(&op_id);
163                    return Ok(result);
164                }
165                _ => {
166                    let error_message = Self::extract_invoke_error(op);
167                    return Err(DurableError::invoke_failed(name, error_message));
168                }
169            }
170        }
171
172        // Target still executing — suspend.
173        Err(DurableError::invoke_suspended(name))
174    }
175
176    /// Deserialize the result from a succeeded invoke operation.
177    ///
178    /// Checks `chained_invoke_details.result` first (per SDK spec), then falls
179    /// back to `step_details.result` (service may store the chained invoke
180    /// result there instead).
181    fn deserialize_invoke_result<T: DeserializeOwned>(
182        op: &aws_sdk_lambda::types::Operation,
183        name: &str,
184    ) -> Result<T, DurableError> {
185        let result_str = op
186            .chained_invoke_details()
187            .and_then(|d| d.result())
188            .or_else(|| op.step_details().and_then(|d| d.result()))
189            .ok_or_else(|| {
190                DurableError::checkpoint_failed(
191                    name,
192                    std::io::Error::new(
193                        std::io::ErrorKind::InvalidData,
194                        "invoke succeeded but no result in chained_invoke_details or step_details",
195                    ),
196                )
197            })?;
198
199        serde_json::from_str(result_str)
200            .map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))
201    }
202
203    /// Extract error message from an invoke operation's chained_invoke_details.
204    fn extract_invoke_error(op: &aws_sdk_lambda::types::Operation) -> String {
205        op.chained_invoke_details()
206            .and_then(|d| d.error())
207            .map(|e| {
208                format!(
209                    "{}: {}",
210                    e.error_type().unwrap_or("Unknown"),
211                    e.error_data().unwrap_or("")
212                )
213            })
214            .unwrap_or_else(|| "invoke failed".to_string())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use std::sync::Arc;
221
222    use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
223    use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
224    use aws_sdk_lambda::types::{
225        ChainedInvokeDetails, ErrorObject, Operation, OperationAction, OperationStatus,
226        OperationType, OperationUpdate,
227    };
228    use aws_smithy_types::DateTime;
229    use tokio::sync::Mutex;
230    use tracing_test::traced_test;
231
232    use crate::backend::DurableBackend;
233    use crate::context::DurableContext;
234    use crate::error::DurableError;
235
236    #[derive(Debug, Clone)]
237    #[allow(dead_code)]
238    struct CheckpointCall {
239        arn: String,
240        checkpoint_token: String,
241        updates: Vec<OperationUpdate>,
242    }
243
244    /// MockBackend for invoke tests. Optionally returns an operation in new_execution_state.
245    struct InvokeMockBackend {
246        calls: Arc<Mutex<Vec<CheckpointCall>>>,
247        checkpoint_token: String,
248        response_operation: Option<Operation>,
249    }
250
251    impl InvokeMockBackend {
252        fn new(
253            checkpoint_token: &str,
254            response_op: Option<Operation>,
255        ) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
256            let calls = Arc::new(Mutex::new(Vec::new()));
257            let backend = Self {
258                calls: calls.clone(),
259                checkpoint_token: checkpoint_token.to_string(),
260                response_operation: response_op,
261            };
262            (backend, calls)
263        }
264    }
265
266    #[async_trait::async_trait]
267    impl DurableBackend for InvokeMockBackend {
268        async fn checkpoint(
269            &self,
270            arn: &str,
271            checkpoint_token: &str,
272            updates: Vec<OperationUpdate>,
273            _client_token: Option<&str>,
274        ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
275            self.calls.lock().await.push(CheckpointCall {
276                arn: arn.to_string(),
277                checkpoint_token: checkpoint_token.to_string(),
278                updates,
279            });
280
281            let mut builder = CheckpointDurableExecutionOutput::builder()
282                .checkpoint_token(&self.checkpoint_token);
283
284            if let Some(ref op) = self.response_operation {
285                let new_state = aws_sdk_lambda::types::CheckpointUpdatedExecutionState::builder()
286                    .operations(op.clone())
287                    .build();
288                builder = builder.new_execution_state(new_state);
289            }
290
291            Ok(builder.build())
292        }
293
294        async fn get_execution_state(
295            &self,
296            _arn: &str,
297            _checkpoint_token: &str,
298            _next_marker: &str,
299            _max_items: i32,
300        ) -> Result<GetDurableExecutionStateOutput, DurableError> {
301            Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
302        }
303    }
304
305    fn first_op_id() -> String {
306        let mut gen = crate::operation_id::OperationIdGenerator::new(None);
307        gen.next_id()
308    }
309
310    fn make_invoke_op(
311        id: &str,
312        status: OperationStatus,
313        result: Option<&str>,
314        error: Option<ErrorObject>,
315    ) -> Operation {
316        let mut details_builder = ChainedInvokeDetails::builder();
317        if let Some(r) = result {
318            details_builder = details_builder.result(r);
319        }
320        if let Some(e) = error {
321            details_builder = details_builder.error(e);
322        }
323
324        Operation::builder()
325            .id(id)
326            .r#type(OperationType::ChainedInvoke)
327            .status(status)
328            .name("test_invoke")
329            .start_timestamp(DateTime::from_secs(0))
330            .chained_invoke_details(details_builder.build())
331            .build()
332            .unwrap()
333    }
334
335    // ─── invoke tests ────────────────────────────────────────────────────
336
337    #[tokio::test]
338    async fn test_invoke_sends_start_checkpoint_and_suspends() {
339        // No response operation → target still executing → should suspend.
340        let (backend, calls) = InvokeMockBackend::new("new-token", None);
341        let mut ctx = DurableContext::new(
342            Arc::new(backend),
343            "arn:test".to_string(),
344            "initial-token".to_string(),
345            vec![],
346            None,
347        )
348        .await
349        .unwrap();
350
351        let result = ctx
352            .invoke::<String, _>(
353                "call_processor",
354                "target-lambda",
355                &serde_json::json!({"id": 42}),
356            )
357            .await;
358
359        // Should return InvokeSuspended.
360        let err = result.unwrap_err();
361        let msg = err.to_string();
362        assert!(msg.contains("invoke suspended"), "error: {msg}");
363        assert!(msg.contains("call_processor"), "error: {msg}");
364
365        // Verify START checkpoint was sent.
366        let captured = calls.lock().await;
367        assert_eq!(captured.len(), 1, "expected exactly 1 checkpoint (START)");
368
369        let update = &captured[0].updates[0];
370        assert_eq!(update.r#type(), &OperationType::ChainedInvoke);
371        assert_eq!(update.action(), &OperationAction::Start);
372        assert_eq!(update.name(), Some("call_processor"));
373        assert_eq!(update.sub_type(), Some("ChainedInvoke"));
374
375        // Verify payload is set.
376        let payload = update.payload().expect("should have payload");
377        assert!(
378            payload.contains("42"),
379            "payload should contain id: {payload}"
380        );
381
382        // Verify ChainedInvokeOptions with function_name.
383        let invoke_opts = update
384            .chained_invoke_options()
385            .expect("should have chained_invoke_options");
386        assert_eq!(invoke_opts.function_name(), "target-lambda");
387    }
388
389    #[tokio::test]
390    async fn test_invoke_replays_succeeded_result() {
391        let op_id = first_op_id();
392
393        let invoke_op = make_invoke_op(
394            &op_id,
395            OperationStatus::Succeeded,
396            Some(r#"{"status":"processed","amount":100}"#),
397            None,
398        );
399
400        let (backend, calls) = InvokeMockBackend::new("token", None);
401        let mut ctx = DurableContext::new(
402            Arc::new(backend),
403            "arn:test".to_string(),
404            "tok".to_string(),
405            vec![invoke_op],
406            None,
407        )
408        .await
409        .unwrap();
410
411        let result: serde_json::Value = ctx
412            .invoke("call_processor", "target-lambda", &serde_json::json!({}))
413            .await
414            .unwrap();
415
416        assert_eq!(result["status"], "processed");
417        assert_eq!(result["amount"], 100);
418
419        // No checkpoints during replay.
420        let captured = calls.lock().await;
421        assert_eq!(captured.len(), 0, "no checkpoints during replay");
422    }
423
424    #[tokio::test]
425    async fn test_invoke_returns_error_on_failed() {
426        let op_id = first_op_id();
427
428        let error_obj = ErrorObject::builder()
429            .error_type("TargetError")
430            .error_data("target function crashed")
431            .build();
432
433        let invoke_op = make_invoke_op(&op_id, OperationStatus::Failed, None, Some(error_obj));
434
435        let (backend, _) = InvokeMockBackend::new("token", None);
436        let mut ctx = DurableContext::new(
437            Arc::new(backend),
438            "arn:test".to_string(),
439            "tok".to_string(),
440            vec![invoke_op],
441            None,
442        )
443        .await
444        .unwrap();
445
446        let err = ctx
447            .invoke::<String, _>("call_processor", "target-lambda", &serde_json::json!({}))
448            .await
449            .unwrap_err();
450
451        let msg = err.to_string();
452        assert!(msg.contains("invoke failed"), "error: {msg}");
453        assert!(msg.contains("TargetError"), "error: {msg}");
454        assert!(msg.contains("target function crashed"), "error: {msg}");
455    }
456
457    #[tokio::test]
458    async fn test_invoke_suspends_on_started() {
459        let op_id = first_op_id();
460
461        // Operation in STARTED status — target still running.
462        let invoke_op = make_invoke_op(&op_id, OperationStatus::Started, None, None);
463
464        let (backend, _) = InvokeMockBackend::new("token", None);
465        let mut ctx = DurableContext::new(
466            Arc::new(backend),
467            "arn:test".to_string(),
468            "tok".to_string(),
469            vec![invoke_op],
470            None,
471        )
472        .await
473        .unwrap();
474
475        let err = ctx
476            .invoke::<String, _>("call_processor", "target-lambda", &serde_json::json!({}))
477            .await
478            .unwrap_err();
479
480        let msg = err.to_string();
481        assert!(msg.contains("invoke suspended"), "error: {msg}");
482    }
483
484    #[tokio::test]
485    async fn test_invoke_double_check_immediate_completion() {
486        let op_id = first_op_id();
487
488        // MockBackend returns SUCCEEDED operation in new_execution_state.
489        let completed_op = make_invoke_op(
490            &op_id,
491            OperationStatus::Succeeded,
492            Some(r#""instant-result""#),
493            None,
494        );
495
496        let (backend, calls) = InvokeMockBackend::new("new-token", Some(completed_op));
497        let mut ctx = DurableContext::new(
498            Arc::new(backend),
499            "arn:test".to_string(),
500            "tok".to_string(),
501            vec![],
502            None,
503        )
504        .await
505        .unwrap();
506
507        // Should return Ok because double-check detects immediate completion.
508        let result: String = ctx
509            .invoke("call_processor", "target-lambda", &serde_json::json!({}))
510            .await
511            .unwrap();
512
513        assert_eq!(result, "instant-result");
514
515        // START checkpoint was still sent.
516        let captured = calls.lock().await;
517        assert_eq!(captured.len(), 1, "START checkpoint sent");
518    }
519
520    // ─── span tests (FEAT-17) ─────────────────────────────────────────────
521
522    #[traced_test]
523    #[tokio::test]
524    async fn test_invoke_emits_span() {
525        let (backend, _calls) = InvokeMockBackend::new("tok", None);
526        let mut ctx = DurableContext::new(
527            Arc::new(backend),
528            "arn:test".to_string(),
529            "tok".to_string(),
530            vec![],
531            None,
532        )
533        .await
534        .unwrap();
535        // invoke returns InvokeSuspended — that's expected; span is emitted before suspension
536        let _ = ctx
537            .invoke::<serde_json::Value, _>("target", "my-lambda", &serde_json::json!({}))
538            .await;
539        assert!(logs_contain("durable_operation"));
540        assert!(logs_contain("target"));
541        assert!(logs_contain("invoke"));
542    }
543}