1use std::collections::HashSet;
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::sync::Arc;
31
32use serde::de::DeserializeOwned;
33use serde::Serialize;
34use tokio::sync::{Mutex, RwLock};
35
36use durable_execution_sdk::{
37 DurableContext, DurableError, DurableServiceClient, Operation, OperationStatus, OperationType,
38};
39
40use super::scheduler::{QueueScheduler, Scheduler};
41use super::types::ExecutionId;
42use super::worker_manager::CheckpointWorkerManager;
43use crate::operation::CallbackSender;
44use crate::operation_handle::{OperationHandle, OperationMatcher};
45use crate::types::{ExecutionStatus, Invocation, TestResultError};
46
47#[derive(Debug, Clone, Default)]
49pub struct SkipTimeConfig {
50 pub enabled: bool,
52}
53
54#[derive(Debug)]
56pub struct TestExecutionResult<T> {
57 pub status: ExecutionStatus,
59 pub result: Option<T>,
61 pub error: Option<TestResultError>,
63 pub operations: Vec<Operation>,
65 pub invocations: Vec<Invocation>,
67 pub execution_id: String,
69}
70
71impl<T> TestExecutionResult<T> {
72 pub fn success(result: T, operations: Vec<Operation>, execution_id: String) -> Self {
74 Self {
75 status: ExecutionStatus::Succeeded,
76 result: Some(result),
77 error: None,
78 operations,
79 invocations: Vec::new(),
80 execution_id,
81 }
82 }
83
84 pub fn failure(
86 error: TestResultError,
87 operations: Vec<Operation>,
88 execution_id: String,
89 ) -> Self {
90 Self {
91 status: ExecutionStatus::Failed,
92 result: None,
93 error: Some(error),
94 operations,
95 invocations: Vec::new(),
96 execution_id,
97 }
98 }
99
100 pub fn running(operations: Vec<Operation>, execution_id: String) -> Self {
102 Self {
103 status: ExecutionStatus::Running,
104 result: None,
105 error: None,
106 operations,
107 invocations: Vec::new(),
108 execution_id,
109 }
110 }
111
112 pub fn with_invocation(mut self, invocation: Invocation) -> Self {
114 self.invocations.push(invocation);
115 self
116 }
117}
118
119#[derive(Debug, Default)]
121pub struct OperationStorage {
122 operations: Vec<Operation>,
124 operations_by_id: std::collections::HashMap<String, usize>,
126 operations_by_name: std::collections::HashMap<String, Vec<usize>>,
128}
129
130impl OperationStorage {
131 pub fn new() -> Self {
133 Self::default()
134 }
135
136 pub fn add_operation(&mut self, operation: Operation) {
138 let index = self.operations.len();
139 let id = operation.operation_id.clone();
140 let name = operation.name.clone();
141
142 self.operations.push(operation);
143 self.operations_by_id.insert(id, index);
144
145 if let Some(name) = name {
146 self.operations_by_name.entry(name).or_default().push(index);
147 }
148 }
149
150 pub fn update_operation(&mut self, operation: Operation) {
152 let id = operation.operation_id.clone();
153 if let Some(&index) = self.operations_by_id.get(&id) {
154 self.operations[index] = operation;
155 } else {
156 self.add_operation(operation);
157 }
158 }
159
160 pub fn get_by_id(&self, id: &str) -> Option<&Operation> {
162 self.operations_by_id
163 .get(id)
164 .and_then(|&idx| self.operations.get(idx))
165 }
166
167 pub fn get_all(&self) -> &[Operation] {
169 &self.operations
170 }
171
172 pub fn clear(&mut self) {
174 self.operations.clear();
175 self.operations_by_id.clear();
176 self.operations_by_name.clear();
177 }
178
179 pub fn len(&self) -> usize {
181 self.operations.len()
182 }
183
184 pub fn is_empty(&self) -> bool {
186 self.operations.is_empty()
187 }
188}
189
190pub type BoxedHandler<I, O> = Box<
192 dyn Fn(I, DurableContext) -> Pin<Box<dyn Future<Output = Result<O, DurableError>> + Send>>
193 + Send
194 + Sync,
195>;
196
197pub struct TestExecutionOrchestrator<I, O>
206where
207 I: DeserializeOwned + Send + Serialize + 'static,
208 O: Serialize + DeserializeOwned + Send + 'static,
209{
210 handler: BoxedHandler<I, O>,
212 operation_storage: Arc<RwLock<OperationStorage>>,
214 checkpoint_api: Arc<CheckpointWorkerManager>,
216 skip_time_config: SkipTimeConfig,
218 scheduler: Box<dyn Scheduler>,
220 pending_operations: HashSet<String>,
222 invocation_active: Arc<AtomicBool>,
224 execution_id: Option<ExecutionId>,
226 checkpoint_token: Option<String>,
228 execution_complete: Arc<AtomicBool>,
230 final_result: Arc<Mutex<Option<Result<O, DurableError>>>>,
232 registered_handles: Vec<OperationHandle>,
234 shared_operations: Arc<RwLock<Vec<Operation>>>,
236 callback_sender: Option<Arc<dyn CallbackSender>>,
238}
239
240impl<I, O> TestExecutionOrchestrator<I, O>
241where
242 I: DeserializeOwned + Send + Serialize + Clone + 'static,
243 O: Serialize + DeserializeOwned + Send + 'static,
244{
245 pub fn new<F, Fut>(
254 handler: F,
255 operation_storage: Arc<RwLock<OperationStorage>>,
256 checkpoint_api: Arc<CheckpointWorkerManager>,
257 skip_time_config: SkipTimeConfig,
258 ) -> Self
259 where
260 F: Fn(I, DurableContext) -> Fut + Send + Sync + 'static,
261 Fut: Future<Output = Result<O, DurableError>> + Send + 'static,
262 {
263 let boxed_handler = Box::new(move |input: I, ctx: DurableContext| {
264 let fut = handler(input, ctx);
265 Box::pin(fut) as Pin<Box<dyn Future<Output = Result<O, DurableError>> + Send>>
266 });
267
268 let scheduler: Box<dyn Scheduler> = Box::new(QueueScheduler::new());
270
271 Self {
272 handler: boxed_handler,
273 operation_storage,
274 checkpoint_api,
275 skip_time_config,
276 scheduler,
277 pending_operations: HashSet::new(),
278 invocation_active: Arc::new(AtomicBool::new(false)),
279 execution_id: None,
280 checkpoint_token: None,
281 execution_complete: Arc::new(AtomicBool::new(false)),
282 final_result: Arc::new(Mutex::new(None)),
283 registered_handles: Vec::new(),
284 shared_operations: Arc::new(RwLock::new(Vec::new())),
285 callback_sender: None,
286 }
287 }
288
289 pub fn with_handles(
297 mut self,
298 handles: Vec<OperationHandle>,
299 shared_operations: Arc<RwLock<Vec<Operation>>>,
300 callback_sender: Option<Arc<dyn CallbackSender>>,
301 ) -> Self {
302 if let Some(ref sender) = callback_sender {
306 for handle in &handles {
307 let sender_clone = Arc::clone(sender);
308 if let Ok(mut guard) = handle.callback_sender.try_write() {
309 *guard = Some(sender_clone);
310 }
311 }
312 }
313 self.registered_handles = handles;
314 self.shared_operations = shared_operations;
315 self.callback_sender = callback_sender;
316 self
317 }
318
319 pub fn is_time_skipping_enabled(&self) -> bool {
321 self.skip_time_config.enabled
322 }
323
324 pub fn execution_id(&self) -> Option<&str> {
326 self.execution_id.as_deref()
327 }
328
329 pub fn checkpoint_token(&self) -> Option<&str> {
331 self.checkpoint_token.as_deref()
332 }
333
334 pub fn is_execution_complete(&self) -> bool {
336 self.execution_complete.load(Ordering::SeqCst)
337 }
338
339 pub fn is_invocation_active(&self) -> bool {
341 self.invocation_active.load(Ordering::SeqCst)
342 }
343}
344
345impl<I, O> TestExecutionOrchestrator<I, O>
346where
347 I: DeserializeOwned + Send + Serialize + Clone + 'static,
348 O: Serialize + DeserializeOwned + Send + 'static,
349{
350 pub async fn execute_handler(
364 &mut self,
365 payload: I,
366 ) -> Result<TestExecutionResult<O>, crate::error::TestError> {
367 use super::types::{ApiType, StartDurableExecutionRequest};
368 use durable_execution_sdk::lambda::InitialExecutionState;
369 use durable_execution_sdk::state::ExecutionState;
370
371 self.operation_storage.write().await.clear();
373 self.pending_operations.clear();
374 self.execution_complete.store(false, Ordering::SeqCst);
375 *self.final_result.lock().await = None;
376
377 let payload_json = serde_json::to_string(&payload)?;
379
380 let invocation_id = uuid::Uuid::new_v4().to_string();
382 let start_request = StartDurableExecutionRequest {
383 invocation_id: invocation_id.clone(),
384 payload: Some(payload_json),
385 };
386 let start_payload = serde_json::to_string(&start_request)?;
387
388 let start_response = self
389 .checkpoint_api
390 .send_api_request(ApiType::StartDurableExecution, start_payload)
391 .await?;
392
393 if let Some(error) = start_response.error {
394 return Err(crate::error::TestError::CheckpointServerError(error));
395 }
396
397 let invocation_result: super::InvocationResult =
398 serde_json::from_str(&start_response.payload.ok_or_else(|| {
399 crate::error::TestError::CheckpointServerError(
400 "Empty response from checkpoint server".to_string(),
401 )
402 })?)?;
403
404 self.execution_id = Some(invocation_result.execution_id.clone());
405 self.checkpoint_token = Some(invocation_result.checkpoint_token.clone());
406
407 let execution_arn = invocation_result.execution_id.clone();
408 let checkpoint_token = invocation_result.checkpoint_token.clone();
409
410 let initial_state = InitialExecutionState::new();
412
413 let execution_state = Arc::new(ExecutionState::new(
415 &execution_arn,
416 &checkpoint_token,
417 initial_state,
418 self.checkpoint_api.clone(),
419 ));
420
421 let ctx = DurableContext::new(execution_state.clone());
423
424 let start_time = chrono::Utc::now();
426 let mut invocation = Invocation::with_start(start_time);
427
428 self.invocation_active.store(true, Ordering::SeqCst);
430 let handler_result = (self.handler)(payload.clone(), ctx).await;
431 self.invocation_active.store(false, Ordering::SeqCst);
432
433 let end_time = chrono::Utc::now();
435 invocation = invocation.with_end(end_time);
436
437 let operations = match self.checkpoint_api.get_operations(&execution_arn, "").await {
439 Ok(response) => {
440 let mut storage = self.operation_storage.write().await;
441 for op in &response.operations {
442 storage.update_operation(op.clone());
443 }
444 response.operations
445 }
446 Err(_) => Vec::new(),
447 };
448
449 self.populate_handles(&operations).await;
451
452 match handler_result {
454 Ok(result) => {
455 self.execution_complete.store(true, Ordering::SeqCst);
456 let mut test_result =
457 TestExecutionResult::success(result, operations, execution_arn);
458 test_result.invocations.push(invocation);
459 Ok(test_result)
460 }
461 Err(error) => {
462 if error.is_suspend() {
464 let test_result = self
466 .handle_pending_execution(payload, execution_arn, invocation)
467 .await?;
468 Ok(test_result)
469 } else {
470 self.execution_complete.store(true, Ordering::SeqCst);
471 let error_obj = durable_execution_sdk::ErrorObject::from(&error);
472 let test_error = TestResultError::new(error_obj.error_type, error.to_string());
473 let mut test_result =
474 TestExecutionResult::failure(test_error.clone(), operations, execution_arn);
475 test_result
476 .invocations
477 .push(invocation.with_error(test_error));
478 Ok(test_result)
479 }
480 }
481 }
482 }
483
484 async fn handle_pending_execution(
486 &mut self,
487 payload: I,
488 execution_arn: String,
489 initial_invocation: Invocation,
490 ) -> Result<TestExecutionResult<O>, crate::error::TestError> {
491 let mut invocations = vec![initial_invocation];
492 let mut iteration_count = 0;
493 const MAX_ITERATIONS: usize = 100; loop {
496 iteration_count += 1;
497 if iteration_count > MAX_ITERATIONS {
498 return Err(crate::error::TestError::CheckpointServerError(
499 "Maximum iteration count exceeded".to_string(),
500 ));
501 }
502
503 let mut operations = match self.checkpoint_api.get_operations(&execution_arn, "").await
505 {
506 Ok(response) => {
507 let mut storage = self.operation_storage.write().await;
508 for op in &response.operations {
509 storage.update_operation(op.clone());
510 }
511 response.operations
512 }
513 Err(_) => Vec::new(),
514 };
515
516 self.populate_handles(&operations).await;
518
519 let process_result = self.process_operations(&operations, &execution_arn);
521
522 match process_result {
523 ProcessOperationsResult::ExecutionSucceeded(result_str) => {
524 self.execution_complete.store(true, Ordering::SeqCst);
525 if let Ok(result) = serde_json::from_str::<O>(&result_str) {
526 let mut test_result =
527 TestExecutionResult::success(result, operations, execution_arn);
528 test_result.invocations = invocations;
529 return Ok(test_result);
530 }
531 let mut test_result = TestExecutionResult::running(operations, execution_arn);
533 test_result.invocations = invocations;
534 return Ok(test_result);
535 }
536 ProcessOperationsResult::ExecutionFailed(error) => {
537 self.execution_complete.store(true, Ordering::SeqCst);
538 let mut test_result =
539 TestExecutionResult::failure(error, operations, execution_arn);
540 test_result.invocations = invocations;
541 return Ok(test_result);
542 }
543 ProcessOperationsResult::NoPendingOperations => {
544 let mut test_result = TestExecutionResult::running(operations, execution_arn);
546 test_result.invocations = invocations;
547 return Ok(test_result);
548 }
549 ProcessOperationsResult::ShouldReinvoke(advance_time_ms) => {
550 if advance_time_ms.is_none() {
554 if !self.registered_handles.is_empty() {
558 const MAX_CALLBACK_POLLS: usize = 6000;
561 let mut callback_poll_count = 0;
562
563 loop {
564 callback_poll_count += 1;
565 if callback_poll_count > MAX_CALLBACK_POLLS {
566 return Err(crate::error::TestError::CheckpointServerError(
567 format!(
568 "Callback polling timed out after {} iterations (~{}s). \
569 Pending operations: {:?}",
570 MAX_CALLBACK_POLLS,
571 MAX_CALLBACK_POLLS * 50 / 1000,
572 self.pending_operations
573 ),
574 ));
575 }
576
577 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
578
579 let poll_operations = match self
581 .checkpoint_api
582 .get_operations(&execution_arn, "")
583 .await
584 {
585 Ok(response) => response.operations,
586 Err(e) => {
587 tracing::warn!(
588 attempt = callback_poll_count,
589 error = %e,
590 "Failed to fetch operations during callback polling"
591 );
592 continue;
593 }
594 };
595
596 self.populate_handles(&poll_operations).await;
598
599 let all_callbacks_done =
601 self.pending_operations.iter().all(|op_id| {
602 poll_operations.iter().any(|op| {
603 &op.operation_id == op_id
604 && op.operation_type == OperationType::Callback
605 && matches!(
606 op.status,
607 OperationStatus::Succeeded
608 | OperationStatus::Failed
609 | OperationStatus::Cancelled
610 )
611 })
612 });
613
614 if all_callbacks_done {
615 operations = poll_operations;
616 break;
617 }
618 }
619 } else {
621 let mut test_result =
623 TestExecutionResult::running(operations, execution_arn);
624 test_result.invocations = invocations;
625 return Ok(test_result);
626 }
627 }
628
629 if let Some(advance_ms) = advance_time_ms {
631 if advance_ms > 0 && self.skip_time_config.enabled {
632 tokio::time::advance(tokio::time::Duration::from_millis(advance_ms))
633 .await;
634 }
635 }
636
637 if advance_time_ms.is_some() {
642 for op in &operations {
643 if op.operation_type == OperationType::Wait
644 && op.status == OperationStatus::Started
645 {
646 let mut updated_operation = op.clone();
651 updated_operation.status = OperationStatus::Succeeded;
652 updated_operation.end_timestamp =
653 Some(chrono::Utc::now().timestamp_millis());
654
655 let update_request = super::types::UpdateCheckpointDataRequest {
656 execution_id: execution_arn.clone(),
657 operation_id: op.operation_id.clone(),
658 operation_data: updated_operation,
659 payload: None,
660 error: None,
661 };
662
663 let payload = serde_json::to_string(&update_request)?;
664 let _ = self
665 .checkpoint_api
666 .send_api_request(
667 super::types::ApiType::UpdateCheckpointData,
668 payload,
669 )
670 .await;
671 }
672 }
673 }
674 }
675 }
676
677 let new_invocation_id = uuid::Uuid::new_v4().to_string();
679 let start_invocation_request = super::types::StartInvocationRequest {
680 execution_id: execution_arn.clone(),
681 invocation_id: new_invocation_id.clone(),
682 };
683 let start_payload = serde_json::to_string(&start_invocation_request)?;
684
685 let start_response = self
686 .checkpoint_api
687 .send_api_request(super::types::ApiType::StartInvocation, start_payload)
688 .await?;
689
690 if let Some(error) = start_response.error {
691 return Err(crate::error::TestError::CheckpointServerError(error));
692 }
693
694 let invocation_result: super::InvocationResult =
695 serde_json::from_str(&start_response.payload.ok_or_else(|| {
696 crate::error::TestError::CheckpointServerError(
697 "Empty response from checkpoint server".to_string(),
698 )
699 })?)?;
700
701 self.checkpoint_token = Some(invocation_result.checkpoint_token.clone());
702
703 use durable_execution_sdk::lambda::InitialExecutionState;
706 use durable_execution_sdk::state::ExecutionState;
707
708 let current_operations: Vec<Operation> = invocation_result
710 .operation_events
711 .iter()
712 .map(|e| e.operation.clone())
713 .collect();
714 let initial_state = InitialExecutionState::with_operations(current_operations);
715 let execution_state = Arc::new(ExecutionState::new(
716 &execution_arn,
717 &invocation_result.checkpoint_token,
718 initial_state,
719 self.checkpoint_api.clone(),
720 ));
721
722 let ctx = DurableContext::new(execution_state);
723
724 let start_time = chrono::Utc::now();
726 let mut invocation = Invocation::with_start(start_time);
727
728 self.invocation_active.store(true, Ordering::SeqCst);
730 let handler_result = (self.handler)(payload.clone(), ctx).await;
731 self.invocation_active.store(false, Ordering::SeqCst);
732
733 let end_time = chrono::Utc::now();
734 invocation = invocation.with_end(end_time);
735
736 match handler_result {
737 Ok(result) => {
738 self.execution_complete.store(true, Ordering::SeqCst);
739 invocations.push(invocation);
740
741 let final_operations =
743 match self.checkpoint_api.get_operations(&execution_arn, "").await {
744 Ok(response) => response.operations,
745 Err(_) => Vec::new(),
746 };
747
748 self.populate_handles(&final_operations).await;
750
751 let mut test_result =
752 TestExecutionResult::success(result, final_operations, execution_arn);
753 test_result.invocations = invocations;
754 return Ok(test_result);
755 }
756 Err(error) => {
757 if error.is_suspend() {
758 invocations.push(invocation);
760 continue;
761 } else {
762 self.execution_complete.store(true, Ordering::SeqCst);
763 let error_obj = durable_execution_sdk::ErrorObject::from(&error);
764 let test_error =
765 TestResultError::new(error_obj.error_type, error.to_string());
766 invocations.push(invocation.with_error(test_error.clone()));
767
768 let final_operations =
769 match self.checkpoint_api.get_operations(&execution_arn, "").await {
770 Ok(response) => response.operations,
771 Err(_) => Vec::new(),
772 };
773
774 self.populate_handles(&final_operations).await;
776
777 let mut test_result = TestExecutionResult::failure(
778 test_error,
779 final_operations,
780 execution_arn,
781 );
782 test_result.invocations = invocations;
783 return Ok(test_result);
784 }
785 }
786 }
787 }
788 }
789
790 fn process_operations(
804 &mut self,
805 operations: &[Operation],
806 execution_id: &str,
807 ) -> ProcessOperationsResult {
808 if let Some(exec_result) = self.handle_execution_update(operations) {
810 return exec_result;
811 }
812
813 let mut has_pending_operations = false;
815 let mut earliest_scheduled_time: Option<i64> = None;
816
817 for operation in operations {
819 let result = self.process_operation(operation, execution_id);
820
821 match result {
822 OperationProcessResult::Pending(scheduled_time) => {
823 has_pending_operations = true;
824 if let Some(time) = scheduled_time {
825 match earliest_scheduled_time {
826 None => earliest_scheduled_time = Some(time),
827 Some(current) if time < current => earliest_scheduled_time = Some(time),
828 _ => {}
829 }
830 }
831 }
832 OperationProcessResult::Completed => {
833 }
835 OperationProcessResult::NotApplicable => {
836 }
838 }
839 }
840
841 if !has_pending_operations {
842 return ProcessOperationsResult::NoPendingOperations;
843 }
844
845 let advance_time_ms = if let Some(end_ts) = earliest_scheduled_time {
850 let now_ms = chrono::Utc::now().timestamp_millis();
851 if end_ts > now_ms {
852 Some((end_ts - now_ms) as u64)
853 } else {
854 Some(0)
855 }
856 } else {
857 None
858 };
859
860 ProcessOperationsResult::ShouldReinvoke(advance_time_ms)
861 }
862
863 fn process_operation(
876 &mut self,
877 operation: &Operation,
878 execution_id: &str,
879 ) -> OperationProcessResult {
880 if operation.status.is_terminal() {
882 return OperationProcessResult::Completed;
883 }
884
885 match operation.operation_type {
886 OperationType::Wait => self.handle_wait_update(operation, execution_id),
887 OperationType::Step => self.handle_step_update(operation, execution_id),
888 OperationType::Callback => self.handle_callback_update(operation, execution_id),
889 OperationType::Execution => {
890 OperationProcessResult::NotApplicable
892 }
893 OperationType::Invoke | OperationType::Context => {
894 OperationProcessResult::NotApplicable
902 }
903 }
904 }
905
906 fn handle_wait_update(
920 &mut self,
921 operation: &Operation,
922 _execution_id: &str,
923 ) -> OperationProcessResult {
924 if operation.status != OperationStatus::Started {
926 return OperationProcessResult::Completed;
927 }
928
929 self.pending_operations
931 .insert(operation.operation_id.clone());
932
933 let scheduled_end_timestamp = operation
935 .wait_details
936 .as_ref()
937 .and_then(|details| details.scheduled_end_timestamp);
938
939 OperationProcessResult::Pending(scheduled_end_timestamp)
940 }
941
942 fn handle_step_update(
956 &mut self,
957 operation: &Operation,
958 _execution_id: &str,
959 ) -> OperationProcessResult {
960 if operation.status != OperationStatus::Pending
962 && operation.status != OperationStatus::Started
963 {
964 return OperationProcessResult::Completed;
965 }
966
967 self.pending_operations
969 .insert(operation.operation_id.clone());
970
971 let next_attempt_timestamp = operation
973 .step_details
974 .as_ref()
975 .and_then(|details| details.next_attempt_timestamp);
976
977 OperationProcessResult::Pending(next_attempt_timestamp)
978 }
979
980 fn handle_callback_update(
994 &mut self,
995 operation: &Operation,
996 _execution_id: &str,
997 ) -> OperationProcessResult {
998 if operation.status != OperationStatus::Started {
1000 self.pending_operations.remove(&operation.operation_id);
1002 return OperationProcessResult::Completed;
1003 }
1004
1005 self.pending_operations
1007 .insert(operation.operation_id.clone());
1008
1009 OperationProcessResult::Pending(None)
1012 }
1013
1014 async fn populate_handles(&self, operations: &[Operation]) {
1027 {
1029 let mut shared_ops = self.shared_operations.write().await;
1030 shared_ops.clear();
1031 shared_ops.extend(operations.iter().cloned());
1032 }
1033
1034 for handle in &self.registered_handles {
1036 let matched_op = match &handle.matcher {
1037 OperationMatcher::ByName(name) => operations
1038 .iter()
1039 .find(|op| op.name.as_deref() == Some(name)),
1040 OperationMatcher::ByIndex(index) => operations.get(*index),
1041 OperationMatcher::ById(id) => operations.iter().find(|op| op.operation_id == *id),
1042 OperationMatcher::ByNameAndIndex(name, index) => operations
1043 .iter()
1044 .filter(|op| op.name.as_deref() == Some(name))
1045 .nth(*index),
1046 };
1047
1048 if let Some(op) = matched_op {
1049 {
1051 let mut inner = handle.inner.write().await;
1052 *inner = Some(op.clone());
1053 }
1054
1055 let _ = handle.status_tx.send(Some(op.status));
1057 }
1058 }
1059 }
1060
1061 pub fn schedule_invocation_at_timestamp(
1073 &mut self,
1074 timestamp_ms: i64,
1075 execution_id: &str,
1076 operation_id: &str,
1077 ) {
1078 let checkpoint_api = Arc::clone(&self.checkpoint_api);
1079 let execution_id_owned = execution_id.to_string();
1080 let operation_id_owned = operation_id.to_string();
1081 let skip_time_enabled = self.skip_time_config.enabled;
1082
1083 let timestamp = chrono::DateTime::from_timestamp_millis(timestamp_ms)
1085 .map(|dt| dt.with_timezone(&chrono::Utc));
1086
1087 let update_checkpoint: super::scheduler::CheckpointUpdateFn = Box::new(move || {
1089 let checkpoint_api = checkpoint_api;
1090 let execution_id = execution_id_owned;
1091 let operation_id = operation_id_owned;
1092
1093 Box::pin(async move {
1094 if skip_time_enabled {
1096 let now_ms = chrono::Utc::now().timestamp_millis();
1097 let target_ms = timestamp_ms;
1098 if target_ms > now_ms {
1099 let advance_duration =
1100 tokio::time::Duration::from_millis((target_ms - now_ms) as u64);
1101 tokio::time::advance(advance_duration).await;
1102 }
1103 }
1104
1105 let mut updated_operation = Operation::new(&operation_id, OperationType::Wait);
1107 updated_operation.status = OperationStatus::Succeeded;
1108 updated_operation.end_timestamp = Some(chrono::Utc::now().timestamp_millis());
1109
1110 let update_request = super::types::UpdateCheckpointDataRequest {
1111 execution_id: execution_id.clone(),
1112 operation_id: operation_id.clone(),
1113 operation_data: updated_operation,
1114 payload: None,
1115 error: None,
1116 };
1117
1118 let payload = serde_json::to_string(&update_request)
1119 .map_err(crate::error::TestError::SerializationError)?;
1120
1121 let response = checkpoint_api
1122 .send_api_request(super::types::ApiType::UpdateCheckpointData, payload)
1123 .await?;
1124
1125 if let Some(error) = response.error {
1126 return Err(crate::error::TestError::CheckpointServerError(error));
1127 }
1128
1129 Ok(())
1130 })
1131 });
1132
1133 let start_invocation: super::scheduler::BoxedAsyncFn = Box::new(|| {
1135 Box::pin(async {
1136 })
1139 });
1140
1141 let on_error: super::scheduler::ErrorHandler = Box::new(|error| {
1143 tracing::error!("Error during scheduled invocation: {:?}", error);
1144 });
1145
1146 self.scheduler.schedule_function(
1148 start_invocation,
1149 on_error,
1150 timestamp,
1151 Some(update_checkpoint),
1152 );
1153 }
1154
1155 pub fn schedule_invocation_with_update(
1164 &mut self,
1165 timestamp: Option<chrono::DateTime<chrono::Utc>>,
1166 update_checkpoint: Option<super::scheduler::CheckpointUpdateFn>,
1167 ) {
1168 let skip_time_enabled = self.skip_time_config.enabled;
1169
1170 let wrapped_update: Option<super::scheduler::CheckpointUpdateFn> = if skip_time_enabled {
1172 if let Some(ts) = timestamp {
1173 let original_update = update_checkpoint;
1174 Some(Box::new(move || {
1175 Box::pin(async move {
1176 let now = chrono::Utc::now();
1178 if ts > now {
1179 let duration = (ts - now).to_std().unwrap_or_default();
1180 tokio::time::advance(duration).await;
1181 }
1182
1183 if let Some(update_fn) = original_update {
1185 update_fn().await?;
1186 }
1187
1188 Ok(())
1189 })
1190 }))
1191 } else {
1192 update_checkpoint
1193 }
1194 } else {
1195 update_checkpoint
1196 };
1197
1198 let start_invocation: super::scheduler::BoxedAsyncFn = Box::new(|| {
1200 Box::pin(async {
1201 })
1203 });
1204
1205 let on_error: super::scheduler::ErrorHandler = Box::new(|error| {
1207 tracing::error!("Error during scheduled invocation: {:?}", error);
1208 });
1209
1210 self.scheduler
1212 .schedule_function(start_invocation, on_error, timestamp, wrapped_update);
1213 }
1214
1215 pub fn has_scheduled_functions(&self) -> bool {
1221 self.scheduler.has_scheduled_function()
1222 }
1223
1224 pub async fn invoke_handler(
1243 &mut self,
1244 payload: I,
1245 execution_id: &str,
1246 is_initial: bool,
1247 ) -> Result<InvokeHandlerResult<O>, crate::error::TestError> {
1248 use super::types::{ApiType, StartDurableExecutionRequest, StartInvocationRequest};
1249 use durable_execution_sdk::lambda::InitialExecutionState;
1250 use durable_execution_sdk::state::ExecutionState;
1251
1252 if self.skip_time_config.enabled && self.invocation_active.load(Ordering::SeqCst) {
1254 return Err(crate::error::TestError::CheckpointServerError(
1255 "Concurrent invocation detected in time-skip mode. Only one invocation can be active at a time.".to_string(),
1256 ));
1257 }
1258
1259 let invocation_id = uuid::Uuid::new_v4().to_string();
1261 let checkpoint_token = if is_initial {
1262 let payload_json = serde_json::to_string(&payload)?;
1264 let start_request = StartDurableExecutionRequest {
1265 invocation_id: invocation_id.clone(),
1266 payload: Some(payload_json),
1267 };
1268 let start_payload = serde_json::to_string(&start_request)?;
1269
1270 let start_response = self
1271 .checkpoint_api
1272 .send_api_request(ApiType::StartDurableExecution, start_payload)
1273 .await?;
1274
1275 if let Some(error) = start_response.error {
1276 return Err(crate::error::TestError::CheckpointServerError(error));
1277 }
1278
1279 let invocation_result: super::InvocationResult =
1280 serde_json::from_str(&start_response.payload.ok_or_else(|| {
1281 crate::error::TestError::CheckpointServerError(
1282 "Empty response from checkpoint server".to_string(),
1283 )
1284 })?)?;
1285
1286 self.execution_id = Some(invocation_result.execution_id.clone());
1288 self.checkpoint_token = Some(invocation_result.checkpoint_token.clone());
1289
1290 invocation_result.checkpoint_token
1291 } else {
1292 let start_invocation_request = StartInvocationRequest {
1294 execution_id: execution_id.to_string(),
1295 invocation_id: invocation_id.clone(),
1296 };
1297 let start_payload = serde_json::to_string(&start_invocation_request)?;
1298
1299 let start_response = self
1300 .checkpoint_api
1301 .send_api_request(ApiType::StartInvocation, start_payload)
1302 .await?;
1303
1304 if let Some(error) = start_response.error {
1305 return Err(crate::error::TestError::CheckpointServerError(error));
1306 }
1307
1308 let invocation_result: super::InvocationResult =
1309 serde_json::from_str(&start_response.payload.ok_or_else(|| {
1310 crate::error::TestError::CheckpointServerError(
1311 "Empty response from checkpoint server".to_string(),
1312 )
1313 })?)?;
1314
1315 self.checkpoint_token = Some(invocation_result.checkpoint_token.clone());
1317
1318 invocation_result.checkpoint_token
1319 };
1320
1321 let initial_state = InitialExecutionState::new();
1323 let execution_state = Arc::new(ExecutionState::new(
1324 execution_id,
1325 &checkpoint_token,
1326 initial_state,
1327 self.checkpoint_api.clone(),
1328 ));
1329
1330 let ctx = DurableContext::new(execution_state.clone());
1332
1333 let start_time = chrono::Utc::now();
1335 let mut invocation = Invocation::with_start(start_time);
1336
1337 self.invocation_active.store(true, Ordering::SeqCst);
1339
1340 let handler_result = (self.handler)(payload.clone(), ctx).await;
1342
1343 self.invocation_active.store(false, Ordering::SeqCst);
1345
1346 let end_time = chrono::Utc::now();
1348 invocation = invocation.with_end(end_time);
1349
1350 let operations = match self.checkpoint_api.get_operations(execution_id, "").await {
1352 Ok(response) => {
1353 let mut storage = self.operation_storage.write().await;
1354 for op in &response.operations {
1355 storage.update_operation(op.clone());
1356 }
1357 response.operations
1358 }
1359 Err(_) => Vec::new(),
1360 };
1361
1362 match handler_result {
1364 Ok(result) => {
1365 self.execution_complete.store(true, Ordering::SeqCst);
1367 Ok(InvokeHandlerResult::Succeeded {
1368 result,
1369 operations,
1370 invocation,
1371 })
1372 }
1373 Err(error) => {
1374 if error.is_suspend() {
1375 let process_result = self.process_operations(&operations, execution_id);
1377
1378 match process_result {
1379 ProcessOperationsResult::ExecutionSucceeded(result_str) => {
1380 self.execution_complete.store(true, Ordering::SeqCst);
1381 if let Ok(result) = serde_json::from_str::<O>(&result_str) {
1382 Ok(InvokeHandlerResult::Succeeded {
1383 result,
1384 operations,
1385 invocation,
1386 })
1387 } else {
1388 Ok(InvokeHandlerResult::Pending {
1389 operations,
1390 invocation,
1391 should_reinvoke: false,
1392 advance_time_ms: None,
1393 })
1394 }
1395 }
1396 ProcessOperationsResult::ExecutionFailed(test_error) => {
1397 self.execution_complete.store(true, Ordering::SeqCst);
1398 Ok(InvokeHandlerResult::Failed {
1399 error: test_error,
1400 operations,
1401 invocation,
1402 })
1403 }
1404 ProcessOperationsResult::NoPendingOperations => {
1405 Ok(InvokeHandlerResult::Pending {
1407 operations,
1408 invocation,
1409 should_reinvoke: false,
1410 advance_time_ms: None,
1411 })
1412 }
1413 ProcessOperationsResult::ShouldReinvoke(advance_time_ms) => {
1414 Ok(InvokeHandlerResult::Pending {
1416 operations,
1417 invocation,
1418 should_reinvoke: true,
1419 advance_time_ms,
1420 })
1421 }
1422 }
1423 } else {
1424 self.execution_complete.store(true, Ordering::SeqCst);
1426 let error_obj = durable_execution_sdk::ErrorObject::from(&error);
1427 let test_error = TestResultError::new(error_obj.error_type, error.to_string());
1428 let invocation_with_error = invocation.with_error(test_error.clone());
1429 Ok(InvokeHandlerResult::Failed {
1430 error: test_error,
1431 operations,
1432 invocation: invocation_with_error,
1433 })
1434 }
1435 }
1436 }
1437 }
1438
1439 pub fn flush_scheduled_functions(&mut self) {
1443 self.scheduler.flush_timers();
1444 }
1445
1446 pub async fn process_next_scheduled(&mut self) -> bool {
1452 self.scheduler.process_next().await
1453 }
1454
1455 fn handle_execution_update(&self, operations: &[Operation]) -> Option<ProcessOperationsResult> {
1468 let execution_op = operations
1470 .iter()
1471 .find(|op| op.operation_type == OperationType::Execution)?;
1472
1473 match execution_op.status {
1474 OperationStatus::Succeeded => {
1475 let result_str = execution_op.result.clone().unwrap_or_default();
1476 Some(ProcessOperationsResult::ExecutionSucceeded(result_str))
1477 }
1478 OperationStatus::Failed => {
1479 let error = if let Some(err) = &execution_op.error {
1480 TestResultError::new(err.error_type.clone(), err.error_message.clone())
1481 } else {
1482 TestResultError::new("ExecutionFailed", "Execution failed")
1483 };
1484 Some(ProcessOperationsResult::ExecutionFailed(error))
1485 }
1486 _ => None,
1487 }
1488 }
1489}
1490
1491#[derive(Debug)]
1495pub enum ProcessOperationsResult {
1496 ExecutionSucceeded(String),
1498 ExecutionFailed(TestResultError),
1500 NoPendingOperations,
1502 ShouldReinvoke(Option<u64>),
1504}
1505
1506#[derive(Debug)]
1510pub enum OperationProcessResult {
1511 Pending(Option<i64>),
1513 Completed,
1515 NotApplicable,
1517}
1518
1519#[derive(Debug)]
1523pub enum InvokeHandlerResult<T> {
1524 Succeeded {
1526 result: T,
1528 operations: Vec<Operation>,
1530 invocation: Invocation,
1532 },
1533 Failed {
1535 error: TestResultError,
1537 operations: Vec<Operation>,
1539 invocation: Invocation,
1541 },
1542 Pending {
1544 operations: Vec<Operation>,
1546 invocation: Invocation,
1548 should_reinvoke: bool,
1550 advance_time_ms: Option<u64>,
1552 },
1553}
1554
1555#[cfg(test)]
1556mod tests {
1557 use super::*;
1558 use durable_execution_sdk::{ErrorObject, StepDetails, WaitDetails};
1559
1560 #[test]
1561 fn test_skip_time_config_default() {
1562 let config = SkipTimeConfig::default();
1563 assert!(!config.enabled);
1564 }
1565
1566 #[test]
1567 fn test_operation_storage_new() {
1568 let storage = OperationStorage::new();
1569 assert!(storage.is_empty());
1570 assert_eq!(storage.len(), 0);
1571 }
1572
1573 #[test]
1574 fn test_operation_storage_add_and_get() {
1575 let mut storage = OperationStorage::new();
1576
1577 let op = Operation::new("op-1", durable_execution_sdk::OperationType::Step);
1578 storage.add_operation(op);
1579
1580 assert_eq!(storage.len(), 1);
1581 assert!(storage.get_by_id("op-1").is_some());
1582 }
1583
1584 #[test]
1585 fn test_operation_storage_update() {
1586 let mut storage = OperationStorage::new();
1587
1588 let mut op = Operation::new("op-1", durable_execution_sdk::OperationType::Step);
1589 op.status = durable_execution_sdk::OperationStatus::Started;
1590 storage.add_operation(op);
1591
1592 let mut updated_op = Operation::new("op-1", durable_execution_sdk::OperationType::Step);
1593 updated_op.status = durable_execution_sdk::OperationStatus::Succeeded;
1594 storage.update_operation(updated_op);
1595
1596 assert_eq!(storage.len(), 1);
1597 let retrieved = storage.get_by_id("op-1").unwrap();
1598 assert_eq!(
1599 retrieved.status,
1600 durable_execution_sdk::OperationStatus::Succeeded
1601 );
1602 }
1603
1604 #[test]
1605 fn test_test_execution_result_success() {
1606 let result: TestExecutionResult<String> =
1607 TestExecutionResult::success("test".to_string(), vec![], "exec-1".to_string());
1608 assert_eq!(result.status, ExecutionStatus::Succeeded);
1609 assert_eq!(result.result, Some("test".to_string()));
1610 assert!(result.error.is_none());
1611 }
1612
1613 #[test]
1614 fn test_test_execution_result_failure() {
1615 let error = TestResultError::new("TestError", "test error");
1616 let result: TestExecutionResult<String> =
1617 TestExecutionResult::failure(error, vec![], "exec-1".to_string());
1618 assert_eq!(result.status, ExecutionStatus::Failed);
1619 assert!(result.result.is_none());
1620 assert!(result.error.is_some());
1621 }
1622
1623 #[test]
1624 fn test_test_execution_result_running() {
1625 let result: TestExecutionResult<String> =
1626 TestExecutionResult::running(vec![], "exec-1".to_string());
1627 assert_eq!(result.status, ExecutionStatus::Running);
1628 assert!(result.result.is_none());
1629 assert!(result.error.is_none());
1630 }
1631
1632 #[test]
1634 fn test_process_operations_result_execution_succeeded() {
1635 let result = ProcessOperationsResult::ExecutionSucceeded("test result".to_string());
1636 match result {
1637 ProcessOperationsResult::ExecutionSucceeded(s) => assert_eq!(s, "test result"),
1638 _ => panic!("Expected ExecutionSucceeded"),
1639 }
1640 }
1641
1642 #[test]
1643 fn test_process_operations_result_execution_failed() {
1644 let error = TestResultError::new("TestError", "test error");
1645 let result = ProcessOperationsResult::ExecutionFailed(error);
1646 match result {
1647 ProcessOperationsResult::ExecutionFailed(e) => {
1648 assert_eq!(e.error_type, Some("TestError".to_string()));
1649 }
1650 _ => panic!("Expected ExecutionFailed"),
1651 }
1652 }
1653
1654 #[test]
1655 fn test_process_operations_result_no_pending() {
1656 let result = ProcessOperationsResult::NoPendingOperations;
1657 assert!(matches!(
1658 result,
1659 ProcessOperationsResult::NoPendingOperations
1660 ));
1661 }
1662
1663 #[test]
1664 fn test_process_operations_result_should_reinvoke() {
1665 let result = ProcessOperationsResult::ShouldReinvoke(Some(1000));
1666 match result {
1667 ProcessOperationsResult::ShouldReinvoke(Some(ms)) => assert_eq!(ms, 1000),
1668 _ => panic!("Expected ShouldReinvoke with time"),
1669 }
1670 }
1671
1672 #[test]
1674 fn test_operation_process_result_pending_with_timestamp() {
1675 let result = OperationProcessResult::Pending(Some(1234567890));
1676 match result {
1677 OperationProcessResult::Pending(Some(ts)) => assert_eq!(ts, 1234567890),
1678 _ => panic!("Expected Pending with timestamp"),
1679 }
1680 }
1681
1682 #[test]
1683 fn test_operation_process_result_pending_without_timestamp() {
1684 let result = OperationProcessResult::Pending(None);
1685 match result {
1686 OperationProcessResult::Pending(None) => {}
1687 _ => panic!("Expected Pending without timestamp"),
1688 }
1689 }
1690
1691 #[test]
1692 fn test_operation_process_result_completed() {
1693 let result = OperationProcessResult::Completed;
1694 assert!(matches!(result, OperationProcessResult::Completed));
1695 }
1696
1697 #[test]
1698 fn test_operation_process_result_not_applicable() {
1699 let result = OperationProcessResult::NotApplicable;
1700 assert!(matches!(result, OperationProcessResult::NotApplicable));
1701 }
1702
1703 #[test]
1705 fn test_handle_execution_update_succeeded() {
1706 let mut exec_op = Operation::new("exec-1", OperationType::Execution);
1709 exec_op.status = OperationStatus::Succeeded;
1710 exec_op.result = Some("\"success\"".to_string());
1711
1712 let operations = vec![exec_op];
1713
1714 let execution_op = operations
1716 .iter()
1717 .find(|op| op.operation_type == OperationType::Execution);
1718
1719 assert!(execution_op.is_some());
1720 let exec = execution_op.unwrap();
1721 assert_eq!(exec.status, OperationStatus::Succeeded);
1722 assert_eq!(exec.result, Some("\"success\"".to_string()));
1723 }
1724
1725 #[test]
1726 fn test_handle_execution_update_failed() {
1727 let mut exec_op = Operation::new("exec-1", OperationType::Execution);
1728 exec_op.status = OperationStatus::Failed;
1729 exec_op.error = Some(ErrorObject {
1730 error_type: "TestError".to_string(),
1731 error_message: "Test error message".to_string(),
1732 stack_trace: None,
1733 });
1734
1735 let operations = vec![exec_op];
1736
1737 let execution_op = operations
1738 .iter()
1739 .find(|op| op.operation_type == OperationType::Execution);
1740
1741 assert!(execution_op.is_some());
1742 let exec = execution_op.unwrap();
1743 assert_eq!(exec.status, OperationStatus::Failed);
1744 assert!(exec.error.is_some());
1745 }
1746
1747 #[test]
1748 fn test_handle_execution_update_still_running() {
1749 let mut exec_op = Operation::new("exec-1", OperationType::Execution);
1750 exec_op.status = OperationStatus::Started;
1751
1752 let operations = vec![exec_op];
1753
1754 let execution_op = operations
1755 .iter()
1756 .find(|op| op.operation_type == OperationType::Execution);
1757
1758 assert!(execution_op.is_some());
1759 let exec = execution_op.unwrap();
1760 assert_eq!(exec.status, OperationStatus::Started);
1761 }
1762
1763 #[test]
1765 fn test_wait_operation_started_with_timestamp() {
1766 let mut wait_op = Operation::new("wait-1", OperationType::Wait);
1767 wait_op.status = OperationStatus::Started;
1768 wait_op.wait_details = Some(WaitDetails {
1769 scheduled_end_timestamp: Some(1234567890000),
1770 });
1771
1772 assert!(wait_op.wait_details.is_some());
1774 let details = wait_op.wait_details.as_ref().unwrap();
1775 assert_eq!(details.scheduled_end_timestamp, Some(1234567890000));
1776 }
1777
1778 #[test]
1779 fn test_wait_operation_completed() {
1780 let mut wait_op = Operation::new("wait-1", OperationType::Wait);
1781 wait_op.status = OperationStatus::Succeeded;
1782
1783 assert!(wait_op.status.is_terminal());
1785 }
1786
1787 #[test]
1789 fn test_step_operation_pending_retry() {
1790 let mut step_op = Operation::new("step-1", OperationType::Step);
1791 step_op.status = OperationStatus::Pending;
1792 step_op.step_details = Some(StepDetails {
1793 result: None,
1794 attempt: Some(1),
1795 next_attempt_timestamp: Some(1234567890000),
1796 error: None,
1797 payload: None,
1798 });
1799
1800 assert!(step_op.step_details.is_some());
1802 let details = step_op.step_details.as_ref().unwrap();
1803 assert_eq!(details.next_attempt_timestamp, Some(1234567890000));
1804 assert_eq!(details.attempt, Some(1));
1805 }
1806
1807 #[test]
1808 fn test_step_operation_succeeded() {
1809 let mut step_op = Operation::new("step-1", OperationType::Step);
1810 step_op.status = OperationStatus::Succeeded;
1811 step_op.step_details = Some(StepDetails {
1812 result: Some("\"result\"".to_string()),
1813 attempt: Some(0),
1814 next_attempt_timestamp: None,
1815 error: None,
1816 payload: None,
1817 });
1818
1819 assert!(step_op.status.is_terminal());
1821 }
1822
1823 #[test]
1825 fn test_callback_operation_started() {
1826 let mut callback_op = Operation::new("callback-1", OperationType::Callback);
1827 callback_op.status = OperationStatus::Started;
1828
1829 assert_eq!(callback_op.status, OperationStatus::Started);
1831 assert!(!callback_op.status.is_terminal());
1832 }
1833
1834 #[test]
1835 fn test_callback_operation_succeeded() {
1836 let mut callback_op = Operation::new("callback-1", OperationType::Callback);
1837 callback_op.status = OperationStatus::Succeeded;
1838
1839 assert!(callback_op.status.is_terminal());
1841 }
1842
1843 #[test]
1845 fn test_operation_type_dispatch_wait() {
1846 let op = Operation::new("op-1", OperationType::Wait);
1847 assert_eq!(op.operation_type, OperationType::Wait);
1848 }
1849
1850 #[test]
1851 fn test_operation_type_dispatch_step() {
1852 let op = Operation::new("op-1", OperationType::Step);
1853 assert_eq!(op.operation_type, OperationType::Step);
1854 }
1855
1856 #[test]
1857 fn test_operation_type_dispatch_callback() {
1858 let op = Operation::new("op-1", OperationType::Callback);
1859 assert_eq!(op.operation_type, OperationType::Callback);
1860 }
1861
1862 #[test]
1863 fn test_operation_type_dispatch_execution() {
1864 let op = Operation::new("op-1", OperationType::Execution);
1865 assert_eq!(op.operation_type, OperationType::Execution);
1866 }
1867
1868 #[test]
1869 fn test_operation_type_dispatch_invoke() {
1870 let op = Operation::new("op-1", OperationType::Invoke);
1871 assert_eq!(op.operation_type, OperationType::Invoke);
1872 }
1873
1874 #[test]
1875 fn test_operation_type_dispatch_context() {
1876 let op = Operation::new("op-1", OperationType::Context);
1877 assert_eq!(op.operation_type, OperationType::Context);
1878 }
1879
1880 #[test]
1882 fn test_earliest_scheduled_time_single_wait() {
1883 let mut wait_op = Operation::new("wait-1", OperationType::Wait);
1884 wait_op.status = OperationStatus::Started;
1885 wait_op.wait_details = Some(WaitDetails {
1886 scheduled_end_timestamp: Some(1000),
1887 });
1888
1889 let operations = vec![wait_op];
1890
1891 let mut earliest: Option<i64> = None;
1892 for op in &operations {
1893 if op.operation_type == OperationType::Wait && op.status == OperationStatus::Started {
1894 if let Some(details) = &op.wait_details {
1895 if let Some(end_ts) = details.scheduled_end_timestamp {
1896 match earliest {
1897 None => earliest = Some(end_ts),
1898 Some(current) if end_ts < current => earliest = Some(end_ts),
1899 _ => {}
1900 }
1901 }
1902 }
1903 }
1904 }
1905
1906 assert_eq!(earliest, Some(1000));
1907 }
1908
1909 #[test]
1910 fn test_earliest_scheduled_time_multiple_waits() {
1911 let mut wait_op1 = Operation::new("wait-1", OperationType::Wait);
1912 wait_op1.status = OperationStatus::Started;
1913 wait_op1.wait_details = Some(WaitDetails {
1914 scheduled_end_timestamp: Some(2000),
1915 });
1916
1917 let mut wait_op2 = Operation::new("wait-2", OperationType::Wait);
1918 wait_op2.status = OperationStatus::Started;
1919 wait_op2.wait_details = Some(WaitDetails {
1920 scheduled_end_timestamp: Some(1000),
1921 });
1922
1923 let mut wait_op3 = Operation::new("wait-3", OperationType::Wait);
1924 wait_op3.status = OperationStatus::Started;
1925 wait_op3.wait_details = Some(WaitDetails {
1926 scheduled_end_timestamp: Some(3000),
1927 });
1928
1929 let operations = vec![wait_op1, wait_op2, wait_op3];
1930
1931 let mut earliest: Option<i64> = None;
1932 for op in &operations {
1933 if op.operation_type == OperationType::Wait && op.status == OperationStatus::Started {
1934 if let Some(details) = &op.wait_details {
1935 if let Some(end_ts) = details.scheduled_end_timestamp {
1936 match earliest {
1937 None => earliest = Some(end_ts),
1938 Some(current) if end_ts < current => earliest = Some(end_ts),
1939 _ => {}
1940 }
1941 }
1942 }
1943 }
1944 }
1945
1946 assert_eq!(earliest, Some(1000)); }
1948
1949 #[tokio::test]
1951 async fn test_schedule_invocation_at_timestamp_schedules_function() {
1952 use super::*;
1953 use std::sync::Arc;
1954 use tokio::sync::RwLock;
1955
1956 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
1958
1959 let handler =
1961 |_input: String, _ctx: DurableContext| async move { Ok("result".to_string()) };
1962
1963 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
1964 let mut orchestrator = TestExecutionOrchestrator::new(
1965 handler,
1966 operation_storage,
1967 checkpoint_api,
1968 SkipTimeConfig { enabled: false },
1969 );
1970
1971 let future_timestamp = chrono::Utc::now().timestamp_millis() + 1000;
1973 orchestrator.schedule_invocation_at_timestamp(future_timestamp, "exec-1", "wait-1");
1974
1975 assert!(orchestrator.has_scheduled_functions());
1977 }
1978
1979 #[tokio::test]
1980 async fn test_schedule_invocation_with_update_schedules_function() {
1981 use super::*;
1982 use std::sync::Arc;
1983 use tokio::sync::RwLock;
1984
1985 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
1987
1988 let handler =
1990 |_input: String, _ctx: DurableContext| async move { Ok("result".to_string()) };
1991
1992 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
1993 let mut orchestrator = TestExecutionOrchestrator::new(
1994 handler,
1995 operation_storage,
1996 checkpoint_api,
1997 SkipTimeConfig { enabled: false },
1998 );
1999
2000 orchestrator.schedule_invocation_with_update(None, None);
2002
2003 assert!(orchestrator.has_scheduled_functions());
2005 }
2006
2007 #[tokio::test]
2008 async fn test_flush_scheduled_functions_clears_queue() {
2009 use super::*;
2010 use std::sync::Arc;
2011 use tokio::sync::RwLock;
2012
2013 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2015
2016 let handler =
2018 |_input: String, _ctx: DurableContext| async move { Ok("result".to_string()) };
2019
2020 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2021 let mut orchestrator = TestExecutionOrchestrator::new(
2022 handler,
2023 operation_storage,
2024 checkpoint_api,
2025 SkipTimeConfig { enabled: false },
2026 );
2027
2028 let future_timestamp = chrono::Utc::now().timestamp_millis() + 1000;
2030 orchestrator.schedule_invocation_at_timestamp(future_timestamp, "exec-1", "wait-1");
2031 orchestrator.schedule_invocation_at_timestamp(future_timestamp + 1000, "exec-1", "wait-2");
2032
2033 assert!(orchestrator.has_scheduled_functions());
2035
2036 orchestrator.flush_scheduled_functions();
2038
2039 assert!(!orchestrator.has_scheduled_functions());
2041 }
2042
2043 #[tokio::test]
2044 async fn test_process_next_scheduled_processes_function() {
2045 use super::*;
2046 use std::sync::Arc;
2047 use tokio::sync::RwLock;
2048
2049 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2051
2052 let handler =
2054 |_input: String, _ctx: DurableContext| async move { Ok("result".to_string()) };
2055
2056 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2057 let mut orchestrator = TestExecutionOrchestrator::new(
2058 handler,
2059 operation_storage,
2060 checkpoint_api,
2061 SkipTimeConfig { enabled: false },
2062 );
2063
2064 orchestrator.schedule_invocation_with_update(None, None);
2066
2067 let processed = orchestrator.process_next_scheduled().await;
2069 assert!(processed);
2070
2071 assert!(!orchestrator.has_scheduled_functions());
2073 }
2074
2075 #[tokio::test]
2076 async fn test_schedule_invocation_with_time_skipping_enabled() {
2077 use super::*;
2078 use std::sync::Arc;
2079 use tokio::sync::RwLock;
2080
2081 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2083
2084 let handler =
2086 |_input: String, _ctx: DurableContext| async move { Ok("result".to_string()) };
2087
2088 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2089 let mut orchestrator = TestExecutionOrchestrator::new(
2090 handler,
2091 operation_storage,
2092 checkpoint_api,
2093 SkipTimeConfig { enabled: true },
2094 );
2095
2096 assert!(orchestrator.is_time_skipping_enabled());
2098
2099 let future_timestamp = chrono::Utc::now().timestamp_millis() + 5000;
2101 orchestrator.schedule_invocation_at_timestamp(future_timestamp, "exec-1", "wait-1");
2102
2103 assert!(orchestrator.has_scheduled_functions());
2105 }
2106
2107 #[test]
2109 fn test_invoke_handler_result_succeeded() {
2110 let invocation = Invocation::with_start(chrono::Utc::now());
2111 let result: InvokeHandlerResult<String> = InvokeHandlerResult::Succeeded {
2112 result: "test result".to_string(),
2113 operations: vec![],
2114 invocation,
2115 };
2116
2117 match result {
2118 InvokeHandlerResult::Succeeded {
2119 result, operations, ..
2120 } => {
2121 assert_eq!(result, "test result");
2122 assert!(operations.is_empty());
2123 }
2124 _ => panic!("Expected Succeeded variant"),
2125 }
2126 }
2127
2128 #[test]
2129 fn test_invoke_handler_result_failed() {
2130 let invocation = Invocation::with_start(chrono::Utc::now());
2131 let error = TestResultError::new("TestError", "test error message");
2132 let result: InvokeHandlerResult<String> = InvokeHandlerResult::Failed {
2133 error,
2134 operations: vec![],
2135 invocation,
2136 };
2137
2138 match result {
2139 InvokeHandlerResult::Failed {
2140 error, operations, ..
2141 } => {
2142 assert_eq!(error.error_type, Some("TestError".to_string()));
2143 assert!(operations.is_empty());
2144 }
2145 _ => panic!("Expected Failed variant"),
2146 }
2147 }
2148
2149 #[test]
2150 fn test_invoke_handler_result_pending_with_reinvoke() {
2151 let invocation = Invocation::with_start(chrono::Utc::now());
2152 let result: InvokeHandlerResult<String> = InvokeHandlerResult::Pending {
2153 operations: vec![],
2154 invocation,
2155 should_reinvoke: true,
2156 advance_time_ms: Some(5000),
2157 };
2158
2159 match result {
2160 InvokeHandlerResult::Pending {
2161 should_reinvoke,
2162 advance_time_ms,
2163 ..
2164 } => {
2165 assert!(should_reinvoke);
2166 assert_eq!(advance_time_ms, Some(5000));
2167 }
2168 _ => panic!("Expected Pending variant"),
2169 }
2170 }
2171
2172 #[test]
2173 fn test_invoke_handler_result_pending_without_reinvoke() {
2174 let invocation = Invocation::with_start(chrono::Utc::now());
2175 let result: InvokeHandlerResult<String> = InvokeHandlerResult::Pending {
2176 operations: vec![],
2177 invocation,
2178 should_reinvoke: false,
2179 advance_time_ms: None,
2180 };
2181
2182 match result {
2183 InvokeHandlerResult::Pending {
2184 should_reinvoke,
2185 advance_time_ms,
2186 ..
2187 } => {
2188 assert!(!should_reinvoke);
2189 assert_eq!(advance_time_ms, None);
2190 }
2191 _ => panic!("Expected Pending variant"),
2192 }
2193 }
2194
2195 #[tokio::test]
2197 async fn test_invoke_handler_creates_orchestrator_state() {
2198 use super::*;
2199 use std::sync::Arc;
2200 use tokio::sync::RwLock;
2201
2202 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2204
2205 let handler =
2207 |_input: String, _ctx: DurableContext| async move { Ok("result".to_string()) };
2208
2209 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2210 let orchestrator = TestExecutionOrchestrator::new(
2211 handler,
2212 operation_storage,
2213 checkpoint_api,
2214 SkipTimeConfig { enabled: false },
2215 );
2216
2217 assert!(orchestrator.execution_id().is_none());
2219 assert!(orchestrator.checkpoint_token().is_none());
2220 assert!(!orchestrator.is_execution_complete());
2221 assert!(!orchestrator.is_invocation_active());
2222 }
2223
2224 #[tokio::test]
2225 async fn test_invoke_handler_time_skip_mode_prevents_concurrent() {
2226 use super::*;
2227 use std::sync::Arc;
2228 use tokio::sync::RwLock;
2229
2230 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2232
2233 let handler =
2235 |_input: String, _ctx: DurableContext| async move { Ok("result".to_string()) };
2236
2237 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2238 let orchestrator = TestExecutionOrchestrator::new(
2239 handler,
2240 operation_storage,
2241 checkpoint_api,
2242 SkipTimeConfig { enabled: true },
2243 );
2244
2245 assert!(orchestrator.is_time_skipping_enabled());
2247
2248 assert!(!orchestrator.is_invocation_active());
2250 }
2251
2252 #[tokio::test]
2253 async fn test_invoke_handler_tracks_invocation_active_state() {
2254 use super::*;
2255 use std::sync::atomic::{AtomicBool, Ordering};
2256 use std::sync::Arc;
2257 use tokio::sync::RwLock;
2258
2259 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2261
2262 let was_active = Arc::new(AtomicBool::new(false));
2264 let was_active_clone = Arc::clone(&was_active);
2265
2266 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2268 let orchestrator = TestExecutionOrchestrator::new(
2269 move |_input: String, _ctx: DurableContext| {
2270 let was_active = Arc::clone(&was_active_clone);
2271 async move {
2272 was_active.store(true, Ordering::SeqCst);
2275 Ok("result".to_string())
2276 }
2277 },
2278 operation_storage,
2279 checkpoint_api,
2280 SkipTimeConfig { enabled: false },
2281 );
2282
2283 assert!(!orchestrator.is_invocation_active());
2285 }
2286}
2287
2288#[cfg(test)]
2292mod property_tests {
2293 use super::*;
2294 use durable_execution_sdk::{OperationType, WaitDetails};
2295 use proptest::prelude::*;
2296
2297 fn wait_duration_strategy() -> impl Strategy<Value = u64> {
2299 1u64..=60
2300 }
2301
2302 fn multiple_wait_durations_strategy() -> impl Strategy<Value = Vec<u64>> {
2304 prop::collection::vec(wait_duration_strategy(), 1..=3)
2305 }
2306
2307 proptest! {
2308 #[test]
2321 fn prop_wait_operation_completion(wait_seconds in wait_duration_strategy()) {
2322 let rt = tokio::runtime::Builder::new_current_thread()
2324 .enable_all()
2325 .build()
2326 .unwrap();
2327
2328 rt.block_on(async {
2329 let now_ms = chrono::Utc::now().timestamp_millis();
2331 let scheduled_end_ms = now_ms + (wait_seconds as i64 * 1000);
2332
2333 let mut wait_op = Operation::new("wait-test", OperationType::Wait);
2335 wait_op.status = OperationStatus::Started;
2336 wait_op.wait_details = Some(WaitDetails {
2337 scheduled_end_timestamp: Some(scheduled_end_ms),
2338 });
2339
2340 prop_assert!(wait_op.wait_details.is_some());
2342 let details = wait_op.wait_details.as_ref().unwrap();
2343 prop_assert_eq!(details.scheduled_end_timestamp, Some(scheduled_end_ms));
2344
2345 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2347 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2348
2349 let handler = |_input: String, _ctx: DurableContext| async move {
2350 Ok("result".to_string())
2351 };
2352
2353 let mut orchestrator = TestExecutionOrchestrator::new(
2354 handler,
2355 operation_storage.clone(),
2356 checkpoint_api,
2357 SkipTimeConfig { enabled: true },
2358 );
2359
2360 prop_assert!(orchestrator.is_time_skipping_enabled());
2362
2363 let operations = vec![wait_op.clone()];
2365 let result = orchestrator.process_operations(&operations, "exec-test");
2366
2367 match result {
2369 ProcessOperationsResult::ShouldReinvoke(advance_time_ms) => {
2370 prop_assert!(
2372 advance_time_ms.is_some(),
2373 "Should have advance time when time skipping is enabled"
2374 );
2375
2376 if let Some(advance_ms) = advance_time_ms {
2379 let expected_min = (wait_seconds as u64).saturating_sub(1) * 1000;
2381 let expected_max = (wait_seconds as u64 + 1) * 1000;
2382 prop_assert!(
2383 advance_ms >= expected_min && advance_ms <= expected_max,
2384 "Advance time {} should be approximately {} seconds ({}ms - {}ms)",
2385 advance_ms, wait_seconds, expected_min, expected_max
2386 );
2387 }
2388 }
2389 ProcessOperationsResult::NoPendingOperations => {
2390 }
2393 other => {
2394 prop_assert!(
2395 false,
2396 "Expected ShouldReinvoke or NoPendingOperations, got {:?}",
2397 other
2398 );
2399 }
2400 }
2401
2402 prop_assert!(
2404 orchestrator.pending_operations.contains("wait-test"),
2405 "Wait operation should be tracked as pending"
2406 );
2407
2408 Ok(())
2409 })?;
2410 }
2411
2412 #[test]
2420 fn prop_wait_operation_completion_multiple_waits(
2421 wait_durations in multiple_wait_durations_strategy()
2422 ) {
2423 let rt = tokio::runtime::Builder::new_current_thread()
2424 .enable_all()
2425 .build()
2426 .unwrap();
2427
2428 rt.block_on(async {
2429 let now_ms = chrono::Utc::now().timestamp_millis();
2430
2431 let mut operations = Vec::new();
2433 for (i, &duration) in wait_durations.iter().enumerate() {
2434 let scheduled_end_ms = now_ms + (duration as i64 * 1000);
2435 let mut wait_op = Operation::new(&format!("wait-{}", i), OperationType::Wait);
2436 wait_op.status = OperationStatus::Started;
2437 wait_op.wait_details = Some(WaitDetails {
2438 scheduled_end_timestamp: Some(scheduled_end_ms),
2439 });
2440 operations.push(wait_op);
2441 }
2442
2443 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2445 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2446
2447 let handler = |_input: String, _ctx: DurableContext| async move {
2448 Ok("result".to_string())
2449 };
2450
2451 let mut orchestrator = TestExecutionOrchestrator::new(
2452 handler,
2453 operation_storage,
2454 checkpoint_api,
2455 SkipTimeConfig { enabled: true },
2456 );
2457
2458 let result = orchestrator.process_operations(&operations, "exec-test");
2460
2461 let min_duration = wait_durations.iter().min().copied().unwrap_or(0);
2463
2464 match result {
2465 ProcessOperationsResult::ShouldReinvoke(advance_time_ms) => {
2466 if let Some(advance_ms) = advance_time_ms {
2468 let expected_min = min_duration.saturating_sub(1) * 1000;
2470 let expected_max = (min_duration + 1) * 1000;
2471 prop_assert!(
2472 advance_ms >= expected_min && advance_ms <= expected_max,
2473 "Advance time {} should be approximately {} seconds (min duration)",
2474 advance_ms, min_duration
2475 );
2476 }
2477 }
2478 ProcessOperationsResult::NoPendingOperations => {
2479 }
2481 other => {
2482 prop_assert!(
2483 false,
2484 "Expected ShouldReinvoke or NoPendingOperations, got {:?}",
2485 other
2486 );
2487 }
2488 }
2489
2490 for (i, _) in wait_durations.iter().enumerate() {
2492 let op_id = format!("wait-{}", i);
2493 prop_assert!(
2494 orchestrator.pending_operations.contains(&op_id),
2495 "Wait operation {} should be tracked as pending",
2496 op_id
2497 );
2498 }
2499
2500 Ok(())
2501 })?;
2502 }
2503
2504 #[test]
2511 fn prop_wait_operation_completion_already_completed(wait_seconds in wait_duration_strategy()) {
2512 let rt = tokio::runtime::Builder::new_current_thread()
2513 .enable_all()
2514 .build()
2515 .unwrap();
2516
2517 rt.block_on(async {
2518 let now_ms = chrono::Utc::now().timestamp_millis();
2519 let scheduled_end_ms = now_ms + (wait_seconds as i64 * 1000);
2520
2521 let mut wait_op = Operation::new("wait-completed", OperationType::Wait);
2523 wait_op.status = OperationStatus::Succeeded; wait_op.wait_details = Some(WaitDetails {
2525 scheduled_end_timestamp: Some(scheduled_end_ms),
2526 });
2527 wait_op.end_timestamp = Some(now_ms);
2528
2529 let checkpoint_api = CheckpointWorkerManager::get_instance(None).unwrap();
2531 let operation_storage = Arc::new(RwLock::new(OperationStorage::new()));
2532
2533 let handler = |_input: String, _ctx: DurableContext| async move {
2534 Ok("result".to_string())
2535 };
2536
2537 let mut orchestrator = TestExecutionOrchestrator::new(
2538 handler,
2539 operation_storage,
2540 checkpoint_api,
2541 SkipTimeConfig { enabled: true },
2542 );
2543
2544 let operations = vec![wait_op];
2546 let result = orchestrator.process_operations(&operations, "exec-test");
2547
2548 match result {
2550 ProcessOperationsResult::NoPendingOperations => {
2551 }
2553 ProcessOperationsResult::ShouldReinvoke(_) => {
2554 prop_assert!(
2555 false,
2556 "Completed wait operation should not trigger re-invocation"
2557 );
2558 }
2559 other => {
2560 prop_assert!(
2561 false,
2562 "Expected NoPendingOperations for completed wait, got {:?}",
2563 other
2564 );
2565 }
2566 }
2567
2568 prop_assert!(
2570 !orchestrator.pending_operations.contains("wait-completed"),
2571 "Completed wait operation should not be tracked as pending"
2572 );
2573
2574 Ok(())
2575 })?;
2576 }
2577 }
2578}