1use std::sync::Arc;
7
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::config::InvokeConfig;
11use crate::context::{create_operation_span, LogInfo, Logger, OperationIdentifier};
12use crate::error::{DurableError, ErrorObject, TerminationReason};
13use crate::operation::{OperationType, OperationUpdate};
14use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
15use crate::state::ExecutionState;
16
17pub async fn invoke_handler<P, R>(
38 function_name: &str,
39 payload: P,
40 state: &Arc<ExecutionState>,
41 op_id: &OperationIdentifier,
42 config: &InvokeConfig<P, R>,
43 logger: &Arc<dyn Logger>,
44) -> Result<R, DurableError>
45where
46 P: Serialize + DeserializeOwned + Send,
47 R: Serialize + DeserializeOwned + Send,
48{
49 let span = create_operation_span("invoke", op_id, state.durable_execution_arn());
52 let _guard = span.enter();
53
54 let mut log_info =
55 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
56 if let Some(ref parent_id) = op_id.parent_id {
57 log_info = log_info.with_parent_id(parent_id);
58 }
59
60 logger.debug(
61 &format!("Starting invoke operation: {} -> {}", op_id, function_name),
62 &log_info,
63 );
64
65 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
67
68 if checkpoint_result.is_existent() {
69 if let Some(op_type) = checkpoint_result.operation_type() {
71 if op_type != OperationType::Invoke {
72 span.record("status", "non_deterministic");
73 return Err(DurableError::NonDeterministic {
74 message: format!(
75 "Expected Invoke operation but found {:?} at operation_id {}",
76 op_type, op_id.operation_id
77 ),
78 operation_id: Some(op_id.operation_id.clone()),
79 });
80 }
81 }
82
83 if checkpoint_result.is_succeeded() {
85 logger.debug(&format!("Replaying succeeded invoke: {}", op_id), &log_info);
86
87 if let Some(result_str) = checkpoint_result.result() {
88 let serdes = JsonSerDes::<R>::new();
89 let serdes_ctx =
90 SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
91 let result = serdes.deserialize(result_str, &serdes_ctx).map_err(|e| {
92 DurableError::SerDes {
93 message: format!("Failed to deserialize invoke result: {}", e),
94 }
95 })?;
96
97 state.track_replay(&op_id.operation_id).await;
98 span.record("status", "replayed_succeeded");
99 return Ok(result);
100 }
101 }
102
103 if checkpoint_result.is_failed() {
105 logger.debug(&format!("Replaying failed invoke: {}", op_id), &log_info);
106
107 state.track_replay(&op_id.operation_id).await;
108 span.record("status", "replayed_failed");
109
110 if let Some(error) = checkpoint_result.error() {
111 return Err(DurableError::Invocation {
112 message: error.error_message.clone(),
113 termination_reason: TerminationReason::InvocationError,
114 });
115 } else {
116 return Err(DurableError::invocation("Invoke failed with unknown error"));
117 }
118 }
119
120 if checkpoint_result.is_stopped() {
122 logger.debug(&format!("Replaying stopped invoke: {}", op_id), &log_info);
123
124 state.track_replay(&op_id.operation_id).await;
125 span.record("status", "replayed_stopped");
126
127 return Err(DurableError::Invocation {
128 message: "Invoke was stopped externally".to_string(),
129 termination_reason: TerminationReason::InvocationError,
130 });
131 }
132
133 if checkpoint_result.is_terminal() {
135 state.track_replay(&op_id.operation_id).await;
136 span.record("status", "replayed_terminal");
137
138 let status = checkpoint_result.status().unwrap();
139 return Err(DurableError::Invocation {
140 message: format!("Invoke was {}", status),
141 termination_reason: TerminationReason::InvocationError,
142 });
143 }
144 }
145
146 let payload_serdes = JsonSerDes::<P>::new();
148 let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
149 let payload_json = payload_serdes
150 .serialize(&payload, &serdes_ctx)
151 .map_err(|e| DurableError::SerDes {
152 message: format!("Failed to serialize invoke payload: {}", e),
153 })?;
154
155 let start_update = create_invoke_start_update(op_id, function_name, &payload_json, config);
157 state.create_checkpoint(start_update, true).await?;
158
159 logger.debug(&format!("Invoking function: {}", function_name), &log_info);
160
161 span.record("status", "suspended");
168 Err(DurableError::Suspend {
169 scheduled_timestamp: None,
170 })
171}
172
173fn create_invoke_start_update<P, R>(
175 op_id: &OperationIdentifier,
176 function_name: &str,
177 payload_json: &str,
178 config: &InvokeConfig<P, R>,
179) -> OperationUpdate {
180 let mut update = OperationUpdate::start(&op_id.operation_id, OperationType::Invoke);
181
182 update.result = Some(payload_json.to_string());
184
185 update = update.with_chained_invoke_options(function_name, config.tenant_id.clone());
187
188 op_id.apply_to(update)
189}
190
191#[allow(dead_code)]
193fn create_invoke_succeed_update(
194 op_id: &OperationIdentifier,
195 result: Option<String>,
196) -> OperationUpdate {
197 op_id.apply_to(OperationUpdate::succeed(
198 &op_id.operation_id,
199 OperationType::Invoke,
200 result,
201 ))
202}
203
204#[allow(dead_code)]
206fn create_invoke_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
207 op_id.apply_to(OperationUpdate::fail(
208 &op_id.operation_id,
209 OperationType::Invoke,
210 error,
211 ))
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
218 use crate::context::TracingLogger;
219 use crate::duration::Duration;
220 use crate::lambda::InitialExecutionState;
221 use crate::operation::{Operation, OperationStatus};
222
223 fn create_mock_client() -> SharedDurableServiceClient {
224 Arc::new(
225 MockDurableServiceClient::new()
226 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
227 )
228 }
229
230 fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
231 Arc::new(ExecutionState::new(
232 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
233 "initial-token",
234 InitialExecutionState::new(),
235 client,
236 ))
237 }
238
239 fn create_test_op_id() -> OperationIdentifier {
240 OperationIdentifier::new("test-invoke-123", None, Some("test-invoke".to_string()))
241 }
242
243 fn create_test_logger() -> Arc<dyn Logger> {
244 Arc::new(TracingLogger)
245 }
246
247 fn create_test_config() -> InvokeConfig<String, String> {
248 let mut config = InvokeConfig::default();
249 config.timeout = Duration::from_minutes(5);
250 config
251 }
252
253 #[tokio::test]
254 async fn test_invoke_handler_suspends_on_new_invoke() {
255 let client = create_mock_client();
256 let state = create_test_state(client);
257 let op_id = create_test_op_id();
258 let config = create_test_config();
259 let logger = create_test_logger();
260
261 let result: Result<String, DurableError> = invoke_handler(
262 "target-function",
263 "test-payload".to_string(),
264 &state,
265 &op_id,
266 &config,
267 &logger,
268 )
269 .await;
270
271 assert!(result.is_err());
273 match result.unwrap_err() {
274 DurableError::Suspend { .. } => {}
275 _ => panic!("Expected Suspend error"),
276 }
277 }
278
279 #[tokio::test]
280 async fn test_invoke_handler_replay_success() {
281 let client = Arc::new(MockDurableServiceClient::new());
282
283 let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
285 op.status = OperationStatus::Succeeded;
286 op.result = Some(r#""invoke_result""#.to_string());
287
288 let initial_state = InitialExecutionState::with_operations(vec![op]);
289 let state = Arc::new(ExecutionState::new(
290 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
291 "initial-token",
292 initial_state,
293 client,
294 ));
295
296 let op_id = create_test_op_id();
297 let config = create_test_config();
298 let logger = create_test_logger();
299
300 let result: Result<String, DurableError> = invoke_handler(
301 "target-function",
302 "test-payload".to_string(),
303 &state,
304 &op_id,
305 &config,
306 &logger,
307 )
308 .await;
309
310 assert!(result.is_ok());
311 assert_eq!(result.unwrap(), "invoke_result");
312 }
313
314 #[tokio::test]
315 async fn test_invoke_handler_replay_failure() {
316 let client = Arc::new(MockDurableServiceClient::new());
317
318 let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
320 op.status = OperationStatus::Failed;
321 op.error = Some(ErrorObject::new("InvokeError", "Target function failed"));
322
323 let initial_state = InitialExecutionState::with_operations(vec![op]);
324 let state = Arc::new(ExecutionState::new(
325 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
326 "initial-token",
327 initial_state,
328 client,
329 ));
330
331 let op_id = create_test_op_id();
332 let config = create_test_config();
333 let logger = create_test_logger();
334
335 let result: Result<String, DurableError> = invoke_handler(
336 "target-function",
337 "test-payload".to_string(),
338 &state,
339 &op_id,
340 &config,
341 &logger,
342 )
343 .await;
344
345 assert!(result.is_err());
346 match result.unwrap_err() {
347 DurableError::Invocation { message, .. } => {
348 assert!(message.contains("Target function failed"));
349 }
350 _ => panic!("Expected Invocation error"),
351 }
352 }
353
354 #[tokio::test]
355 async fn test_invoke_handler_non_deterministic_detection() {
356 let client = Arc::new(MockDurableServiceClient::new());
357
358 let mut op = Operation::new("test-invoke-123", OperationType::Step);
360 op.status = OperationStatus::Succeeded;
361
362 let initial_state = InitialExecutionState::with_operations(vec![op]);
363 let state = Arc::new(ExecutionState::new(
364 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
365 "initial-token",
366 initial_state,
367 client,
368 ));
369
370 let op_id = create_test_op_id();
371 let config = create_test_config();
372 let logger = create_test_logger();
373
374 let result: Result<String, DurableError> = invoke_handler(
375 "target-function",
376 "test-payload".to_string(),
377 &state,
378 &op_id,
379 &config,
380 &logger,
381 )
382 .await;
383
384 assert!(result.is_err());
385 match result.unwrap_err() {
386 DurableError::NonDeterministic { operation_id, .. } => {
387 assert_eq!(operation_id, Some("test-invoke-123".to_string()));
388 }
389 _ => panic!("Expected NonDeterministic error"),
390 }
391 }
392
393 #[test]
394 fn test_create_invoke_start_update() {
395 let op_id = OperationIdentifier::new(
396 "op-123",
397 Some("parent-456".to_string()),
398 Some("my-invoke".to_string()),
399 );
400 let mut config: InvokeConfig<String, String> = InvokeConfig::default();
401 config.timeout = Duration::from_minutes(5);
402 config.tenant_id = Some("tenant-123".to_string());
403 let update =
404 create_invoke_start_update(&op_id, "target-function", r#"{"key":"value"}"#, &config);
405
406 assert_eq!(update.operation_id, "op-123");
407 assert_eq!(update.operation_type, OperationType::Invoke);
408 assert!(update.result.is_some());
409 assert_eq!(update.parent_id, Some("parent-456".to_string()));
410 assert_eq!(update.name, Some("my-invoke".to_string()));
411
412 assert!(update.chained_invoke_options.is_some());
414 let invoke_options = update.chained_invoke_options.unwrap();
415 assert_eq!(invoke_options.function_name, "target-function");
416 assert_eq!(invoke_options.tenant_id, Some("tenant-123".to_string()));
417 }
418
419 #[test]
420 fn test_create_invoke_start_update_without_tenant_id() {
421 let op_id = OperationIdentifier::new("op-123", None, None);
422 let config: InvokeConfig<String, String> = InvokeConfig::default();
423 let update =
424 create_invoke_start_update(&op_id, "target-function", r#"{"key":"value"}"#, &config);
425
426 assert!(update.chained_invoke_options.is_some());
427 let invoke_options = update.chained_invoke_options.unwrap();
428 assert_eq!(invoke_options.function_name, "target-function");
429 assert!(invoke_options.tenant_id.is_none());
430 }
431
432 #[tokio::test]
433 async fn test_invoke_handler_replay_stopped() {
434 let client = Arc::new(MockDurableServiceClient::new());
435
436 let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
438 op.status = OperationStatus::Stopped;
439
440 let initial_state = InitialExecutionState::with_operations(vec![op]);
441 let state = Arc::new(ExecutionState::new(
442 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
443 "initial-token",
444 initial_state,
445 client,
446 ));
447
448 let op_id = create_test_op_id();
449 let config = create_test_config();
450 let logger = create_test_logger();
451
452 let result: Result<String, DurableError> = invoke_handler(
453 "target-function",
454 "test-payload".to_string(),
455 &state,
456 &op_id,
457 &config,
458 &logger,
459 )
460 .await;
461
462 assert!(result.is_err());
463 match result.unwrap_err() {
464 DurableError::Invocation { message, .. } => {
465 assert!(message.contains("stopped externally"));
466 }
467 e => panic!("Expected Invocation error, got {:?}", e),
468 }
469 }
470
471 #[test]
472 fn test_create_invoke_succeed_update() {
473 let op_id = OperationIdentifier::new("op-123", None, None);
474 let update = create_invoke_succeed_update(&op_id, Some("result".to_string()));
475
476 assert_eq!(update.operation_id, "op-123");
477 assert_eq!(update.operation_type, OperationType::Invoke);
478 assert_eq!(update.result, Some("result".to_string()));
479 }
480
481 #[test]
482 fn test_create_invoke_fail_update() {
483 let op_id = OperationIdentifier::new("op-123", None, None);
484 let error = ErrorObject::new("InvokeError", "test message");
485 let update = create_invoke_fail_update(&op_id, error);
486
487 assert_eq!(update.operation_id, "op-123");
488 assert_eq!(update.operation_type, OperationType::Invoke);
489 assert!(update.error.is_some());
490 assert_eq!(update.error.unwrap().error_type, "InvokeError");
491 }
492
493 #[tokio::test]
499 async fn test_invoke_handler_replay_timed_out() {
500 let client = Arc::new(MockDurableServiceClient::new());
501
502 let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
504 op.status = OperationStatus::TimedOut;
505
506 let initial_state = InitialExecutionState::with_operations(vec![op]);
507 let state = Arc::new(ExecutionState::new(
508 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
509 "initial-token",
510 initial_state,
511 client,
512 ));
513
514 let op_id = create_test_op_id();
515 let config = create_test_config();
516 let logger = create_test_logger();
517
518 let result: Result<String, DurableError> = invoke_handler(
519 "target-function",
520 "test-payload".to_string(),
521 &state,
522 &op_id,
523 &config,
524 &logger,
525 )
526 .await;
527
528 assert!(result.is_err());
530 match result.unwrap_err() {
531 DurableError::Invocation { message, .. } => {
532 assert!(
533 message.contains("TimedOut"),
534 "Expected message to contain 'TimedOut', got: {}",
535 message
536 );
537 }
538 e => panic!("Expected Invocation error, got {:?}", e),
539 }
540 }
541
542 #[tokio::test]
546 async fn test_invoke_handler_replay_stopped_returns_invocation_error() {
547 let client = Arc::new(MockDurableServiceClient::new());
548
549 let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
551 op.status = OperationStatus::Stopped;
552
553 let initial_state = InitialExecutionState::with_operations(vec![op]);
554 let state = Arc::new(ExecutionState::new(
555 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
556 "initial-token",
557 initial_state,
558 client,
559 ));
560
561 let op_id = create_test_op_id();
562 let config = create_test_config();
563 let logger = create_test_logger();
564
565 let result: Result<String, DurableError> = invoke_handler(
566 "target-function",
567 "test-payload".to_string(),
568 &state,
569 &op_id,
570 &config,
571 &logger,
572 )
573 .await;
574
575 assert!(result.is_err());
577 match result.unwrap_err() {
578 DurableError::Invocation {
579 message,
580 termination_reason,
581 } => {
582 assert!(
583 message.contains("stopped externally"),
584 "Expected message to contain 'stopped externally', got: {}",
585 message
586 );
587 assert_eq!(termination_reason, TerminationReason::InvocationError);
588 }
589 e => panic!("Expected Invocation error, got {:?}", e),
590 }
591 }
592
593 #[tokio::test]
596 async fn test_invoke_handler_replay_started_suspends() {
597 let client = Arc::new(
598 MockDurableServiceClient::new()
599 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
600 );
601
602 let mut op = Operation::new("test-invoke-123", OperationType::Invoke);
604 op.status = OperationStatus::Started;
605
606 let initial_state = InitialExecutionState::with_operations(vec![op]);
607 let state = Arc::new(ExecutionState::new(
608 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
609 "initial-token",
610 initial_state,
611 client,
612 ));
613
614 let op_id = create_test_op_id();
615 let config = create_test_config();
616 let logger = create_test_logger();
617
618 let result: Result<String, DurableError> = invoke_handler(
619 "target-function",
620 "test-payload".to_string(),
621 &state,
622 &op_id,
623 &config,
624 &logger,
625 )
626 .await;
627
628 assert!(result.is_err());
631 match result.unwrap_err() {
632 DurableError::Suspend { .. } => {
633 }
635 e => panic!("Expected Suspend error for in-progress invoke, got {:?}", e),
636 }
637 }
638}