Skip to main content

durable_execution_sdk/handlers/
invoke.rs

1//! Invoke operation handler for the AWS Durable Execution SDK.
2//!
3//! This module implements the invoke handler which calls other
4//! durable Lambda functions from within a workflow.
5
6use std::sync::Arc;
7
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::config::InvokeConfig;
11use crate::context::{create_operation_span, LogInfo, Logger, OperationIdentifier};
12use crate::error::{DurableError, ErrorObject, TerminationReason};
13use crate::operation::{OperationType, OperationUpdate};
14use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
15use crate::state::ExecutionState;
16
17/// Invokes another durable Lambda function.
18///
19/// This handler implements the invoke semantics:
20/// - Calls the target Lambda function via service client
21/// - Handles timeout configuration
22/// - Checkpoints invocation and result
23/// - Propagates errors from invoked function
24///
25/// # Arguments
26///
27/// * `function_name` - The name or ARN of the Lambda function to invoke
28/// * `payload` - The payload to send to the function
29/// * `state` - The execution state for checkpointing
30/// * `op_id` - The operation identifier
31/// * `config` - Invoke configuration (timeout, serdes)
32/// * `logger` - Logger for structured logging
33///
34/// # Returns
35///
36/// The result from the invoked function, or an error if invocation fails.
37pub async fn invoke_handler<P, R>(
38    function_name: &str,
39    payload: P,
40    state: &Arc<ExecutionState>,
41    op_id: &OperationIdentifier,
42    config: &InvokeConfig<P, R>,
43    logger: &Arc<dyn Logger>,
44) -> Result<R, DurableError>
45where
46    P: Serialize + DeserializeOwned + Send,
47    R: Serialize + DeserializeOwned + Send,
48{
49    // Create tracing span for this operation
50    // Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 3.6
51    let span = create_operation_span("invoke", op_id, state.durable_execution_arn());
52    let _guard = span.enter();
53
54    let mut log_info =
55        LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
56    if let Some(ref parent_id) = op_id.parent_id {
57        log_info = log_info.with_parent_id(parent_id);
58    }
59
60    logger.debug(
61        &format!("Starting invoke operation: {} -> {}", op_id, function_name),
62        &log_info,
63    );
64
65    // Check for existing checkpoint (replay)
66    let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
67
68    if checkpoint_result.is_existent() {
69        // Check for non-deterministic execution
70        if let Some(op_type) = checkpoint_result.operation_type() {
71            if op_type != OperationType::Invoke {
72                span.record("status", "non_deterministic");
73                return Err(DurableError::NonDeterministic {
74                    message: format!(
75                        "Expected Invoke operation but found {:?} at operation_id {}",
76                        op_type, op_id.operation_id
77                    ),
78                    operation_id: Some(op_id.operation_id.clone()),
79                });
80            }
81        }
82
83        // Handle succeeded checkpoint
84        if checkpoint_result.is_succeeded() {
85            logger.debug(&format!("Replaying succeeded invoke: {}", op_id), &log_info);
86
87            if let Some(result_str) = checkpoint_result.result() {
88                let serdes = JsonSerDes::<R>::new();
89                let serdes_ctx =
90                    SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
91                let result = serdes.deserialize(result_str, &serdes_ctx).map_err(|e| {
92                    DurableError::SerDes {
93                        message: format!("Failed to deserialize invoke result: {}", e),
94                    }
95                })?;
96
97                state.track_replay(&op_id.operation_id).await;
98                span.record("status", "replayed_succeeded");
99                return Ok(result);
100            }
101        }
102
103        // Handle failed checkpoint
104        if checkpoint_result.is_failed() {
105            logger.debug(&format!("Replaying failed invoke: {}", op_id), &log_info);
106
107            state.track_replay(&op_id.operation_id).await;
108            span.record("status", "replayed_failed");
109
110            if let Some(error) = checkpoint_result.error() {
111                return Err(DurableError::Invocation {
112                    message: error.error_message.clone(),
113                    termination_reason: TerminationReason::InvocationError,
114                });
115            } else {
116                return Err(DurableError::invocation("Invoke failed with unknown error"));
117            }
118        }
119
120        // Handle STOPPED status (Requirement 7.7)
121        if checkpoint_result.is_stopped() {
122            logger.debug(&format!("Replaying stopped invoke: {}", op_id), &log_info);
123
124            state.track_replay(&op_id.operation_id).await;
125            span.record("status", "replayed_stopped");
126
127            return Err(DurableError::Invocation {
128                message: "Invoke was stopped externally".to_string(),
129                termination_reason: TerminationReason::InvocationError,
130            });
131        }
132
133        // Handle other terminal states
134        if checkpoint_result.is_terminal() {
135            state.track_replay(&op_id.operation_id).await;
136            span.record("status", "replayed_terminal");
137
138            let status = checkpoint_result.status().unwrap();
139            return Err(DurableError::Invocation {
140                message: format!("Invoke was {}", status),
141                termination_reason: TerminationReason::InvocationError,
142            });
143        }
144    }
145
146    // Serialize the payload
147    let payload_serdes = JsonSerDes::<P>::new();
148    let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
149    let payload_json = payload_serdes
150        .serialize(&payload, &serdes_ctx)
151        .map_err(|e| DurableError::SerDes {
152            message: format!("Failed to serialize invoke payload: {}", e),
153        })?;
154
155    // Checkpoint the invocation start (Requirement 7.4)
156    let start_update = create_invoke_start_update(op_id, function_name, &payload_json, config);
157    state.create_checkpoint(start_update, true).await?;
158
159    logger.debug(&format!("Invoking function: {}", function_name), &log_info);
160
161    // For now, we simulate the invoke by suspending
162    // In a real implementation, this would call the Lambda service
163    // and the result would be delivered via the durable execution service
164
165    // The actual invocation is handled by the Lambda durable execution service
166    // We suspend here and wait for the result to be checkpointed
167    span.record("status", "suspended");
168    Err(DurableError::Suspend {
169        scheduled_timestamp: None,
170    })
171}
172
173/// Creates a Start operation update for invoke.
174fn create_invoke_start_update<P, R>(
175    op_id: &OperationIdentifier,
176    function_name: &str,
177    payload_json: &str,
178    config: &InvokeConfig<P, R>,
179) -> OperationUpdate {
180    let mut update = OperationUpdate::start(&op_id.operation_id, OperationType::Invoke);
181
182    // Store the payload in the result field
183    update.result = Some(payload_json.to_string());
184
185    // Set ChainedInvokeOptions with function name and optional tenant_id (Requirement 7.6)
186    update = update.with_chained_invoke_options(function_name, config.tenant_id.clone());
187
188    op_id.apply_to(update)
189}
190
191/// Creates a Succeed operation update for invoke.
192#[allow(dead_code)]
193fn create_invoke_succeed_update(
194    op_id: &OperationIdentifier,
195    result: Option<String>,
196) -> OperationUpdate {
197    op_id.apply_to(OperationUpdate::succeed(
198        &op_id.operation_id,
199        OperationType::Invoke,
200        result,
201    ))
202}
203
204/// Creates a Fail operation update for invoke.
205#[allow(dead_code)]
206fn create_invoke_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
207    op_id.apply_to(OperationUpdate::fail(
208        &op_id.operation_id,
209        OperationType::Invoke,
210        error,
211    ))
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
218    use crate::context::TracingLogger;
219    use crate::duration::Duration;
220    use crate::lambda::InitialExecutionState;
221    use crate::operation::{Operation, OperationStatus};
222
223    fn create_mock_client() -> SharedDurableServiceClient {
224        Arc::new(
225            MockDurableServiceClient::new()
226                .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
227        )
228    }
229
230    fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
231        Arc::new(ExecutionState::new(
232            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
233            "initial-token",
234            InitialExecutionState::new(),
235            client,
236        ))
237    }
238
239    fn create_test_op_id() -> OperationIdentifier {
240        OperationIdentifier::new("test-invoke-123", None, Some("test-invoke".to_string()))
241    }
242
243    fn create_test_logger() -> Arc<dyn Logger> {
244        Arc::new(TracingLogger)
245    }
246
247    fn create_test_config() -> InvokeConfig<String, String> {
248        let mut config = InvokeConfig::default();
249        config.timeout = Duration::from_minutes(5);
250        config
251    }
252
253    #[tokio::test]
254    async fn test_invoke_handler_suspends_on_new_invoke() {
255        let client = create_mock_client();
256        let state = create_test_state(client);
257        let op_id = create_test_op_id();
258        let config = create_test_config();
259        let logger = create_test_logger();
260
261        let result: Result<String, DurableError> = invoke_handler(
262            "target-function",
263            "test-payload".to_string(),
264            &state,
265            &op_id,
266            &config,
267            &logger,
268        )
269        .await;
270
271        // Should suspend since invoke is async
272        assert!(result.is_err());
273        match result.unwrap_err() {
274            DurableError::Suspend { .. } => {}
275            _ => panic!("Expected Suspend error"),
276        }
277    }
278
279    #[tokio::test]
280    async fn test_invoke_handler_replay_success() {
281        let client = Arc::new(MockDurableServiceClient::new());
282
283        // Create state with a pre-existing succeeded invoke operation
284        let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
285        op.status = OperationStatus::Succeeded;
286        op.result = Some(r#""invoke_result""#.to_string());
287
288        let initial_state = InitialExecutionState::with_operations(vec![op]);
289        let state = Arc::new(ExecutionState::new(
290            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
291            "initial-token",
292            initial_state,
293            client,
294        ));
295
296        let op_id = create_test_op_id();
297        let config = create_test_config();
298        let logger = create_test_logger();
299
300        let result: Result<String, DurableError> = invoke_handler(
301            "target-function",
302            "test-payload".to_string(),
303            &state,
304            &op_id,
305            &config,
306            &logger,
307        )
308        .await;
309
310        assert!(result.is_ok());
311        assert_eq!(result.unwrap(), "invoke_result");
312    }
313
314    #[tokio::test]
315    async fn test_invoke_handler_replay_failure() {
316        let client = Arc::new(MockDurableServiceClient::new());
317
318        // Create state with a pre-existing failed invoke operation
319        let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
320        op.status = OperationStatus::Failed;
321        op.error = Some(ErrorObject::new("InvokeError", "Target function failed"));
322
323        let initial_state = InitialExecutionState::with_operations(vec![op]);
324        let state = Arc::new(ExecutionState::new(
325            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
326            "initial-token",
327            initial_state,
328            client,
329        ));
330
331        let op_id = create_test_op_id();
332        let config = create_test_config();
333        let logger = create_test_logger();
334
335        let result: Result<String, DurableError> = invoke_handler(
336            "target-function",
337            "test-payload".to_string(),
338            &state,
339            &op_id,
340            &config,
341            &logger,
342        )
343        .await;
344
345        assert!(result.is_err());
346        match result.unwrap_err() {
347            DurableError::Invocation { message, .. } => {
348                assert!(message.contains("Target function failed"));
349            }
350            _ => panic!("Expected Invocation error"),
351        }
352    }
353
354    #[tokio::test]
355    async fn test_invoke_handler_non_deterministic_detection() {
356        let client = Arc::new(MockDurableServiceClient::new());
357
358        // Create state with a Step operation at the same ID (wrong type)
359        let mut op = Operation::new("test-invoke-123", OperationType::Step);
360        op.status = OperationStatus::Succeeded;
361
362        let initial_state = InitialExecutionState::with_operations(vec![op]);
363        let state = Arc::new(ExecutionState::new(
364            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
365            "initial-token",
366            initial_state,
367            client,
368        ));
369
370        let op_id = create_test_op_id();
371        let config = create_test_config();
372        let logger = create_test_logger();
373
374        let result: Result<String, DurableError> = invoke_handler(
375            "target-function",
376            "test-payload".to_string(),
377            &state,
378            &op_id,
379            &config,
380            &logger,
381        )
382        .await;
383
384        assert!(result.is_err());
385        match result.unwrap_err() {
386            DurableError::NonDeterministic { operation_id, .. } => {
387                assert_eq!(operation_id, Some("test-invoke-123".to_string()));
388            }
389            _ => panic!("Expected NonDeterministic error"),
390        }
391    }
392
393    #[test]
394    fn test_create_invoke_start_update() {
395        let op_id = OperationIdentifier::new(
396            "op-123",
397            Some("parent-456".to_string()),
398            Some("my-invoke".to_string()),
399        );
400        let mut config: InvokeConfig<String, String> = InvokeConfig::default();
401        config.timeout = Duration::from_minutes(5);
402        config.tenant_id = Some("tenant-123".to_string());
403        let update =
404            create_invoke_start_update(&op_id, "target-function", r#"{"key":"value"}"#, &config);
405
406        assert_eq!(update.operation_id, "op-123");
407        assert_eq!(update.operation_type, OperationType::Invoke);
408        assert!(update.result.is_some());
409        assert_eq!(update.parent_id, Some("parent-456".to_string()));
410        assert_eq!(update.name, Some("my-invoke".to_string()));
411
412        // Verify ChainedInvokeOptions are set correctly (Requirement 7.6)
413        assert!(update.chained_invoke_options.is_some());
414        let invoke_options = update.chained_invoke_options.unwrap();
415        assert_eq!(invoke_options.function_name, "target-function");
416        assert_eq!(invoke_options.tenant_id, Some("tenant-123".to_string()));
417    }
418
419    #[test]
420    fn test_create_invoke_start_update_without_tenant_id() {
421        let op_id = OperationIdentifier::new("op-123", None, None);
422        let config: InvokeConfig<String, String> = InvokeConfig::default();
423        let update =
424            create_invoke_start_update(&op_id, "target-function", r#"{"key":"value"}"#, &config);
425
426        assert!(update.chained_invoke_options.is_some());
427        let invoke_options = update.chained_invoke_options.unwrap();
428        assert_eq!(invoke_options.function_name, "target-function");
429        assert!(invoke_options.tenant_id.is_none());
430    }
431
432    #[tokio::test]
433    async fn test_invoke_handler_replay_stopped() {
434        let client = Arc::new(MockDurableServiceClient::new());
435
436        // Create state with a pre-existing stopped invoke operation (Requirement 7.7)
437        let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
438        op.status = OperationStatus::Stopped;
439
440        let initial_state = InitialExecutionState::with_operations(vec![op]);
441        let state = Arc::new(ExecutionState::new(
442            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
443            "initial-token",
444            initial_state,
445            client,
446        ));
447
448        let op_id = create_test_op_id();
449        let config = create_test_config();
450        let logger = create_test_logger();
451
452        let result: Result<String, DurableError> = invoke_handler(
453            "target-function",
454            "test-payload".to_string(),
455            &state,
456            &op_id,
457            &config,
458            &logger,
459        )
460        .await;
461
462        assert!(result.is_err());
463        match result.unwrap_err() {
464            DurableError::Invocation { message, .. } => {
465                assert!(message.contains("stopped externally"));
466            }
467            e => panic!("Expected Invocation error, got {:?}", e),
468        }
469    }
470
471    #[test]
472    fn test_create_invoke_succeed_update() {
473        let op_id = OperationIdentifier::new("op-123", None, None);
474        let update = create_invoke_succeed_update(&op_id, Some("result".to_string()));
475
476        assert_eq!(update.operation_id, "op-123");
477        assert_eq!(update.operation_type, OperationType::Invoke);
478        assert_eq!(update.result, Some("result".to_string()));
479    }
480
481    #[test]
482    fn test_create_invoke_fail_update() {
483        let op_id = OperationIdentifier::new("op-123", None, None);
484        let error = ErrorObject::new("InvokeError", "test message");
485        let update = create_invoke_fail_update(&op_id, error);
486
487        assert_eq!(update.operation_id, "op-123");
488        assert_eq!(update.operation_type, OperationType::Invoke);
489        assert!(update.error.is_some());
490        assert_eq!(update.error.unwrap().error_type, "InvokeError");
491    }
492
493    // Gap Tests for Invoke Handler (Task 10)
494    // Requirements: 12.1, 12.2, 12.3
495
496    /// Test for TIMED_OUT status handling (Requirement 12.1)
497    /// WHEN an invoke operation times out, THE Test_Suite SHALL verify TIMED_OUT status is handled correctly
498    #[tokio::test]
499    async fn test_invoke_handler_replay_timed_out() {
500        let client = Arc::new(MockDurableServiceClient::new());
501
502        // Create state with a pre-existing timed out invoke operation
503        let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
504        op.status = OperationStatus::TimedOut;
505
506        let initial_state = InitialExecutionState::with_operations(vec![op]);
507        let state = Arc::new(ExecutionState::new(
508            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
509            "initial-token",
510            initial_state,
511            client,
512        ));
513
514        let op_id = create_test_op_id();
515        let config = create_test_config();
516        let logger = create_test_logger();
517
518        let result: Result<String, DurableError> = invoke_handler(
519            "target-function",
520            "test-payload".to_string(),
521            &state,
522            &op_id,
523            &config,
524            &logger,
525        )
526        .await;
527
528        // Should return an Invocation error indicating timeout
529        assert!(result.is_err());
530        match result.unwrap_err() {
531            DurableError::Invocation { message, .. } => {
532                assert!(
533                    message.contains("TimedOut"),
534                    "Expected message to contain 'TimedOut', got: {}",
535                    message
536                );
537            }
538            e => panic!("Expected Invocation error, got {:?}", e),
539        }
540    }
541
542    /// Test for STOPPED status handling (Requirement 12.2)
543    /// WHEN an invoke operation is stopped externally, THE Test_Suite SHALL verify STOPPED status is handled correctly
544    /// Note: This test validates the explicit STOPPED handling path (already exists but this confirms the behavior)
545    #[tokio::test]
546    async fn test_invoke_handler_replay_stopped_returns_invocation_error() {
547        let client = Arc::new(MockDurableServiceClient::new());
548
549        // Create state with a pre-existing stopped invoke operation
550        let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
551        op.status = OperationStatus::Stopped;
552
553        let initial_state = InitialExecutionState::with_operations(vec![op]);
554        let state = Arc::new(ExecutionState::new(
555            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
556            "initial-token",
557            initial_state,
558            client,
559        ));
560
561        let op_id = create_test_op_id();
562        let config = create_test_config();
563        let logger = create_test_logger();
564
565        let result: Result<String, DurableError> = invoke_handler(
566            "target-function",
567            "test-payload".to_string(),
568            &state,
569            &op_id,
570            &config,
571            &logger,
572        )
573        .await;
574
575        // Should return an Invocation error with specific "stopped externally" message
576        assert!(result.is_err());
577        match result.unwrap_err() {
578            DurableError::Invocation {
579                message,
580                termination_reason,
581            } => {
582                assert!(
583                    message.contains("stopped externally"),
584                    "Expected message to contain 'stopped externally', got: {}",
585                    message
586                );
587                assert_eq!(termination_reason, TerminationReason::InvocationError);
588            }
589            e => panic!("Expected Invocation error, got {:?}", e),
590        }
591    }
592
593    /// Test for replaying STARTED invoke (Requirement 12.3)
594    /// WHEN replaying a STARTED invoke, THE Test_Suite SHALL verify execution suspends
595    #[tokio::test]
596    async fn test_invoke_handler_replay_started_suspends() {
597        let client = Arc::new(
598            MockDurableServiceClient::new()
599                .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
600        );
601
602        // Create state with a pre-existing STARTED invoke operation (in-progress)
603        let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
604        op.status = OperationStatus::Started;
605
606        let initial_state = InitialExecutionState::with_operations(vec![op]);
607        let state = Arc::new(ExecutionState::new(
608            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
609            "initial-token",
610            initial_state,
611            client,
612        ));
613
614        let op_id = create_test_op_id();
615        let config = create_test_config();
616        let logger = create_test_logger();
617
618        let result: Result<String, DurableError> = invoke_handler(
619            "target-function",
620            "test-payload".to_string(),
621            &state,
622            &op_id,
623            &config,
624            &logger,
625        )
626        .await;
627
628        // Should suspend since the invoke is still in progress (STARTED status)
629        // The handler should recognize this is a replay of an in-progress invoke and suspend
630        assert!(result.is_err());
631        match result.unwrap_err() {
632            DurableError::Suspend { .. } => {
633                // Expected - the invoke is in progress, so we suspend waiting for completion
634            }
635            e => panic!("Expected Suspend error for in-progress invoke, got {:?}", e),
636        }
637    }
638}