1use std::collections::HashMap;
30use std::future::Future;
31use std::marker::PhantomData;
32use std::pin::Pin;
33use std::sync::atomic::{AtomicBool, Ordering};
34use std::sync::Arc;
35
36use serde::de::DeserializeOwned;
37use serde::Serialize;
38use tokio::sync::RwLock;
39
40use crate::checkpoint_server::{
41 ApiType, CheckpointWorkerManager, CheckpointWorkerParams, InvocationResult, SkipTimeConfig,
42 StartDurableExecutionRequest, TestExecutionOrchestrator,
43};
44use crate::error::TestError;
45use crate::mock_client::MockDurableServiceClient;
46use crate::operation::CallbackSender;
47use crate::operation_handle::{OperationHandle, OperationMatcher};
48use crate::test_result::TestResult;
49use crate::types::{ExecutionStatus, Invocation, TestResultError};
50use durable_execution_sdk::{
51 DurableContext, DurableError, DurableServiceClient, ErrorObject, Operation,
52};
53
54static TEST_ENVIRONMENT_SETUP: AtomicBool = AtomicBool::new(false);
56
57static TIME_SKIPPING_ENABLED: AtomicBool = AtomicBool::new(false);
59
60struct CheckpointCallbackSender {
66 checkpoint_worker: Arc<CheckpointWorkerManager>,
67}
68
69impl CheckpointCallbackSender {
70 fn new(checkpoint_worker: Arc<CheckpointWorkerManager>) -> Self {
71 Self { checkpoint_worker }
72 }
73}
74
75#[async_trait::async_trait]
76impl CallbackSender for CheckpointCallbackSender {
77 async fn send_success(&self, callback_id: &str, result: &str) -> Result<(), TestError> {
78 self.checkpoint_worker
79 .send_callback_success(callback_id, result)
80 .await
81 .map_err(|e| TestError::CheckpointServerError(e.to_string()))
82 }
83
84 async fn send_failure(
85 &self,
86 callback_id: &str,
87 error: &TestResultError,
88 ) -> Result<(), TestError> {
89 let error_obj = ErrorObject::new(
90 error.error_type.clone().unwrap_or_default(),
91 error.error_message.clone().unwrap_or_default(),
92 );
93 self.checkpoint_worker
94 .send_callback_failure(callback_id, &error_obj)
95 .await
96 .map_err(|e| TestError::CheckpointServerError(e.to_string()))
97 }
98
99 async fn send_heartbeat(&self, callback_id: &str) -> Result<(), TestError> {
100 self.checkpoint_worker
101 .send_callback_heartbeat(callback_id)
102 .await
103 .map_err(|e| TestError::CheckpointServerError(e.to_string()))
104 }
105}
106
107#[derive(Debug, Clone)]
120pub struct TestEnvironmentConfig {
121 pub skip_time: bool,
125
126 pub checkpoint_delay: Option<u64>,
130}
131
132impl Default for TestEnvironmentConfig {
133 fn default() -> Self {
134 Self {
135 skip_time: true,
136 checkpoint_delay: None,
137 }
138 }
139}
140
141#[derive(Debug, Default)]
143struct OperationStorage {
144 operations: Vec<Operation>,
146 operations_by_id: HashMap<String, usize>,
148 operations_by_name: HashMap<String, Vec<usize>>,
150}
151
152impl OperationStorage {
153 fn new() -> Self {
154 Self::default()
155 }
156
157 fn add_operation(&mut self, operation: Operation) {
158 let index = self.operations.len();
159 let id = operation.operation_id.clone();
160 let name = operation.name.clone();
161
162 self.operations.push(operation);
163 self.operations_by_id.insert(id, index);
164
165 if let Some(name) = name {
166 self.operations_by_name.entry(name).or_default().push(index);
167 }
168 }
169
170 fn get_by_id(&self, id: &str) -> Option<&Operation> {
171 self.operations_by_id
172 .get(id)
173 .and_then(|&idx| self.operations.get(idx))
174 }
175
176 fn get_by_name(&self, name: &str) -> Option<&Operation> {
177 self.operations_by_name
178 .get(name)
179 .and_then(|indices| indices.first())
180 .and_then(|&idx| self.operations.get(idx))
181 }
182
183 fn get_by_name_and_index(&self, name: &str, index: usize) -> Option<&Operation> {
184 self.operations_by_name
185 .get(name)
186 .and_then(|indices| indices.get(index))
187 .and_then(|&idx| self.operations.get(idx))
188 }
189
190 fn get_by_index(&self, index: usize) -> Option<&Operation> {
191 self.operations.get(index)
192 }
193
194 fn get_all(&self) -> &[Operation] {
195 &self.operations
196 }
197
198 fn clear(&mut self) {
199 self.operations.clear();
200 self.operations_by_id.clear();
201 self.operations_by_name.clear();
202 }
203
204 #[allow(dead_code)]
205 fn len(&self) -> usize {
206 self.operations.len()
207 }
208
209 #[allow(dead_code)]
210 fn is_empty(&self) -> bool {
211 self.operations.is_empty()
212 }
213}
214
215type SharedAsyncFn<I, O> = Arc<
217 dyn Fn(I, DurableContext) -> Pin<Box<dyn Future<Output = Result<O, DurableError>> + Send>>
218 + Send
219 + Sync,
220>;
221
222type BoxedDurableFn = Box<
224 dyn Fn(
225 serde_json::Value,
226 DurableContext,
227 ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, DurableError>> + Send>>
228 + Send
229 + Sync,
230>;
231
232#[allow(dead_code)]
236enum RegisteredFunction {
237 Durable(BoxedDurableFn),
239 Regular(
241 Box<dyn Fn(serde_json::Value) -> Result<serde_json::Value, DurableError> + Send + Sync>,
242 ),
243}
244
245pub struct LocalDurableTestRunner<I, O>
274where
275 I: DeserializeOwned + Send + 'static,
276 O: Serialize + DeserializeOwned + Send + 'static,
277{
278 handler: SharedAsyncFn<I, O>,
280 checkpoint_worker: Arc<CheckpointWorkerManager>,
282 #[deprecated(note = "Use checkpoint_worker instead. Retained for backward compatibility.")]
284 mock_client: Arc<MockDurableServiceClient>,
285 operation_storage: Arc<RwLock<OperationStorage>>,
287 registered_functions: Arc<RwLock<HashMap<String, RegisteredFunction>>>,
289 registered_handles: Vec<OperationHandle>,
291 shared_operations: Arc<RwLock<Vec<Operation>>>,
293 _phantom: PhantomData<(I, O)>,
295}
296
297impl<I, O> LocalDurableTestRunner<I, O>
298where
299 I: DeserializeOwned + Send + Serialize + 'static,
300 O: Serialize + DeserializeOwned + Send + 'static,
301{
302 pub async fn setup_test_environment(config: TestEnvironmentConfig) -> Result<(), TestError> {
321 if config.skip_time {
325 use std::panic;
330 let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
331 tokio::time::pause();
332 }));
333
334 if let Err(panic_info) = result {
335 let is_runtime_error = panic_info
337 .downcast_ref::<&str>()
338 .map(|msg| msg.contains("current_thread"))
339 .unwrap_or(false)
340 || panic_info
341 .downcast_ref::<String>()
342 .map(|msg| msg.contains("current_thread"))
343 .unwrap_or(false);
344
345 let is_already_frozen = panic_info
347 .downcast_ref::<&str>()
348 .map(|msg| msg.contains("already frozen") || msg.contains("already paused"))
349 .unwrap_or(false)
350 || panic_info
351 .downcast_ref::<String>()
352 .map(|msg| msg.contains("already frozen") || msg.contains("already paused"))
353 .unwrap_or(false);
354
355 if is_runtime_error {
356 tracing::warn!(
358 "Time control requires current_thread Tokio runtime. \
359 Time skipping may not work correctly."
360 );
361 } else if is_already_frozen {
362 } else {
365 panic::resume_unwind(panic_info);
367 }
368 }
369 TIME_SKIPPING_ENABLED.store(true, Ordering::SeqCst);
370 }
371
372 TEST_ENVIRONMENT_SETUP.store(true, Ordering::SeqCst);
373 Ok(())
374 }
375
376 pub async fn teardown_test_environment() -> Result<(), TestError> {
388 if !TEST_ENVIRONMENT_SETUP.load(Ordering::SeqCst) {
390 return Ok(());
391 }
392
393 if TIME_SKIPPING_ENABLED.load(Ordering::SeqCst) {
397 use std::panic;
400 let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
401 if crate::time_control::is_time_paused() {
402 tokio::time::resume();
403 }
404 }));
405 TIME_SKIPPING_ENABLED.store(false, Ordering::SeqCst);
406 }
407
408 TEST_ENVIRONMENT_SETUP.store(false, Ordering::SeqCst);
409 Ok(())
410 }
411
412 pub fn is_environment_setup() -> bool {
414 TEST_ENVIRONMENT_SETUP.load(Ordering::SeqCst)
415 }
416
417 pub fn is_time_skipping_enabled() -> bool {
419 TIME_SKIPPING_ENABLED.load(Ordering::SeqCst)
420 }
421
422 pub fn new<F, Fut>(handler: F) -> Self
439 where
440 F: Fn(I, DurableContext) -> Fut + Send + Sync + 'static,
441 Fut: Future<Output = Result<O, DurableError>> + Send + 'static,
442 {
443 let handler: SharedAsyncFn<I, O> = Arc::new(move |input: I, ctx: DurableContext| {
445 let fut = handler(input, ctx);
446 Box::pin(fut) as Pin<Box<dyn Future<Output = Result<O, DurableError>> + Send>>
447 });
448
449 let checkpoint_worker = match CheckpointWorkerManager::get_instance(None) {
452 Ok(worker) => worker,
453 Err(_) => {
454 CheckpointWorkerManager::reset_instance_for_testing();
456 CheckpointWorkerManager::get_instance(None)
457 .expect("Failed to create CheckpointWorkerManager after reset")
458 }
459 };
460
461 #[allow(deprecated)]
462 Self {
463 handler,
464 checkpoint_worker,
465 mock_client: Arc::new(MockDurableServiceClient::new().with_checkpoint_responses(100)),
466 operation_storage: Arc::new(RwLock::new(OperationStorage::new())),
467 registered_functions: Arc::new(RwLock::new(HashMap::new())),
468 registered_handles: Vec::new(),
469 shared_operations: Arc::new(RwLock::new(Vec::new())),
470 _phantom: PhantomData,
471 }
472 }
473
474 #[deprecated(note = "Use new() instead. The checkpoint worker manager is now the default.")]
481 pub fn with_mock_client<F, Fut>(handler: F, mock_client: MockDurableServiceClient) -> Self
482 where
483 F: Fn(I, DurableContext) -> Fut + Send + Sync + 'static,
484 Fut: Future<Output = Result<O, DurableError>> + Send + 'static,
485 {
486 let handler: SharedAsyncFn<I, O> = Arc::new(move |input: I, ctx: DurableContext| {
488 let fut = handler(input, ctx);
489 Box::pin(fut) as Pin<Box<dyn Future<Output = Result<O, DurableError>> + Send>>
490 });
491
492 let checkpoint_worker = CheckpointWorkerManager::get_instance(None)
494 .expect("Failed to create CheckpointWorkerManager");
495
496 #[allow(deprecated)]
497 Self {
498 handler,
499 checkpoint_worker,
500 mock_client: Arc::new(mock_client),
501 operation_storage: Arc::new(RwLock::new(OperationStorage::new())),
502 registered_functions: Arc::new(RwLock::new(HashMap::new())),
503 registered_handles: Vec::new(),
504 shared_operations: Arc::new(RwLock::new(Vec::new())),
505 _phantom: PhantomData,
506 }
507 }
508
509 pub fn with_checkpoint_params<F, Fut>(handler: F, params: CheckpointWorkerParams) -> Self
516 where
517 F: Fn(I, DurableContext) -> Fut + Send + Sync + 'static,
518 Fut: Future<Output = Result<O, DurableError>> + Send + 'static,
519 {
520 let handler: SharedAsyncFn<I, O> = Arc::new(move |input: I, ctx: DurableContext| {
522 let fut = handler(input, ctx);
523 Box::pin(fut) as Pin<Box<dyn Future<Output = Result<O, DurableError>> + Send>>
524 });
525
526 let checkpoint_worker = CheckpointWorkerManager::get_instance(Some(params))
528 .expect("Failed to create CheckpointWorkerManager");
529
530 #[allow(deprecated)]
531 Self {
532 handler,
533 checkpoint_worker,
534 mock_client: Arc::new(MockDurableServiceClient::new().with_checkpoint_responses(100)),
535 operation_storage: Arc::new(RwLock::new(OperationStorage::new())),
536 registered_functions: Arc::new(RwLock::new(HashMap::new())),
537 registered_handles: Vec::new(),
538 shared_operations: Arc::new(RwLock::new(Vec::new())),
539 _phantom: PhantomData,
540 }
541 }
542
543 pub fn checkpoint_worker(&self) -> &Arc<CheckpointWorkerManager> {
545 &self.checkpoint_worker
546 }
547
548 #[deprecated(note = "Use checkpoint_worker() instead.")]
550 #[allow(deprecated)]
551 pub fn mock_client(&self) -> &Arc<MockDurableServiceClient> {
552 &self.mock_client
553 }
554
555 pub async fn operation_count(&self) -> usize {
557 self.operation_storage.read().await.len()
558 }
559
560 pub fn run(
591 &mut self,
592 input: impl Into<crate::types::InvokeRequest<I>>,
593 ) -> crate::run_future::RunFuture<O>
594 where
595 I: Clone + Send + Sync + 'static,
596 O: Send + 'static,
597 {
598 let invoke_request: crate::types::InvokeRequest<I> = input.into();
599
600 let payload_result: Result<I, _> = match invoke_request.payload {
605 Some(p) => Ok(p),
606 None => serde_json::from_value(serde_json::Value::Null),
607 };
608
609 let payload = match payload_result {
611 Ok(p) => p,
612 Err(e) => {
613 return crate::run_future::RunFuture::from_future(Box::pin(async move {
614 Err(TestError::InvalidConfiguration(format!(
615 "InvokeRequest has no payload and the input type cannot be \
616 deserialized from null: {}. Use InvokeRequest::with_payload() \
617 to provide a value.",
618 e
619 )))
620 }));
621 }
622 };
623
624 if let Ok(mut storage) = self.operation_storage.try_write() {
627 storage.clear();
628 }
629 #[allow(deprecated)]
630 self.mock_client.clear_all_calls();
631
632 let handler = Arc::clone(&self.handler);
634 let checkpoint_worker = self.checkpoint_worker.clone();
635 let operation_storage = self.operation_storage.clone();
636 let registered_handles = self.registered_handles.clone();
637 let shared_operations = self.shared_operations.clone();
638
639 crate::run_future::RunFuture::from_future(Box::pin(async move {
640 use crate::checkpoint_server::OperationStorage as OrchestratorOperationStorage;
641
642 operation_storage.write().await.clear();
644
645 let skip_time_config = SkipTimeConfig {
647 enabled: LocalDurableTestRunner::<I, O>::is_time_skipping_enabled(),
648 };
649
650 let orchestrator_storage =
652 Arc::new(tokio::sync::RwLock::new(OrchestratorOperationStorage::new()));
653
654 let handler_clone = Arc::clone(&handler);
656 let mut orchestrator = TestExecutionOrchestrator::new(
657 move |input: I, ctx: DurableContext| {
658 let handler = Arc::clone(&handler_clone);
659 async move { handler(input, ctx).await }
660 },
661 orchestrator_storage.clone(),
662 checkpoint_worker.clone(),
663 skip_time_config,
664 );
665
666 if !registered_handles.is_empty() {
668 let callback_sender: Option<Arc<dyn CallbackSender>> = Some(Arc::new(
669 CheckpointCallbackSender::new(checkpoint_worker.clone()),
670 ));
671 orchestrator = orchestrator.with_handles(
672 registered_handles,
673 shared_operations,
674 callback_sender,
675 );
676 }
677
678 let execution_result = orchestrator.execute_handler(payload).await?;
680
681 {
683 let orch_storage = orchestrator_storage.read().await;
684 let mut our_storage = operation_storage.write().await;
685 for op in orch_storage.get_all() {
686 our_storage.add_operation(op.clone());
687 }
688 }
689
690 let mut test_result = match execution_result.status {
692 ExecutionStatus::Succeeded => {
693 if let Some(result) = execution_result.result {
694 TestResult::success(result, execution_result.operations)
695 } else {
696 TestResult::with_status(
697 ExecutionStatus::Succeeded,
698 execution_result.operations,
699 )
700 }
701 }
702 ExecutionStatus::Failed => {
703 if let Some(error) = execution_result.error {
704 TestResult::failure(error, execution_result.operations)
705 } else {
706 TestResult::with_status(
707 ExecutionStatus::Failed,
708 execution_result.operations,
709 )
710 }
711 }
712 ExecutionStatus::Running => {
713 TestResult::with_status(ExecutionStatus::Running, execution_result.operations)
714 }
715 _ => TestResult::with_status(execution_result.status, execution_result.operations),
716 };
717
718 for invocation in execution_result.invocations {
720 test_result.add_invocation(invocation);
721 }
722
723 if let Ok(nodejs_events) = checkpoint_worker
725 .get_nodejs_history_events(&execution_result.execution_id)
726 .await
727 {
728 test_result.set_nodejs_history_events(nodejs_events);
729 }
730
731 Ok(test_result)
732 }))
733 }
734
735 pub async fn run_single_invocation(&mut self, payload: I) -> Result<TestResult<O>, TestError> {
760 use durable_execution_sdk::lambda::InitialExecutionState;
761 use durable_execution_sdk::state::ExecutionState;
762
763 self.operation_storage.write().await.clear();
765
766 #[allow(deprecated)]
768 self.mock_client.clear_all_calls();
769
770 let payload_json = serde_json::to_string(&payload)?;
772
773 let invocation_id = uuid::Uuid::new_v4().to_string();
776 let start_request = StartDurableExecutionRequest {
777 invocation_id: invocation_id.clone(),
778 payload: Some(payload_json),
779 };
780 let start_payload = serde_json::to_string(&start_request)?;
781
782 let start_response = self
783 .checkpoint_worker
784 .send_api_request(ApiType::StartDurableExecution, start_payload)
785 .await?;
786
787 if let Some(error) = start_response.error {
788 return Err(TestError::CheckpointServerError(error));
789 }
790
791 let invocation_result: InvocationResult =
792 serde_json::from_str(&start_response.payload.ok_or_else(|| {
793 TestError::CheckpointServerError(
794 "Empty response from checkpoint server".to_string(),
795 )
796 })?)?;
797
798 let execution_arn = invocation_result.execution_id;
799 let checkpoint_token = invocation_result.checkpoint_token;
800
801 let initial_state = InitialExecutionState::new();
803
804 let execution_state = Arc::new(ExecutionState::new(
807 &execution_arn,
808 &checkpoint_token,
809 initial_state,
810 self.checkpoint_worker.clone(),
811 ));
812
813 let ctx = DurableContext::new(execution_state.clone());
815
816 let start_time = chrono::Utc::now();
818 let mut invocation = Invocation::with_start(start_time);
819
820 let handler_result = (self.handler)(payload, ctx).await;
822
823 let end_time = chrono::Utc::now();
825 invocation = invocation.with_end(end_time);
826
827 let operations = match self
829 .checkpoint_worker
830 .get_operations(&execution_arn, "")
831 .await
832 {
833 Ok(response) => {
834 let mut storage = self.operation_storage.write().await;
835 for op in &response.operations {
836 storage.add_operation(op.clone());
837 }
838 response.operations
839 }
840 Err(_) => {
841 Vec::new()
843 }
844 };
845
846 match handler_result {
848 Ok(result) => {
849 let mut test_result = TestResult::success(result, operations);
850 test_result.add_invocation(invocation);
851 Ok(test_result)
852 }
853 Err(error) => {
854 if error.is_suspend() {
856 let mut test_result =
857 TestResult::with_status(ExecutionStatus::Running, operations);
858 test_result.add_invocation(invocation);
859 Ok(test_result)
860 } else {
861 let error_obj = durable_execution_sdk::ErrorObject::from(&error);
863 let test_error = TestResultError::new(error_obj.error_type, error.to_string());
864 invocation = invocation.with_error(test_error.clone());
865 let mut test_result = TestResult::failure(test_error, operations);
866 test_result.add_invocation(invocation);
867 Ok(test_result)
868 }
869 }
870 }
871 }
872
873 pub async fn run_with_orchestrator(&mut self, payload: I) -> Result<TestResult<O>, TestError>
900 where
901 I: Clone,
902 {
903 use crate::checkpoint_server::OperationStorage as OrchestratorOperationStorage;
904
905 self.operation_storage.write().await.clear();
907
908 #[allow(deprecated)]
910 self.mock_client.clear_all_calls();
911
912 let skip_time_config = SkipTimeConfig {
914 enabled: Self::is_time_skipping_enabled(),
915 };
916
917 let orchestrator_storage =
919 Arc::new(tokio::sync::RwLock::new(OrchestratorOperationStorage::new()));
920
921 let handler = Arc::clone(&self.handler);
923
924 let mut orchestrator = TestExecutionOrchestrator::new(
927 move |input: I, ctx: DurableContext| {
928 let handler = Arc::clone(&handler);
929 async move { handler(input, ctx).await }
930 },
931 orchestrator_storage.clone(),
932 self.checkpoint_worker.clone(),
933 skip_time_config,
934 );
935
936 if !self.registered_handles.is_empty() {
939 let callback_sender: Option<Arc<dyn CallbackSender>> = Some(Arc::new(
940 CheckpointCallbackSender::new(self.checkpoint_worker.clone()),
941 ));
942 orchestrator = orchestrator.with_handles(
943 self.registered_handles.clone(),
944 self.shared_operations.clone(),
945 callback_sender,
946 );
947 }
948
949 let execution_result = orchestrator.execute_handler(payload.clone()).await?;
951
952 {
954 let orch_storage = orchestrator_storage.read().await;
955 let mut our_storage = self.operation_storage.write().await;
956 for op in orch_storage.get_all() {
957 our_storage.add_operation(op.clone());
958 }
959 }
960
961 let mut test_result = match execution_result.status {
963 ExecutionStatus::Succeeded => {
964 if let Some(result) = execution_result.result {
965 TestResult::success(result, execution_result.operations)
966 } else {
967 TestResult::with_status(ExecutionStatus::Succeeded, execution_result.operations)
968 }
969 }
970 ExecutionStatus::Failed => {
971 if let Some(error) = execution_result.error {
972 TestResult::failure(error, execution_result.operations)
973 } else {
974 TestResult::with_status(ExecutionStatus::Failed, execution_result.operations)
975 }
976 }
977 ExecutionStatus::Running => {
978 TestResult::with_status(ExecutionStatus::Running, execution_result.operations)
979 }
980 _ => TestResult::with_status(execution_result.status, execution_result.operations),
981 };
982
983 for invocation in execution_result.invocations {
985 test_result.add_invocation(invocation);
986 }
987
988 if let Ok(nodejs_events) = self
990 .checkpoint_worker
991 .get_nodejs_history_events(&execution_result.execution_id)
992 .await
993 {
994 test_result.set_nodejs_history_events(nodejs_events);
995 }
996
997 Ok(test_result)
998 }
999
1000 pub fn get_operation_handle(&mut self, name: &str) -> OperationHandle {
1013 let handle = OperationHandle::new(
1014 OperationMatcher::ByName(name.to_string()),
1015 self.shared_operations.clone(),
1016 );
1017 self.registered_handles.push(handle.clone());
1018 handle
1019 }
1020
1021 pub fn get_operation_handle_by_index(&mut self, index: usize) -> OperationHandle {
1034 let handle = OperationHandle::new(
1035 OperationMatcher::ByIndex(index),
1036 self.shared_operations.clone(),
1037 );
1038 self.registered_handles.push(handle.clone());
1039 handle
1040 }
1041
1042 pub fn get_operation_handle_by_id(&mut self, id: &str) -> OperationHandle {
1055 let handle = OperationHandle::new(
1056 OperationMatcher::ById(id.to_string()),
1057 self.shared_operations.clone(),
1058 );
1059 self.registered_handles.push(handle.clone());
1060 handle
1061 }
1062
1063 pub async fn reset(&mut self) {
1084 self.operation_storage.write().await.clear();
1086
1087 CheckpointWorkerManager::reset_instance_for_testing();
1089
1090 self.checkpoint_worker = CheckpointWorkerManager::get_instance(None)
1092 .expect("Failed to create CheckpointWorkerManager after reset");
1093
1094 self.registered_handles.clear();
1096 self.shared_operations.write().await.clear();
1097
1098 #[allow(deprecated)]
1100 self.mock_client.clear_all_calls();
1101 }
1102
1103 pub async fn get_operation_by_id(&self, id: &str) -> Option<Operation> {
1125 self.operation_storage.read().await.get_by_id(id).cloned()
1126 }
1127
1128 pub async fn get_operation(&self, name: &str) -> Option<Operation> {
1150 self.operation_storage
1151 .read()
1152 .await
1153 .get_by_name(name)
1154 .cloned()
1155 }
1156
1157 pub async fn get_operation_by_index(&self, index: usize) -> Option<Operation> {
1180 self.operation_storage
1181 .read()
1182 .await
1183 .get_by_index(index)
1184 .cloned()
1185 }
1186
1187 pub async fn get_operation_by_name_and_index(
1214 &self,
1215 name: &str,
1216 index: usize,
1217 ) -> Option<Operation> {
1218 self.operation_storage
1219 .read()
1220 .await
1221 .get_by_name_and_index(name, index)
1222 .cloned()
1223 }
1224
1225 pub async fn get_all_operations(&self) -> Vec<Operation> {
1242 self.operation_storage.read().await.get_all().to_vec()
1243 }
1244
1245 pub async fn register_durable_function<F, Fut>(&self, name: impl Into<String>, func: F)
1268 where
1269 F: Fn(serde_json::Value, DurableContext) -> Fut + Send + Sync + 'static,
1270 Fut: Future<Output = Result<serde_json::Value, DurableError>> + Send + 'static,
1271 {
1272 let boxed_func = Box::new(move |input: serde_json::Value, ctx: DurableContext| {
1273 let fut = func(input, ctx);
1274 Box::pin(fut)
1275 as Pin<Box<dyn Future<Output = Result<serde_json::Value, DurableError>> + Send>>
1276 });
1277
1278 self.registered_functions
1279 .write()
1280 .await
1281 .insert(name.into(), RegisteredFunction::Durable(boxed_func));
1282 }
1283
1284 pub async fn register_function<F>(&self, name: impl Into<String>, func: F)
1306 where
1307 F: Fn(serde_json::Value) -> Result<serde_json::Value, DurableError> + Send + Sync + 'static,
1308 {
1309 self.registered_functions
1310 .write()
1311 .await
1312 .insert(name.into(), RegisteredFunction::Regular(Box::new(func)));
1313 }
1314
1315 pub async fn has_registered_function(&self, name: &str) -> bool {
1337 self.registered_functions.read().await.contains_key(name)
1338 }
1339
1340 pub async fn registered_function_count(&self) -> usize {
1346 self.registered_functions.read().await.len()
1347 }
1348
1349 pub async fn clear_registered_functions(&mut self) {
1364 self.registered_functions.write().await.clear();
1365 }
1366}
1367
1368impl<I, O> std::fmt::Debug for LocalDurableTestRunner<I, O>
1369where
1370 I: DeserializeOwned + Send + 'static,
1371 O: Serialize + DeserializeOwned + Send + 'static,
1372{
1373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1374 f.debug_struct("LocalDurableTestRunner")
1375 .field("checkpoint_worker", &"CheckpointWorkerManager")
1376 .finish()
1377 }
1378}
1379
1380#[cfg(test)]
1381mod tests {
1382 use super::*;
1383 use durable_execution_sdk::OperationType;
1384
1385 async fn simple_handler(input: String, _ctx: DurableContext) -> Result<String, DurableError> {
1386 Ok(format!("processed: {}", input))
1387 }
1388
1389 #[test]
1390 fn test_test_environment_config_default() {
1391 let config = TestEnvironmentConfig::default();
1392 assert!(config.skip_time);
1393 assert!(config.checkpoint_delay.is_none());
1394 }
1395
1396 #[test]
1397 fn test_operation_storage_new() {
1398 let storage = OperationStorage::new();
1399 assert!(storage.is_empty());
1400 assert_eq!(storage.len(), 0);
1401 }
1402
1403 #[test]
1404 fn test_operation_storage_add_and_get() {
1405 let mut storage = OperationStorage::new();
1406
1407 let mut op1 = Operation::new("op-1", OperationType::Step);
1408 op1.name = Some("step1".to_string());
1409 storage.add_operation(op1);
1410
1411 let mut op2 = Operation::new("op-2", OperationType::Wait);
1412 op2.name = Some("wait1".to_string());
1413 storage.add_operation(op2);
1414
1415 assert_eq!(storage.len(), 2);
1416
1417 let found = storage.get_by_id("op-1");
1419 assert!(found.is_some());
1420 assert_eq!(found.unwrap().operation_id, "op-1");
1421
1422 let found = storage.get_by_name("step1");
1424 assert!(found.is_some());
1425 assert_eq!(found.unwrap().operation_id, "op-1");
1426
1427 let found = storage.get_by_index(1);
1429 assert!(found.is_some());
1430 assert_eq!(found.unwrap().operation_id, "op-2");
1431
1432 let all = storage.get_all();
1434 assert_eq!(all.len(), 2);
1435 }
1436
1437 #[test]
1438 fn test_operation_storage_get_by_name_and_index() {
1439 let mut storage = OperationStorage::new();
1440
1441 let mut op1 = Operation::new("op-1", OperationType::Step);
1443 op1.name = Some("step".to_string());
1444 storage.add_operation(op1);
1445
1446 let mut op2 = Operation::new("op-2", OperationType::Step);
1447 op2.name = Some("step".to_string());
1448 storage.add_operation(op2);
1449
1450 let mut op3 = Operation::new("op-3", OperationType::Step);
1451 op3.name = Some("step".to_string());
1452 storage.add_operation(op3);
1453
1454 let found = storage.get_by_name_and_index("step", 0);
1456 assert!(found.is_some());
1457 assert_eq!(found.unwrap().operation_id, "op-1");
1458
1459 let found = storage.get_by_name_and_index("step", 1);
1461 assert!(found.is_some());
1462 assert_eq!(found.unwrap().operation_id, "op-2");
1463
1464 let found = storage.get_by_name_and_index("step", 2);
1466 assert!(found.is_some());
1467 assert_eq!(found.unwrap().operation_id, "op-3");
1468
1469 let found = storage.get_by_name_and_index("step", 3);
1471 assert!(found.is_none());
1472 }
1473
1474 #[test]
1475 fn test_operation_storage_clear() {
1476 let mut storage = OperationStorage::new();
1477
1478 let op = Operation::new("op-1", OperationType::Step);
1479 storage.add_operation(op);
1480 assert_eq!(storage.len(), 1);
1481
1482 storage.clear();
1483 assert!(storage.is_empty());
1484 assert!(storage.get_by_id("op-1").is_none());
1485 }
1486
1487 #[tokio::test]
1488 async fn test_local_runner_creation() {
1489 let runner = LocalDurableTestRunner::new(simple_handler);
1490 assert_eq!(runner.operation_count().await, 0);
1491 }
1492
1493 #[tokio::test]
1494 #[allow(deprecated)]
1495 async fn test_local_runner_with_mock_client() {
1496 let mock_client = MockDurableServiceClient::new().with_checkpoint_responses(10);
1497 let runner = LocalDurableTestRunner::with_mock_client(simple_handler, mock_client);
1498 assert_eq!(runner.operation_count().await, 0);
1499 }
1500
1501 #[tokio::test]
1502 async fn test_setup_teardown_environment() {
1503 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1505 .await
1506 .unwrap();
1507
1508 assert!(!LocalDurableTestRunner::<String, String>::is_environment_setup());
1509 assert!(!LocalDurableTestRunner::<String, String>::is_time_skipping_enabled());
1510
1511 LocalDurableTestRunner::<String, String>::setup_test_environment(TestEnvironmentConfig {
1513 skip_time: true,
1514 checkpoint_delay: None,
1515 })
1516 .await
1517 .unwrap();
1518
1519 assert!(LocalDurableTestRunner::<String, String>::is_environment_setup());
1520 assert!(LocalDurableTestRunner::<String, String>::is_time_skipping_enabled());
1521
1522 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1524 .await
1525 .unwrap();
1526
1527 assert!(!LocalDurableTestRunner::<String, String>::is_environment_setup());
1528 assert!(!LocalDurableTestRunner::<String, String>::is_time_skipping_enabled());
1529 }
1530
1531 #[tokio::test]
1532 async fn test_setup_without_time_skipping() {
1533 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1535 .await
1536 .unwrap();
1537
1538 LocalDurableTestRunner::<String, String>::setup_test_environment(TestEnvironmentConfig {
1540 skip_time: false,
1541 checkpoint_delay: None,
1542 })
1543 .await
1544 .unwrap();
1545
1546 assert!(LocalDurableTestRunner::<String, String>::is_environment_setup());
1547 assert!(!LocalDurableTestRunner::<String, String>::is_time_skipping_enabled());
1548
1549 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1551 .await
1552 .unwrap();
1553 }
1554
1555 #[tokio::test]
1556 async fn test_double_setup_is_idempotent() {
1557 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1559 .await
1560 .unwrap();
1561
1562 LocalDurableTestRunner::<String, String>::setup_test_environment(
1564 TestEnvironmentConfig::default(),
1565 )
1566 .await
1567 .unwrap();
1568
1569 LocalDurableTestRunner::<String, String>::setup_test_environment(
1571 TestEnvironmentConfig::default(),
1572 )
1573 .await
1574 .unwrap();
1575
1576 assert!(LocalDurableTestRunner::<String, String>::is_environment_setup());
1577
1578 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1580 .await
1581 .unwrap();
1582 }
1583
1584 #[tokio::test]
1585 async fn test_double_teardown_is_idempotent() {
1586 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1588 .await
1589 .unwrap();
1590
1591 LocalDurableTestRunner::<String, String>::setup_test_environment(
1593 TestEnvironmentConfig::default(),
1594 )
1595 .await
1596 .unwrap();
1597
1598 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1600 .await
1601 .unwrap();
1602
1603 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1605 .await
1606 .unwrap();
1607
1608 assert!(!LocalDurableTestRunner::<String, String>::is_environment_setup());
1609 }
1610
1611 #[tokio::test]
1614 async fn test_run_successful_execution() {
1615 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1617 .await
1618 .unwrap();
1619
1620 let mut runner = LocalDurableTestRunner::new(simple_handler);
1621 let result = runner.run("hello".to_string()).await.unwrap();
1622
1623 assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
1625
1626 let output = result.get_result().unwrap();
1628 assert_eq!(output, "processed: hello");
1629
1630 assert_eq!(result.get_invocations().len(), 1);
1632 }
1633
1634 #[tokio::test]
1635 async fn test_run_failed_execution() {
1636 async fn failing_handler(
1637 _input: String,
1638 _ctx: DurableContext,
1639 ) -> Result<String, DurableError> {
1640 Err(DurableError::execution("Test failure"))
1641 }
1642
1643 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1645 .await
1646 .unwrap();
1647
1648 let mut runner = LocalDurableTestRunner::new(failing_handler);
1649 let result = runner.run("hello".to_string()).await.unwrap();
1650
1651 assert_eq!(result.get_status(), ExecutionStatus::Failed);
1653
1654 let error = result.get_error().unwrap();
1656 assert!(error
1657 .error_message
1658 .as_ref()
1659 .unwrap()
1660 .contains("Test failure"));
1661
1662 let invocations = result.get_invocations();
1664 assert_eq!(invocations.len(), 1);
1665 assert!(invocations[0].error.is_some());
1666 }
1667
1668 #[tokio::test]
1669 async fn test_run_multiple_times_clears_previous_operations() {
1670 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1672 .await
1673 .unwrap();
1674
1675 let mut runner = LocalDurableTestRunner::new(simple_handler);
1676
1677 let result1 = runner.run("first".to_string()).await.unwrap();
1679 assert_eq!(result1.get_status(), ExecutionStatus::Succeeded);
1680 assert_eq!(result1.get_result().unwrap(), "processed: first");
1681
1682 let result2 = runner.run("second".to_string()).await.unwrap();
1684 assert_eq!(result2.get_status(), ExecutionStatus::Succeeded);
1685 assert_eq!(result2.get_result().unwrap(), "processed: second");
1686 }
1687
1688 #[tokio::test]
1691 async fn test_reset_clears_operation_storage() {
1692 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1694 .await
1695 .unwrap();
1696
1697 CheckpointWorkerManager::reset_instance_for_testing();
1699
1700 let mut runner = LocalDurableTestRunner::new(simple_handler);
1701
1702 let _ = runner.run("hello".to_string()).await.unwrap();
1704
1705 runner.reset().await;
1707
1708 assert_eq!(runner.operation_count().await, 0);
1710 }
1711
1712 #[tokio::test]
1713 async fn test_reset_allows_fresh_run() {
1714 LocalDurableTestRunner::<String, String>::teardown_test_environment()
1716 .await
1717 .unwrap();
1718
1719 let mut runner = LocalDurableTestRunner::new(simple_handler);
1720
1721 let result1 = runner.run("first".to_string()).await.unwrap();
1723 assert_eq!(result1.get_result().unwrap(), "processed: first");
1724
1725 runner.reset().await;
1727
1728 let result2 = runner.run("second".to_string()).await.unwrap();
1730 assert_eq!(result2.get_result().unwrap(), "processed: second");
1731 assert_eq!(result2.get_status(), ExecutionStatus::Succeeded);
1732 }
1733
1734 #[tokio::test]
1737 async fn test_get_operation_by_id() {
1738 let runner: LocalDurableTestRunner<String, String> =
1739 LocalDurableTestRunner::new(simple_handler);
1740
1741 {
1743 let mut storage = runner.operation_storage.write().await;
1744 let mut op = Operation::new("test-op-id", OperationType::Step);
1745 op.name = Some("test_step".to_string());
1746 storage.add_operation(op);
1747 }
1748
1749 let found = runner.get_operation_by_id("test-op-id").await;
1751 assert!(found.is_some());
1752 assert_eq!(found.unwrap().operation_id, "test-op-id");
1753
1754 let not_found = runner.get_operation_by_id("nonexistent").await;
1756 assert!(not_found.is_none());
1757 }
1758
1759 #[tokio::test]
1760 async fn test_get_operation_by_name() {
1761 let runner: LocalDurableTestRunner<String, String> =
1762 LocalDurableTestRunner::new(simple_handler);
1763
1764 {
1766 let mut storage = runner.operation_storage.write().await;
1767 let mut op1 = Operation::new("op-1", OperationType::Step);
1768 op1.name = Some("process".to_string());
1769 storage.add_operation(op1);
1770
1771 let mut op2 = Operation::new("op-2", OperationType::Step);
1772 op2.name = Some("validate".to_string());
1773 storage.add_operation(op2);
1774 }
1775
1776 let found = runner.get_operation("process").await;
1778 assert!(found.is_some());
1779 assert_eq!(found.unwrap().operation_id, "op-1");
1780
1781 let not_found = runner.get_operation("nonexistent").await;
1783 assert!(not_found.is_none());
1784 }
1785
1786 #[tokio::test]
1787 async fn test_get_operation_by_index() {
1788 let runner: LocalDurableTestRunner<String, String> =
1789 LocalDurableTestRunner::new(simple_handler);
1790
1791 {
1793 let mut storage = runner.operation_storage.write().await;
1794 storage.add_operation(Operation::new("op-0", OperationType::Step));
1795 storage.add_operation(Operation::new("op-1", OperationType::Wait));
1796 storage.add_operation(Operation::new("op-2", OperationType::Callback));
1797 }
1798
1799 let op0 = runner.get_operation_by_index(0).await;
1801 assert!(op0.is_some());
1802 assert_eq!(op0.unwrap().operation_id, "op-0");
1803
1804 let op1 = runner.get_operation_by_index(1).await;
1805 assert!(op1.is_some());
1806 assert_eq!(op1.unwrap().operation_id, "op-1");
1807
1808 let op2 = runner.get_operation_by_index(2).await;
1809 assert!(op2.is_some());
1810 assert_eq!(op2.unwrap().operation_id, "op-2");
1811
1812 let out_of_bounds = runner.get_operation_by_index(3).await;
1814 assert!(out_of_bounds.is_none());
1815 }
1816
1817 #[tokio::test]
1818 async fn test_get_operation_by_name_and_index() {
1819 let runner: LocalDurableTestRunner<String, String> =
1820 LocalDurableTestRunner::new(simple_handler);
1821
1822 {
1824 let mut storage = runner.operation_storage.write().await;
1825 let mut op1 = Operation::new("op-1", OperationType::Step);
1826 op1.name = Some("process".to_string());
1827 storage.add_operation(op1);
1828
1829 let mut op2 = Operation::new("op-2", OperationType::Step);
1830 op2.name = Some("process".to_string());
1831 storage.add_operation(op2);
1832
1833 let mut op3 = Operation::new("op-3", OperationType::Step);
1834 op3.name = Some("process".to_string());
1835 storage.add_operation(op3);
1836 }
1837
1838 let first = runner.get_operation_by_name_and_index("process", 0).await;
1840 assert!(first.is_some());
1841 assert_eq!(first.unwrap().operation_id, "op-1");
1842
1843 let second = runner.get_operation_by_name_and_index("process", 1).await;
1844 assert!(second.is_some());
1845 assert_eq!(second.unwrap().operation_id, "op-2");
1846
1847 let third = runner.get_operation_by_name_and_index("process", 2).await;
1848 assert!(third.is_some());
1849 assert_eq!(third.unwrap().operation_id, "op-3");
1850
1851 let out_of_bounds = runner.get_operation_by_name_and_index("process", 3).await;
1853 assert!(out_of_bounds.is_none());
1854 }
1855
1856 #[tokio::test]
1857 async fn test_get_all_operations() {
1858 let runner: LocalDurableTestRunner<String, String> =
1859 LocalDurableTestRunner::new(simple_handler);
1860
1861 {
1863 let mut storage = runner.operation_storage.write().await;
1864 storage.add_operation(Operation::new("op-0", OperationType::Step));
1865 storage.add_operation(Operation::new("op-1", OperationType::Wait));
1866 storage.add_operation(Operation::new("op-2", OperationType::Callback));
1867 }
1868
1869 let all_ops = runner.get_all_operations().await;
1871 assert_eq!(all_ops.len(), 3);
1872 assert_eq!(all_ops[0].operation_id, "op-0");
1873 assert_eq!(all_ops[1].operation_id, "op-1");
1874 assert_eq!(all_ops[2].operation_id, "op-2");
1875 }
1876
1877 #[tokio::test]
1880 async fn test_register_durable_function() {
1881 async fn helper_func(
1882 _input: serde_json::Value,
1883 _ctx: DurableContext,
1884 ) -> Result<serde_json::Value, DurableError> {
1885 Ok(serde_json::json!({"result": "ok"}))
1886 }
1887
1888 let runner: LocalDurableTestRunner<String, String> =
1889 LocalDurableTestRunner::new(simple_handler);
1890
1891 runner
1893 .register_durable_function("helper", helper_func)
1894 .await;
1895
1896 assert!(runner.has_registered_function("helper").await);
1898 assert_eq!(runner.registered_function_count().await, 1);
1899 }
1900
1901 #[tokio::test]
1902 async fn test_register_regular_function() {
1903 fn simple_func(_input: serde_json::Value) -> Result<serde_json::Value, DurableError> {
1904 Ok(serde_json::json!({"result": "ok"}))
1905 }
1906
1907 let runner: LocalDurableTestRunner<String, String> =
1908 LocalDurableTestRunner::new(simple_handler);
1909
1910 runner.register_function("simple", simple_func).await;
1912
1913 assert!(runner.has_registered_function("simple").await);
1915 assert_eq!(runner.registered_function_count().await, 1);
1916 }
1917
1918 #[tokio::test]
1919 async fn test_register_multiple_functions() {
1920 async fn durable_func(
1921 _input: serde_json::Value,
1922 _ctx: DurableContext,
1923 ) -> Result<serde_json::Value, DurableError> {
1924 Ok(serde_json::json!({}))
1925 }
1926
1927 fn regular_func(_input: serde_json::Value) -> Result<serde_json::Value, DurableError> {
1928 Ok(serde_json::json!({}))
1929 }
1930
1931 let runner: LocalDurableTestRunner<String, String> =
1932 LocalDurableTestRunner::new(simple_handler);
1933
1934 runner
1936 .register_durable_function("durable1", durable_func)
1937 .await;
1938 runner.register_function("regular1", regular_func).await;
1939 runner
1940 .register_durable_function("durable2", durable_func)
1941 .await;
1942
1943 assert!(runner.has_registered_function("durable1").await);
1945 assert!(runner.has_registered_function("regular1").await);
1946 assert!(runner.has_registered_function("durable2").await);
1947 assert!(!runner.has_registered_function("nonexistent").await);
1948 assert_eq!(runner.registered_function_count().await, 3);
1949 }
1950
1951 #[tokio::test]
1952 async fn test_clear_registered_functions() {
1953 fn simple_func(_input: serde_json::Value) -> Result<serde_json::Value, DurableError> {
1954 Ok(serde_json::json!({}))
1955 }
1956
1957 let mut runner: LocalDurableTestRunner<String, String> =
1958 LocalDurableTestRunner::new(simple_handler);
1959
1960 runner.register_function("func1", simple_func).await;
1962 runner.register_function("func2", simple_func).await;
1963 assert_eq!(runner.registered_function_count().await, 2);
1964
1965 runner.clear_registered_functions().await;
1967 assert_eq!(runner.registered_function_count().await, 0);
1968 assert!(!runner.has_registered_function("func1").await);
1969 assert!(!runner.has_registered_function("func2").await);
1970 }
1971
1972 #[tokio::test]
1973 async fn test_register_function_overwrites_existing() {
1974 fn func_v1(_input: serde_json::Value) -> Result<serde_json::Value, DurableError> {
1975 Ok(serde_json::json!({"version": 1}))
1976 }
1977
1978 fn func_v2(_input: serde_json::Value) -> Result<serde_json::Value, DurableError> {
1979 Ok(serde_json::json!({"version": 2}))
1980 }
1981
1982 let runner: LocalDurableTestRunner<String, String> =
1983 LocalDurableTestRunner::new(simple_handler);
1984
1985 runner.register_function("func", func_v1).await;
1987 assert_eq!(runner.registered_function_count().await, 1);
1988
1989 runner.register_function("func", func_v2).await;
1991
1992 assert_eq!(runner.registered_function_count().await, 1);
1994 assert!(runner.has_registered_function("func").await);
1995 }
1996
1997 #[tokio::test]
2000 async fn test_run_with_orchestrator_successful_execution() {
2001 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2003 .await
2004 .unwrap();
2005
2006 let mut runner = LocalDurableTestRunner::new(simple_handler);
2007 let result = runner
2008 .run_with_orchestrator("hello".to_string())
2009 .await
2010 .unwrap();
2011
2012 assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
2014
2015 let output = result.get_result().unwrap();
2017 assert_eq!(output, "processed: hello");
2018
2019 assert!(!result.get_invocations().is_empty());
2021 }
2022
2023 #[tokio::test]
2024 async fn test_run_with_orchestrator_failed_execution() {
2025 async fn failing_handler(
2026 _input: String,
2027 _ctx: DurableContext,
2028 ) -> Result<String, DurableError> {
2029 Err(DurableError::execution("Test failure"))
2030 }
2031
2032 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2034 .await
2035 .unwrap();
2036
2037 let mut runner = LocalDurableTestRunner::new(failing_handler);
2038 let result = runner
2039 .run_with_orchestrator("hello".to_string())
2040 .await
2041 .unwrap();
2042
2043 assert_eq!(result.get_status(), ExecutionStatus::Failed);
2045
2046 let error = result.get_error().unwrap();
2048 assert!(error
2049 .error_message
2050 .as_ref()
2051 .unwrap()
2052 .contains("Test failure"));
2053 }
2054
2055 #[tokio::test]
2056 async fn test_run_with_orchestrator_with_time_skipping() {
2057 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2059 .await
2060 .unwrap();
2061
2062 LocalDurableTestRunner::<String, String>::setup_test_environment(TestEnvironmentConfig {
2064 skip_time: true,
2065 checkpoint_delay: None,
2066 })
2067 .await
2068 .unwrap();
2069
2070 let mut runner = LocalDurableTestRunner::new(simple_handler);
2071 let result = runner
2072 .run_with_orchestrator("hello".to_string())
2073 .await
2074 .unwrap();
2075
2076 assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
2078
2079 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2081 .await
2082 .unwrap();
2083 }
2084
2085 #[tokio::test]
2086 async fn test_run_with_orchestrator_populates_nodejs_history_events() {
2087 use crate::checkpoint_server::NodeJsEventType;
2088
2089 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2091 .await
2092 .unwrap();
2093
2094 let mut runner = LocalDurableTestRunner::new(simple_handler);
2095 let result = runner
2096 .run_with_orchestrator("hello".to_string())
2097 .await
2098 .unwrap();
2099
2100 assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
2102
2103 let nodejs_events = result.get_nodejs_history_events();
2105 assert!(
2106 !nodejs_events.is_empty(),
2107 "Node.js history events should be populated"
2108 );
2109
2110 assert_eq!(
2112 nodejs_events[0].event_type,
2113 NodeJsEventType::ExecutionStarted,
2114 "First event should be ExecutionStarted"
2115 );
2116
2117 for (i, event) in nodejs_events.iter().enumerate() {
2119 assert_eq!(
2120 event.event_id,
2121 (i + 1) as u64,
2122 "Event IDs should be sequential starting from 1"
2123 );
2124 }
2125
2126 for event in nodejs_events {
2128 assert!(
2129 event.event_timestamp.contains('T') && event.event_timestamp.contains('Z'),
2130 "Timestamps should be in ISO 8601 format"
2131 );
2132 }
2133 }
2134}
2135
2136#[cfg(test)]
2140mod property_tests {
2141 use super::*;
2142 use durable_execution_sdk::OperationType;
2143 use proptest::prelude::*;
2144
2145 fn non_empty_string_strategy() -> impl Strategy<Value = String> {
2147 "[a-zA-Z0-9_ ]{1,32}".prop_map(|s| s)
2148 }
2149
2150 fn operation_name_strategy() -> impl Strategy<Value = String> {
2152 "[a-zA-Z_][a-zA-Z0-9_]{0,15}".prop_map(|s| s)
2153 }
2154
2155 fn function_name_strategy() -> impl Strategy<Value = String> {
2157 "[a-zA-Z_][a-zA-Z0-9_]{0,15}".prop_map(|s| s)
2158 }
2159
2160 proptest! {
2161 #[test]
2169 fn prop_execution_status_consistency_success(input in non_empty_string_strategy()) {
2170 let rt = tokio::runtime::Builder::new_current_thread()
2172 .enable_all()
2173 .build()
2174 .unwrap();
2175 rt.block_on(async {
2176 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2178 .await
2179 .unwrap();
2180
2181 async fn success_handler(
2183 input: String,
2184 _ctx: DurableContext,
2185 ) -> Result<String, DurableError> {
2186 Ok(format!("success: {}", input))
2187 }
2188
2189 let mut runner = LocalDurableTestRunner::new(success_handler);
2190 let result = runner.run(input.clone()).await.unwrap();
2191
2192 prop_assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
2194 prop_assert!(result.get_result().is_ok());
2195 let expected = format!("success: {}", input);
2196 prop_assert_eq!(result.get_result().unwrap(), &expected);
2197
2198 Ok(())
2199 })?;
2200 }
2201
2202 #[test]
2208 fn prop_execution_status_consistency_failure(error_msg in non_empty_string_strategy()) {
2209 let rt = tokio::runtime::Builder::new_current_thread()
2211 .enable_all()
2212 .build()
2213 .unwrap();
2214 rt.block_on(async {
2215 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2217 .await
2218 .unwrap();
2219
2220 let error_msg_clone = error_msg.clone();
2222 let failing_handler = move |_input: String, _ctx: DurableContext| {
2223 let msg = error_msg_clone.clone();
2224 async move { Err::<String, DurableError>(DurableError::execution(msg)) }
2225 };
2226
2227 let mut runner = LocalDurableTestRunner::new(failing_handler);
2228 let result = runner.run("test".to_string()).await.unwrap();
2229
2230 prop_assert_eq!(result.get_status(), ExecutionStatus::Failed);
2232 prop_assert!(result.get_error().is_ok());
2233
2234 Ok(())
2235 })?;
2236 }
2237
2238 #[test]
2245 fn prop_reset_clears_state(
2246 input1 in non_empty_string_strategy(),
2247 input2 in non_empty_string_strategy()
2248 ) {
2249 let rt = tokio::runtime::Builder::new_current_thread()
2251 .enable_all()
2252 .build()
2253 .unwrap();
2254 rt.block_on(async {
2255 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2257 .await
2258 .unwrap();
2259
2260 CheckpointWorkerManager::reset_instance_for_testing();
2262
2263 async fn simple_handler(
2264 input: String,
2265 _ctx: DurableContext,
2266 ) -> Result<String, DurableError> {
2267 Ok(format!("processed: {}", input))
2268 }
2269
2270 let mut runner = LocalDurableTestRunner::new(simple_handler);
2271
2272 let _ = runner.run(input1).await.unwrap();
2274
2275 runner.reset().await;
2277
2278 prop_assert_eq!(runner.operation_count().await, 0);
2280
2281 let result2 = runner.run(input2.clone()).await.unwrap();
2283 prop_assert_eq!(result2.get_status(), ExecutionStatus::Succeeded);
2284 let expected = format!("processed: {}", input2);
2285 prop_assert_eq!(result2.get_result().unwrap(), &expected);
2286
2287 Ok(())
2288 })?;
2289 }
2290
2291 #[test]
2299 fn prop_operation_lookup_consistency(
2300 names in prop::collection::vec(operation_name_strategy(), 1..=5)
2301 ) {
2302 let rt = tokio::runtime::Runtime::new().unwrap();
2303 rt.block_on(async {
2304 let runner: LocalDurableTestRunner<String, String> =
2305 LocalDurableTestRunner::new(|input: String, _ctx: DurableContext| async move {
2306 Ok(input)
2307 });
2308
2309 {
2311 let mut storage = runner.operation_storage.write().await;
2312 for (i, name) in names.iter().enumerate() {
2313 let mut op = Operation::new(&format!("op-{}", i), OperationType::Step);
2314 op.name = Some(name.clone());
2315 storage.add_operation(op);
2316 }
2317 }
2318
2319 for (i, _name) in names.iter().enumerate() {
2321 let op = runner.get_operation_by_index(i).await;
2322 prop_assert!(op.is_some());
2323 let expected_id = format!("op-{}", i);
2324 prop_assert_eq!(&op.unwrap().operation_id, &expected_id);
2325 }
2326
2327 for name in &names {
2329 let op = runner.get_operation(name).await;
2330 prop_assert!(op.is_some());
2331 prop_assert_eq!(op.as_ref().unwrap().name.as_ref().unwrap(), name);
2332 }
2333
2334 for i in 0..names.len() {
2336 let op = runner.get_operation_by_id(&format!("op-{}", i)).await;
2337 prop_assert!(op.is_some());
2338 let expected_id = format!("op-{}", i);
2339 prop_assert_eq!(&op.unwrap().operation_id, &expected_id);
2340 }
2341
2342 let all_ops = runner.get_all_operations().await;
2344 prop_assert_eq!(all_ops.len(), names.len());
2345 for (i, op) in all_ops.iter().enumerate() {
2346 let expected_id = format!("op-{}", i);
2347 prop_assert_eq!(&op.operation_id, &expected_id);
2348 }
2349
2350 Ok(())
2351 })?;
2352 }
2353
2354 #[test]
2361 fn prop_function_registration_retrieval(
2362 func_names in prop::collection::vec(function_name_strategy(), 1..=5)
2363 ) {
2364 let rt = tokio::runtime::Runtime::new().unwrap();
2365 rt.block_on(async {
2366 let runner: LocalDurableTestRunner<String, String> =
2367 LocalDurableTestRunner::new(|input: String, _ctx: DurableContext| async move {
2368 Ok(input)
2369 });
2370
2371 for name in &func_names {
2373 runner.register_function(name.clone(), |_input: serde_json::Value| {
2374 Ok(serde_json::json!({}))
2375 }).await;
2376 }
2377
2378 for name in &func_names {
2380 prop_assert!(
2381 runner.has_registered_function(name).await,
2382 "Function '{}' should be registered",
2383 name
2384 );
2385 }
2386
2387 let unique_names: std::collections::HashSet<_> = func_names.iter().collect();
2389 prop_assert_eq!(
2390 runner.registered_function_count().await,
2391 unique_names.len()
2392 );
2393
2394 prop_assert!(!runner.has_registered_function("__nonexistent__").await);
2396
2397 Ok(())
2398 })?;
2399 }
2400
2401 #[test]
2408 fn prop_operation_capture_completeness(
2409 num_steps in 1usize..=5,
2410 step_values in prop::collection::vec(1i32..100, 1..=5)
2411 ) {
2412 let rt = tokio::runtime::Builder::new_current_thread()
2414 .enable_all()
2415 .build()
2416 .unwrap();
2417 rt.block_on(async {
2418 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2420 .await
2421 .unwrap();
2422
2423 let num_steps_to_perform = num_steps.min(step_values.len());
2425 let values = step_values.clone();
2426
2427 let multi_step_handler = move |_input: String, ctx: DurableContext| {
2428 let values = values.clone();
2429 let num = num_steps_to_perform;
2430 async move {
2431 let mut results = Vec::new();
2432 for i in 0..num {
2433 let value = values.get(i).copied().unwrap_or(0);
2434 let step_name = format!("step_{}", i);
2435 let result = ctx.step_named(
2437 &step_name,
2438 |_| Ok(value * 2),
2439 None
2440 ).await?;
2441 results.push(result);
2442 }
2443 Ok::<String, DurableError>(format!("completed {} steps", results.len()))
2444 }
2445 };
2446
2447 let mut runner = LocalDurableTestRunner::new(multi_step_handler);
2448 let result = runner.run("test".to_string()).await.unwrap();
2449
2450 let operations = result.get_operations();
2454
2455 prop_assert!(
2460 !operations.is_empty() || num_steps_to_perform == 0,
2461 "Operations should be captured when steps are performed"
2462 );
2463
2464 prop_assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
2466
2467 Ok(())
2468 })?;
2469 }
2470
2471 #[test]
2483 fn prop_time_skipping_acceleration(
2484 wait_seconds in 5u64..=60
2485 ) {
2486 let rt = tokio::runtime::Builder::new_current_thread()
2488 .enable_all()
2489 .build()
2490 .unwrap();
2491
2492 rt.block_on(async {
2493 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2495 .await
2496 .unwrap();
2497
2498 LocalDurableTestRunner::<String, String>::setup_test_environment(
2499 TestEnvironmentConfig {
2500 skip_time: true,
2501 checkpoint_delay: None,
2502 }
2503 ).await.unwrap();
2504
2505 prop_assert!(
2507 LocalDurableTestRunner::<String, String>::is_time_skipping_enabled(),
2508 "Time skipping should be enabled after setup"
2509 );
2510
2511 let wait_duration = wait_seconds;
2513 let wait_handler = move |_input: String, ctx: DurableContext| {
2514 async move {
2515 ctx.wait(
2517 durable_execution_sdk::Duration::from_seconds(wait_duration),
2518 Some("test_wait")
2519 ).await?;
2520 Ok::<String, DurableError>("wait completed".to_string())
2521 }
2522 };
2523
2524 let mut runner = LocalDurableTestRunner::new(wait_handler);
2525
2526 let start_time = std::time::Instant::now();
2528 let result = runner.run("test".to_string()).await.unwrap();
2529 let elapsed = start_time.elapsed();
2530
2531 let max_allowed_seconds = wait_seconds.saturating_sub(1).max(1);
2540 prop_assert!(
2541 elapsed.as_secs() < max_allowed_seconds,
2542 "Wall-clock time ({:?}) should be less than wait duration ({} seconds). \
2543 Time skipping should prevent actual waiting.",
2544 elapsed,
2545 wait_seconds
2546 );
2547
2548 prop_assert!(
2551 result.get_status() == ExecutionStatus::Running ||
2552 result.get_status() == ExecutionStatus::Succeeded,
2553 "Execution should be Running (suspended) or Succeeded, got {:?}",
2554 result.get_status()
2555 );
2556
2557 LocalDurableTestRunner::<String, String>::teardown_test_environment()
2559 .await
2560 .unwrap();
2561
2562 Ok(())
2563 })?;
2564 }
2565 }
2566}