1use std::sync::Arc;
7
8use crate::context::{create_operation_span, LogInfo, Logger, OperationIdentifier};
9use crate::duration::Duration;
10use crate::error::DurableError;
11use crate::operation::{OperationType, OperationUpdate};
12use crate::state::ExecutionState;
13
14const MIN_WAIT_SECONDS: u64 = 1;
16
17pub async fn wait_handler(
36 duration: Duration,
37 state: &Arc<ExecutionState>,
38 op_id: &OperationIdentifier,
39 logger: &Arc<dyn Logger>,
40) -> Result<(), DurableError> {
41 let span = create_operation_span("wait", op_id, state.durable_execution_arn());
44 let _guard = span.enter();
45
46 let mut log_info =
47 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
48 if let Some(ref parent_id) = op_id.parent_id {
49 log_info = log_info.with_parent_id(parent_id);
50 }
51
52 logger.debug(
53 &format!(
54 "Starting wait operation: {} for {} seconds",
55 op_id,
56 duration.to_seconds()
57 ),
58 &log_info,
59 );
60
61 let wait_seconds = duration.to_seconds();
63 if wait_seconds < MIN_WAIT_SECONDS {
64 span.record("status", "validation_failed");
65 return Err(DurableError::Validation {
66 message: format!(
67 "Wait duration must be at least {} second(s), got {} seconds",
68 MIN_WAIT_SECONDS, wait_seconds
69 ),
70 });
71 }
72
73 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
75
76 if checkpoint_result.is_existent() {
77 if let Some(op_type) = checkpoint_result.operation_type() {
79 if op_type != OperationType::Wait {
80 span.record("status", "non_deterministic");
81 return Err(DurableError::NonDeterministic {
82 message: format!(
83 "Expected Wait operation but found {:?} at operation_id {}",
84 op_type, op_id.operation_id
85 ),
86 operation_id: Some(op_id.operation_id.clone()),
87 });
88 }
89 }
90
91 if checkpoint_result.is_succeeded() {
93 logger.debug(&format!("Wait already completed: {}", op_id), &log_info);
94 state.track_replay(&op_id.operation_id).await;
95 span.record("status", "replayed_succeeded");
96 return Ok(());
97 }
98
99 if checkpoint_result.is_cancelled() {
102 logger.debug(&format!("Wait was cancelled: {}", op_id), &log_info);
103 state.track_replay(&op_id.operation_id).await;
104 span.record("status", "replayed_cancelled");
105 return Ok(());
106 }
107
108 if checkpoint_result.is_existent() && !checkpoint_result.is_terminal() {
110 logger.debug(&format!("Wait still in progress: {}", op_id), &log_info);
111 span.record("status", "suspended");
112 return Err(DurableError::Suspend {
113 scheduled_timestamp: None,
114 });
115 }
116 }
117
118 let start_update = create_start_update(op_id, wait_seconds);
120 state.create_checkpoint(start_update, true).await?;
121
122 logger.debug(
123 &format!("Wait started for {} seconds", wait_seconds),
124 &log_info,
125 );
126
127 span.record("status", "suspended");
130 Err(DurableError::Suspend {
131 scheduled_timestamp: None,
132 })
133}
134
135pub async fn wait_cancel_handler(
156 state: &Arc<ExecutionState>,
157 operation_id: &str,
158 logger: &Arc<dyn Logger>,
159) -> Result<(), DurableError> {
160 let log_info = LogInfo::new(state.durable_execution_arn()).with_operation_id(operation_id);
161
162 logger.debug(
163 &format!("Attempting to cancel wait operation: {}", operation_id),
164 &log_info,
165 );
166
167 let checkpoint_result = state.get_checkpoint_result(operation_id).await;
169
170 if !checkpoint_result.is_existent() {
171 logger.debug(
174 &format!(
175 "Wait operation not found, nothing to cancel: {}",
176 operation_id
177 ),
178 &log_info,
179 );
180 return Ok(());
181 }
182
183 if let Some(op_type) = checkpoint_result.operation_type() {
185 if op_type != OperationType::Wait {
186 return Err(DurableError::Validation {
187 message: format!(
188 "Cannot cancel operation {}: expected WAIT operation but found {:?}",
189 operation_id, op_type
190 ),
191 });
192 }
193 }
194
195 if checkpoint_result.is_terminal() {
197 logger.debug(
200 &format!(
201 "Wait already completed, nothing to cancel: {}",
202 operation_id
203 ),
204 &log_info,
205 );
206 return Ok(());
207 }
208
209 let cancel_update = OperationUpdate::cancel(operation_id, OperationType::Wait);
211 state.create_checkpoint(cancel_update, true).await?;
212
213 logger.info(
214 &format!("Wait operation cancelled: {}", operation_id),
215 &log_info,
216 );
217
218 Ok(())
219}
220
221fn create_start_update(op_id: &OperationIdentifier, wait_seconds: u64) -> OperationUpdate {
223 op_id.apply_to(OperationUpdate::start_wait(
224 &op_id.operation_id,
225 wait_seconds,
226 ))
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
233 use crate::context::TracingLogger;
234 use crate::lambda::InitialExecutionState;
235 use crate::operation::{Operation, OperationStatus};
236
237 fn create_mock_client() -> SharedDurableServiceClient {
238 Arc::new(
239 MockDurableServiceClient::new()
240 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
241 .with_checkpoint_response(Ok(CheckpointResponse::new("token-2"))),
242 )
243 }
244
245 fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
246 Arc::new(ExecutionState::new(
247 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
248 "initial-token",
249 InitialExecutionState::new(),
250 client,
251 ))
252 }
253
254 fn create_test_op_id() -> OperationIdentifier {
255 OperationIdentifier::new("test-wait-123", None, Some("test-wait".to_string()))
256 }
257
258 fn create_test_logger() -> Arc<dyn Logger> {
259 Arc::new(TracingLogger)
260 }
261
262 #[tokio::test]
263 async fn test_wait_handler_validation_error() {
264 let client = create_mock_client();
265 let state = create_test_state(client);
266 let op_id = create_test_op_id();
267 let logger = create_test_logger();
268
269 let result = wait_handler(Duration::from_seconds(0), &state, &op_id, &logger).await;
271
272 assert!(result.is_err());
273 match result.unwrap_err() {
274 DurableError::Validation { message } => {
275 assert!(message.contains("at least 1 second"));
276 }
277 _ => panic!("Expected Validation error"),
278 }
279 }
280
281 #[tokio::test]
282 async fn test_wait_handler_suspends_on_new_wait() {
283 let client = create_mock_client();
284 let state = create_test_state(client);
285 let op_id = create_test_op_id();
286 let logger = create_test_logger();
287
288 let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
289
290 assert!(result.is_err());
291 match result.unwrap_err() {
292 DurableError::Suspend {
293 scheduled_timestamp: _,
294 } => {
295 }
298 _ => panic!("Expected Suspend error"),
299 }
300 }
301
302 #[tokio::test]
303 async fn test_wait_handler_replay_completed() {
304 let client = Arc::new(MockDurableServiceClient::new());
305
306 let mut op = Operation::new("test-wait-123", OperationType::Wait);
308 op.status = OperationStatus::Succeeded;
309
310 let initial_state = InitialExecutionState::with_operations(vec![op]);
311 let state = Arc::new(ExecutionState::new(
312 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
313 "initial-token",
314 initial_state,
315 client,
316 ));
317
318 let op_id = create_test_op_id();
319 let logger = create_test_logger();
320
321 let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
322
323 assert!(result.is_ok());
325 }
326
327 #[tokio::test]
328 async fn test_wait_handler_non_deterministic_detection() {
329 let client = Arc::new(MockDurableServiceClient::new());
330
331 let mut op = Operation::new("test-wait-123", OperationType::Step);
333 op.status = OperationStatus::Succeeded;
334
335 let initial_state = InitialExecutionState::with_operations(vec![op]);
336 let state = Arc::new(ExecutionState::new(
337 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
338 "initial-token",
339 initial_state,
340 client,
341 ));
342
343 let op_id = create_test_op_id();
344 let logger = create_test_logger();
345
346 let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
347
348 assert!(result.is_err());
349 match result.unwrap_err() {
350 DurableError::NonDeterministic { operation_id, .. } => {
351 assert_eq!(operation_id, Some("test-wait-123".to_string()));
352 }
353 _ => panic!("Expected NonDeterministic error"),
354 }
355 }
356
357 #[tokio::test]
358 async fn test_wait_handler_replay_still_waiting() {
359 let client = Arc::new(MockDurableServiceClient::new());
360
361 let mut op = Operation::new("test-wait-123", OperationType::Wait);
363 op.status = OperationStatus::Started;
364 let initial_state = InitialExecutionState::with_operations(vec![op]);
367 let state = Arc::new(ExecutionState::new(
368 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
369 "initial-token",
370 initial_state,
371 client,
372 ));
373
374 let op_id = create_test_op_id();
375 let logger = create_test_logger();
376
377 let result = wait_handler(
378 Duration::from_seconds(3600), &state,
380 &op_id,
381 &logger,
382 )
383 .await;
384
385 assert!(result.is_err());
387 match result.unwrap_err() {
388 DurableError::Suspend {
389 scheduled_timestamp: _,
390 } => {
391 }
393 _ => panic!("Expected Suspend error"),
394 }
395 }
396
397 #[test]
398 fn test_create_start_update() {
399 let op_id = OperationIdentifier::new(
400 "op-123",
401 Some("parent-456".to_string()),
402 Some("my-wait".to_string()),
403 );
404 let update = create_start_update(&op_id, 60);
405
406 assert_eq!(update.operation_id, "op-123");
407 assert_eq!(update.operation_type, OperationType::Wait);
408 assert!(update.wait_options.is_some());
409 assert_eq!(update.wait_options.unwrap().wait_seconds, 60);
410 assert_eq!(update.parent_id, Some("parent-456".to_string()));
411 assert_eq!(update.name, Some("my-wait".to_string()));
412 }
413
414 #[tokio::test]
417 async fn test_wait_cancel_handler_cancels_active_wait() {
418 let client = Arc::new(
419 MockDurableServiceClient::new()
420 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
421 );
422
423 let mut op = Operation::new("test-wait-123", OperationType::Wait);
425 op.status = OperationStatus::Started;
426
427 let initial_state = InitialExecutionState::with_operations(vec![op]);
428 let state = Arc::new(ExecutionState::new(
429 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
430 "initial-token",
431 initial_state,
432 client,
433 ));
434
435 let logger = create_test_logger();
436
437 let result = wait_cancel_handler(&state, "test-wait-123", &logger).await;
439
440 assert!(result.is_ok());
442
443 let checkpoint_result = state.get_checkpoint_result("test-wait-123").await;
445 assert!(checkpoint_result.is_cancelled());
446 }
447
448 #[tokio::test]
449 async fn test_wait_cancel_handler_handles_already_completed_wait() {
450 let client = Arc::new(MockDurableServiceClient::new());
451
452 let mut op = Operation::new("test-wait-123", OperationType::Wait);
454 op.status = OperationStatus::Succeeded;
455
456 let initial_state = InitialExecutionState::with_operations(vec![op]);
457 let state = Arc::new(ExecutionState::new(
458 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
459 "initial-token",
460 initial_state,
461 client,
462 ));
463
464 let logger = create_test_logger();
465
466 let result = wait_cancel_handler(&state, "test-wait-123", &logger).await;
468
469 assert!(result.is_ok());
471 }
472
473 #[tokio::test]
474 async fn test_wait_cancel_handler_handles_nonexistent_wait() {
475 let client = Arc::new(MockDurableServiceClient::new());
476
477 let initial_state = InitialExecutionState::new();
479 let state = Arc::new(ExecutionState::new(
480 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
481 "initial-token",
482 initial_state,
483 client,
484 ));
485
486 let logger = create_test_logger();
487
488 let result = wait_cancel_handler(&state, "nonexistent-wait", &logger).await;
490
491 assert!(result.is_ok());
493 }
494
495 #[tokio::test]
496 async fn test_wait_cancel_handler_rejects_non_wait_operation() {
497 let client = Arc::new(MockDurableServiceClient::new());
498
499 let mut op = Operation::new("test-step-123", OperationType::Step);
501 op.status = OperationStatus::Started;
502
503 let initial_state = InitialExecutionState::with_operations(vec![op]);
504 let state = Arc::new(ExecutionState::new(
505 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
506 "initial-token",
507 initial_state,
508 client,
509 ));
510
511 let logger = create_test_logger();
512
513 let result = wait_cancel_handler(&state, "test-step-123", &logger).await;
515
516 assert!(result.is_err());
518 match result.unwrap_err() {
519 DurableError::Validation { message } => {
520 assert!(message.contains("expected WAIT operation"));
521 }
522 _ => panic!("Expected Validation error"),
523 }
524 }
525
526 #[tokio::test]
527 async fn test_wait_handler_replay_cancelled_wait() {
528 let client = Arc::new(MockDurableServiceClient::new());
529
530 let mut op = Operation::new("test-wait-123", OperationType::Wait);
532 op.status = OperationStatus::Cancelled;
533
534 let initial_state = InitialExecutionState::with_operations(vec![op]);
535 let state = Arc::new(ExecutionState::new(
536 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
537 "initial-token",
538 initial_state,
539 client,
540 ));
541
542 let op_id = create_test_op_id();
543 let logger = create_test_logger();
544
545 let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
547
548 assert!(result.is_ok());
550 }
551
552 #[tokio::test]
560 async fn test_wait_handler_checks_status_before_checkpoint() {
561 let client = Arc::new(
563 MockDurableServiceClient::new()
564 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))),
565 );
566
567 let initial_state = InitialExecutionState::new();
569 let state = Arc::new(ExecutionState::new(
570 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
571 "initial-token",
572 initial_state,
573 client.clone(),
574 ));
575
576 let op_id = create_test_op_id();
577 let logger = create_test_logger();
578
579 let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
581
582 assert!(matches!(result, Err(DurableError::Suspend { .. })));
584
585 let checkpoint_result = state.get_checkpoint_result("test-wait-123").await;
587 assert!(
588 checkpoint_result.is_existent(),
589 "Checkpoint should have been created"
590 );
591 assert_eq!(
592 checkpoint_result.operation_type(),
593 Some(OperationType::Wait)
594 );
595 }
596
597 #[tokio::test]
602 async fn test_wait_handler_status_check_detects_existing_operation() {
603 let client = Arc::new(MockDurableServiceClient::new());
604
605 let mut op = Operation::new("test-wait-123", OperationType::Wait);
607 op.status = OperationStatus::Pending;
608
609 let initial_state = InitialExecutionState::with_operations(vec![op]);
610 let state = Arc::new(ExecutionState::new(
611 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
612 "initial-token",
613 initial_state,
614 client,
615 ));
616
617 let op_id = create_test_op_id();
618 let logger = create_test_logger();
619
620 let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
622
623 assert!(matches!(result, Err(DurableError::Suspend { .. })));
625 }
626
627 #[tokio::test]
632 async fn test_wait_handler_immediate_completion_via_checkpoint_response() {
633 use crate::client::{CheckpointResponse, NewExecutionState};
634
635 let mut succeeded_op = Operation::new("test-wait-123", OperationType::Wait);
638 succeeded_op.status = OperationStatus::Succeeded;
639
640 let checkpoint_response = CheckpointResponse {
641 checkpoint_token: "token-1".to_string(),
642 new_execution_state: Some(NewExecutionState {
643 operations: vec![succeeded_op],
644 next_marker: None,
645 }),
646 };
647
648 let client = Arc::new(
649 MockDurableServiceClient::new().with_checkpoint_response(Ok(checkpoint_response)),
650 );
651
652 let initial_state = InitialExecutionState::new();
654 let state = Arc::new(ExecutionState::new(
655 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
656 "initial-token",
657 initial_state,
658 client,
659 ));
660
661 let op_id = create_test_op_id();
662 let logger = create_test_logger();
663
664 let result = wait_handler(Duration::from_seconds(60), &state, &op_id, &logger).await;
666
667 assert!(matches!(result, Err(DurableError::Suspend { .. })));
672
673 let checkpoint_result = state.get_checkpoint_result("test-wait-123").await;
675 assert!(
676 checkpoint_result.is_succeeded(),
677 "State should reflect succeeded status from checkpoint response"
678 );
679 }
680}