1use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
13use serde::de::DeserializeOwned;
14
15use crate::context::DurableContext;
16use crate::error::DurableError;
17use crate::types::{CallbackHandle, CallbackOptions};
18
19impl DurableContext {
20 #[allow(clippy::await_holding_lock)]
56 pub async fn create_callback(
57 &mut self,
58 name: &str,
59 options: CallbackOptions,
60 ) -> Result<CallbackHandle, DurableError> {
61 let op_id = self.replay_engine_mut().generate_operation_id();
62
63 let span = tracing::info_span!(
64 "durable_operation",
65 op.name = name,
66 op.type = "callback",
67 op.id = %op_id,
68 );
69 let _guard = span.enter();
70 tracing::trace!("durable_operation");
71
72 if let Some(op) = self.replay_engine().get_operation(&op_id) {
74 let callback_id = op
75 .callback_details()
76 .and_then(|d| d.callback_id())
77 .ok_or_else(|| {
78 DurableError::checkpoint_failed(
79 name,
80 std::io::Error::new(
81 std::io::ErrorKind::InvalidData,
82 "callback_details missing callback_id in history",
83 ),
84 )
85 })?
86 .to_string();
87
88 self.replay_engine_mut().track_replay(&op_id);
89 return Ok(CallbackHandle {
90 callback_id,
91 operation_id: op_id,
92 });
93 }
94
95 let callback_opts = aws_sdk_lambda::types::CallbackOptions::builder()
97 .timeout_seconds(options.get_timeout_seconds())
98 .heartbeat_timeout_seconds(options.get_heartbeat_timeout_seconds())
99 .build();
100
101 let start_update = OperationUpdate::builder()
102 .id(op_id.clone())
103 .r#type(OperationType::Callback)
104 .action(OperationAction::Start)
105 .sub_type("Callback")
106 .name(name)
107 .callback_options(callback_opts)
108 .build()
109 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
110
111 let start_response = self
112 .backend()
113 .checkpoint(
114 self.arn(),
115 self.checkpoint_token(),
116 vec![start_update],
117 None,
118 )
119 .await?;
120
121 let new_token = start_response.checkpoint_token().ok_or_else(|| {
122 DurableError::checkpoint_failed(
123 name,
124 std::io::Error::new(
125 std::io::ErrorKind::InvalidData,
126 "checkpoint response missing checkpoint_token",
127 ),
128 )
129 })?;
130 self.set_checkpoint_token(new_token.to_string());
131
132 if let Some(new_state) = start_response.new_execution_state() {
134 for op in new_state.operations() {
135 self.replay_engine_mut()
136 .insert_operation(op.id().to_string(), op.clone());
137 }
138 }
139
140 let callback_id = self
142 .replay_engine()
143 .get_operation(&op_id)
144 .and_then(|op| op.callback_details())
145 .and_then(|d| d.callback_id())
146 .ok_or_else(|| {
147 DurableError::checkpoint_failed(
148 name,
149 std::io::Error::new(
150 std::io::ErrorKind::InvalidData,
151 "no callback_id in checkpoint response",
152 ),
153 )
154 })?
155 .to_string();
156
157 self.replay_engine_mut().track_replay(&op_id);
158
159 Ok(CallbackHandle {
160 callback_id,
161 operation_id: op_id,
162 })
163 }
164
165 pub fn callback_result<T: DeserializeOwned>(
201 &self,
202 handle: &CallbackHandle,
203 ) -> Result<T, DurableError> {
204 let Some(op) = self.replay_engine().get_operation(&handle.operation_id) else {
205 return Err(DurableError::callback_suspended(
208 "unknown",
209 &handle.callback_id,
210 ));
211 };
212
213 match &op.status {
214 OperationStatus::Succeeded => {
215 let result_str =
216 op.callback_details()
217 .and_then(|d| d.result())
218 .ok_or_else(|| {
219 DurableError::checkpoint_failed(
220 op.name().unwrap_or("callback"),
221 std::io::Error::new(
222 std::io::ErrorKind::InvalidData,
223 "callback succeeded but no result in callback_details",
224 ),
225 )
226 })?;
227
228 serde_json::from_str(result_str)
229 .map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))
230 }
231 OperationStatus::Failed
232 | OperationStatus::Cancelled
233 | OperationStatus::TimedOut
234 | OperationStatus::Stopped => {
235 let error_message = op
236 .callback_details()
237 .and_then(|d| d.error())
238 .map(|e| {
239 format!(
240 "{}: {}",
241 e.error_type().unwrap_or("Unknown"),
242 e.error_data().unwrap_or("")
243 )
244 })
245 .unwrap_or_else(|| "callback failed".to_string());
246
247 Err(DurableError::callback_failed(
248 op.name().unwrap_or("callback"),
249 &handle.callback_id,
250 error_message,
251 ))
252 }
253 _ => Err(DurableError::callback_suspended(
255 op.name().unwrap_or("callback"),
256 &handle.callback_id,
257 )),
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use std::sync::Arc;
265
266 use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
267 use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
268 use aws_sdk_lambda::types::{
269 CallbackDetails, ErrorObject, Operation, OperationAction, OperationStatus, OperationType,
270 OperationUpdate,
271 };
272 use aws_smithy_types::DateTime;
273 use tokio::sync::Mutex;
274 use tracing_test::traced_test;
275
276 use crate::backend::DurableBackend;
277 use crate::context::DurableContext;
278 use crate::error::DurableError;
279 use crate::types::CallbackOptions;
280
281 #[derive(Debug, Clone)]
282 #[allow(dead_code)]
283 struct CheckpointCall {
284 arn: String,
285 checkpoint_token: String,
286 updates: Vec<OperationUpdate>,
287 }
288
289 struct CallbackMockBackend {
291 calls: Arc<Mutex<Vec<CheckpointCall>>>,
292 checkpoint_token: String,
293 response_operation: Option<Operation>,
295 }
296
297 impl CallbackMockBackend {
298 fn new(
299 checkpoint_token: &str,
300 response_op: Operation,
301 ) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
302 let calls = Arc::new(Mutex::new(Vec::new()));
303 let backend = Self {
304 calls: calls.clone(),
305 checkpoint_token: checkpoint_token.to_string(),
306 response_operation: Some(response_op),
307 };
308 (backend, calls)
309 }
310 }
311
312 #[async_trait::async_trait]
313 impl DurableBackend for CallbackMockBackend {
314 async fn checkpoint(
315 &self,
316 arn: &str,
317 checkpoint_token: &str,
318 updates: Vec<OperationUpdate>,
319 _client_token: Option<&str>,
320 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
321 self.calls.lock().await.push(CheckpointCall {
322 arn: arn.to_string(),
323 checkpoint_token: checkpoint_token.to_string(),
324 updates,
325 });
326
327 let mut builder = CheckpointDurableExecutionOutput::builder()
328 .checkpoint_token(&self.checkpoint_token);
329
330 if let Some(ref op) = self.response_operation {
331 let new_state = aws_sdk_lambda::types::CheckpointUpdatedExecutionState::builder()
332 .operations(op.clone())
333 .build();
334 builder = builder.new_execution_state(new_state);
335 }
336
337 Ok(builder.build())
338 }
339
340 async fn get_execution_state(
341 &self,
342 _arn: &str,
343 _checkpoint_token: &str,
344 _next_marker: &str,
345 _max_items: i32,
346 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
347 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
348 }
349 }
350
351 fn first_op_id() -> String {
353 let mut gen = crate::operation_id::OperationIdGenerator::new(None);
354 gen.next_id()
355 }
356
357 fn make_callback_op(
358 id: &str,
359 status: OperationStatus,
360 callback_id: &str,
361 result: Option<&str>,
362 error: Option<ErrorObject>,
363 ) -> Operation {
364 let mut cb_builder = CallbackDetails::builder().callback_id(callback_id);
365 if let Some(r) = result {
366 cb_builder = cb_builder.result(r);
367 }
368 if let Some(e) = error {
369 cb_builder = cb_builder.error(e);
370 }
371
372 Operation::builder()
373 .id(id)
374 .r#type(OperationType::Callback)
375 .status(status)
376 .name("test_callback")
377 .start_timestamp(DateTime::from_secs(0))
378 .callback_details(cb_builder.build())
379 .build()
380 .unwrap()
381 }
382
383 #[tokio::test]
386 async fn test_create_callback_sends_start_checkpoint_and_returns_handle() {
387 let op_id = first_op_id();
388
389 let response_op = make_callback_op(
391 &op_id,
392 OperationStatus::Started,
393 "cb-server-123",
394 None,
395 None,
396 );
397
398 let (backend, calls) = CallbackMockBackend::new("new-token", response_op);
399 let mut ctx = DurableContext::new(
400 Arc::new(backend),
401 "arn:test".to_string(),
402 "initial-token".to_string(),
403 vec![],
404 None,
405 )
406 .await
407 .unwrap();
408
409 let handle = ctx
410 .create_callback("approval", CallbackOptions::new().timeout_seconds(300))
411 .await
412 .unwrap();
413
414 assert_eq!(handle.callback_id, "cb-server-123");
416
417 let captured = calls.lock().await;
419 assert_eq!(captured.len(), 1, "expected exactly 1 checkpoint (START)");
420
421 let update = &captured[0].updates[0];
422 assert_eq!(update.r#type(), &OperationType::Callback);
423 assert_eq!(update.action(), &OperationAction::Start);
424 assert_eq!(update.name(), Some("approval"));
425 assert_eq!(update.sub_type(), Some("Callback"));
426
427 let callback_opts = update
429 .callback_options()
430 .expect("should have callback_options");
431 assert_eq!(callback_opts.timeout_seconds(), 300);
432 assert_eq!(callback_opts.heartbeat_timeout_seconds(), 0);
433 }
434
435 #[tokio::test]
436 async fn test_create_callback_replays_from_history() {
437 let op_id = first_op_id();
438
439 let callback_op = make_callback_op(
441 &op_id,
442 OperationStatus::Succeeded,
443 "cb-cached-456",
444 Some(r#""approved""#),
445 None,
446 );
447
448 let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
450 let (backend, calls) = CallbackMockBackend::new("token", response_op);
451
452 let mut ctx = DurableContext::new(
453 Arc::new(backend),
454 "arn:test".to_string(),
455 "tok".to_string(),
456 vec![callback_op],
457 None,
458 )
459 .await
460 .unwrap();
461
462 let handle = ctx
463 .create_callback("approval", CallbackOptions::new())
464 .await
465 .unwrap();
466
467 assert_eq!(handle.callback_id, "cb-cached-456");
469
470 let captured = calls.lock().await;
472 assert_eq!(captured.len(), 0, "no checkpoints during replay");
473 }
474
475 #[tokio::test]
478 async fn test_callback_result_returns_deserialized_value_on_succeeded() {
479 let op_id = first_op_id();
480
481 let callback_op = make_callback_op(
482 &op_id,
483 OperationStatus::Succeeded,
484 "cb-789",
485 Some(r#"{"status":"approved","approver":"alice"}"#),
486 None,
487 );
488
489 let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
490 let (backend, _) = CallbackMockBackend::new("token", response_op);
491
492 let mut ctx = DurableContext::new(
493 Arc::new(backend),
494 "arn:test".to_string(),
495 "tok".to_string(),
496 vec![callback_op],
497 None,
498 )
499 .await
500 .unwrap();
501
502 let handle = ctx
504 .create_callback("approval", CallbackOptions::new())
505 .await
506 .unwrap();
507
508 let result: serde_json::Value = ctx.callback_result(&handle).unwrap();
510 assert_eq!(result["status"], "approved");
511 assert_eq!(result["approver"], "alice");
512 }
513
514 #[tokio::test]
515 async fn test_callback_result_returns_error_on_failed() {
516 let op_id = first_op_id();
517
518 let error_obj = ErrorObject::builder()
519 .error_type("RejectionError")
520 .error_data("reviewer declined the request")
521 .build();
522
523 let callback_op = make_callback_op(
524 &op_id,
525 OperationStatus::Failed,
526 "cb-fail-1",
527 None,
528 Some(error_obj),
529 );
530
531 let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
532 let (backend, _) = CallbackMockBackend::new("token", response_op);
533
534 let mut ctx = DurableContext::new(
535 Arc::new(backend),
536 "arn:test".to_string(),
537 "tok".to_string(),
538 vec![callback_op],
539 None,
540 )
541 .await
542 .unwrap();
543
544 let handle = ctx
545 .create_callback("approval", CallbackOptions::new())
546 .await
547 .unwrap();
548
549 let err = ctx.callback_result::<String>(&handle).unwrap_err();
550 let msg = err.to_string();
551 assert!(msg.contains("callback failed"), "error: {msg}");
552 assert!(
553 msg.contains("cb-fail-1"),
554 "should contain callback_id: {msg}"
555 );
556 assert!(
557 msg.contains("RejectionError"),
558 "should contain error type: {msg}"
559 );
560 }
561
562 #[tokio::test]
563 async fn test_callback_result_returns_error_on_timed_out() {
564 let op_id = first_op_id();
565
566 let callback_op = make_callback_op(
567 &op_id,
568 OperationStatus::TimedOut,
569 "cb-timeout-1",
570 None,
571 None,
572 );
573
574 let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
575 let (backend, _) = CallbackMockBackend::new("token", response_op);
576
577 let mut ctx = DurableContext::new(
578 Arc::new(backend),
579 "arn:test".to_string(),
580 "tok".to_string(),
581 vec![callback_op],
582 None,
583 )
584 .await
585 .unwrap();
586
587 let handle = ctx
588 .create_callback("approval", CallbackOptions::new())
589 .await
590 .unwrap();
591
592 let err = ctx.callback_result::<String>(&handle).unwrap_err();
593 let msg = err.to_string();
594 assert!(msg.contains("callback failed"), "error: {msg}");
595 assert!(
596 msg.contains("cb-timeout-1"),
597 "should contain callback_id: {msg}"
598 );
599 }
600
601 #[tokio::test]
602 async fn test_callback_result_suspends_on_started() {
603 let op_id = first_op_id();
604
605 let callback_op =
607 make_callback_op(&op_id, OperationStatus::Started, "cb-pending-1", None, None);
608
609 let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
610 let (backend, _) = CallbackMockBackend::new("token", response_op);
611
612 let mut ctx = DurableContext::new(
613 Arc::new(backend),
614 "arn:test".to_string(),
615 "tok".to_string(),
616 vec![callback_op],
617 None,
618 )
619 .await
620 .unwrap();
621
622 let handle = ctx
623 .create_callback("approval", CallbackOptions::new())
624 .await
625 .unwrap();
626
627 let err = ctx.callback_result::<String>(&handle).unwrap_err();
628 let msg = err.to_string();
629 assert!(msg.contains("callback suspended"), "error: {msg}");
630 assert!(
631 msg.contains("cb-pending-1"),
632 "should contain callback_id: {msg}"
633 );
634 }
635
636 #[traced_test]
639 #[tokio::test]
640 async fn test_callback_emits_span() {
641 let op_id = first_op_id();
642 let response_op =
644 make_callback_op(&op_id, OperationStatus::Started, "cb-span-test", None, None);
645 let (backend, _calls) = CallbackMockBackend::new("tok", response_op);
646 let mut ctx = DurableContext::new(
647 Arc::new(backend),
648 "arn:test".to_string(),
649 "tok".to_string(),
650 vec![],
651 None,
652 )
653 .await
654 .unwrap();
655 let _ = ctx.create_callback("notify", CallbackOptions::new()).await;
656 assert!(logs_contain("durable_operation"));
657 assert!(logs_contain("notify"));
658 assert!(logs_contain("callback"));
659 }
660}