1use std::future::Future;
12use std::time::Duration;
13
14use aws_sdk_lambda::types::{
15 ErrorObject, OperationAction, OperationStatus, OperationType, OperationUpdate,
16};
17use serde::de::DeserializeOwned;
18use serde::Serialize;
19
20use crate::context::DurableContext;
21use crate::error::DurableError;
22use crate::types::StepOptions;
23
24impl DurableContext {
25 pub async fn step<T, E, F, Fut>(
68 &mut self,
69 name: &str,
70 f: F,
71 ) -> Result<Result<T, E>, DurableError>
72 where
73 T: Serialize + DeserializeOwned + Send + 'static,
74 E: Serialize + DeserializeOwned + Send + 'static,
75 F: FnOnce() -> Fut + Send + 'static,
76 Fut: Future<Output = Result<T, E>> + Send + 'static,
77 {
78 self.step_with_options(name, StepOptions::default(), f)
79 .await
80 }
81
82 #[allow(clippy::await_holding_lock)]
120 pub async fn step_with_options<T, E, F, Fut>(
121 &mut self,
122 name: &str,
123 options: StepOptions,
124 f: F,
125 ) -> Result<Result<T, E>, DurableError>
126 where
127 T: Serialize + DeserializeOwned + Send + 'static,
128 E: Serialize + DeserializeOwned + Send + 'static,
129 F: FnOnce() -> Fut + Send + 'static,
130 Fut: Future<Output = Result<T, E>> + Send + 'static,
131 {
132 let op_id = self.replay_engine_mut().generate_operation_id();
133
134 let span = tracing::info_span!(
135 "durable_operation",
136 op.name = name,
137 op.type = "step",
138 op.id = %op_id,
139 );
140 let _guard = span.enter();
141 tracing::trace!("durable_operation");
142
143 if let Some(operation) = self.replay_engine().check_result(&op_id) {
145 let result = extract_step_result::<T, E>(operation)?;
146 self.replay_engine_mut().track_replay(&op_id);
147 return Ok(result);
148 }
149
150 let is_retry_reexecution =
154 self.replay_engine()
155 .operations()
156 .get(&op_id)
157 .is_some_and(|op| {
158 matches!(
159 op.status,
160 OperationStatus::Pending
161 | OperationStatus::Ready
162 | OperationStatus::Started
163 )
164 });
165
166 let current_attempt = if is_retry_reexecution {
167 self.replay_engine()
169 .operations()
170 .get(&op_id)
171 .and_then(|op| op.step_details())
172 .map(|d| d.attempt())
173 .unwrap_or(1)
174 } else {
175 let start_update = OperationUpdate::builder()
177 .id(op_id.clone())
178 .r#type(OperationType::Step)
179 .action(OperationAction::Start)
180 .name(name)
181 .sub_type("Step")
182 .build()
183 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
184
185 if self.is_batch_mode() {
186 self.push_pending_update(start_update);
188 } else {
190 let start_response = self
191 .backend()
192 .checkpoint(
193 self.arn(),
194 self.checkpoint_token(),
195 vec![start_update],
196 None,
197 )
198 .await?;
199
200 let new_token = start_response.checkpoint_token().ok_or_else(|| {
201 DurableError::checkpoint_failed(
202 name,
203 std::io::Error::new(
204 std::io::ErrorKind::InvalidData,
205 "checkpoint response missing checkpoint_token",
206 ),
207 )
208 })?;
209 self.set_checkpoint_token(new_token.to_string());
210
211 if let Some(new_state) = start_response.new_execution_state() {
213 for op in new_state.operations() {
214 self.replay_engine_mut()
215 .insert_operation(op.id().to_string(), op.clone());
216 }
217 }
218
219 if let Some(operation) = self.replay_engine().check_result(&op_id) {
221 let result = extract_step_result::<T, E>(operation)?;
222 self.replay_engine_mut().track_replay(&op_id);
223 return Ok(result);
224 }
225 }
226
227 1 };
229
230 let name_owned = name.to_string();
236 let mut handle = tokio::spawn(async move { f().await });
237 let user_result = if let Some(secs) = options.get_timeout_seconds() {
238 match tokio::time::timeout(Duration::from_secs(secs), &mut handle).await {
239 Ok(join_result) => join_result.map_err(|join_err| {
240 DurableError::checkpoint_failed(
241 &name_owned,
242 std::io::Error::other(format!("step closure panicked: {join_err}")),
243 )
244 })?,
245 Err(_elapsed) => {
246 handle.abort();
247 return Err(DurableError::step_timeout(&name_owned));
248 }
249 }
250 } else {
251 handle.await.map_err(|join_err| {
252 DurableError::checkpoint_failed(
253 &name_owned,
254 std::io::Error::other(format!("step closure panicked: {join_err}")),
255 )
256 })?
257 };
258
259 match &user_result {
261 Ok(value) => {
262 let payload = serde_json::to_string(value)
263 .map_err(|e| DurableError::serialization(std::any::type_name::<T>(), e))?;
264
265 let succeed_update = OperationUpdate::builder()
266 .id(op_id.clone())
267 .r#type(OperationType::Step)
268 .action(OperationAction::Succeed)
269 .name(name)
270 .sub_type("Step")
271 .payload(payload)
272 .build()
273 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
274
275 if self.is_batch_mode() {
276 self.push_pending_update(succeed_update);
277 } else {
278 let response = self
279 .backend()
280 .checkpoint(
281 self.arn(),
282 self.checkpoint_token(),
283 vec![succeed_update],
284 None,
285 )
286 .await?;
287
288 let new_token = response.checkpoint_token().ok_or_else(|| {
289 DurableError::checkpoint_failed(
290 name,
291 std::io::Error::new(
292 std::io::ErrorKind::InvalidData,
293 "checkpoint response missing checkpoint_token",
294 ),
295 )
296 })?;
297 self.set_checkpoint_token(new_token.to_string());
298 }
299 }
300 Err(error) => {
301 let max_retries = options.get_retries().unwrap_or(0);
302
303 let should_retry = if let Some(pred) = options.get_retry_if() {
306 pred(error as &dyn std::any::Any)
307 } else {
308 true };
310
311 if should_retry && (current_attempt as u32) <= max_retries {
312 let delay = options.get_backoff_seconds().unwrap_or(0);
314 let aws_step_options = aws_sdk_lambda::types::StepOptions::builder()
315 .next_attempt_delay_seconds(delay)
316 .build();
317
318 let retry_update = OperationUpdate::builder()
319 .id(op_id.clone())
320 .r#type(OperationType::Step)
321 .action(OperationAction::Retry)
322 .name(name)
323 .sub_type("Step")
324 .step_options(aws_step_options)
325 .build()
326 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
327
328 if self.is_batch_mode() {
329 self.push_pending_update(retry_update);
331 self.flush_batch().await?;
332 } else {
333 let response = self
334 .backend()
335 .checkpoint(
336 self.arn(),
337 self.checkpoint_token(),
338 vec![retry_update],
339 None,
340 )
341 .await?;
342
343 let new_token = response.checkpoint_token().ok_or_else(|| {
344 DurableError::checkpoint_failed(
345 name,
346 std::io::Error::new(
347 std::io::ErrorKind::InvalidData,
348 "checkpoint response missing checkpoint_token",
349 ),
350 )
351 })?;
352 self.set_checkpoint_token(new_token.to_string());
353 }
354
355 return Err(DurableError::step_retry_scheduled(name));
356 }
357
358 let error_data = serde_json::to_string(error)
360 .map_err(|e| DurableError::serialization(std::any::type_name::<E>(), e))?;
361
362 let error_object = ErrorObject::builder()
363 .error_type(std::any::type_name::<E>())
364 .error_data(error_data)
365 .build();
366
367 let fail_update = OperationUpdate::builder()
368 .id(op_id.clone())
369 .r#type(OperationType::Step)
370 .action(OperationAction::Fail)
371 .name(name)
372 .sub_type("Step")
373 .error(error_object)
374 .build()
375 .map_err(|e| DurableError::checkpoint_failed(name, e))?;
376
377 if self.is_batch_mode() {
378 self.push_pending_update(fail_update);
379 } else {
380 let response = self
381 .backend()
382 .checkpoint(self.arn(), self.checkpoint_token(), vec![fail_update], None)
383 .await?;
384
385 let new_token = response.checkpoint_token().ok_or_else(|| {
386 DurableError::checkpoint_failed(
387 name,
388 std::io::Error::new(
389 std::io::ErrorKind::InvalidData,
390 "checkpoint response missing checkpoint_token",
391 ),
392 )
393 })?;
394 self.set_checkpoint_token(new_token.to_string());
395 }
396 }
397 }
398
399 Ok(user_result)
400 }
401}
402
403fn extract_step_result<T, E>(
408 operation: &aws_sdk_lambda::types::Operation,
409) -> Result<Result<T, E>, DurableError>
410where
411 T: DeserializeOwned,
412 E: DeserializeOwned,
413{
414 match &operation.status {
415 OperationStatus::Succeeded => {
416 let result_json = operation
417 .step_details()
418 .and_then(|d| d.result())
419 .ok_or_else(|| {
420 DurableError::checkpoint_failed(
421 "step",
422 std::io::Error::new(
423 std::io::ErrorKind::InvalidData,
424 "SUCCEEDED operation missing step_details.result",
425 ),
426 )
427 })?;
428
429 let value: T = serde_json::from_str(result_json)
430 .map_err(|e| DurableError::deserialization(std::any::type_name::<T>(), e))?;
431 Ok(Ok(value))
432 }
433 OperationStatus::Failed => {
434 let error_data = operation
435 .step_details()
436 .and_then(|d| d.error())
437 .and_then(|e| e.error_data())
438 .ok_or_else(|| {
439 DurableError::checkpoint_failed(
440 "step",
441 std::io::Error::new(
442 std::io::ErrorKind::InvalidData,
443 "FAILED operation missing step_details.error.error_data",
444 ),
445 )
446 })?;
447
448 let error: E = serde_json::from_str(error_data)
449 .map_err(|e| DurableError::deserialization(std::any::type_name::<E>(), e))?;
450 Ok(Err(error))
451 }
452 other => Err(DurableError::replay_mismatch(
453 "Succeeded or Failed",
454 format!("{other:?}"),
455 0,
456 )),
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use std::sync::Arc;
463
464 use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
465 use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
466 use aws_sdk_lambda::types::{
467 ErrorObject, Operation, OperationStatus, OperationType, OperationUpdate, StepDetails,
468 };
469 use aws_smithy_types::DateTime;
470 use serde::{Deserialize, Serialize};
471 use tokio::sync::Mutex;
472 use tracing_test::traced_test;
473
474 use crate::backend::DurableBackend;
475 use crate::context::DurableContext;
476 use crate::error::DurableError;
477 use crate::operation_id::OperationIdGenerator;
478 use crate::types::StepOptions;
479
480 #[derive(Debug, Clone)]
482 #[allow(dead_code)]
483 struct CheckpointCall {
484 arn: String,
485 checkpoint_token: String,
486 updates: Vec<OperationUpdate>,
487 }
488
489 struct MockBackend {
491 calls: Arc<Mutex<Vec<CheckpointCall>>>,
492 checkpoint_token: String,
493 }
494
495 impl MockBackend {
496 fn new(checkpoint_token: &str) -> (Self, Arc<Mutex<Vec<CheckpointCall>>>) {
497 let calls = Arc::new(Mutex::new(Vec::new()));
498 let backend = Self {
499 calls: calls.clone(),
500 checkpoint_token: checkpoint_token.to_string(),
501 };
502 (backend, calls)
503 }
504 }
505
506 #[async_trait::async_trait]
507 impl DurableBackend for MockBackend {
508 async fn checkpoint(
509 &self,
510 arn: &str,
511 checkpoint_token: &str,
512 updates: Vec<OperationUpdate>,
513 _client_token: Option<&str>,
514 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
515 self.calls.lock().await.push(CheckpointCall {
516 arn: arn.to_string(),
517 checkpoint_token: checkpoint_token.to_string(),
518 updates,
519 });
520 Ok(CheckpointDurableExecutionOutput::builder()
521 .checkpoint_token(&self.checkpoint_token)
522 .build())
523 }
524
525 async fn get_execution_state(
526 &self,
527 _arn: &str,
528 _checkpoint_token: &str,
529 _next_marker: &str,
530 _max_items: i32,
531 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
532 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
533 }
534 }
535
536 #[tokio::test]
537 async fn test_step_executes_closure_in_executing_mode() {
538 let (backend, calls) = MockBackend::new("new-token");
539 let backend = Arc::new(backend);
540
541 let mut ctx = DurableContext::new(
542 backend,
543 "arn:test".to_string(),
544 "initial-token".to_string(),
545 vec![],
546 None,
547 )
548 .await
549 .unwrap();
550
551 let result: Result<i32, String> = ctx.step("my_step", || async { Ok(42) }).await.unwrap();
552
553 assert_eq!(result.unwrap(), 42);
554
555 let captured = calls.lock().await;
556 assert_eq!(captured.len(), 2, "expected START + SUCCEED checkpoints");
557
558 let start_call = &captured[0];
560 assert_eq!(start_call.updates.len(), 1);
561 let start_update = &start_call.updates[0];
562 assert_eq!(start_update.r#type(), &OperationType::Step);
563 assert_eq!(
564 start_update.action(),
565 &aws_sdk_lambda::types::OperationAction::Start
566 );
567 assert_eq!(start_update.name(), Some("my_step"));
568
569 let succeed_call = &captured[1];
571 assert_eq!(succeed_call.updates.len(), 1);
572 let succeed_update = &succeed_call.updates[0];
573 assert_eq!(succeed_update.r#type(), &OperationType::Step);
574 assert_eq!(
575 succeed_update.action(),
576 &aws_sdk_lambda::types::OperationAction::Succeed
577 );
578 assert_eq!(succeed_update.payload().unwrap(), "42");
579
580 assert_eq!(succeed_call.checkpoint_token, "new-token");
582 }
583
584 #[tokio::test]
585 async fn test_step_returns_cached_result_in_replaying_mode() {
586 let (backend, calls) = MockBackend::new("new-token");
587 let backend = Arc::new(backend);
588
589 let mut gen = OperationIdGenerator::new(None);
591 let expected_op_id = gen.next_id();
592
593 let cached_op = Operation::builder()
594 .id(&expected_op_id)
595 .r#type(OperationType::Step)
596 .status(OperationStatus::Succeeded)
597 .start_timestamp(DateTime::from_secs(0))
598 .step_details(
599 StepDetails::builder()
600 .attempt(1)
601 .result(r#"{"value":42}"#)
602 .build(),
603 )
604 .build()
605 .unwrap();
606
607 let mut ctx = DurableContext::new(
608 backend,
609 "arn:test".to_string(),
610 "initial-token".to_string(),
611 vec![cached_op],
612 None,
613 )
614 .await
615 .unwrap();
616
617 let closure_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
619 let closure_called_clone = closure_called.clone();
620
621 #[derive(Serialize, Deserialize, Debug, PartialEq)]
622 struct MyResult {
623 value: i32,
624 }
625
626 let result: Result<MyResult, String> = ctx
627 .step("my_step", move || {
628 let flag = closure_called_clone.clone();
629 async move {
630 flag.store(true, std::sync::atomic::Ordering::SeqCst);
631 Ok(MyResult { value: 999 })
632 }
633 })
634 .await
635 .unwrap();
636
637 assert_eq!(result.unwrap(), MyResult { value: 42 });
638 assert!(
639 !closure_called.load(std::sync::atomic::Ordering::SeqCst),
640 "closure should NOT have been called during replay"
641 );
642
643 let captured = calls.lock().await;
645 assert_eq!(captured.len(), 0, "no checkpoint calls during replay");
646 }
647
648 #[tokio::test]
649 async fn test_step_returns_cached_error_in_replaying_mode() {
650 let (backend, _calls) = MockBackend::new("new-token");
651 let backend = Arc::new(backend);
652
653 let mut gen = OperationIdGenerator::new(None);
654 let expected_op_id = gen.next_id();
655
656 #[derive(Serialize, Deserialize, Debug, PartialEq)]
657 struct MyError {
658 code: i32,
659 message: String,
660 }
661
662 let error_data = serde_json::to_string(&MyError {
663 code: 404,
664 message: "not found".to_string(),
665 })
666 .unwrap();
667
668 let cached_op = Operation::builder()
669 .id(&expected_op_id)
670 .r#type(OperationType::Step)
671 .status(OperationStatus::Failed)
672 .start_timestamp(DateTime::from_secs(0))
673 .step_details(
674 StepDetails::builder()
675 .attempt(1)
676 .error(
677 ErrorObject::builder()
678 .error_type("MyError")
679 .error_data(&error_data)
680 .build(),
681 )
682 .build(),
683 )
684 .build()
685 .unwrap();
686
687 let mut ctx = DurableContext::new(
688 backend,
689 "arn:test".to_string(),
690 "initial-token".to_string(),
691 vec![cached_op],
692 None,
693 )
694 .await
695 .unwrap();
696
697 let result: Result<String, MyError> = ctx
698 .step("my_step", || async { Ok("nope".to_string()) })
699 .await
700 .unwrap();
701
702 let err = result.unwrap_err();
703 assert_eq!(err.code, 404);
704 assert_eq!(err.message, "not found");
705 }
706
707 #[tokio::test]
708 async fn test_step_serialization_roundtrip() {
709 let (backend, _calls) = MockBackend::new("new-token");
710 let backend = Arc::new(backend);
711
712 #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
713 struct ComplexData {
714 name: String,
715 values: Vec<i32>,
716 nested: NestedData,
717 optional: Option<String>,
718 }
719
720 #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
721 struct NestedData {
722 flag: bool,
723 score: f64,
724 }
725
726 let expected = ComplexData {
727 name: "test".to_string(),
728 values: vec![1, 2, 3],
729 nested: NestedData {
730 flag: true,
731 score: 99.5,
732 },
733 optional: Some("present".to_string()),
734 };
735
736 let mut gen = OperationIdGenerator::new(None);
738 let expected_op_id = gen.next_id();
739
740 let serialized = serde_json::to_string(&expected).unwrap();
741
742 let cached_op = Operation::builder()
743 .id(&expected_op_id)
744 .r#type(OperationType::Step)
745 .status(OperationStatus::Succeeded)
746 .start_timestamp(DateTime::from_secs(0))
747 .step_details(
748 StepDetails::builder()
749 .attempt(1)
750 .result(&serialized)
751 .build(),
752 )
753 .build()
754 .unwrap();
755
756 let mut ctx = DurableContext::new(
757 backend,
758 "arn:test".to_string(),
759 "initial-token".to_string(),
760 vec![cached_op],
761 None,
762 )
763 .await
764 .unwrap();
765
766 let result: Result<ComplexData, String> = ctx
767 .step("complex_step", || async {
768 panic!("should not execute during replay")
769 })
770 .await
771 .unwrap();
772
773 assert_eq!(result.unwrap(), expected);
774 }
775
776 #[tokio::test]
777 async fn test_step_sequential_unique_ids() {
778 let (backend, calls) = MockBackend::new("new-token");
779 let backend = Arc::new(backend);
780
781 let mut ctx = DurableContext::new(
782 backend,
783 "arn:test".to_string(),
784 "initial-token".to_string(),
785 vec![],
786 None,
787 )
788 .await
789 .unwrap();
790
791 let _r1: Result<i32, String> = ctx.step("step_1", || async { Ok(1) }).await.unwrap();
792 let _r2: Result<i32, String> = ctx.step("step_2", || async { Ok(2) }).await.unwrap();
793
794 let captured = calls.lock().await;
795 assert_eq!(captured.len(), 4);
797
798 let step1_id = captured[0].updates[0].id().to_string();
800 let step2_id = captured[2].updates[0].id().to_string();
801
802 assert_ne!(
803 step1_id, step2_id,
804 "sequential steps must have different operation IDs"
805 );
806
807 assert_eq!(step1_id, captured[1].updates[0].id());
809 assert_eq!(step2_id, captured[3].updates[0].id());
810 }
811
812 #[tokio::test]
813 async fn test_step_tracks_replay() {
814 let (backend, _calls) = MockBackend::new("new-token");
815 let backend = Arc::new(backend);
816
817 let mut gen = OperationIdGenerator::new(None);
819 let expected_op_id = gen.next_id();
820
821 let cached_op = Operation::builder()
822 .id(&expected_op_id)
823 .r#type(OperationType::Step)
824 .status(OperationStatus::Succeeded)
825 .start_timestamp(DateTime::from_secs(0))
826 .step_details(StepDetails::builder().attempt(1).result("100").build())
827 .build()
828 .unwrap();
829
830 let mut ctx = DurableContext::new(
831 backend,
832 "arn:test".to_string(),
833 "initial-token".to_string(),
834 vec![cached_op],
835 None,
836 )
837 .await
838 .unwrap();
839
840 assert!(
842 ctx.is_replaying(),
843 "should be replaying before visiting cached ops"
844 );
845
846 let result: Result<i32, String> =
847 ctx.step("cached_step", || async { Ok(999) }).await.unwrap();
848 assert_eq!(result.unwrap(), 100);
849
850 assert!(
852 !ctx.is_replaying(),
853 "should transition to executing after all cached ops replayed"
854 );
855 }
856
857 #[tokio::test]
858 async fn test_step_with_options_basic_success() {
859 let (backend, calls) = MockBackend::new("new-token");
860 let backend = Arc::new(backend);
861
862 let mut ctx = DurableContext::new(
863 backend,
864 "arn:test".to_string(),
865 "initial-token".to_string(),
866 vec![],
867 None,
868 )
869 .await
870 .unwrap();
871
872 let result: Result<i32, String> = ctx
873 .step_with_options("opts_step", StepOptions::default(), || async { Ok(42) })
874 .await
875 .unwrap();
876
877 assert_eq!(result.unwrap(), 42);
878
879 let captured = calls.lock().await;
880 assert_eq!(captured.len(), 2, "expected START + SUCCEED checkpoints");
881
882 let start_update = &captured[0].updates[0];
883 assert_eq!(start_update.r#type(), &OperationType::Step);
884 assert_eq!(
885 start_update.action(),
886 &aws_sdk_lambda::types::OperationAction::Start
887 );
888 assert_eq!(start_update.name(), Some("opts_step"));
889
890 let succeed_update = &captured[1].updates[0];
891 assert_eq!(succeed_update.r#type(), &OperationType::Step);
892 assert_eq!(
893 succeed_update.action(),
894 &aws_sdk_lambda::types::OperationAction::Succeed
895 );
896 assert_eq!(succeed_update.payload().unwrap(), "42");
897 }
898
899 #[tokio::test]
900 async fn test_step_with_options_retry_on_failure() {
901 let (backend, calls) = MockBackend::new("new-token");
902 let backend = Arc::new(backend);
903
904 let mut ctx = DurableContext::new(
905 backend,
906 "arn:test".to_string(),
907 "initial-token".to_string(),
908 vec![],
909 None,
910 )
911 .await
912 .unwrap();
913
914 let options = StepOptions::new().retries(3).backoff_seconds(5);
915 let result: Result<Result<i32, String>, DurableError> = ctx
916 .step_with_options("retry_step", options, || async {
917 Err("transient failure".to_string())
918 })
919 .await;
920
921 let err = result.unwrap_err();
923 match err {
924 DurableError::StepRetryScheduled { .. } => {}
925 other => panic!("expected StepRetryScheduled, got {other:?}"),
926 }
927
928 let captured = calls.lock().await;
929 assert_eq!(captured.len(), 2, "expected START + RETRY checkpoints");
930
931 let start_update = &captured[0].updates[0];
933 assert_eq!(
934 start_update.action(),
935 &aws_sdk_lambda::types::OperationAction::Start
936 );
937
938 let retry_update = &captured[1].updates[0];
940 assert_eq!(
941 retry_update.action(),
942 &aws_sdk_lambda::types::OperationAction::Retry
943 );
944 let step_opts = retry_update
945 .step_options()
946 .expect("should have step_options");
947 assert_eq!(step_opts.next_attempt_delay_seconds(), Some(5));
948 }
949
950 #[tokio::test]
951 async fn test_step_with_options_retry_exhaustion() {
952 let (backend, calls) = MockBackend::new("new-token");
953 let backend = Arc::new(backend);
954
955 let mut gen = OperationIdGenerator::new(None);
957 let expected_op_id = gen.next_id();
958
959 let cached_op = Operation::builder()
961 .id(&expected_op_id)
962 .r#type(OperationType::Step)
963 .status(OperationStatus::Pending)
964 .start_timestamp(DateTime::from_secs(0))
965 .step_details(StepDetails::builder().attempt(4).build())
966 .build()
967 .unwrap();
968
969 let mut ctx = DurableContext::new(
970 backend,
971 "arn:test".to_string(),
972 "initial-token".to_string(),
973 vec![cached_op],
974 None,
975 )
976 .await
977 .unwrap();
978
979 let options = StepOptions::new().retries(3).backoff_seconds(5);
980 let result: Result<Result<i32, String>, DurableError> = ctx
981 .step_with_options("exhaust_step", options, || async {
982 Err("final failure".to_string())
983 })
984 .await;
985
986 let inner = result.unwrap();
988 let user_error = inner.unwrap_err();
989 assert_eq!(user_error, "final failure");
990
991 let captured = calls.lock().await;
993 assert_eq!(captured.len(), 1, "expected only FAIL checkpoint");
994
995 let fail_update = &captured[0].updates[0];
996 assert_eq!(
997 fail_update.action(),
998 &aws_sdk_lambda::types::OperationAction::Fail
999 );
1000 }
1001
1002 #[tokio::test]
1003 async fn test_step_with_options_replay_succeeded_with_retries() {
1004 let (backend, calls) = MockBackend::new("new-token");
1005 let backend = Arc::new(backend);
1006
1007 let mut gen = OperationIdGenerator::new(None);
1008 let expected_op_id = gen.next_id();
1009
1010 let cached_op = Operation::builder()
1012 .id(&expected_op_id)
1013 .r#type(OperationType::Step)
1014 .status(OperationStatus::Succeeded)
1015 .start_timestamp(DateTime::from_secs(0))
1016 .step_details(StepDetails::builder().attempt(3).result("99").build())
1017 .build()
1018 .unwrap();
1019
1020 let mut ctx = DurableContext::new(
1021 backend,
1022 "arn:test".to_string(),
1023 "initial-token".to_string(),
1024 vec![cached_op],
1025 None,
1026 )
1027 .await
1028 .unwrap();
1029
1030 let closure_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
1031 let closure_called_clone = closure_called.clone();
1032
1033 let options = StepOptions::new().retries(3);
1034 let result: Result<i32, String> = ctx
1035 .step_with_options("replay_retry_step", options, move || {
1036 let flag = closure_called_clone.clone();
1037 async move {
1038 flag.store(true, std::sync::atomic::Ordering::SeqCst);
1039 Ok(999)
1040 }
1041 })
1042 .await
1043 .unwrap();
1044
1045 assert_eq!(result.unwrap(), 99);
1046 assert!(
1047 !closure_called.load(std::sync::atomic::Ordering::SeqCst),
1048 "closure should NOT have been called during replay"
1049 );
1050
1051 let captured = calls.lock().await;
1052 assert_eq!(captured.len(), 0, "no checkpoint calls during replay");
1053 }
1054
1055 #[tokio::test]
1056 async fn test_step_backward_compatibility() {
1057 let (backend, calls) = MockBackend::new("compat-token");
1058 let backend = Arc::new(backend);
1059
1060 let mut ctx = DurableContext::new(
1061 backend,
1062 "arn:test".to_string(),
1063 "initial-token".to_string(),
1064 vec![],
1065 None,
1066 )
1067 .await
1068 .unwrap();
1069
1070 let result: Result<String, String> = ctx
1072 .step("compat_step", || async { Ok("hello".to_string()) })
1073 .await
1074 .unwrap();
1075
1076 assert_eq!(result.unwrap(), "hello");
1077
1078 let captured = calls.lock().await;
1079 assert_eq!(captured.len(), 2, "expected START + SUCCEED checkpoints");
1080
1081 let start_update = &captured[0].updates[0];
1082 assert_eq!(
1083 start_update.action(),
1084 &aws_sdk_lambda::types::OperationAction::Start
1085 );
1086 assert_eq!(start_update.name(), Some("compat_step"));
1087
1088 let succeed_update = &captured[1].updates[0];
1089 assert_eq!(
1090 succeed_update.action(),
1091 &aws_sdk_lambda::types::OperationAction::Succeed
1092 );
1093 assert_eq!(succeed_update.payload().unwrap(), r#""hello""#);
1094 }
1095
1096 #[test]
1097 fn test_step_options_builder() {
1098 let default_opts = StepOptions::default();
1100 assert_eq!(default_opts.get_retries(), None);
1101 assert_eq!(default_opts.get_backoff_seconds(), None);
1102
1103 let new_opts = StepOptions::new();
1105 assert_eq!(new_opts.get_retries(), None);
1106 assert_eq!(new_opts.get_backoff_seconds(), None);
1107
1108 let opts = StepOptions::new().retries(5).backoff_seconds(10);
1110 assert_eq!(opts.get_retries(), Some(5));
1111 assert_eq!(opts.get_backoff_seconds(), Some(10));
1112
1113 let opts2 = StepOptions::new().retries(1).retries(3);
1115 assert_eq!(opts2.get_retries(), Some(3));
1116 }
1117
1118 #[tokio::test]
1119 async fn test_step_with_options_typed_error_roundtrip() {
1120 let (backend, calls) = MockBackend::new("new-token");
1121 let backend = Arc::new(backend);
1122
1123 #[derive(Serialize, Deserialize, Debug, PartialEq)]
1124 enum DomainError {
1125 NotFound { resource: String },
1126 PermissionDenied { user: String, action: String },
1127 RateLimited { retry_after_secs: u64 },
1128 }
1129
1130 let mut gen = OperationIdGenerator::new(None);
1132 let expected_op_id = gen.next_id();
1133
1134 let original_error = DomainError::PermissionDenied {
1135 user: "alice".to_string(),
1136 action: "delete".to_string(),
1137 };
1138 let error_data = serde_json::to_string(&original_error).unwrap();
1139
1140 let cached_op = Operation::builder()
1142 .id(&expected_op_id)
1143 .r#type(OperationType::Step)
1144 .status(OperationStatus::Failed)
1145 .start_timestamp(DateTime::from_secs(0))
1146 .step_details(
1147 StepDetails::builder()
1148 .attempt(1)
1149 .error(
1150 ErrorObject::builder()
1151 .error_type("DomainError")
1152 .error_data(&error_data)
1153 .build(),
1154 )
1155 .build(),
1156 )
1157 .build()
1158 .unwrap();
1159
1160 let mut ctx = DurableContext::new(
1161 backend,
1162 "arn:test".to_string(),
1163 "initial-token".to_string(),
1164 vec![cached_op],
1165 None,
1166 )
1167 .await
1168 .unwrap();
1169
1170 let result: Result<String, DomainError> = ctx
1171 .step_with_options("typed_err_step", StepOptions::default(), || async {
1172 Ok("should not run".to_string())
1173 })
1174 .await
1175 .unwrap();
1176
1177 let err = result.unwrap_err();
1178 assert_eq!(
1179 err,
1180 DomainError::PermissionDenied {
1181 user: "alice".to_string(),
1182 action: "delete".to_string(),
1183 }
1184 );
1185
1186 let captured = calls.lock().await;
1188 assert_eq!(captured.len(), 0, "no checkpoint calls during replay");
1189 }
1190
1191 #[tokio::test]
1192 async fn test_step_execute_fail_checkpoint() {
1193 let (backend, calls) = MockBackend::new("new-token");
1194 let backend = Arc::new(backend);
1195
1196 let mut ctx = DurableContext::new(
1197 backend,
1198 "arn:test".to_string(),
1199 "initial-token".to_string(),
1200 vec![],
1201 None,
1202 )
1203 .await
1204 .unwrap();
1205
1206 let result: Result<i32, String> = ctx
1208 .step("failing_step", || async {
1209 Err("something went wrong".to_string())
1210 })
1211 .await
1212 .unwrap();
1213
1214 assert_eq!(result.unwrap_err(), "something went wrong");
1215
1216 let captured = calls.lock().await;
1217 assert_eq!(captured.len(), 2, "expected START + FAIL checkpoints");
1218
1219 assert_eq!(
1221 captured[0].updates[0].action(),
1222 &aws_sdk_lambda::types::OperationAction::Start
1223 );
1224
1225 assert_eq!(
1227 captured[1].updates[0].action(),
1228 &aws_sdk_lambda::types::OperationAction::Fail
1229 );
1230 }
1231
1232 struct NoneTokenMockBackend;
1235
1236 #[async_trait::async_trait]
1237 impl DurableBackend for NoneTokenMockBackend {
1238 async fn checkpoint(
1239 &self,
1240 _arn: &str,
1241 _checkpoint_token: &str,
1242 _updates: Vec<OperationUpdate>,
1243 _client_token: Option<&str>,
1244 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
1245 Ok(CheckpointDurableExecutionOutput::builder().build())
1247 }
1248
1249 async fn get_execution_state(
1250 &self,
1251 _arn: &str,
1252 _checkpoint_token: &str,
1253 _next_marker: &str,
1254 _max_items: i32,
1255 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
1256 Ok(GetDurableExecutionStateOutput::builder().build().unwrap())
1257 }
1258 }
1259
1260 #[tokio::test]
1263 async fn test_step_timeout_aborts_slow_closure() {
1264 let (backend, _calls) = MockBackend::new("new-token");
1265 let backend = Arc::new(backend);
1266
1267 let mut ctx = DurableContext::new(
1268 backend,
1269 "arn:test".to_string(),
1270 "initial-token".to_string(),
1271 vec![],
1272 None,
1273 )
1274 .await
1275 .unwrap();
1276
1277 let options = StepOptions::new().timeout_seconds(1);
1278 let result: Result<Result<i32, String>, DurableError> = ctx
1279 .step_with_options("slow_step", options, || async {
1280 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
1281 Ok::<i32, String>(42)
1282 })
1283 .await;
1284
1285 let err = result.unwrap_err();
1286 match err {
1287 DurableError::StepTimeout { operation_name } => {
1288 assert_eq!(operation_name, "slow_step");
1289 }
1290 other => panic!("expected StepTimeout, got {other:?}"),
1291 }
1292 }
1293
1294 #[tokio::test]
1295 async fn test_step_timeout_does_not_fire_when_fast_enough() {
1296 let (backend, _calls) = MockBackend::new("new-token");
1297 let backend = Arc::new(backend);
1298
1299 let mut ctx = DurableContext::new(
1300 backend,
1301 "arn:test".to_string(),
1302 "initial-token".to_string(),
1303 vec![],
1304 None,
1305 )
1306 .await
1307 .unwrap();
1308
1309 let options = StepOptions::new().timeout_seconds(5);
1310 let result: Result<i32, String> = ctx
1311 .step_with_options("fast_step", options, || async { Ok(99) })
1312 .await
1313 .unwrap();
1314
1315 assert_eq!(result.unwrap(), 99);
1316 }
1317
1318 #[tokio::test]
1319 async fn test_retry_if_false_causes_immediate_fail_no_retry_budget_consumed() {
1320 let (backend, calls) = MockBackend::new("new-token");
1321 let backend = Arc::new(backend);
1322
1323 let mut ctx = DurableContext::new(
1324 backend,
1325 "arn:test".to_string(),
1326 "initial-token".to_string(),
1327 vec![],
1328 None,
1329 )
1330 .await
1331 .unwrap();
1332
1333 let options = StepOptions::new().retries(3).retry_if(|_e: &String| false);
1335
1336 let result: Result<Result<i32, String>, DurableError> = ctx
1337 .step_with_options("no_retry_step", options, || async {
1338 Err("permanent error".to_string())
1339 })
1340 .await;
1341
1342 let inner = result.unwrap();
1344 let user_error = inner.unwrap_err();
1345 assert_eq!(user_error, "permanent error");
1346
1347 let captured = calls.lock().await;
1348 assert_eq!(
1350 captured.len(),
1351 2,
1352 "expected START + FAIL, got {}",
1353 captured.len()
1354 );
1355 assert_eq!(
1356 captured[1].updates[0].action(),
1357 &aws_sdk_lambda::types::OperationAction::Fail,
1358 "second checkpoint should be FAIL not RETRY"
1359 );
1360 }
1361
1362 #[tokio::test]
1363 async fn test_retry_if_true_retries_normally() {
1364 let (backend, calls) = MockBackend::new("new-token");
1365 let backend = Arc::new(backend);
1366
1367 let mut ctx = DurableContext::new(
1368 backend,
1369 "arn:test".to_string(),
1370 "initial-token".to_string(),
1371 vec![],
1372 None,
1373 )
1374 .await
1375 .unwrap();
1376
1377 let options = StepOptions::new().retries(3).retry_if(|_e: &String| true);
1379
1380 let result: Result<Result<i32, String>, DurableError> = ctx
1381 .step_with_options("retry_true_step", options, || async {
1382 Err("transient error".to_string())
1383 })
1384 .await;
1385
1386 let err = result.unwrap_err();
1387 match err {
1388 DurableError::StepRetryScheduled { .. } => {}
1389 other => panic!("expected StepRetryScheduled, got {other:?}"),
1390 }
1391
1392 let captured = calls.lock().await;
1393 assert_eq!(captured.len(), 2, "expected START + RETRY");
1394 assert_eq!(
1395 captured[1].updates[0].action(),
1396 &aws_sdk_lambda::types::OperationAction::Retry,
1397 );
1398 }
1399
1400 #[tokio::test]
1401 async fn test_no_retry_if_retries_all_errors_backward_compatible() {
1402 let (backend, calls) = MockBackend::new("new-token");
1403 let backend = Arc::new(backend);
1404
1405 let mut ctx = DurableContext::new(
1406 backend,
1407 "arn:test".to_string(),
1408 "initial-token".to_string(),
1409 vec![],
1410 None,
1411 )
1412 .await
1413 .unwrap();
1414
1415 let options = StepOptions::new().retries(2);
1417
1418 let result: Result<Result<i32, String>, DurableError> = ctx
1419 .step_with_options("compat_retry_step", options, || async {
1420 Err("any error".to_string())
1421 })
1422 .await;
1423
1424 let err = result.unwrap_err();
1425 match err {
1426 DurableError::StepRetryScheduled { .. } => {}
1427 other => panic!("expected StepRetryScheduled, got {other:?}"),
1428 }
1429
1430 let captured = calls.lock().await;
1431 assert_eq!(captured.len(), 2, "expected START + RETRY");
1432 }
1433
1434 #[tokio::test]
1435 async fn checkpoint_none_token_returns_error() {
1436 let backend = Arc::new(NoneTokenMockBackend);
1437
1438 let mut ctx = DurableContext::new(
1439 backend,
1440 "arn:test".to_string(),
1441 "initial-token".to_string(),
1442 vec![], None,
1444 )
1445 .await
1446 .unwrap();
1447
1448 let result: Result<Result<i32, String>, DurableError> =
1452 ctx.step("test_step", || async { Ok(42) }).await;
1453
1454 let err = result
1456 .expect_err("step should fail when checkpoint response has None checkpoint_token");
1457
1458 match &err {
1460 DurableError::CheckpointFailed { operation_name, .. } => {
1461 assert!(
1462 operation_name.contains("test_step"),
1463 "error should reference the operation name, got: {}",
1464 operation_name
1465 );
1466 }
1467 other => panic!("expected DurableError::CheckpointFailed, got: {:?}", other),
1468 }
1469
1470 let err_msg = err.to_string();
1472 assert!(
1473 err_msg.contains("checkpoint response missing checkpoint_token"),
1474 "error message should mention missing checkpoint_token, got: {}",
1475 err_msg
1476 );
1477 }
1478
1479 #[traced_test]
1482 #[tokio::test]
1483 async fn test_step_emits_span() {
1484 let (backend, _calls) = MockBackend::new("tok");
1485 let mut ctx = DurableContext::new(
1486 Arc::new(backend),
1487 "arn:test".to_string(),
1488 "tok".to_string(),
1489 vec![],
1490 None,
1491 )
1492 .await
1493 .unwrap();
1494 let _: Result<i32, String> = ctx.step("validate", || async { Ok(42) }).await.unwrap();
1495 assert!(logs_contain("durable_operation"));
1496 assert!(logs_contain("validate"));
1497 assert!(logs_contain("step"));
1498 }
1499
1500 #[traced_test]
1501 #[tokio::test]
1502 async fn test_span_includes_op_id() {
1503 let (backend, _calls) = MockBackend::new("tok");
1504 let mut ctx = DurableContext::new(
1505 Arc::new(backend),
1506 "arn:test".to_string(),
1507 "tok".to_string(),
1508 vec![],
1509 None,
1510 )
1511 .await
1512 .unwrap();
1513 let _: Result<i32, String> = ctx.step("id_check", || async { Ok(42) }).await.unwrap();
1514 assert!(logs_contain("durable_operation"));
1515 assert!(logs_contain("op.id"));
1516 }
1517}