1use std::sync::Arc;
7
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::config::{StepConfig, StepSemantics};
11use crate::context::{create_operation_span, LogInfo, Logger, OperationIdentifier};
12use crate::error::{DurableError, ErrorObject, StepResult, TerminationReason};
13use crate::operation::{OperationType, OperationUpdate};
14use crate::serdes::{JsonSerDes, SerDes, SerDesContext};
15use crate::state::{CheckpointedResult, ExecutionState};
16use crate::traits::{DurableValue, StepFn};
17
18#[derive(Debug, Clone)]
51pub struct StepContext {
52 pub operation_id: String,
54 pub parent_id: Option<String>,
56 pub name: Option<String>,
58 pub durable_execution_arn: String,
60 pub attempt: u32,
63 pub retry_payload: Option<String>,
66}
67
68impl StepContext {
69 pub fn new(operation_id: impl Into<String>, durable_execution_arn: impl Into<String>) -> Self {
71 Self {
72 operation_id: operation_id.into(),
73 parent_id: None,
74 name: None,
75 durable_execution_arn: durable_execution_arn.into(),
76 attempt: 0,
77 retry_payload: None,
78 }
79 }
80
81 pub fn with_parent_id(mut self, parent_id: impl Into<String>) -> Self {
83 self.parent_id = Some(parent_id.into());
84 self
85 }
86
87 pub fn with_name(mut self, name: impl Into<String>) -> Self {
89 self.name = Some(name.into());
90 self
91 }
92
93 pub fn with_attempt(mut self, attempt: u32) -> Self {
95 self.attempt = attempt;
96 self
97 }
98
99 pub fn with_retry_payload(mut self, payload: impl Into<String>) -> Self {
102 self.retry_payload = Some(payload.into());
103 self
104 }
105
106 pub fn serdes_context(&self) -> SerDesContext {
108 SerDesContext::new(&self.operation_id, &self.durable_execution_arn)
109 }
110
111 pub fn get_retry_payload<T>(
149 &self,
150 ) -> Result<Option<T>, Box<dyn std::error::Error + Send + Sync>>
151 where
152 T: serde::de::DeserializeOwned,
153 {
154 match &self.retry_payload {
155 Some(payload) => {
156 let value: T = serde_json::from_str(payload)?;
157 Ok(Some(value))
158 }
159 None => Ok(None),
160 }
161 }
162}
163
164pub async fn step_handler<T, F>(
182 func: F,
183 state: &Arc<ExecutionState>,
184 op_id: &OperationIdentifier,
185 config: &StepConfig,
186 logger: &Arc<dyn Logger>,
187) -> StepResult<T>
188where
189 T: DurableValue,
190 F: StepFn<T>,
191{
192 let span = create_operation_span("step", op_id, state.durable_execution_arn());
195 let _guard = span.enter();
196
197 let mut log_info =
198 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
199 if let Some(ref parent_id) = op_id.parent_id {
200 log_info = log_info.with_parent_id(parent_id);
201 }
202
203 logger.debug(&format!("Starting step operation: {}", op_id), &log_info);
204
205 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
207
208 let skip_start_checkpoint = checkpoint_result.is_ready();
211
212 let attempt = checkpoint_result.attempt().unwrap_or(0);
215 let retry_payload = checkpoint_result.retry_payload().map(|s| s.to_string());
216
217 if let Some(result) = handle_replay::<T>(&checkpoint_result, state, op_id, logger).await? {
218 span.record("status", "replayed");
220 return Ok(result);
221 }
222
223 let mut step_ctx =
226 StepContext::new(&op_id.operation_id, state.durable_execution_arn()).with_attempt(attempt);
227 if let Some(ref parent_id) = op_id.parent_id {
228 step_ctx = step_ctx.with_parent_id(parent_id);
229 }
230 if let Some(ref name) = op_id.name {
231 step_ctx = step_ctx.with_name(name);
232 }
233 if let Some(payload) = retry_payload {
234 step_ctx = step_ctx.with_retry_payload(payload);
235 }
236
237 let serdes = JsonSerDes::<T>::new();
239 let serdes_ctx = step_ctx.serdes_context();
240
241 let result = match config.step_semantics {
243 StepSemantics::AtMostOncePerRetry => {
244 execute_at_most_once(
245 func,
246 state,
247 op_id,
248 &step_ctx,
249 &serdes,
250 &serdes_ctx,
251 config,
252 logger,
253 skip_start_checkpoint,
254 )
255 .await
256 }
257 StepSemantics::AtLeastOncePerRetry => {
258 execute_at_least_once(
259 func,
260 state,
261 op_id,
262 &step_ctx,
263 &serdes,
264 &serdes_ctx,
265 config,
266 logger,
267 )
268 .await
269 }
270 };
271
272 match &result {
275 Ok(_) => span.record("status", "succeeded"),
276 Err(_) => span.record("status", "failed"),
277 };
278
279 result
280}
281
282async fn handle_replay<T>(
284 checkpoint_result: &CheckpointedResult,
285 state: &Arc<ExecutionState>,
286 op_id: &OperationIdentifier,
287 logger: &Arc<dyn Logger>,
288) -> StepResult<Option<T>>
289where
290 T: Serialize + DeserializeOwned,
291{
292 if !checkpoint_result.is_existent() {
293 return Ok(None);
294 }
295
296 let mut log_info =
297 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
298 if let Some(ref parent_id) = op_id.parent_id {
299 log_info = log_info.with_parent_id(parent_id);
300 }
301
302 if let Some(op_type) = checkpoint_result.operation_type() {
304 if op_type != OperationType::Step {
305 return Err(DurableError::NonDeterministic {
306 message: format!(
307 "Expected Step operation but found {:?} at operation_id {}",
308 op_type, op_id.operation_id
309 ),
310 operation_id: Some(op_id.operation_id.clone()),
311 });
312 }
313 }
314
315 if checkpoint_result.is_succeeded() {
317 logger.debug(&format!("Replaying succeeded step: {}", op_id), &log_info);
318
319 state.track_replay(&op_id.operation_id).await;
321
322 if let Some(result_str) = checkpoint_result.result() {
324 let serdes = JsonSerDes::<T>::new();
325 let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
326 let result =
327 serdes
328 .deserialize(result_str, &serdes_ctx)
329 .map_err(|e| DurableError::SerDes {
330 message: format!("Failed to deserialize checkpointed result: {}", e),
331 })?;
332
333 return Ok(Some(result));
334 } else {
335 let serdes = JsonSerDes::<T>::new();
337 let serdes_ctx = SerDesContext::new(&op_id.operation_id, state.durable_execution_arn());
338 match serdes.deserialize("null", &serdes_ctx) {
339 Ok(result) => return Ok(Some(result)),
340 Err(_) => {
341 return Err(DurableError::SerDes {
343 message:
344 "Step succeeded but no result was stored and type requires a value"
345 .to_string(),
346 });
347 }
348 }
349 }
350 }
351
352 if checkpoint_result.is_failed() {
354 logger.debug(&format!("Replaying failed step: {}", op_id), &log_info);
355
356 state.track_replay(&op_id.operation_id).await;
358
359 if let Some(error) = checkpoint_result.error() {
360 return Err(DurableError::UserCode {
361 message: error.error_message.clone(),
362 error_type: error.error_type.clone(),
363 stack_trace: error.stack_trace.clone(),
364 });
365 } else {
366 return Err(DurableError::execution("Step failed with unknown error"));
367 }
368 }
369
370 if checkpoint_result.is_terminal() {
372 state.track_replay(&op_id.operation_id).await;
373
374 let status = checkpoint_result.status().unwrap();
375 return Err(DurableError::Execution {
376 message: format!("Step was {}", status),
377 termination_reason: TerminationReason::StepInterrupted,
378 });
379 }
380
381 if checkpoint_result.is_ready() {
384 logger.debug(&format!("Resuming READY step: {}", op_id), &log_info);
385 return Ok(None);
388 }
389
390 if checkpoint_result.is_pending() {
393 logger.debug(
394 &format!("Step is PENDING, waiting for retry: {}", op_id),
395 &log_info,
396 );
397 return Err(DurableError::Suspend {
399 scheduled_timestamp: None,
400 });
401 }
402
403 Ok(None)
405}
406
407#[allow(clippy::too_many_arguments)]
413async fn execute_at_most_once<T, F>(
414 func: F,
415 state: &Arc<ExecutionState>,
416 op_id: &OperationIdentifier,
417 step_ctx: &StepContext,
418 serdes: &JsonSerDes<T>,
419 serdes_ctx: &SerDesContext,
420 config: &StepConfig,
421 logger: &Arc<dyn Logger>,
422 skip_start_checkpoint: bool,
423) -> StepResult<T>
424where
425 T: DurableValue,
426 F: StepFn<T>,
427{
428 let mut log_info =
429 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
430 if let Some(ref parent_id) = op_id.parent_id {
431 log_info = log_info.with_parent_id(parent_id);
432 }
433
434 if !skip_start_checkpoint {
437 logger.debug("Checkpointing step start (AT_MOST_ONCE)", &log_info);
438 let start_update = create_start_update(op_id);
439 state.create_checkpoint(start_update, true).await?;
440 } else {
441 logger.debug(
442 "Skipping START checkpoint for READY operation (AT_MOST_ONCE)",
443 &log_info,
444 );
445 }
446
447 let result = execute_with_retry(func, step_ctx.clone(), config, logger, &log_info);
449
450 match result {
452 Ok(value) => {
453 let serialized =
454 serdes
455 .serialize(&value, serdes_ctx)
456 .map_err(|e| DurableError::SerDes {
457 message: format!("Failed to serialize step result: {}", e),
458 })?;
459
460 let succeed_update = create_succeed_update(op_id, Some(serialized));
461 state.create_checkpoint(succeed_update, true).await?;
462
463 logger.debug("Step completed successfully", &log_info);
464 Ok(value)
465 }
466 Err(error) => {
467 let error_obj = ErrorObject::new("UserCodeError", error.to_string());
468 let fail_update = create_fail_update(op_id, error_obj);
469 state.create_checkpoint(fail_update, true).await?;
470
471 logger.error(&format!("Step failed: {}", error), &log_info);
472 Err(DurableError::UserCode {
473 message: error.to_string(),
474 error_type: "UserCodeError".to_string(),
475 stack_trace: None,
476 })
477 }
478 }
479}
480
481#[allow(clippy::too_many_arguments)]
485async fn execute_at_least_once<T, F>(
486 func: F,
487 state: &Arc<ExecutionState>,
488 op_id: &OperationIdentifier,
489 step_ctx: &StepContext,
490 serdes: &JsonSerDes<T>,
491 serdes_ctx: &SerDesContext,
492 config: &StepConfig,
493 logger: &Arc<dyn Logger>,
494) -> StepResult<T>
495where
496 T: DurableValue,
497 F: StepFn<T>,
498{
499 let mut log_info =
500 LogInfo::new(state.durable_execution_arn()).with_operation_id(&op_id.operation_id);
501 if let Some(ref parent_id) = op_id.parent_id {
502 log_info = log_info.with_parent_id(parent_id);
503 }
504
505 logger.debug("Executing step (AT_LEAST_ONCE)", &log_info);
506
507 let result = execute_with_retry(func, step_ctx.clone(), config, logger, &log_info);
509
510 match result {
512 Ok(value) => {
513 let serialized =
514 serdes
515 .serialize(&value, serdes_ctx)
516 .map_err(|e| DurableError::SerDes {
517 message: format!("Failed to serialize step result: {}", e),
518 })?;
519
520 let succeed_update = create_succeed_update(op_id, Some(serialized));
521 state.create_checkpoint(succeed_update, true).await?;
522
523 logger.debug("Step completed successfully", &log_info);
524 Ok(value)
525 }
526 Err(error) => {
527 let error_obj = ErrorObject::new("UserCodeError", error.to_string());
528 let fail_update = create_fail_update(op_id, error_obj);
529 state.create_checkpoint(fail_update, true).await?;
530
531 logger.error(&format!("Step failed: {}", error), &log_info);
532 Err(DurableError::UserCode {
533 message: error.to_string(),
534 error_type: "UserCodeError".to_string(),
535 stack_trace: None,
536 })
537 }
538 }
539}
540
541fn execute_with_retry<T, F>(
543 func: F,
544 step_ctx: StepContext,
545 config: &StepConfig,
546 logger: &Arc<dyn Logger>,
547 log_info: &LogInfo,
548) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
549where
550 T: Send,
551 F: FnOnce(StepContext) -> Result<T, Box<dyn std::error::Error + Send + Sync>> + Send,
552{
553 if config.retry_strategy.is_some() {
558 logger.debug(
559 "Retry strategy configured but not yet implemented for consumed closures",
560 log_info,
561 );
562 }
563
564 let result = func(step_ctx);
565
566 if let Err(ref err) = result {
571 if let Some(ref filter) = config.retryable_error_filter {
572 let error_msg = err.to_string();
573 if !filter.is_retryable(&error_msg) {
574 logger.debug(
575 &format!(
576 "Error does not match retryable error filter, skipping retry: {}",
577 error_msg
578 ),
579 log_info,
580 );
581 return result;
583 }
584 logger.debug(
585 &format!("Error matches retryable error filter: {}", error_msg),
586 log_info,
587 );
588 }
589 }
590
591 result
592}
593
594fn create_start_update(op_id: &OperationIdentifier) -> OperationUpdate {
596 let mut update = OperationUpdate::start(&op_id.operation_id, OperationType::Step);
597 if let Some(ref parent_id) = op_id.parent_id {
598 update = update.with_parent_id(parent_id);
599 }
600 if let Some(ref name) = op_id.name {
601 update = update.with_name(name);
602 }
603 update
604}
605
606fn create_succeed_update(op_id: &OperationIdentifier, result: Option<String>) -> OperationUpdate {
608 let mut update = OperationUpdate::succeed(&op_id.operation_id, OperationType::Step, result);
609 if let Some(ref parent_id) = op_id.parent_id {
610 update = update.with_parent_id(parent_id);
611 }
612 if let Some(ref name) = op_id.name {
613 update = update.with_name(name);
614 }
615 update
616}
617
618fn create_fail_update(op_id: &OperationIdentifier, error: ErrorObject) -> OperationUpdate {
620 let mut update = OperationUpdate::fail(&op_id.operation_id, OperationType::Step, error);
621 if let Some(ref parent_id) = op_id.parent_id {
622 update = update.with_parent_id(parent_id);
623 }
624 if let Some(ref name) = op_id.name {
625 update = update.with_name(name);
626 }
627 update
628}
629
630#[allow(dead_code)]
641fn create_retry_update(
642 op_id: &OperationIdentifier,
643 payload: Option<String>,
644 next_attempt_delay_seconds: Option<u64>,
645) -> OperationUpdate {
646 let mut update = OperationUpdate::retry(
647 &op_id.operation_id,
648 OperationType::Step,
649 payload,
650 next_attempt_delay_seconds,
651 );
652 if let Some(ref parent_id) = op_id.parent_id {
653 update = update.with_parent_id(parent_id);
654 }
655 if let Some(ref name) = op_id.name {
656 update = update.with_name(name);
657 }
658 update
659}
660
661#[allow(dead_code)]
672fn create_retry_with_error_update(
673 op_id: &OperationIdentifier,
674 error: ErrorObject,
675 next_attempt_delay_seconds: Option<u64>,
676) -> OperationUpdate {
677 let mut update = OperationUpdate::retry_with_error(
678 &op_id.operation_id,
679 OperationType::Step,
680 error,
681 next_attempt_delay_seconds,
682 );
683 if let Some(ref parent_id) = op_id.parent_id {
684 update = update.with_parent_id(parent_id);
685 }
686 if let Some(ref name) = op_id.name {
687 update = update.with_name(name);
688 }
689 update
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695 use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
696 use crate::context::TracingLogger;
697 use crate::lambda::InitialExecutionState;
698 use crate::operation::{Operation, OperationStatus};
699
700 fn create_mock_client() -> SharedDurableServiceClient {
701 Arc::new(
702 MockDurableServiceClient::new()
703 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
704 .with_checkpoint_response(Ok(CheckpointResponse::new("token-2"))),
705 )
706 }
707
708 fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
709 Arc::new(ExecutionState::new(
710 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
711 "initial-token",
712 InitialExecutionState::new(),
713 client,
714 ))
715 }
716
717 fn create_test_op_id() -> OperationIdentifier {
718 OperationIdentifier::new("test-op-123", None, Some("test-step".to_string()))
719 }
720
721 fn create_test_logger() -> Arc<dyn Logger> {
722 Arc::new(TracingLogger)
723 }
724
725 #[test]
726 fn test_step_context_new() {
727 let ctx = StepContext::new("op-123", "arn:test");
728 assert_eq!(ctx.operation_id, "op-123");
729 assert_eq!(ctx.durable_execution_arn, "arn:test");
730 assert!(ctx.parent_id.is_none());
731 assert!(ctx.name.is_none());
732 assert_eq!(ctx.attempt, 0);
733 }
734
735 #[test]
736 fn test_step_context_with_parent_id() {
737 let ctx = StepContext::new("op-123", "arn:test").with_parent_id("parent-456");
738 assert_eq!(ctx.parent_id, Some("parent-456".to_string()));
739 }
740
741 #[test]
742 fn test_step_context_with_name() {
743 let ctx = StepContext::new("op-123", "arn:test").with_name("my-step");
744 assert_eq!(ctx.name, Some("my-step".to_string()));
745 }
746
747 #[test]
748 fn test_step_context_with_attempt() {
749 let ctx = StepContext::new("op-123", "arn:test").with_attempt(3);
750 assert_eq!(ctx.attempt, 3);
751 }
752
753 #[test]
754 fn test_step_context_serdes_context() {
755 let ctx = StepContext::new("op-123", "arn:test");
756 let serdes_ctx = ctx.serdes_context();
757 assert_eq!(serdes_ctx.operation_id, "op-123");
758 assert_eq!(serdes_ctx.durable_execution_arn, "arn:test");
759 }
760
761 #[test]
762 fn test_step_context_with_retry_payload() {
763 let ctx = StepContext::new("op-123", "arn:test").with_retry_payload(r#"{"counter": 5}"#);
764 assert_eq!(ctx.retry_payload, Some(r#"{"counter": 5}"#.to_string()));
765 }
766
767 #[test]
768 fn test_step_context_get_retry_payload() {
769 #[derive(serde::Deserialize, PartialEq, Debug)]
770 struct State {
771 counter: i32,
772 }
773
774 let ctx = StepContext::new("op-123", "arn:test").with_retry_payload(r#"{"counter": 5}"#);
775
776 let payload: Option<State> = ctx.get_retry_payload().unwrap();
777 assert!(payload.is_some());
778 assert_eq!(payload.unwrap().counter, 5);
779 }
780
781 #[test]
782 fn test_step_context_get_retry_payload_none() {
783 #[derive(serde::Deserialize)]
784 #[allow(dead_code)]
785 struct State {
786 counter: i32,
787 }
788
789 let ctx = StepContext::new("op-123", "arn:test");
790 let payload: Option<State> = ctx.get_retry_payload().unwrap();
791 assert!(payload.is_none());
792 }
793
794 #[test]
795 fn test_create_retry_update() {
796 let op_id = OperationIdentifier::new(
797 "op-123",
798 Some("parent-456".to_string()),
799 Some("my-step".to_string()),
800 );
801 let update =
802 create_retry_update(&op_id, Some(r#"{"state": "waiting"}"#.to_string()), Some(5));
803
804 assert_eq!(update.operation_id, "op-123");
805 assert_eq!(update.action, crate::operation::OperationAction::Retry);
806 assert_eq!(update.operation_type, OperationType::Step);
807 assert_eq!(update.parent_id, Some("parent-456".to_string()));
808 assert_eq!(update.name, Some("my-step".to_string()));
809 assert_eq!(update.result, Some(r#"{"state": "waiting"}"#.to_string()));
810 assert!(update.step_options.is_some());
811 assert_eq!(
812 update
813 .step_options
814 .as_ref()
815 .unwrap()
816 .next_attempt_delay_seconds,
817 Some(5)
818 );
819 }
820
821 #[test]
822 fn test_create_retry_with_error_update() {
823 let op_id = OperationIdentifier::new("op-123", None, None);
824 let error = ErrorObject::new("RetryableError", "Temporary failure");
825 let update = create_retry_with_error_update(&op_id, error, Some(10));
826
827 assert_eq!(update.operation_id, "op-123");
828 assert_eq!(update.action, crate::operation::OperationAction::Retry);
829 assert!(update.result.is_none());
830 assert!(update.error.is_some());
831 assert_eq!(update.error.as_ref().unwrap().error_type, "RetryableError");
832 }
833
834 #[tokio::test]
835 async fn test_step_handler_success() {
836 let client = create_mock_client();
837 let state = create_test_state(client);
838 let op_id = create_test_op_id();
839 let config = StepConfig::default();
840 let logger = create_test_logger();
841
842 let result: Result<i32, DurableError> =
843 step_handler(|_ctx| Ok(42), &state, &op_id, &config, &logger).await;
844
845 assert!(result.is_ok());
846 assert_eq!(result.unwrap(), 42);
847 }
848
849 #[tokio::test]
850 async fn test_step_handler_failure() {
851 let client = create_mock_client();
852 let state = create_test_state(client);
853 let op_id = create_test_op_id();
854 let config = StepConfig::default();
855 let logger = create_test_logger();
856
857 let result: Result<i32, DurableError> = step_handler(
858 |_ctx| Err("test error".into()),
859 &state,
860 &op_id,
861 &config,
862 &logger,
863 )
864 .await;
865
866 assert!(result.is_err());
867 match result.unwrap_err() {
868 DurableError::UserCode { message, .. } => {
869 assert!(message.contains("test error"));
870 }
871 _ => panic!("Expected UserCode error"),
872 }
873 }
874
875 #[tokio::test]
876 async fn test_step_handler_replay_success() {
877 let client = Arc::new(MockDurableServiceClient::new());
878
879 let mut op = Operation::new("test-op-123", OperationType::Step);
881 op.status = OperationStatus::Succeeded;
882 op.result = Some("42".to_string());
883
884 let initial_state = InitialExecutionState::with_operations(vec![op]);
885 let state = Arc::new(ExecutionState::new(
886 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
887 "initial-token",
888 initial_state,
889 client,
890 ));
891
892 let op_id = create_test_op_id();
893 let config = StepConfig::default();
894 let logger = create_test_logger();
895
896 let result: Result<i32, DurableError> = step_handler(
898 |_ctx| panic!("Function should not be called during replay"),
899 &state,
900 &op_id,
901 &config,
902 &logger,
903 )
904 .await;
905
906 assert!(result.is_ok());
907 assert_eq!(result.unwrap(), 42);
908 }
909
910 #[tokio::test]
911 async fn test_step_handler_replay_failure() {
912 let client = Arc::new(MockDurableServiceClient::new());
913
914 let mut op = Operation::new("test-op-123", OperationType::Step);
916 op.status = OperationStatus::Failed;
917 op.error = Some(ErrorObject::new("TestError", "Previous failure"));
918
919 let initial_state = InitialExecutionState::with_operations(vec![op]);
920 let state = Arc::new(ExecutionState::new(
921 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
922 "initial-token",
923 initial_state,
924 client,
925 ));
926
927 let op_id = create_test_op_id();
928 let config = StepConfig::default();
929 let logger = create_test_logger();
930
931 let result: Result<i32, DurableError> = step_handler(
932 |_ctx| panic!("Function should not be called during replay"),
933 &state,
934 &op_id,
935 &config,
936 &logger,
937 )
938 .await;
939
940 assert!(result.is_err());
941 match result.unwrap_err() {
942 DurableError::UserCode { message, .. } => {
943 assert!(message.contains("Previous failure"));
944 }
945 _ => panic!("Expected UserCode error"),
946 }
947 }
948
949 #[tokio::test]
950 async fn test_step_handler_non_deterministic_detection() {
951 let client = Arc::new(MockDurableServiceClient::new());
952
953 let mut op = Operation::new("test-op-123", OperationType::Wait);
955 op.status = OperationStatus::Succeeded;
956
957 let initial_state = InitialExecutionState::with_operations(vec![op]);
958 let state = Arc::new(ExecutionState::new(
959 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
960 "initial-token",
961 initial_state,
962 client,
963 ));
964
965 let op_id = create_test_op_id();
966 let config = StepConfig::default();
967 let logger = create_test_logger();
968
969 let result: Result<i32, DurableError> =
970 step_handler(|_ctx| Ok(42), &state, &op_id, &config, &logger).await;
971
972 assert!(result.is_err());
973 match result.unwrap_err() {
974 DurableError::NonDeterministic { operation_id, .. } => {
975 assert_eq!(operation_id, Some("test-op-123".to_string()));
976 }
977 _ => panic!("Expected NonDeterministic error"),
978 }
979 }
980
981 #[tokio::test]
982 async fn test_step_handler_at_most_once_semantics() {
983 let client = create_mock_client();
984 let state = create_test_state(client);
985 let op_id = create_test_op_id();
986 let config = StepConfig {
987 step_semantics: StepSemantics::AtMostOncePerRetry,
988 ..Default::default()
989 };
990 let logger = create_test_logger();
991
992 let result: Result<String, DurableError> = step_handler(
993 |_ctx| Ok("at_most_once_result".to_string()),
994 &state,
995 &op_id,
996 &config,
997 &logger,
998 )
999 .await;
1000
1001 assert!(result.is_ok());
1002 assert_eq!(result.unwrap(), "at_most_once_result");
1003 }
1004
1005 #[tokio::test]
1006 async fn test_step_handler_at_least_once_semantics() {
1007 let client = create_mock_client();
1008 let state = create_test_state(client);
1009 let op_id = create_test_op_id();
1010 let config = StepConfig {
1011 step_semantics: StepSemantics::AtLeastOncePerRetry,
1012 ..Default::default()
1013 };
1014 let logger = create_test_logger();
1015
1016 let result: Result<String, DurableError> = step_handler(
1017 |_ctx| Ok("at_least_once_result".to_string()),
1018 &state,
1019 &op_id,
1020 &config,
1021 &logger,
1022 )
1023 .await;
1024
1025 assert!(result.is_ok());
1026 assert_eq!(result.unwrap(), "at_least_once_result");
1027 }
1028
1029 #[test]
1030 fn test_create_start_update() {
1031 let op_id = OperationIdentifier::new(
1032 "op-123",
1033 Some("parent-456".to_string()),
1034 Some("my-step".to_string()),
1035 );
1036 let update = create_start_update(&op_id);
1037
1038 assert_eq!(update.operation_id, "op-123");
1039 assert_eq!(update.operation_type, OperationType::Step);
1040 assert_eq!(update.parent_id, Some("parent-456".to_string()));
1041 assert_eq!(update.name, Some("my-step".to_string()));
1042 }
1043
1044 #[test]
1045 fn test_create_succeed_update() {
1046 let op_id = OperationIdentifier::new("op-123", None, None);
1047 let update = create_succeed_update(&op_id, Some("result".to_string()));
1048
1049 assert_eq!(update.operation_id, "op-123");
1050 assert_eq!(update.result, Some("result".to_string()));
1051 }
1052
1053 #[test]
1054 fn test_create_fail_update() {
1055 let op_id = OperationIdentifier::new("op-123", None, None);
1056 let error = ErrorObject::new("TestError", "test message");
1057 let update = create_fail_update(&op_id, error);
1058
1059 assert_eq!(update.operation_id, "op-123");
1060 assert!(update.error.is_some());
1061 assert_eq!(update.error.unwrap().error_type, "TestError");
1062 }
1063}
1064
1065#[cfg(test)]
1066mod property_tests {
1067 use super::*;
1068 use crate::client::{CheckpointResponse, MockDurableServiceClient, SharedDurableServiceClient};
1069 use crate::context::TracingLogger;
1070 use crate::lambda::InitialExecutionState;
1071 use crate::operation::{Operation, OperationStatus};
1072 use proptest::prelude::*;
1073
1074 mod step_semantics_tests {
1081 use super::*;
1082 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
1083
1084 fn create_test_state(client: SharedDurableServiceClient) -> Arc<ExecutionState> {
1085 Arc::new(ExecutionState::new(
1086 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
1087 "initial-token",
1088 InitialExecutionState::new(),
1089 client,
1090 ))
1091 }
1092
1093 fn create_test_logger() -> Arc<dyn Logger> {
1094 Arc::new(TracingLogger)
1095 }
1096
1097 proptest! {
1098 #![proptest_config(ProptestConfig::with_cases(100))]
1099
1100 #[test]
1104 fn prop_at_most_once_checkpoints_before_execution(
1105 result_value in any::<i32>(),
1106 ) {
1107 let rt = tokio::runtime::Runtime::new().unwrap();
1108 rt.block_on(async {
1109 let checkpoint_order = Arc::new(AtomicU32::new(0));
1111 let execution_order = Arc::new(AtomicU32::new(0));
1112 let order_counter = Arc::new(AtomicU32::new(0));
1113
1114 let _checkpoint_order_clone = checkpoint_order.clone();
1115 let execution_order_clone = execution_order.clone();
1116 let order_counter_clone = order_counter.clone();
1117
1118 let client = Arc::new(MockDurableServiceClient::new()
1119 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
1120 .with_checkpoint_response(Ok(CheckpointResponse::new("token-2"))));
1121
1122 let state = create_test_state(client);
1123 let op_id = OperationIdentifier::new(
1124 format!("test-op-{}", result_value),
1125 None,
1126 Some("test-step".to_string()),
1127 );
1128 let config = StepConfig {
1129 step_semantics: StepSemantics::AtMostOncePerRetry,
1130 ..Default::default()
1131 };
1132 let logger = create_test_logger();
1133
1134 let result: Result<i32, DurableError> = step_handler(
1135 move |_ctx| {
1136 let order = order_counter_clone.fetch_add(1, Ordering::SeqCst);
1138 execution_order_clone.store(order, Ordering::SeqCst);
1139 Ok(result_value)
1140 },
1141 &state,
1142 &op_id,
1143 &config,
1144 &logger,
1145 ).await;
1146
1147 prop_assert!(result.is_ok(), "Step should succeed");
1149 prop_assert_eq!(result.unwrap(), result_value, "Result should match input");
1150
1151 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
1153 prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
1154 prop_assert!(checkpoint_result.is_succeeded(), "Checkpoint should be succeeded");
1155
1156 Ok(())
1157 })?;
1158 }
1159
1160 #[test]
1164 fn prop_at_least_once_checkpoints_after_execution(
1165 result_value in any::<i32>(),
1166 ) {
1167 let rt = tokio::runtime::Runtime::new().unwrap();
1168 rt.block_on(async {
1169 let client = Arc::new(MockDurableServiceClient::new()
1170 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))));
1171
1172 let state = create_test_state(client);
1173 let op_id = OperationIdentifier::new(
1174 format!("test-op-{}", result_value),
1175 None,
1176 Some("test-step".to_string()),
1177 );
1178 let config = StepConfig {
1179 step_semantics: StepSemantics::AtLeastOncePerRetry,
1180 ..Default::default()
1181 };
1182 let logger = create_test_logger();
1183
1184 let result: Result<i32, DurableError> = step_handler(
1185 move |_ctx| Ok(result_value),
1186 &state,
1187 &op_id,
1188 &config,
1189 &logger,
1190 ).await;
1191
1192 prop_assert!(result.is_ok(), "Step should succeed");
1194 prop_assert_eq!(result.unwrap(), result_value, "Result should match input");
1195
1196 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
1198 prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
1199 prop_assert!(checkpoint_result.is_succeeded(), "Checkpoint should be succeeded");
1200
1201 if let Some(result_str) = checkpoint_result.result() {
1203 let deserialized: i32 = serde_json::from_str(result_str).unwrap();
1204 prop_assert_eq!(deserialized, result_value, "Checkpointed result should match");
1205 }
1206
1207 Ok(())
1208 })?;
1209 }
1210
1211 #[test]
1215 fn prop_at_most_once_checkpoints_error_on_failure(
1216 error_msg in "[a-zA-Z0-9 ]{1,50}",
1217 ) {
1218 let rt = tokio::runtime::Runtime::new().unwrap();
1219 rt.block_on(async {
1220 let client = Arc::new(MockDurableServiceClient::new()
1221 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
1222 .with_checkpoint_response(Ok(CheckpointResponse::new("token-2"))));
1223
1224 let state = create_test_state(client);
1225 let op_id = OperationIdentifier::new(
1226 format!("test-op-fail-{}", error_msg.len()),
1227 None,
1228 Some("test-step".to_string()),
1229 );
1230 let config = StepConfig {
1231 step_semantics: StepSemantics::AtMostOncePerRetry,
1232 ..Default::default()
1233 };
1234 let logger = create_test_logger();
1235
1236 let error_msg_clone = error_msg.clone();
1237 let result: Result<i32, DurableError> = step_handler(
1238 move |_ctx| Err(error_msg_clone.into()),
1239 &state,
1240 &op_id,
1241 &config,
1242 &logger,
1243 ).await;
1244
1245 prop_assert!(result.is_err(), "Step should fail");
1247
1248 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
1250 prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
1251 prop_assert!(checkpoint_result.is_failed(), "Checkpoint should be failed");
1252
1253 Ok(())
1254 })?;
1255 }
1256
1257 #[test]
1261 fn prop_at_least_once_checkpoints_error_on_failure(
1262 error_msg in "[a-zA-Z0-9 ]{1,50}",
1263 ) {
1264 let rt = tokio::runtime::Runtime::new().unwrap();
1265 rt.block_on(async {
1266 let client = Arc::new(MockDurableServiceClient::new()
1267 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))));
1268
1269 let state = create_test_state(client);
1270 let op_id = OperationIdentifier::new(
1271 format!("test-op-fail-{}", error_msg.len()),
1272 None,
1273 Some("test-step".to_string()),
1274 );
1275 let config = StepConfig {
1276 step_semantics: StepSemantics::AtLeastOncePerRetry,
1277 ..Default::default()
1278 };
1279 let logger = create_test_logger();
1280
1281 let error_msg_clone = error_msg.clone();
1282 let result: Result<i32, DurableError> = step_handler(
1283 move |_ctx| Err(error_msg_clone.into()),
1284 &state,
1285 &op_id,
1286 &config,
1287 &logger,
1288 ).await;
1289
1290 prop_assert!(result.is_err(), "Step should fail");
1292
1293 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
1295 prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
1296 prop_assert!(checkpoint_result.is_failed(), "Checkpoint should be failed");
1297
1298 Ok(())
1299 })?;
1300 }
1301
1302 #[test]
1306 fn prop_replay_returns_checkpointed_result(
1307 result_value in any::<i32>(),
1308 semantics in prop_oneof![
1309 Just(StepSemantics::AtMostOncePerRetry),
1310 Just(StepSemantics::AtLeastOncePerRetry),
1311 ],
1312 ) {
1313 let rt = tokio::runtime::Runtime::new().unwrap();
1314 rt.block_on(async {
1315 let client = Arc::new(MockDurableServiceClient::new());
1316
1317 let mut op = Operation::new("test-op-replay", OperationType::Step);
1319 op.status = OperationStatus::Succeeded;
1320 op.result = Some(result_value.to_string());
1321
1322 let initial_state = InitialExecutionState::with_operations(vec![op]);
1323 let state = Arc::new(ExecutionState::new(
1324 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
1325 "initial-token",
1326 initial_state,
1327 client,
1328 ));
1329
1330 let op_id = OperationIdentifier::new("test-op-replay", None, None);
1331 let config = StepConfig {
1332 step_semantics: semantics,
1333 ..Default::default()
1334 };
1335 let logger = create_test_logger();
1336
1337 let was_called = Arc::new(AtomicBool::new(false));
1339 let was_called_clone = was_called.clone();
1340
1341 let result: Result<i32, DurableError> = step_handler(
1342 move |_ctx| {
1343 was_called_clone.store(true, Ordering::SeqCst);
1344 Ok(999) },
1346 &state,
1347 &op_id,
1348 &config,
1349 &logger,
1350 ).await;
1351
1352 prop_assert!(result.is_ok(), "Replay should succeed");
1354 prop_assert_eq!(result.unwrap(), result_value, "Should return checkpointed value");
1355
1356 prop_assert!(!was_called.load(Ordering::SeqCst), "Function should not be called during replay");
1358
1359 Ok(())
1360 })?;
1361 }
1362
1363 #[test]
1369 fn prop_ready_status_resumes_without_start_checkpoint(
1370 result_value in any::<i32>(),
1371 ) {
1372 let rt = tokio::runtime::Runtime::new().unwrap();
1373 rt.block_on(async {
1374 let client = Arc::new(MockDurableServiceClient::new()
1376 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1"))));
1377
1378 let mut op = Operation::new("test-op-ready", OperationType::Step);
1380 op.status = OperationStatus::Ready;
1381
1382 let initial_state = InitialExecutionState::with_operations(vec![op]);
1383 let state = Arc::new(ExecutionState::new(
1384 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
1385 "initial-token",
1386 initial_state,
1387 client,
1388 ));
1389
1390 let op_id = OperationIdentifier::new("test-op-ready", None, None);
1391 let config = StepConfig {
1393 step_semantics: StepSemantics::AtMostOncePerRetry,
1394 ..Default::default()
1395 };
1396 let logger = create_test_logger();
1397
1398 let was_called = Arc::new(AtomicBool::new(false));
1400 let was_called_clone = was_called.clone();
1401
1402 let result: Result<i32, DurableError> = step_handler(
1403 move |_ctx| {
1404 was_called_clone.store(true, Ordering::SeqCst);
1405 Ok(result_value)
1406 },
1407 &state,
1408 &op_id,
1409 &config,
1410 &logger,
1411 ).await;
1412
1413 prop_assert!(result.is_ok(), "Step should succeed");
1415 prop_assert_eq!(result.unwrap(), result_value, "Result should match input");
1416
1417 prop_assert!(was_called.load(Ordering::SeqCst), "Function should be called for READY status");
1419
1420 let checkpoint_result = state.get_checkpoint_result(&op_id.operation_id).await;
1422 prop_assert!(checkpoint_result.is_existent(), "Checkpoint should exist");
1423 prop_assert!(checkpoint_result.is_succeeded(), "Checkpoint should be succeeded");
1424
1425 Ok(())
1426 })?;
1427 }
1428 }
1429 }
1430}