durable_lambda_core/operations/
wait.rs1use aws_sdk_lambda::types::{OperationAction, OperationType, OperationUpdate};
10
11use crate::context::DurableContext;
12use crate::error::DurableError;
13
14impl DurableContext {
15 #[allow(clippy::await_holding_lock)]
50 pub async fn wait(&mut self, name: &str, duration_secs: i32) -> Result<(), DurableError> {
51 let op_id = self.replay_engine_mut().generate_operation_id();
52
53 let span = tracing::info_span!(
54 "durable_operation",
55 op.name = name,
56 op.type = "wait",
57 op.id = %op_id,
58 );
59 let _guard = span.enter();
60 tracing::trace!("durable_operation");
61
62 if self.replay_engine().check_result(&op_id).is_some() {
64 self.replay_engine_mut().track_replay(&op_id);
65 return Ok(());
66 }
67
68 let wait_opts = aws_sdk_lambda::types::WaitOptions::builder()
70 .wait_seconds(duration_secs)
71 .build();
72
73 let start_update = OperationUpdate::builder()
74 .id(op_id.clone())
75 .r#type(OperationType::Wait)
76 .action(OperationAction::Start)
77 .sub_type("Wait")
78 .name(name)
79 .wait_options(wait_opts)
80 .build()
81 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
82
83 let start_response = self
84 .backend()
85 .checkpoint(
86 self.arn(),
87 self.checkpoint_token(),
88 vec![start_update],
89 None,
90 )
91 .await?;
92
93 let new_token = start_response.checkpoint_token().ok_or_else(|| {
94 DurableError::checkpoint_failed(
95 name,
96 std::io::Error::new(
97 std::io::ErrorKind::InvalidData,
98 "checkpoint response missing checkpoint_token",
99 ),
100 )
101 })?;
102 self.set_checkpoint_token(new_token.to_string());
103
104 if let Some(new_state) = start_response.new_execution_state() {
106 for op in new_state.operations() {
107 self.replay_engine_mut()
108 .insert_operation(op.id().to_string(), op.clone());
109 }
110 }
111
112 if self.replay_engine().check_result(&op_id).is_some() {
114 self.replay_engine_mut().track_replay(&op_id);
115 return Ok(());
116 }
117
118 Err(DurableError::wait_suspended(name))
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use std::sync::Arc;
126
127 use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
128 use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
129 use aws_sdk_lambda::types::{
130 Operation, OperationAction, OperationStatus, OperationType, OperationUpdate,
131 };
132 use aws_smithy_types::DateTime;
133 use tokio::sync::Mutex;
134 use tracing_test::traced_test;
135
136 use crate::backend::DurableBackend;
137 use crate::context::DurableContext;
138 use crate::error::DurableError;
139
140 #[derive(Debug, Clone)]
141 #[allow(dead_code)]
142 struct CheckpointCall {
143 arn: String,
144 checkpoint_token: String,
145 updates: Vec<OperationUpdate>,
146 }
147
148 struct MockBackend {
149 calls: Arc<Mutex<Vec<CheckpointCall>>>,
150 checkpoint_token: String,
151 }
152
153 impl MockBackend {
154 fn new(checkpoint_token: &str) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
155 let calls = Arc::new(Mutex::new(Vec::new()));
156 let backend = Self {
157 calls: calls.clone(),
158 checkpoint_token: checkpoint_token.to_string(),
159 };
160 (backend, calls)
161 }
162 }
163
164 #[async_trait::async_trait]
165 impl DurableBackend for MockBackend {
166 async fn checkpoint(
167 &self,
168 arn: &str,
169 checkpoint_token: &str,
170 updates: Vec<OperationUpdate>,
171 _client_token: Option<&str>,
172 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
173 self.calls.lock().await.push(CheckpointCall {
174 arn: arn.to_string(),
175 checkpoint_token: checkpoint_token.to_string(),
176 updates,
177 });
178 Ok(CheckpointDurableExecutionOutput::builder()
179 .checkpoint_token(&self.checkpoint_token)
180 .build())
181 }
182
183 async fn get_execution_state(
184 &self,
185 _arn: &str,
186 _checkpoint_token: &str,
187 _next_marker: &str,
188 _max_items: i32,
189 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
190 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
191 }
192 }
193
194 #[tokio::test]
195 async fn test_wait_sends_start_checkpoint_and_suspends() {
196 let (backend, calls) = MockBackend::new("new-token");
197 let mut ctx = DurableContext::new(
198 Arc::new(backend),
199 "arn:test".to_string(),
200 "initial-token".to_string(),
201 vec![],
202 None,
203 )
204 .await
205 .unwrap();
206
207 let result = ctx.wait("cooldown", 30).await;
208
209 let err = result.unwrap_err();
211 assert!(
212 err.to_string().contains("cooldown"),
213 "error should contain operation name"
214 );
215 assert!(
216 err.to_string().contains("wait suspended"),
217 "error should indicate wait suspension"
218 );
219
220 let captured = calls.lock().await;
222 assert_eq!(captured.len(), 1, "expected exactly 1 checkpoint (START)");
223
224 let update = &captured[0].updates[0];
225 assert_eq!(update.r#type(), &OperationType::Wait);
226 assert_eq!(update.action(), &OperationAction::Start);
227 assert_eq!(update.name(), Some("cooldown"));
228 assert_eq!(update.sub_type(), Some("Wait"));
229
230 let wait_opts = update.wait_options().expect("should have wait_options");
232 assert_eq!(wait_opts.wait_seconds(), Some(30));
233 }
234
235 #[tokio::test]
236 async fn test_wait_replays_completed_wait() {
237 let op_id = {
239 let mut gen = crate::operation_id::OperationIdGenerator::new(None);
240 gen.next_id()
241 };
242
243 let wait_op = Operation::builder()
244 .id(&op_id)
245 .r#type(OperationType::Wait)
246 .status(OperationStatus::Succeeded)
247 .start_timestamp(DateTime::from_secs(0))
248 .build()
249 .unwrap();
250
251 let (backend, calls) = MockBackend::new("token");
252 let mut ctx = DurableContext::new(
253 Arc::new(backend),
254 "arn:test".to_string(),
255 "tok".to_string(),
256 vec![wait_op],
257 None,
258 )
259 .await
260 .unwrap();
261
262 let result = ctx.wait("cooldown", 30).await;
264 assert!(result.is_ok(), "replay should return Ok(())");
265
266 let captured = calls.lock().await;
268 assert_eq!(captured.len(), 0, "no checkpoints during replay");
269 }
270
271 #[tokio::test]
272 async fn test_wait_double_check_after_start() {
273 struct DoubleCheckBackend {
275 calls: Arc<Mutex<Vec<CheckpointCall>>>,
276 completed_op_id: String,
277 }
278
279 #[async_trait::async_trait]
280 impl DurableBackend for DoubleCheckBackend {
281 async fn checkpoint(
282 &self,
283 arn: &str,
284 checkpoint_token: &str,
285 updates: Vec<OperationUpdate>,
286 _client_token: Option<&str>,
287 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
288 self.calls.lock().await.push(CheckpointCall {
289 arn: arn.to_string(),
290 checkpoint_token: checkpoint_token.to_string(),
291 updates,
292 });
293
294 let completed_op = Operation::builder()
296 .id(&self.completed_op_id)
297 .r#type(OperationType::Wait)
298 .status(OperationStatus::Succeeded)
299 .start_timestamp(DateTime::from_secs(0))
300 .build()
301 .unwrap();
302
303 let new_state = aws_sdk_lambda::types::CheckpointUpdatedExecutionState::builder()
304 .operations(completed_op)
305 .build();
306
307 Ok(CheckpointDurableExecutionOutput::builder()
308 .checkpoint_token("new-token")
309 .new_execution_state(new_state)
310 .build())
311 }
312
313 async fn get_execution_state(
314 &self,
315 _arn: &str,
316 _checkpoint_token: &str,
317 _next_marker: &str,
318 _max_items: i32,
319 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
320 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
321 }
322 }
323
324 let op_id = {
326 let mut gen = crate::operation_id::OperationIdGenerator::new(None);
327 gen.next_id()
328 };
329
330 let calls = Arc::new(Mutex::new(Vec::new()));
331 let backend = DoubleCheckBackend {
332 calls: calls.clone(),
333 completed_op_id: op_id,
334 };
335
336 let mut ctx = DurableContext::new(
337 Arc::new(backend),
338 "arn:test".to_string(),
339 "tok".to_string(),
340 vec![],
341 None,
342 )
343 .await
344 .unwrap();
345
346 let result = ctx.wait("fast_wait", 1).await;
348 assert!(
349 result.is_ok(),
350 "double-check should detect completion and return Ok(())"
351 );
352
353 let captured = calls.lock().await;
355 assert_eq!(captured.len(), 1, "START checkpoint sent");
356 }
357
358 #[traced_test]
361 #[tokio::test]
362 async fn test_wait_emits_span() {
363 let (backend, _calls) = MockBackend::new("tok");
364 let mut ctx = DurableContext::new(
365 Arc::new(backend),
366 "arn:test".to_string(),
367 "tok".to_string(),
368 vec![],
369 None,
370 )
371 .await
372 .unwrap();
373 let _ = ctx.wait("cooldown", 30).await;
375 assert!(logs_contain("durable_operation"));
376 assert!(logs_contain("cooldown"));
377 assert!(logs_contain("wait"));
378 }
379}