Skip to main content

durable_lambda_core/operations/
callback.rs

1//! Callback operation — external signal coordination.
2//!
3//! Implement FR14-FR16: register callback, suspend until signal,
4//! handle success/failure/heartbeat signals.
5//!
6//! The callback operation is a **two-phase** operation:
7//! 1. [`DurableContext::create_callback`] — sends a START checkpoint and
8//!    returns a [`CallbackHandle`] with the server-generated `callback_id`.
9//! 2. [`DurableContext::callback_result`] — checks if the callback has been
10//!    signaled and returns the result or suspends.
11
12use aws_sdk_lambda::types::{OperationAction, OperationStatus, OperationType, OperationUpdate};
13use serde::de::DeserializeOwned;
14
15use crate::context::DurableContext;
16use crate::error::DurableError;
17use crate::types::{CallbackHandle, CallbackOptions};
18
19impl DurableContext {
20    /// Register a callback and return a handle with the server-generated callback ID.
21    ///
22    /// During execution mode, sends a START checkpoint with callback configuration
23    /// and returns a [`CallbackHandle`] containing the `callback_id` that external
24    /// systems use to signal completion via `SendDurableExecutionCallbackSuccess`,
25    /// `SendDurableExecutionCallbackFailure`, or `SendDurableExecutionCallbackHeartbeat`.
26    ///
27    /// During replay mode, extracts the cached `callback_id` from history without
28    /// sending any checkpoint.
29    ///
30    /// **Important:** This method NEVER suspends. Suspension happens in
31    /// [`callback_result`](Self::callback_result) when the callback hasn't
32    /// been signaled yet.
33    ///
34    /// # Arguments
35    ///
36    /// * `name` — Human-readable name for the callback operation
37    /// * `options` — Timeout configuration (see [`CallbackOptions`])
38    ///
39    /// # Errors
40    ///
41    /// Returns [`DurableError::CheckpointFailed`] if the AWS checkpoint API
42    /// call fails or if the callback_id cannot be extracted from the response.
43    ///
44    /// # Examples
45    ///
46    /// ```no_run
47    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
48    /// use durable_lambda_core::types::CallbackOptions;
49    ///
50    /// let handle = ctx.create_callback("approval", CallbackOptions::new()).await?;
51    /// println!("Callback ID for external system: {}", handle.callback_id);
52    /// # Ok(())
53    /// # }
54    /// ```
55    #[allow(clippy::await_holding_lock)]
56    pub async fn create_callback(
57        &mut self,
58        name: &str,
59        options: CallbackOptions,
60    ) -> Result<CallbackHandle, DurableError> {
61        let op_id = self.replay_engine_mut().generate_operation_id();
62
63        let span = tracing::info_span!(
64            "durable_operation",
65            op.name = name,
66            op.type = "callback",
67            op.id = %op_id,
68        );
69        let _guard = span.enter();
70        tracing::trace!("durable_operation");
71
72        // Check if operation exists in history (any status — not just completed).
73        if let Some(op) = self.replay_engine().get_operation(&op_id) {
74            let callback_id = op
75                .callback_details()
76                .and_then(|d| d.callback_id())
77                .ok_or_else(|| {
78                    DurableError::checkpoint_failed(
79                        name,
80                        std::io::Error::new(
81                            std::io::ErrorKind::InvalidData,
82                            "callback_details missing callback_id in history",
83                        ),
84                    )
85                })?
86                .to_string();
87
88            self.replay_engine_mut().track_replay(&op_id);
89            return Ok(CallbackHandle {
90                callback_id,
91                operation_id: op_id,
92            });
93        }
94
95        // Execute path — send START checkpoint with CallbackOptions.
96        let callback_opts = aws_sdk_lambda::types::CallbackOptions::builder()
97            .timeout_seconds(options.get_timeout_seconds())
98            .heartbeat_timeout_seconds(options.get_heartbeat_timeout_seconds())
99            .build();
100
101        let start_update = OperationUpdate::builder()
102            .id(op_id.clone())
103            .r#type(OperationType::Callback)
104            .action(OperationAction::Start)
105            .sub_type("Callback")
106            .name(name)
107            .callback_options(callback_opts)
108            .build()
109            .map_err(|e| DurableError::checkpoint_failed(name, e))?;
110
111        let start_response = self
112            .backend()
113            .checkpoint(
114                self.arn(),
115                self.checkpoint_token(),
116                vec![start_update],
117                None,
118            )
119            .await?;
120
121        let new_token = start_response.checkpoint_token().ok_or_else(|| {
122            DurableError::checkpoint_failed(
123                name,
124                std::io::Error::new(
125                    std::io::ErrorKind::InvalidData,
126                    "checkpoint response missing checkpoint_token",
127                ),
128            )
129        })?;
130        self.set_checkpoint_token(new_token.to_string());
131
132        // Merge any new execution state from checkpoint response.
133        if let Some(new_state) = start_response.new_execution_state() {
134            for op in new_state.operations() {
135                self.replay_engine_mut()
136                    .insert_operation(op.id().to_string(), op.clone());
137            }
138        }
139
140        // Extract callback_id from the merged operation's callback_details.
141        let callback_id = self
142            .replay_engine()
143            .get_operation(&op_id)
144            .and_then(|op| op.callback_details())
145            .and_then(|d| d.callback_id())
146            .ok_or_else(|| {
147                DurableError::checkpoint_failed(
148                    name,
149                    std::io::Error::new(
150                        std::io::ErrorKind::InvalidData,
151                        "no callback_id in checkpoint response",
152                    ),
153                )
154            })?
155            .to_string();
156
157        self.replay_engine_mut().track_replay(&op_id);
158
159        Ok(CallbackHandle {
160            callback_id,
161            operation_id: op_id,
162        })
163    }
164
165    /// Check the result of a previously created callback.
166    ///
167    /// Return the deserialized success payload if the callback has been
168    /// signaled with success. Return an error if the callback failed,
169    /// timed out, or hasn't been signaled yet.
170    ///
171    /// **Important:** This is NOT an async/durable operation — it only reads
172    /// the current operation state. It does NOT generate an operation ID or
173    /// create checkpoints.
174    ///
175    /// # Arguments
176    ///
177    /// * `handle` — The [`CallbackHandle`] returned by [`create_callback`](Self::create_callback)
178    ///
179    /// # Errors
180    ///
181    /// Returns [`DurableError::CallbackSuspended`] if the callback has not
182    /// been signaled yet (the handler should propagate this to exit).
183    /// Returns [`DurableError::CallbackFailed`] if the callback was signaled
184    /// with failure, was cancelled, or timed out.
185    /// Returns [`DurableError::Deserialization`] if the callback result
186    /// cannot be deserialized as type `T`.
187    ///
188    /// # Examples
189    ///
190    /// ```no_run
191    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
192    /// use durable_lambda_core::types::CallbackOptions;
193    ///
194    /// let handle = ctx.create_callback("approval", CallbackOptions::new()).await?;
195    /// // ... pass handle.callback_id to external system ...
196    /// let result: String = ctx.callback_result(&handle)?;
197    /// # Ok(())
198    /// # }
199    /// ```
200    pub fn callback_result<T: DeserializeOwned>(
201        &self,
202        handle: &CallbackHandle,
203    ) -> Result<T, DurableError> {
204        let Some(op) = self.replay_engine().get_operation(&handle.operation_id) else {
205            // Operation not found — shouldn't happen if create_callback was called,
206            // but treat as suspended to be safe.
207            return Err(DurableError::callback_suspended(
208                "unknown",
209                &handle.callback_id,
210            ));
211        };
212
213        match &op.status {
214            OperationStatus::Succeeded => {
215                let result_str =
216                    op.callback_details()
217                        .and_then(|d| d.result())
218                        .ok_or_else(|| {
219                            DurableError::checkpoint_failed(
220                                op.name().unwrap_or("callback"),
221                                std::io::Error::new(
222                                    std::io::ErrorKind::InvalidData,
223                                    "callback succeeded but no result in callback_details",
224                                ),
225                            )
226                        })?;
227
228                serde_json::from_str(result_str)
229                    .map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))
230            }
231            OperationStatus::Failed
232            | OperationStatus::Cancelled
233            | OperationStatus::TimedOut
234            | OperationStatus::Stopped => {
235                let error_message = op
236                    .callback_details()
237                    .and_then(|d| d.error())
238                    .map(|e| {
239                        format!(
240                            "{}: {}",
241                            e.error_type().unwrap_or("Unknown"),
242                            e.error_data().unwrap_or("")
243                        )
244                    })
245                    .unwrap_or_else(|| "callback failed".to_string());
246
247                Err(DurableError::callback_failed(
248                    op.name().unwrap_or("callback"),
249                    &handle.callback_id,
250                    error_message,
251                ))
252            }
253            // Started, Pending, Ready, or any other status — not yet signaled.
254            _ => Err(DurableError::callback_suspended(
255                op.name().unwrap_or("callback"),
256                &handle.callback_id,
257            )),
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use std::sync::Arc;
265
266    use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
267    use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
268    use aws_sdk_lambda::types::{
269        CallbackDetails, ErrorObject, Operation, OperationAction, OperationStatus, OperationType,
270        OperationUpdate,
271    };
272    use aws_smithy_types::DateTime;
273    use tokio::sync::Mutex;
274    use tracing_test::traced_test;
275
276    use crate::backend::DurableBackend;
277    use crate::context::DurableContext;
278    use crate::error::DurableError;
279    use crate::types::CallbackOptions;
280
281    #[derive(Debug, Clone)]
282    #[allow(dead_code)]
283    struct CheckpointCall {
284        arn: String,
285        checkpoint_token: String,
286        updates: Vec<OperationUpdate>,
287    }
288
289    /// MockBackend that returns an operation with callback_details in new_execution_state.
290    struct CallbackMockBackend {
291        calls: Arc<Mutex<Vec<CheckpointCall>>>,
292        checkpoint_token: String,
293        /// The operation to return in new_execution_state after checkpoint.
294        response_operation: Option<Operation>,
295    }
296
297    impl CallbackMockBackend {
298        fn new(
299            checkpoint_token: &str,
300            response_op: Operation,
301        ) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
302            let calls = Arc::new(Mutex::new(Vec::new()));
303            let backend = Self {
304                calls: calls.clone(),
305                checkpoint_token: checkpoint_token.to_string(),
306                response_operation: Some(response_op),
307            };
308            (backend, calls)
309        }
310    }
311
312    #[async_trait::async_trait]
313    impl DurableBackend for CallbackMockBackend {
314        async fn checkpoint(
315            &self,
316            arn: &str,
317            checkpoint_token: &str,
318            updates: Vec<OperationUpdate>,
319            _client_token: Option<&str>,
320        ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
321            self.calls.lock().await.push(CheckpointCall {
322                arn: arn.to_string(),
323                checkpoint_token: checkpoint_token.to_string(),
324                updates,
325            });
326
327            let mut builder = CheckpointDurableExecutionOutput::builder()
328                .checkpoint_token(&self.checkpoint_token);
329
330            if let Some(ref op) = self.response_operation {
331                let new_state = aws_sdk_lambda::types::CheckpointUpdatedExecutionState::builder()
332                    .operations(op.clone())
333                    .build();
334                builder = builder.new_execution_state(new_state);
335            }
336
337            Ok(builder.build())
338        }
339
340        async fn get_execution_state(
341            &self,
342            _arn: &str,
343            _checkpoint_token: &str,
344            _next_marker: &str,
345            _max_items: i32,
346        ) -> Result<GetDurableExecutionStateOutput, DurableError> {
347            Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
348        }
349    }
350
351    /// Pre-compute the first operation ID that will be generated.
352    fn first_op_id() -> String {
353        let mut gen = crate::operation_id::OperationIdGenerator::new(None);
354        gen.next_id()
355    }
356
357    fn make_callback_op(
358        id: &str,
359        status: OperationStatus,
360        callback_id: &str,
361        result: Option<&str>,
362        error: Option<ErrorObject>,
363    ) -> Operation {
364        let mut cb_builder = CallbackDetails::builder().callback_id(callback_id);
365        if let Some(r) = result {
366            cb_builder = cb_builder.result(r);
367        }
368        if let Some(e) = error {
369            cb_builder = cb_builder.error(e);
370        }
371
372        Operation::builder()
373            .id(id)
374            .r#type(OperationType::Callback)
375            .status(status)
376            .name("test_callback")
377            .start_timestamp(DateTime::from_secs(0))
378            .callback_details(cb_builder.build())
379            .build()
380            .unwrap()
381    }
382
383    // ─── create_callback tests ───────────────────────────────────────────
384
385    #[tokio::test]
386    async fn test_create_callback_sends_start_checkpoint_and_returns_handle() {
387        let op_id = first_op_id();
388
389        // Mock returns an operation with callback_details containing the callback_id.
390        let response_op = make_callback_op(
391            &op_id,
392            OperationStatus::Started,
393            "cb-server-123",
394            None,
395            None,
396        );
397
398        let (backend, calls) = CallbackMockBackend::new("new-token", response_op);
399        let mut ctx = DurableContext::new(
400            Arc::new(backend),
401            "arn:test".to_string(),
402            "initial-token".to_string(),
403            vec![],
404            None,
405        )
406        .await
407        .unwrap();
408
409        let handle = ctx
410            .create_callback("approval", CallbackOptions::new().timeout_seconds(300))
411            .await
412            .unwrap();
413
414        // Verify the handle contains the server-generated callback_id.
415        assert_eq!(handle.callback_id, "cb-server-123");
416
417        // Verify START checkpoint was sent.
418        let captured = calls.lock().await;
419        assert_eq!(captured.len(), 1, "expected exactly 1 checkpoint (START)");
420
421        let update = &captured[0].updates[0];
422        assert_eq!(update.r#type(), &OperationType::Callback);
423        assert_eq!(update.action(), &OperationAction::Start);
424        assert_eq!(update.name(), Some("approval"));
425        assert_eq!(update.sub_type(), Some("Callback"));
426
427        // Verify CallbackOptions with timeout.
428        let callback_opts = update
429            .callback_options()
430            .expect("should have callback_options");
431        assert_eq!(callback_opts.timeout_seconds(), 300);
432        assert_eq!(callback_opts.heartbeat_timeout_seconds(), 0);
433    }
434
435    #[tokio::test]
436    async fn test_create_callback_replays_from_history() {
437        let op_id = first_op_id();
438
439        // Operation in history with SUCCEEDED status and callback_details.
440        let callback_op = make_callback_op(
441            &op_id,
442            OperationStatus::Succeeded,
443            "cb-cached-456",
444            Some(r#""approved""#),
445            None,
446        );
447
448        // Use a backend that should NOT be called for checkpoints.
449        let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
450        let (backend, calls) = CallbackMockBackend::new("token", response_op);
451
452        let mut ctx = DurableContext::new(
453            Arc::new(backend),
454            "arn:test".to_string(),
455            "tok".to_string(),
456            vec![callback_op],
457            None,
458        )
459        .await
460        .unwrap();
461
462        let handle = ctx
463            .create_callback("approval", CallbackOptions::new())
464            .await
465            .unwrap();
466
467        // Should return the cached callback_id.
468        assert_eq!(handle.callback_id, "cb-cached-456");
469
470        // No checkpoints during replay.
471        let captured = calls.lock().await;
472        assert_eq!(captured.len(), 0, "no checkpoints during replay");
473    }
474
475    // ─── callback_result tests ───────────────────────────────────────────
476
477    #[tokio::test]
478    async fn test_callback_result_returns_deserialized_value_on_succeeded() {
479        let op_id = first_op_id();
480
481        let callback_op = make_callback_op(
482            &op_id,
483            OperationStatus::Succeeded,
484            "cb-789",
485            Some(r#"{"status":"approved","approver":"alice"}"#),
486            None,
487        );
488
489        let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
490        let (backend, _) = CallbackMockBackend::new("token", response_op);
491
492        let mut ctx = DurableContext::new(
493            Arc::new(backend),
494            "arn:test".to_string(),
495            "tok".to_string(),
496            vec![callback_op],
497            None,
498        )
499        .await
500        .unwrap();
501
502        // First create_callback to replay and get the handle.
503        let handle = ctx
504            .create_callback("approval", CallbackOptions::new())
505            .await
506            .unwrap();
507
508        // Now check the result.
509        let result: serde_json::Value = ctx.callback_result(&handle).unwrap();
510        assert_eq!(result["status"], "approved");
511        assert_eq!(result["approver"], "alice");
512    }
513
514    #[tokio::test]
515    async fn test_callback_result_returns_error_on_failed() {
516        let op_id = first_op_id();
517
518        let error_obj = ErrorObject::builder()
519            .error_type("RejectionError")
520            .error_data("reviewer declined the request")
521            .build();
522
523        let callback_op = make_callback_op(
524            &op_id,
525            OperationStatus::Failed,
526            "cb-fail-1",
527            None,
528            Some(error_obj),
529        );
530
531        let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
532        let (backend, _) = CallbackMockBackend::new("token", response_op);
533
534        let mut ctx = DurableContext::new(
535            Arc::new(backend),
536            "arn:test".to_string(),
537            "tok".to_string(),
538            vec![callback_op],
539            None,
540        )
541        .await
542        .unwrap();
543
544        let handle = ctx
545            .create_callback("approval", CallbackOptions::new())
546            .await
547            .unwrap();
548
549        let err = ctx.callback_result::<String>(&handle).unwrap_err();
550        let msg = err.to_string();
551        assert!(msg.contains("callback failed"), "error: {msg}");
552        assert!(
553            msg.contains("cb-fail-1"),
554            "should contain callback_id: {msg}"
555        );
556        assert!(
557            msg.contains("RejectionError"),
558            "should contain error type: {msg}"
559        );
560    }
561
562    #[tokio::test]
563    async fn test_callback_result_returns_error_on_timed_out() {
564        let op_id = first_op_id();
565
566        let callback_op = make_callback_op(
567            &op_id,
568            OperationStatus::TimedOut,
569            "cb-timeout-1",
570            None,
571            None,
572        );
573
574        let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
575        let (backend, _) = CallbackMockBackend::new("token", response_op);
576
577        let mut ctx = DurableContext::new(
578            Arc::new(backend),
579            "arn:test".to_string(),
580            "tok".to_string(),
581            vec![callback_op],
582            None,
583        )
584        .await
585        .unwrap();
586
587        let handle = ctx
588            .create_callback("approval", CallbackOptions::new())
589            .await
590            .unwrap();
591
592        let err = ctx.callback_result::<String>(&handle).unwrap_err();
593        let msg = err.to_string();
594        assert!(msg.contains("callback failed"), "error: {msg}");
595        assert!(
596            msg.contains("cb-timeout-1"),
597            "should contain callback_id: {msg}"
598        );
599    }
600
601    #[tokio::test]
602    async fn test_callback_result_suspends_on_started() {
603        let op_id = first_op_id();
604
605        // Operation in STARTED status — callback not yet signaled.
606        let callback_op =
607            make_callback_op(&op_id, OperationStatus::Started, "cb-pending-1", None, None);
608
609        let response_op = make_callback_op(&op_id, OperationStatus::Started, "unused", None, None);
610        let (backend, _) = CallbackMockBackend::new("token", response_op);
611
612        let mut ctx = DurableContext::new(
613            Arc::new(backend),
614            "arn:test".to_string(),
615            "tok".to_string(),
616            vec![callback_op],
617            None,
618        )
619        .await
620        .unwrap();
621
622        let handle = ctx
623            .create_callback("approval", CallbackOptions::new())
624            .await
625            .unwrap();
626
627        let err = ctx.callback_result::<String>(&handle).unwrap_err();
628        let msg = err.to_string();
629        assert!(msg.contains("callback suspended"), "error: {msg}");
630        assert!(
631            msg.contains("cb-pending-1"),
632            "should contain callback_id: {msg}"
633        );
634    }
635
636    // ─── span tests (FEAT-17) ─────────────────────────────────────────────
637
638    #[traced_test]
639    #[tokio::test]
640    async fn test_callback_emits_span() {
641        let op_id = first_op_id();
642        // Return an operation with callback_id so create_callback succeeds
643        let response_op =
644            make_callback_op(&op_id, OperationStatus::Started, "cb-span-test", None, None);
645        let (backend, _calls) = CallbackMockBackend::new("tok", response_op);
646        let mut ctx = DurableContext::new(
647            Arc::new(backend),
648            "arn:test".to_string(),
649            "tok".to_string(),
650            vec![],
651            None,
652        )
653        .await
654        .unwrap();
655        let _ = ctx.create_callback("notify", CallbackOptions::new()).await;
656        assert!(logs_contain("durable_operation"));
657        assert!(logs_contain("notify"));
658        assert!(logs_contain("callback"));
659    }
660}