Skip to main content

durable_lambda_core/operations/
wait.rs

1//! Wait operation — time-based suspension.
2//!
3//! Implement FR12-FR13: suspend for specified duration, resume after elapsed.
4//!
5//! The wait operation uses a **single START checkpoint** with `WaitOptions`.
6//! The server handles the timer and transitions the operation to SUCCEEDED.
7//! On re-invocation, the replay engine finds the completed wait and skips it.
8
9use aws_sdk_lambda::types::{OperationAction, OperationType, OperationUpdate};
10
11use crate::context::DurableContext;
12use crate::error::DurableError;
13
14impl DurableContext {
15    /// Suspend execution for the specified duration.
16    ///
17    /// During execution mode, sends a START checkpoint with the wait duration
18    /// and returns [`DurableError::WaitSuspended`] to signal the function
19    /// should exit. The durable execution server re-invokes the Lambda after
20    /// the duration elapses.
21    ///
22    /// During replay mode, returns `Ok(())` immediately if the wait has
23    /// already completed (status SUCCEEDED in history).
24    ///
25    /// # Arguments
26    ///
27    /// * `name` — Human-readable name for the wait operation
28    /// * `duration_secs` — Duration to wait in seconds (1 to 31,622,400)
29    ///
30    /// # Errors
31    ///
32    /// Returns [`DurableError::WaitSuspended`] when the wait has been
33    /// checkpointed — the handler must propagate this to exit the function.
34    /// Returns [`DurableError::CheckpointFailed`] if the AWS checkpoint
35    /// API call fails.
36    ///
37    /// # Examples
38    ///
39    /// ```no_run
40    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
41    /// // Wait 30 seconds before continuing.
42    /// ctx.wait("cooldown", 30).await?;
43    ///
44    /// // Execution continues here after the wait completes.
45    /// println!("Wait completed!");
46    /// # Ok(())
47    /// # }
48    /// ```
49    #[allow(clippy::await_holding_lock)]
50    pub async fn wait(&mut self, name: &str, duration_secs: i32) -> Result<(), DurableError> {
51        let op_id = self.replay_engine_mut().generate_operation_id();
52
53        let span = tracing::info_span!(
54            "durable_operation",
55            op.name = name,
56            op.type = "wait",
57            op.id = %op_id,
58        );
59        let _guard = span.enter();
60        tracing::trace!("durable_operation");
61
62        // Check if we have a completed result (replay path).
63        if self.replay_engine().check_result(&op_id).is_some() {
64            self.replay_engine_mut().track_replay(&op_id);
65            return Ok(());
66        }
67
68        // Execute path — send START checkpoint with WaitOptions.
69        let wait_opts = aws_sdk_lambda::types::WaitOptions::builder()
70            .wait_seconds(duration_secs)
71            .build();
72
73        let start_update = OperationUpdate::builder()
74            .id(op_id.clone())
75            .r#type(OperationType::Wait)
76            .action(OperationAction::Start)
77            .sub_type("Wait")
78            .name(name)
79            .wait_options(wait_opts)
80            .build()
81            .map_err(|e| DurableError::checkpoint_failed(name, e))?;
82
83        let start_response = self
84            .backend()
85            .checkpoint(
86                self.arn(),
87                self.checkpoint_token(),
88                vec![start_update],
89                None,
90            )
91            .await?;
92
93        let new_token = start_response.checkpoint_token().ok_or_else(|| {
94            DurableError::checkpoint_failed(
95                name,
96                std::io::Error::new(
97                    std::io::ErrorKind::InvalidData,
98                    "checkpoint response missing checkpoint_token",
99                ),
100            )
101        })?;
102        self.set_checkpoint_token(new_token.to_string());
103
104        // Merge any new execution state from checkpoint response.
105        if let Some(new_state) = start_response.new_execution_state() {
106            for op in new_state.operations() {
107                self.replay_engine_mut()
108                    .insert_operation(op.id().to_string(), op.clone());
109            }
110        }
111
112        // Double-check: after START, re-check if operation already completed.
113        if self.replay_engine().check_result(&op_id).is_some() {
114            self.replay_engine_mut().track_replay(&op_id);
115            return Ok(());
116        }
117
118        // Wait not yet completed — signal the handler to exit.
119        Err(DurableError::wait_suspended(name))
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use std::sync::Arc;
126
127    use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
128    use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
129    use aws_sdk_lambda::types::{
130        Operation, OperationAction, OperationStatus, OperationType, OperationUpdate,
131    };
132    use aws_smithy_types::DateTime;
133    use tokio::sync::Mutex;
134    use tracing_test::traced_test;
135
136    use crate::backend::DurableBackend;
137    use crate::context::DurableContext;
138    use crate::error::DurableError;
139
140    #[derive(Debug, Clone)]
141    #[allow(dead_code)]
142    struct CheckpointCall {
143        arn: String,
144        checkpoint_token: String,
145        updates: Vec<OperationUpdate>,
146    }
147
148    struct MockBackend {
149        calls: Arc<Mutex<Vec<CheckpointCall>>>,
150        checkpoint_token: String,
151    }
152
153    impl MockBackend {
154        fn new(checkpoint_token: &str) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
155            let calls = Arc::new(Mutex::new(Vec::new()));
156            let backend = Self {
157                calls: calls.clone(),
158                checkpoint_token: checkpoint_token.to_string(),
159            };
160            (backend, calls)
161        }
162    }
163
164    #[async_trait::async_trait]
165    impl DurableBackend for MockBackend {
166        async fn checkpoint(
167            &self,
168            arn: &str,
169            checkpoint_token: &str,
170            updates: Vec<OperationUpdate>,
171            _client_token: Option<&str>,
172        ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
173            self.calls.lock().await.push(CheckpointCall {
174                arn: arn.to_string(),
175                checkpoint_token: checkpoint_token.to_string(),
176                updates,
177            });
178            Ok(CheckpointDurableExecutionOutput::builder()
179                .checkpoint_token(&self.checkpoint_token)
180                .build())
181        }
182
183        async fn get_execution_state(
184            &self,
185            _arn: &str,
186            _checkpoint_token: &str,
187            _next_marker: &str,
188            _max_items: i32,
189        ) -> Result<GetDurableExecutionStateOutput, DurableError> {
190            Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
191        }
192    }
193
194    #[tokio::test]
195    async fn test_wait_sends_start_checkpoint_and_suspends() {
196        let (backend, calls) = MockBackend::new("new-token");
197        let mut ctx = DurableContext::new(
198            Arc::new(backend),
199            "arn:test".to_string(),
200            "initial-token".to_string(),
201            vec![],
202            None,
203        )
204        .await
205        .unwrap();
206
207        let result = ctx.wait("cooldown", 30).await;
208
209        // Should return WaitSuspended error.
210        let err = result.unwrap_err();
211        assert!(
212            err.to_string().contains("cooldown"),
213            "error should contain operation name"
214        );
215        assert!(
216            err.to_string().contains("wait suspended"),
217            "error should indicate wait suspension"
218        );
219
220        // Verify START checkpoint was sent.
221        let captured = calls.lock().await;
222        assert_eq!(captured.len(), 1, "expected exactly 1 checkpoint (START)");
223
224        let update = &captured[0].updates[0];
225        assert_eq!(update.r#type(), &OperationType::Wait);
226        assert_eq!(update.action(), &OperationAction::Start);
227        assert_eq!(update.name(), Some("cooldown"));
228        assert_eq!(update.sub_type(), Some("Wait"));
229
230        // Verify WaitOptions with duration.
231        let wait_opts = update.wait_options().expect("should have wait_options");
232        assert_eq!(wait_opts.wait_seconds(), Some(30));
233    }
234
235    #[tokio::test]
236    async fn test_wait_replays_completed_wait() {
237        // Create a completed wait operation in history.
238        let op_id = {
239            let mut gen = crate::operation_id::OperationIdGenerator::new(None);
240            gen.next_id()
241        };
242
243        let wait_op = Operation::builder()
244            .id(&op_id)
245            .r#type(OperationType::Wait)
246            .status(OperationStatus::Succeeded)
247            .start_timestamp(DateTime::from_secs(0))
248            .build()
249            .unwrap();
250
251        let (backend, calls) = MockBackend::new("token");
252        let mut ctx = DurableContext::new(
253            Arc::new(backend),
254            "arn:test".to_string(),
255            "tok".to_string(),
256            vec![wait_op],
257            None,
258        )
259        .await
260        .unwrap();
261
262        // Should replay successfully — no suspension.
263        let result = ctx.wait("cooldown", 30).await;
264        assert!(result.is_ok(), "replay should return Ok(())");
265
266        // No checkpoints during replay.
267        let captured = calls.lock().await;
268        assert_eq!(captured.len(), 0, "no checkpoints during replay");
269    }
270
271    #[tokio::test]
272    async fn test_wait_double_check_after_start() {
273        // MockBackend that returns a completed wait in new_execution_state after START.
274        struct DoubleCheckBackend {
275            calls: Arc<Mutex<Vec<CheckpointCall>>>,
276            completed_op_id: String,
277        }
278
279        #[async_trait::async_trait]
280        impl DurableBackend for DoubleCheckBackend {
281            async fn checkpoint(
282                &self,
283                arn: &str,
284                checkpoint_token: &str,
285                updates: Vec<OperationUpdate>,
286                _client_token: Option<&str>,
287            ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
288                self.calls.lock().await.push(CheckpointCall {
289                    arn: arn.to_string(),
290                    checkpoint_token: checkpoint_token.to_string(),
291                    updates,
292                });
293
294                // Simulate server completing the wait synchronously.
295                let completed_op = Operation::builder()
296                    .id(&self.completed_op_id)
297                    .r#type(OperationType::Wait)
298                    .status(OperationStatus::Succeeded)
299                    .start_timestamp(DateTime::from_secs(0))
300                    .build()
301                    .unwrap();
302
303                let new_state = aws_sdk_lambda::types::CheckpointUpdatedExecutionState::builder()
304                    .operations(completed_op)
305                    .build();
306
307                Ok(CheckpointDurableExecutionOutput::builder()
308                    .checkpoint_token("new-token")
309                    .new_execution_state(new_state)
310                    .build())
311            }
312
313            async fn get_execution_state(
314                &self,
315                _arn: &str,
316                _checkpoint_token: &str,
317                _next_marker: &str,
318                _max_items: i32,
319            ) -> Result<GetDurableExecutionStateOutput, DurableError> {
320                Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
321            }
322        }
323
324        // Pre-compute the operation ID that will be generated.
325        let op_id = {
326            let mut gen = crate::operation_id::OperationIdGenerator::new(None);
327            gen.next_id()
328        };
329
330        let calls = Arc::new(Mutex::new(Vec::new()));
331        let backend = DoubleCheckBackend {
332            calls: calls.clone(),
333            completed_op_id: op_id,
334        };
335
336        let mut ctx = DurableContext::new(
337            Arc::new(backend),
338            "arn:test".to_string(),
339            "tok".to_string(),
340            vec![],
341            None,
342        )
343        .await
344        .unwrap();
345
346        // Should return Ok(()) because the server completed the wait during START.
347        let result = ctx.wait("fast_wait", 1).await;
348        assert!(
349            result.is_ok(),
350            "double-check should detect completion and return Ok(())"
351        );
352
353        // START checkpoint was still sent.
354        let captured = calls.lock().await;
355        assert_eq!(captured.len(), 1, "START checkpoint sent");
356    }
357
358    // ─── span tests (FEAT-17) ─────────────────────────────────────────────
359
360    #[traced_test]
361    #[tokio::test]
362    async fn test_wait_emits_span() {
363        let (backend, _calls) = MockBackend::new("tok");
364        let mut ctx = DurableContext::new(
365            Arc::new(backend),
366            "arn:test".to_string(),
367            "tok".to_string(),
368            vec![],
369            None,
370        )
371        .await
372        .unwrap();
373        // wait returns WaitSuspended — that's expected; span is emitted before suspension
374        let _ = ctx.wait("cooldown", 30).await;
375        assert!(logs_contain("durable_operation"));
376        assert!(logs_contain("cooldown"));
377        assert!(logs_contain("wait"));
378    }
379}