1use std::marker::PhantomData;
7use std::sync::Arc;
8
9use serde::de::DeserializeOwned;
10
11use crate::config::CallbackConfig;
12use crate::context::{create_operation_span, LogInfo, Logger, OperationIdentifier};
13use crate::error::DurableError;
14use crate::operation::{OperationType, OperationUpdate};
15use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
16use crate::state::ExecutionState;
17use crate::types::CallbackId;
18
19pub struct Callback<T> {
24 pub callback_id: String,
26 operation_id: String,
28 state: Arc<ExecutionState>,
30 logger: Arc<dyn Logger>,
32 _marker: PhantomData<fn() -> T>,
36}
37
38static_assertions::assert_impl_all!(Callback<*mut ()>: Send, Sync);
41
42impl<T> std::fmt::Debug for Callback<T> {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("Callback")
45 .field("callback_id", &self.callback_id)
46 .field("operation_id", &self.operation_id)
47 .finish_non_exhaustive()
48 }
49}
50
51impl<T> Callback<T>
52where
53 T: serde::Serialize + DeserializeOwned,
54{
55 pub(crate) fn new(
57 callback_id: impl Into<String>,
58 operation_id: impl Into<String>,
59 state: Arc<ExecutionState>,
60 logger: Arc<dyn Logger>,
61 ) -> Self {
62 Self {
63 callback_id: callback_id.into(),
64 operation_id: operation_id.into(),
65 state,
66 logger,
67 _marker: PhantomData,
68 }
69 }
70
71 pub fn id(&self) -> &str {
73 &self.callback_id
74 }
75
76 #[inline]
78 pub fn id_typed(&self) -> CallbackId {
79 CallbackId::from(self.callback_id.clone())
80 }
81
82 pub async fn result(&self) -> Result<T, DurableError> {
92 let log_info = LogInfo::new(self.state.durable_execution_arn())
93 .with_operation_id(&self.operation_id)
94 .with_extra("callback_id", &self.callback_id);
95
96 self.logger.debug(
97 &format!("Checking callback result: {}", self.callback_id),
98 &log_info,
99 );
100
101 let checkpoint_result = self.state.get_checkpoint_result(&self.operation_id).await;
103
104 if !checkpoint_result.is_existent() {
105 return Err(DurableError::Callback {
107 message: "Callback not found".to_string(),
108 callback_id: Some(self.callback_id.clone()),
109 });
110 }
111
112 if checkpoint_result.is_succeeded() {
114 if let Some(result_str) = checkpoint_result.result() {
115 let serdes = JsonSerDes::<T>::new();
116 let serdes_ctx =
117 SerDesContext::new(&self.operation_id, self.state.durable_execution_arn());
118
119 let result = serdes.deserialize(result_str, &serdes_ctx).map_err(|e| {
120 DurableError::SerDes {
121 message: format!("Failed to deserialize callback result: {}", e),
122 }
123 })?;
124
125 self.logger
126 .debug("Callback completed successfully", &log_info);
127 self.state.track_replay(&self.operation_id).await;
128
129 return Ok(result);
130 }
131 }
132
133 if checkpoint_result.is_failed() {
135 self.state.track_replay(&self.operation_id).await;
136
137 if let Some(error) = checkpoint_result.error() {
138 return Err(DurableError::Callback {
139 message: error.error_message.clone(),
140 callback_id: Some(self.callback_id.clone()),
141 });
142 } else {
143 return Err(DurableError::Callback {
144 message: "Callback failed with unknown error".to_string(),
145 callback_id: Some(self.callback_id.clone()),
146 });
147 }
148 }
149
150 if checkpoint_result.is_timed_out() {
152 self.state.track_replay(&self.operation_id).await;
153
154 return Err(DurableError::Callback {
155 message: "Callback timed out".to_string(),
156 callback_id: Some(self.callback_id.clone()),
157 });
158 }
159
160 self.logger
162 .debug("Callback pending, suspending execution", &log_info);
163
164 Err(DurableError::Suspend {
165 scheduled_timestamp: None,
166 })
167 }
168}
169
170pub async fn callback_handler<T>(
188 state: &Arc<ExecutionState>,
189 op_id: &OperationIdentifier,
190 config: &CallbackConfig,
191 logger: &Arc<dyn Logger>,
192) -> Result<Callback<T>, DurableError>
193where
194 T: serde::Serialize + DeserializeOwned,
195{
196 let span = create_operation_span("callback", op_id, state.durable_execution_arn());
199 let _guard = span.enter();
200
201 let mut log_info =
202 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
203 if let Some(ref parent_id) = op_id.parent_id {
204 log_info = log_info.with_parent_id(parent_id);
205 }
206
207 logger.debug(&format!("Creating callback: {}", op_id), &log_info);
208
209 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
211
212 if checkpoint_result.is_existent() {
213 if let Some(op_type) = checkpoint_result.operation_type() {
215 if op_type != OperationType::Callback {
216 span.record("status", "non_deterministic");
217 return Err(DurableError::NonDeterministic {
218 message: format!(
219 "Expected Callback operation but found {:?} at operation_id {}",
220 op_type, op_id.operation_id
221 ),
222 operation_id: Some(op_id.operation_id.clone()),
223 });
224 }
225 }
226
227 let callback_id = checkpoint_result.callback_id().unwrap_or_else(|| {
230 op_id
231 .name
232 .clone()
233 .unwrap_or_else(|| op_id.operation_id.clone())
234 });
235
236 logger.debug(
237 &format!("Returning existing callback: {}", callback_id),
238 &log_info,
239 );
240 span.record("status", "replayed");
241
242 return Ok(Callback::new(
243 callback_id,
244 &op_id.operation_id,
245 state.clone(),
246 logger.clone(),
247 ));
248 }
249
250 let start_update = create_callback_start_update(op_id, config);
252
253 let response = state.create_checkpoint_with_response(start_update).await?;
255
256 let callback_id = response
259 .new_execution_state
260 .as_ref()
261 .and_then(|new_state| new_state.find_operation(&op_id.operation_id))
262 .and_then(|op| op.callback_details.as_ref())
263 .and_then(|details| details.callback_id.clone())
264 .ok_or_else(|| {
265 span.record("status", "failed");
266 DurableError::Callback {
267 message: format!(
268 "Service did not return callback_id in checkpoint response for operation {}",
269 op_id.operation_id
270 ),
271 callback_id: None,
272 }
273 })?;
274
275 logger.debug(
276 &format!("Callback created with ID: {}", callback_id),
277 &log_info,
278 );
279 span.record("status", "created");
280
281 Ok(Callback::new(
282 callback_id,
283 &op_id.operation_id,
284 state.clone(),
285 logger.clone(),
286 ))
287}
288
289fn create_callback_start_update(
291 op_id: &OperationIdentifier,
292 config: &CallbackConfig,
293) -> OperationUpdate {
294 let mut update = OperationUpdate::start(&op_id.operation_id, OperationType::Callback);
295
296 update.callback_options = Some(crate::operation::CallbackOptions {
298 timeout_seconds: Some(config.timeout.to_seconds()),
299 heartbeat_timeout_seconds: Some(config.heartbeat_timeout.to_seconds()),
300 });
301
302 op_id.apply_to(update)
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::client::{
309 CheckpointResponse, MockDurableServiceClient, NewExecutionState, SharedDurableServiceClient,
310 };
311 use crate::context::TracingLogger;
312 use crate::duration::Duration;
313 use crate::error::ErrorObject;
314 use crate::lambda::InitialExecutionState;
315 use crate::operation::{CallbackDetails, Operation, OperationStatus};
316
317 fn create_mock_client() -> SharedDurableServiceClient {
318 Arc::new(
319 MockDurableServiceClient::new().with_checkpoint_response(Ok(CheckpointResponse {
320 checkpoint_token: "token-1".to_string(),
321 new_execution_state: Some(NewExecutionState {
322 operations: vec![{
323 let mut op = Operation::new("test-callback-123", OperationType::Callback);
324 op.callback_details = Some(CallbackDetails {
325 callback_id: Some("service-generated-callback-id".to_string()),
326 result: None,
327 error: None,
328 });
329 op
330 }],
331 next_marker: None,
332 }),
333 })),
334 )
335 }
336
337 fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
338 Arc::new(ExecutionState::new(
339 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
340 "initial-token",
341 InitialExecutionState::new(),
342 client,
343 ))
344 }
345
346 fn create_test_op_id() -> OperationIdentifier {
347 OperationIdentifier::new("test-callback-123", None, Some("test-callback".to_string()))
348 }
349
350 fn create_test_logger() -> Arc<dyn Logger> {
351 Arc::new(TracingLogger)
352 }
353
354 fn create_test_config() -> CallbackConfig {
355 CallbackConfig {
356 timeout: Duration::from_hours(24),
357 heartbeat_timeout: Duration::from_minutes(5),
358 ..Default::default()
359 }
360 }
361
362 #[tokio::test]
363 async fn test_callback_handler_creates_callback() {
364 let client = create_mock_client();
365 let state = create_test_state(client);
366 let op_id = create_test_op_id();
367 let config = create_test_config();
368 let logger = create_test_logger();
369
370 let result: Result<Callback<String>, DurableError> =
371 callback_handler(&state, &op_id, &config, &logger).await;
372
373 assert!(result.is_ok());
374 let callback = result.unwrap();
375 assert_eq!(callback.id(), "service-generated-callback-id");
376 }
377
378 #[tokio::test]
379 async fn test_callback_handler_error_when_no_callback_id_returned() {
380 let client = Arc::new(MockDurableServiceClient::new().with_checkpoint_response(Ok(
382 CheckpointResponse {
383 checkpoint_token: "token-1".to_string(),
384 new_execution_state: None, },
386 )));
387 let state = create_test_state(client);
388 let op_id = create_test_op_id();
389 let config = create_test_config();
390 let logger = create_test_logger();
391
392 let result: Result<Callback<String>, DurableError> =
393 callback_handler(&state, &op_id, &config, &logger).await;
394
395 assert!(result.is_err());
396 match result.unwrap_err() {
397 DurableError::Callback {
398 message,
399 callback_id,
400 } => {
401 assert!(message.contains("did not return callback_id"));
402 assert!(callback_id.is_none());
403 }
404 e => panic!("Expected Callback error, got {:?}", e),
405 }
406 }
407
408 #[tokio::test]
409 async fn test_callback_handler_error_when_operation_not_in_response() {
410 let client = Arc::new(MockDurableServiceClient::new().with_checkpoint_response(Ok(
412 CheckpointResponse {
413 checkpoint_token: "token-1".to_string(),
414 new_execution_state: Some(NewExecutionState {
415 operations: vec![{
416 let mut op =
418 Operation::new("different-operation-id", OperationType::Callback);
419 op.callback_details = Some(CallbackDetails {
420 callback_id: Some("some-callback-id".to_string()),
421 result: None,
422 error: None,
423 });
424 op
425 }],
426 next_marker: None,
427 }),
428 },
429 )));
430 let state = create_test_state(client);
431 let op_id = create_test_op_id(); let config = create_test_config();
433 let logger = create_test_logger();
434
435 let result: Result<Callback<String>, DurableError> =
436 callback_handler(&state, &op_id, &config, &logger).await;
437
438 assert!(result.is_err());
439 match result.unwrap_err() {
440 DurableError::Callback {
441 message,
442 callback_id,
443 } => {
444 assert!(message.contains("did not return callback_id"));
445 assert!(callback_id.is_none());
446 }
447 e => panic!("Expected Callback error, got {:?}", e),
448 }
449 }
450
451 #[tokio::test]
452 async fn test_callback_handler_error_when_callback_details_missing() {
453 let client = Arc::new(MockDurableServiceClient::new().with_checkpoint_response(Ok(
455 CheckpointResponse {
456 checkpoint_token: "token-1".to_string(),
457 new_execution_state: Some(NewExecutionState {
458 operations: vec![{
459 Operation::new("test-callback-123", OperationType::Callback)
461 }],
462 next_marker: None,
463 }),
464 },
465 )));
466 let state = create_test_state(client);
467 let op_id = create_test_op_id();
468 let config = create_test_config();
469 let logger = create_test_logger();
470
471 let result: Result<Callback<String>, DurableError> =
472 callback_handler(&state, &op_id, &config, &logger).await;
473
474 assert!(result.is_err());
475 match result.unwrap_err() {
476 DurableError::Callback {
477 message,
478 callback_id,
479 } => {
480 assert!(message.contains("did not return callback_id"));
481 assert!(callback_id.is_none());
482 }
483 e => panic!("Expected Callback error, got {:?}", e),
484 }
485 }
486
487 #[tokio::test]
488 async fn test_callback_handler_replay_existing() {
489 let client = Arc::new(MockDurableServiceClient::new());
490
491 let mut op = Operation::new("test-callback-123", OperationType::Callback);
493 op.status = OperationStatus::Started;
494 op.name = Some("test-callback".to_string());
495
496 let initial_state = InitialExecutionState::with_operations(vec![op]);
497 let state = Arc::new(ExecutionState::new(
498 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
499 "initial-token",
500 initial_state,
501 client,
502 ));
503
504 let op_id = create_test_op_id();
505 let config = create_test_config();
506 let logger = create_test_logger();
507
508 let result: Result<Callback<String>, DurableError> =
509 callback_handler(&state, &op_id, &config, &logger).await;
510
511 assert!(result.is_ok());
512 let callback = result.unwrap();
513 assert_eq!(callback.id(), "test-callback");
515 }
516
517 #[tokio::test]
518 async fn test_callback_handler_non_deterministic_detection() {
519 let client = Arc::new(MockDurableServiceClient::new());
520
521 let mut op = Operation::new("test-callback-123", OperationType::Step);
523 op.status = OperationStatus::Succeeded;
524
525 let initial_state = InitialExecutionState::with_operations(vec![op]);
526 let state = Arc::new(ExecutionState::new(
527 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
528 "initial-token",
529 initial_state,
530 client,
531 ));
532
533 let op_id = create_test_op_id();
534 let config = create_test_config();
535 let logger = create_test_logger();
536
537 let result: Result<Callback<String>, DurableError> =
538 callback_handler(&state, &op_id, &config, &logger).await;
539
540 assert!(result.is_err());
541 match result.unwrap_err() {
542 DurableError::NonDeterministic { operation_id, .. } => {
543 assert_eq!(operation_id, Some("test-callback-123".to_string()));
544 }
545 _ => panic!("Expected NonDeterministic error"),
546 }
547 }
548
549 #[tokio::test]
550 async fn test_callback_result_pending() {
551 let client = Arc::new(MockDurableServiceClient::new());
552
553 let mut op = Operation::new("test-callback-123", OperationType::Callback);
555 op.status = OperationStatus::Started;
556
557 let initial_state = InitialExecutionState::with_operations(vec![op]);
558 let state = Arc::new(ExecutionState::new(
559 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
560 "initial-token",
561 initial_state,
562 client,
563 ));
564
565 let logger = create_test_logger();
566 let callback: Callback<String> =
567 Callback::new("test-callback-123", "test-callback-123", state, logger);
568
569 let result = callback.result().await;
570
571 assert!(result.is_err());
573 match result.unwrap_err() {
574 DurableError::Suspend { .. } => {}
575 _ => panic!("Expected Suspend error"),
576 }
577 }
578
579 #[tokio::test]
580 async fn test_callback_result_succeeded() {
581 let client = Arc::new(MockDurableServiceClient::new());
582
583 let mut op = Operation::new("test-callback-123", OperationType::Callback);
585 op.status = OperationStatus::Succeeded;
586 op.result = Some(r#""callback_result""#.to_string());
587
588 let initial_state = InitialExecutionState::with_operations(vec![op]);
589 let state = Arc::new(ExecutionState::new(
590 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
591 "initial-token",
592 initial_state,
593 client,
594 ));
595
596 let logger = create_test_logger();
597 let callback: Callback<String> =
598 Callback::new("test-callback-123", "test-callback-123", state, logger);
599
600 let result = callback.result().await;
601
602 assert!(result.is_ok());
603 assert_eq!(result.unwrap(), "callback_result");
604 }
605
606 #[tokio::test]
607 async fn test_callback_result_failed() {
608 let client = Arc::new(MockDurableServiceClient::new());
609
610 let mut op = Operation::new("test-callback-123", OperationType::Callback);
612 op.status = OperationStatus::Failed;
613 op.error = Some(ErrorObject::new("CallbackError", "External system failed"));
614
615 let initial_state = InitialExecutionState::with_operations(vec![op]);
616 let state = Arc::new(ExecutionState::new(
617 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
618 "initial-token",
619 initial_state,
620 client,
621 ));
622
623 let logger = create_test_logger();
624 let callback: Callback<String> =
625 Callback::new("test-callback-123", "test-callback-123", state, logger);
626
627 let result = callback.result().await;
628
629 assert!(result.is_err());
630 match result.unwrap_err() {
631 DurableError::Callback {
632 message,
633 callback_id,
634 } => {
635 assert!(message.contains("External system failed"));
636 assert_eq!(callback_id, Some("test-callback-123".to_string()));
637 }
638 _ => panic!("Expected Callback error"),
639 }
640 }
641
642 #[tokio::test]
643 async fn test_callback_result_timed_out() {
644 let client = Arc::new(MockDurableServiceClient::new());
645
646 let mut op = Operation::new("test-callback-123", OperationType::Callback);
648 op.status = OperationStatus::TimedOut;
649
650 let initial_state = InitialExecutionState::with_operations(vec![op]);
651 let state = Arc::new(ExecutionState::new(
652 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
653 "initial-token",
654 initial_state,
655 client,
656 ));
657
658 let logger = create_test_logger();
659 let callback: Callback<String> =
660 Callback::new("test-callback-123", "test-callback-123", state, logger);
661
662 let result = callback.result().await;
663
664 assert!(result.is_err());
665 match result.unwrap_err() {
666 DurableError::Callback { message, .. } => {
667 assert!(message.contains("timed out"));
668 }
669 _ => panic!("Expected Callback error"),
670 }
671 }
672
673 #[tokio::test]
675 async fn test_callback_handler_replay_failed_returns_callback_id() {
676 let client = Arc::new(MockDurableServiceClient::new());
677
678 let mut op = Operation::new("test-callback-123", OperationType::Callback);
680 op.status = OperationStatus::Failed;
681 op.name = Some("test-callback".to_string());
682 op.callback_details = Some(CallbackDetails {
683 callback_id: Some("failed-callback-id".to_string()),
684 result: None,
685 error: Some(ErrorObject::new("CallbackError", "External system failed")),
686 });
687 op.error = Some(ErrorObject::new("CallbackError", "External system failed"));
688
689 let initial_state = InitialExecutionState::with_operations(vec![op]);
690 let state = Arc::new(ExecutionState::new(
691 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
692 "initial-token",
693 initial_state,
694 client,
695 ));
696
697 let op_id = create_test_op_id();
698 let config = create_test_config();
699 let logger = create_test_logger();
700
701 let result: Result<Callback<String>, DurableError> =
704 callback_handler(&state, &op_id, &config, &logger).await;
705
706 assert!(
707 result.is_ok(),
708 "Handler should return Ok with callback_id for FAILED callback"
709 );
710 let callback = result.unwrap();
711 assert_eq!(callback.id(), "failed-callback-id");
712
713 let result_err = callback.result().await;
716 assert!(result_err.is_err());
717 match result_err.unwrap_err() {
718 DurableError::Callback {
719 message,
720 callback_id,
721 } => {
722 assert!(message.contains("External system failed"));
723 assert_eq!(callback_id, Some("failed-callback-id".to_string()));
725 }
726 e => panic!("Expected Callback error, got {:?}", e),
727 }
728 }
729
730 #[tokio::test]
732 async fn test_callback_handler_replay_timed_out_returns_callback_id() {
733 let client = Arc::new(MockDurableServiceClient::new());
734
735 let mut op = Operation::new("test-callback-123", OperationType::Callback);
737 op.status = OperationStatus::TimedOut;
738 op.name = Some("test-callback".to_string());
739 op.callback_details = Some(CallbackDetails {
740 callback_id: Some("timed-out-callback-id".to_string()),
741 result: None,
742 error: None,
743 });
744
745 let initial_state = InitialExecutionState::with_operations(vec![op]);
746 let state = Arc::new(ExecutionState::new(
747 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
748 "initial-token",
749 initial_state,
750 client,
751 ));
752
753 let op_id = create_test_op_id();
754 let config = create_test_config();
755 let logger = create_test_logger();
756
757 let result: Result<Callback<String>, DurableError> =
760 callback_handler(&state, &op_id, &config, &logger).await;
761
762 assert!(
763 result.is_ok(),
764 "Handler should return Ok with callback_id for TIMED_OUT callback"
765 );
766 let callback = result.unwrap();
767 assert_eq!(callback.id(), "timed-out-callback-id");
768
769 let result_err = callback.result().await;
771 assert!(result_err.is_err());
772 match result_err.unwrap_err() {
773 DurableError::Callback { message, .. } => {
774 assert!(message.contains("timed out"));
775 }
776 e => panic!("Expected Callback error with timeout message, got {:?}", e),
777 }
778 }
779
780 #[test]
781 fn test_create_callback_start_update() {
782 let op_id = OperationIdentifier::new(
783 "op-123",
784 Some("parent-456".to_string()),
785 Some("my-callback".to_string()),
786 );
787 let config = create_test_config();
788 let update = create_callback_start_update(&op_id, &config);
789
790 assert_eq!(update.operation_id, "op-123");
791 assert_eq!(update.operation_type, OperationType::Callback);
792 assert!(update.callback_options.is_some());
793 let callback_opts = update.callback_options.unwrap();
794 assert_eq!(callback_opts.timeout_seconds, Some(86400)); assert_eq!(callback_opts.heartbeat_timeout_seconds, Some(300)); assert_eq!(update.parent_id, Some("parent-456".to_string()));
797 assert_eq!(update.name, Some("my-callback".to_string()));
798 }
799}