1use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use tokio::sync::{Mutex, Notify, Semaphore};
12
13use crate::config::CompletionConfig;
14use crate::error::{DurableError, ErrorObject};
15
16#[derive(Debug)]
49pub struct ExecutionCounters {
50 total_tasks: AtomicUsize,
52 success_count: AtomicUsize,
54 failure_count: AtomicUsize,
56 completed_count: AtomicUsize,
58 suspended_count: AtomicUsize,
60}
61
62impl ExecutionCounters {
63 pub fn new(total_tasks: usize) -> Self {
65 Self {
66 total_tasks: AtomicUsize::new(total_tasks),
67 success_count: AtomicUsize::new(0),
68 failure_count: AtomicUsize::new(0),
69 completed_count: AtomicUsize::new(0),
70 suspended_count: AtomicUsize::new(0),
71 }
72 }
73
74 pub fn complete_task(&self) -> usize {
85 self.completed_count.fetch_add(1, Ordering::Relaxed);
87 self.success_count.fetch_add(1, Ordering::Relaxed) + 1
88 }
89
90 pub fn fail_task(&self) -> usize {
100 self.completed_count.fetch_add(1, Ordering::Relaxed);
102 self.failure_count.fetch_add(1, Ordering::Relaxed) + 1
103 }
104
105 pub fn suspend_task(&self) -> usize {
115 self.suspended_count.fetch_add(1, Ordering::Relaxed) + 1
117 }
118
119 pub fn total_tasks(&self) -> usize {
128 self.total_tasks.load(Ordering::Acquire)
130 }
131
132 pub fn success_count(&self) -> usize {
140 self.success_count.load(Ordering::Acquire)
142 }
143
144 pub fn failure_count(&self) -> usize {
152 self.failure_count.load(Ordering::Acquire)
154 }
155
156 pub fn completed_count(&self) -> usize {
164 self.completed_count.load(Ordering::Acquire)
166 }
167
168 pub fn suspended_count(&self) -> usize {
176 self.suspended_count.load(Ordering::Acquire)
178 }
179
180 pub fn pending_count(&self) -> usize {
182 let total = self.total_tasks();
183 let completed = self.completed_count();
184 let suspended = self.suspended_count();
185 total.saturating_sub(completed + suspended)
186 }
187
188 pub fn is_min_successful_reached(&self, min_successful: usize) -> bool {
194 self.success_count() >= min_successful
195 }
196
197 pub fn is_failure_tolerance_exceeded(&self, config: &CompletionConfig) -> bool {
203 let failures = self.failure_count();
204 let total = self.total_tasks();
205
206 if let Some(max_failures) = config.tolerated_failure_count {
208 if failures > max_failures {
209 return true;
210 }
211 }
212
213 if let Some(max_percentage) = config.tolerated_failure_percentage {
215 if total > 0 {
216 let failure_percentage = failures as f64 / total as f64;
217 if failure_percentage > max_percentage {
218 return true;
219 }
220 }
221 }
222
223 false
224 }
225
226 pub fn should_complete(&self, config: &CompletionConfig) -> Option<CompletionReason> {
235 let total = self.total_tasks();
236 let completed = self.completed_count();
237 let suspended = self.suspended_count();
238 let successes = self.success_count();
239
240 if let Some(min_successful) = config.min_successful {
242 if successes >= min_successful {
243 return Some(CompletionReason::MinSuccessfulReached);
244 }
245 }
246
247 if self.is_failure_tolerance_exceeded(config) {
249 return Some(CompletionReason::FailureToleranceExceeded);
250 }
251
252 if completed + suspended >= total {
254 if suspended > 0 && completed < total {
255 return Some(CompletionReason::Suspended);
256 }
257 return Some(CompletionReason::AllCompleted);
258 }
259
260 None
261 }
262
263 pub fn all_completed(&self) -> bool {
265 self.completed_count() >= self.total_tasks()
266 }
267
268 pub fn has_pending(&self) -> bool {
270 self.pending_count() > 0
271 }
272}
273
274impl Default for ExecutionCounters {
275 fn default() -> Self {
276 Self::new(0)
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
282pub enum CompletionReason {
283 AllCompleted,
285 MinSuccessfulReached,
287 FailureToleranceExceeded,
289 Suspended,
291}
292
293impl std::fmt::Display for CompletionReason {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 match self {
296 Self::AllCompleted => write!(f, "AllCompleted"),
297 Self::MinSuccessfulReached => write!(f, "MinSuccessfulReached"),
298 Self::FailureToleranceExceeded => write!(f, "FailureToleranceExceeded"),
299 Self::Suspended => write!(f, "Suspended"),
300 }
301 }
302}
303
304#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
306pub enum BatchItemStatus {
307 Succeeded,
309 Failed,
311 Cancelled,
313 Pending,
315 Suspended,
317}
318
319impl BatchItemStatus {
320 pub fn is_success(&self) -> bool {
322 matches!(self, Self::Succeeded)
323 }
324
325 pub fn is_failure(&self) -> bool {
327 matches!(self, Self::Failed)
328 }
329
330 pub fn is_terminal(&self) -> bool {
332 matches!(self, Self::Succeeded | Self::Failed | Self::Cancelled)
333 }
334
335 pub fn is_pending(&self) -> bool {
337 matches!(self, Self::Pending | Self::Suspended)
338 }
339}
340
341impl std::fmt::Display for BatchItemStatus {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 match self {
344 Self::Succeeded => write!(f, "Succeeded"),
345 Self::Failed => write!(f, "Failed"),
346 Self::Cancelled => write!(f, "Cancelled"),
347 Self::Pending => write!(f, "Pending"),
348 Self::Suspended => write!(f, "Suspended"),
349 }
350 }
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct BatchItem<T> {
381 pub index: usize,
383 pub status: BatchItemStatus,
385 pub result: Option<T>,
387 pub error: Option<ErrorObject>,
389}
390
391impl<T> BatchItem<T> {
392 pub fn succeeded(index: usize, result: T) -> Self {
394 Self {
395 index,
396 status: BatchItemStatus::Succeeded,
397 result: Some(result),
398 error: None,
399 }
400 }
401
402 pub fn failed(index: usize, error: ErrorObject) -> Self {
404 Self {
405 index,
406 status: BatchItemStatus::Failed,
407 result: None,
408 error: Some(error),
409 }
410 }
411
412 pub fn cancelled(index: usize) -> Self {
414 Self {
415 index,
416 status: BatchItemStatus::Cancelled,
417 result: None,
418 error: None,
419 }
420 }
421
422 pub fn pending(index: usize) -> Self {
424 Self {
425 index,
426 status: BatchItemStatus::Pending,
427 result: None,
428 error: None,
429 }
430 }
431
432 pub fn suspended(index: usize) -> Self {
434 Self {
435 index,
436 status: BatchItemStatus::Suspended,
437 result: None,
438 error: None,
439 }
440 }
441
442 pub fn is_succeeded(&self) -> bool {
444 self.status.is_success()
445 }
446
447 pub fn is_failed(&self) -> bool {
449 self.status.is_failure()
450 }
451
452 pub fn get_result(&self) -> Option<&T> {
454 self.result.as_ref()
455 }
456
457 pub fn get_error(&self) -> Option<&ErrorObject> {
459 self.error.as_ref()
460 }
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct BatchResult<T> {
488 pub items: Vec<BatchItem<T>>,
490 pub completion_reason: CompletionReason,
492}
493
494impl<T> BatchResult<T> {
495 pub fn new(items: Vec<BatchItem<T>>, completion_reason: CompletionReason) -> Self {
497 Self {
498 items,
499 completion_reason,
500 }
501 }
502
503 pub fn empty() -> Self {
505 Self {
506 items: Vec::new(),
507 completion_reason: CompletionReason::AllCompleted,
508 }
509 }
510
511 pub fn succeeded(&self) -> Vec<&BatchItem<T>> {
513 self.items
514 .iter()
515 .filter(|item| item.is_succeeded())
516 .collect()
517 }
518
519 pub fn failed(&self) -> Vec<&BatchItem<T>> {
521 self.items.iter().filter(|item| item.is_failed()).collect()
522 }
523
524 pub fn get_results(&self) -> Result<Vec<&T>, DurableError> {
529 if self.completion_reason == CompletionReason::FailureToleranceExceeded {
530 if let Some(failed_item) = self.failed().first() {
532 if let Some(ref error) = failed_item.error {
533 return Err(DurableError::UserCode {
534 message: error.error_message.clone(),
535 error_type: error.error_type.clone(),
536 stack_trace: error.stack_trace.clone(),
537 });
538 }
539 }
540 return Err(DurableError::execution("Batch operation failed"));
541 }
542
543 Ok(self
544 .items
545 .iter()
546 .filter_map(|item| item.result.as_ref())
547 .collect())
548 }
549
550 pub fn success_count(&self) -> usize {
552 self.succeeded().len()
553 }
554
555 pub fn failure_count(&self) -> usize {
557 self.failed().len()
558 }
559
560 pub fn total_count(&self) -> usize {
562 self.items.len()
563 }
564
565 pub fn all_succeeded(&self) -> bool {
567 self.items.iter().all(|item| item.is_succeeded())
568 }
569
570 pub fn has_failures(&self) -> bool {
572 self.items.iter().any(|item| item.is_failed())
573 }
574
575 pub fn is_failure(&self) -> bool {
577 self.completion_reason == CompletionReason::FailureToleranceExceeded
578 }
579
580 pub fn is_success(&self) -> bool {
582 matches!(
583 self.completion_reason,
584 CompletionReason::AllCompleted | CompletionReason::MinSuccessfulReached
585 )
586 }
587}
588
589impl<T> Default for BatchResult<T> {
590 fn default() -> Self {
591 Self::empty()
592 }
593}
594
595pub struct ConcurrentExecutor {
600 max_concurrency: Option<usize>,
602 completion_config: CompletionConfig,
604 counters: Arc<ExecutionCounters>,
606 completion_notify: Arc<Notify>,
608 semaphore: Option<Arc<Semaphore>>,
610}
611
612impl ConcurrentExecutor {
613 pub fn new(
621 total_tasks: usize,
622 max_concurrency: Option<usize>,
623 completion_config: CompletionConfig,
624 ) -> Self {
625 let semaphore = max_concurrency.map(|n| Arc::new(Semaphore::new(n)));
626
627 Self {
628 max_concurrency,
629 completion_config,
630 counters: Arc::new(ExecutionCounters::new(total_tasks)),
631 completion_notify: Arc::new(Notify::new()),
632 semaphore,
633 }
634 }
635
636 pub fn counters(&self) -> &Arc<ExecutionCounters> {
638 &self.counters
639 }
640
641 pub fn completion_notify(&self) -> &Arc<Notify> {
643 &self.completion_notify
644 }
645
646 pub fn should_complete(&self) -> Option<CompletionReason> {
648 self.counters.should_complete(&self.completion_config)
649 }
650
651 pub fn record_success(&self) -> Option<CompletionReason> {
655 self.counters.complete_task();
656 let reason = self.should_complete();
657 if reason.is_some() {
658 self.completion_notify.notify_waiters();
659 }
660 reason
661 }
662
663 pub fn record_failure(&self) -> Option<CompletionReason> {
667 self.counters.fail_task();
668 let reason = self.should_complete();
669 if reason.is_some() {
670 self.completion_notify.notify_waiters();
671 }
672 reason
673 }
674
675 pub fn record_suspend(&self) -> Option<CompletionReason> {
679 self.counters.suspend_task();
680 let reason = self.should_complete();
681 if reason.is_some() {
682 self.completion_notify.notify_waiters();
683 }
684 reason
685 }
686
687 pub async fn execute<T, F, Fut>(self, tasks: Vec<F>) -> BatchResult<T>
697 where
698 T: Send + 'static,
699 F: FnOnce(usize) -> Fut + Send + 'static,
700 Fut: std::future::Future<Output = Result<T, DurableError>> + Send + 'static,
701 {
702 let total = tasks.len();
703 if total == 0 {
704 return BatchResult::empty();
705 }
706
707 let results: Arc<Mutex<Vec<BatchItem<T>>>> =
709 Arc::new(Mutex::new((0..total).map(BatchItem::pending).collect()));
710
711 let mut handles = Vec::with_capacity(total);
713
714 for (index, task) in tasks.into_iter().enumerate() {
715 let counters = self.counters.clone();
716 let completion_notify = self.completion_notify.clone();
717 let completion_config = self.completion_config.clone();
718 let results = results.clone();
719 let semaphore = self.semaphore.clone();
720
721 let handle = tokio::spawn(async move {
722 let _permit = if let Some(ref sem) = semaphore {
724 Some(sem.acquire().await.expect("Semaphore closed"))
725 } else {
726 None
727 };
728
729 if counters.should_complete(&completion_config).is_some() {
731 let mut results_guard = results.lock().await;
732 results_guard[index] = BatchItem::cancelled(index);
733 return;
734 }
735
736 let result = task(index).await;
738
739 let mut results_guard = results.lock().await;
741 match result {
742 Ok(value) => {
743 results_guard[index] = BatchItem::succeeded(index, value);
744 counters.complete_task();
745 }
746 Err(DurableError::Suspend { .. }) => {
747 results_guard[index] = BatchItem::suspended(index);
748 counters.suspend_task();
749 }
750 Err(error) => {
751 let error_obj = ErrorObject::from(&error);
752 results_guard[index] = BatchItem::failed(index, error_obj);
753 counters.fail_task();
754 }
755 }
756 drop(results_guard);
757
758 if counters.should_complete(&completion_config).is_some() {
760 completion_notify.notify_waiters();
761 }
762 });
763
764 handles.push(handle);
765 }
766
767 for handle in handles {
769 let _ = handle.await;
770 }
771
772 let final_results = Arc::try_unwrap(results)
774 .map_err(|_| "All handles should be done")
775 .unwrap()
776 .into_inner();
777
778 let completion_reason = self
779 .counters
780 .should_complete(&self.completion_config)
781 .unwrap_or(CompletionReason::AllCompleted);
782
783 BatchResult::new(final_results, completion_reason)
784 }
785}
786
787impl std::fmt::Debug for ConcurrentExecutor {
788 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
789 f.debug_struct("ConcurrentExecutor")
790 .field("max_concurrency", &self.max_concurrency)
791 .field("completion_config", &self.completion_config)
792 .field("counters", &self.counters)
793 .finish_non_exhaustive()
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 mod execution_counters_tests {
802 use super::*;
803
804 #[test]
805 fn test_new() {
806 let counters = ExecutionCounters::new(10);
807 assert_eq!(counters.total_tasks(), 10);
808 assert_eq!(counters.success_count(), 0);
809 assert_eq!(counters.failure_count(), 0);
810 assert_eq!(counters.completed_count(), 0);
811 assert_eq!(counters.suspended_count(), 0);
812 assert_eq!(counters.pending_count(), 10);
813 }
814
815 #[test]
816 fn test_complete_task() {
817 let counters = ExecutionCounters::new(5);
818
819 assert_eq!(counters.complete_task(), 1);
820 assert_eq!(counters.success_count(), 1);
821 assert_eq!(counters.completed_count(), 1);
822 assert_eq!(counters.pending_count(), 4);
823
824 assert_eq!(counters.complete_task(), 2);
825 assert_eq!(counters.success_count(), 2);
826 assert_eq!(counters.completed_count(), 2);
827 }
828
829 #[test]
830 fn test_fail_task() {
831 let counters = ExecutionCounters::new(5);
832
833 assert_eq!(counters.fail_task(), 1);
834 assert_eq!(counters.failure_count(), 1);
835 assert_eq!(counters.completed_count(), 1);
836 assert_eq!(counters.pending_count(), 4);
837 }
838
839 #[test]
840 fn test_suspend_task() {
841 let counters = ExecutionCounters::new(5);
842
843 assert_eq!(counters.suspend_task(), 1);
844 assert_eq!(counters.suspended_count(), 1);
845 assert_eq!(counters.completed_count(), 0);
846 assert_eq!(counters.pending_count(), 4);
847 }
848
849 #[test]
850 fn test_is_min_successful_reached() {
851 let counters = ExecutionCounters::new(10);
852
853 assert!(!counters.is_min_successful_reached(3));
854
855 counters.complete_task();
856 counters.complete_task();
857 assert!(!counters.is_min_successful_reached(3));
858
859 counters.complete_task();
860 assert!(counters.is_min_successful_reached(3));
861 }
862
863 #[test]
864 fn test_is_failure_tolerance_exceeded_count() {
865 let counters = ExecutionCounters::new(10);
866 let config = CompletionConfig {
867 tolerated_failure_count: Some(2),
868 ..Default::default()
869 };
870
871 counters.fail_task();
872 counters.fail_task();
873 assert!(!counters.is_failure_tolerance_exceeded(&config));
874
875 counters.fail_task();
876 assert!(counters.is_failure_tolerance_exceeded(&config));
877 }
878
879 #[test]
880 fn test_is_failure_tolerance_exceeded_percentage() {
881 let counters = ExecutionCounters::new(10);
882 let config = CompletionConfig {
883 tolerated_failure_percentage: Some(0.2),
884 ..Default::default()
885 };
886
887 counters.fail_task();
888 counters.fail_task();
889 assert!(!counters.is_failure_tolerance_exceeded(&config));
890
891 counters.fail_task();
892 assert!(counters.is_failure_tolerance_exceeded(&config));
893 }
894
895 #[test]
896 fn test_should_complete_min_successful() {
897 let counters = ExecutionCounters::new(10);
898 let config = CompletionConfig::with_min_successful(3);
899
900 assert!(counters.should_complete(&config).is_none());
901
902 counters.complete_task();
903 counters.complete_task();
904 assert!(counters.should_complete(&config).is_none());
905
906 counters.complete_task();
907 assert_eq!(
908 counters.should_complete(&config),
909 Some(CompletionReason::MinSuccessfulReached)
910 );
911 }
912
913 #[test]
914 fn test_should_complete_failure_tolerance() {
915 let counters = ExecutionCounters::new(10);
916 let config = CompletionConfig::all_successful();
917
918 assert!(counters.should_complete(&config).is_none());
919
920 counters.fail_task();
921 assert_eq!(
922 counters.should_complete(&config),
923 Some(CompletionReason::FailureToleranceExceeded)
924 );
925 }
926
927 #[test]
928 fn test_should_complete_all_completed() {
929 let counters = ExecutionCounters::new(3);
930 let config = CompletionConfig::all_completed();
931
932 counters.complete_task();
933 counters.complete_task();
934 assert!(counters.should_complete(&config).is_none());
935
936 counters.complete_task();
937 assert_eq!(
938 counters.should_complete(&config),
939 Some(CompletionReason::AllCompleted)
940 );
941 }
942
943 #[test]
944 fn test_should_complete_suspended() {
945 let counters = ExecutionCounters::new(3);
946 let config = CompletionConfig::all_completed();
947
948 counters.complete_task();
949 counters.complete_task();
950 counters.suspend_task();
951
952 assert_eq!(
953 counters.should_complete(&config),
954 Some(CompletionReason::Suspended)
955 );
956 }
957
958 #[test]
959 fn test_all_completed() {
960 let counters = ExecutionCounters::new(3);
961
962 assert!(!counters.all_completed());
963
964 counters.complete_task();
965 counters.complete_task();
966 counters.complete_task();
967
968 assert!(counters.all_completed());
969 }
970
971 #[test]
972 fn test_has_pending() {
973 let counters = ExecutionCounters::new(2);
974
975 assert!(counters.has_pending());
976
977 counters.complete_task();
978 assert!(counters.has_pending());
979
980 counters.complete_task();
981 assert!(!counters.has_pending());
982 }
983
984 #[test]
989 fn test_concurrent_counter_updates() {
990 use std::sync::Arc;
991 use std::thread;
992
993 let counters = Arc::new(ExecutionCounters::new(1000));
994 let mut handles = vec![];
995
996 for _ in 0..10 {
998 let counters_clone = counters.clone();
999 handles.push(thread::spawn(move || {
1000 for _ in 0..50 {
1001 counters_clone.complete_task();
1002 }
1003 }));
1004 }
1005
1006 for _ in 0..5 {
1008 let counters_clone = counters.clone();
1009 handles.push(thread::spawn(move || {
1010 for _ in 0..50 {
1011 counters_clone.fail_task();
1012 }
1013 }));
1014 }
1015
1016 for _ in 0..5 {
1018 let counters_clone = counters.clone();
1019 handles.push(thread::spawn(move || {
1020 for _ in 0..50 {
1021 counters_clone.suspend_task();
1022 }
1023 }));
1024 }
1025
1026 for handle in handles {
1028 handle.join().unwrap();
1029 }
1030
1031 assert_eq!(counters.success_count(), 500);
1036 assert_eq!(counters.failure_count(), 250);
1037 assert_eq!(counters.suspended_count(), 250);
1038 assert_eq!(counters.completed_count(), 750);
1040 }
1041
1042 #[test]
1047 fn test_concurrent_read_write_stress() {
1048 use std::sync::Arc;
1049 use std::thread;
1050
1051 let counters = Arc::new(ExecutionCounters::new(10000));
1052 let mut handles = vec![];
1053
1054 for _ in 0..5 {
1056 let counters_clone = counters.clone();
1057 handles.push(thread::spawn(move || {
1058 for _ in 0..200 {
1059 counters_clone.complete_task();
1060 }
1061 }));
1062 }
1063
1064 for _ in 0..5 {
1066 let counters_clone = counters.clone();
1067 handles.push(thread::spawn(move || {
1068 let mut last_success = 0;
1069 for _ in 0..1000 {
1070 let current_success = counters_clone.success_count();
1071 assert!(
1073 current_success >= last_success,
1074 "Success count decreased from {} to {}",
1075 last_success,
1076 current_success
1077 );
1078 last_success = current_success;
1079
1080 let completed = counters_clone.completed_count();
1082 assert!(
1083 completed >= current_success,
1084 "Completed {} should be >= success {}",
1085 completed,
1086 current_success
1087 );
1088 }
1089 }));
1090 }
1091
1092 for handle in handles {
1094 handle.join().unwrap();
1095 }
1096
1097 assert_eq!(counters.success_count(), 1000); assert_eq!(counters.completed_count(), 1000);
1100 }
1101 }
1102
1103 mod batch_item_tests {
1104 use super::*;
1105
1106 #[test]
1107 fn test_succeeded() {
1108 let item = BatchItem::succeeded(0, 42);
1109 assert_eq!(item.index, 0);
1110 assert!(item.is_succeeded());
1111 assert!(!item.is_failed());
1112 assert_eq!(item.get_result(), Some(&42));
1113 assert!(item.get_error().is_none());
1114 }
1115
1116 #[test]
1117 fn test_failed() {
1118 let error = ErrorObject::new("TestError", "test message");
1119 let item: BatchItem<i32> = BatchItem::failed(1, error);
1120 assert_eq!(item.index, 1);
1121 assert!(!item.is_succeeded());
1122 assert!(item.is_failed());
1123 assert!(item.get_result().is_none());
1124 assert!(item.get_error().is_some());
1125 }
1126
1127 #[test]
1128 fn test_cancelled() {
1129 let item: BatchItem<i32> = BatchItem::cancelled(2);
1130 assert_eq!(item.index, 2);
1131 assert_eq!(item.status, BatchItemStatus::Cancelled);
1132 }
1133
1134 #[test]
1135 fn test_pending() {
1136 let item: BatchItem<i32> = BatchItem::pending(3);
1137 assert_eq!(item.index, 3);
1138 assert_eq!(item.status, BatchItemStatus::Pending);
1139 }
1140
1141 #[test]
1142 fn test_suspended() {
1143 let item: BatchItem<i32> = BatchItem::suspended(4);
1144 assert_eq!(item.index, 4);
1145 assert_eq!(item.status, BatchItemStatus::Suspended);
1146 }
1147 }
1148
1149 mod batch_result_tests {
1150 use super::*;
1151
1152 #[test]
1153 fn test_empty() {
1154 let result: BatchResult<i32> = BatchResult::empty();
1155 assert!(result.items.is_empty());
1156 assert_eq!(result.completion_reason, CompletionReason::AllCompleted);
1157 }
1158
1159 #[test]
1160 fn test_succeeded() {
1161 let items = vec![
1162 BatchItem::succeeded(0, 1),
1163 BatchItem::succeeded(1, 2),
1164 BatchItem::failed(2, ErrorObject::new("Error", "msg")),
1165 ];
1166 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1167
1168 let succeeded = result.succeeded();
1169 assert_eq!(succeeded.len(), 2);
1170 }
1171
1172 #[test]
1173 fn test_failed() {
1174 let items = vec![
1175 BatchItem::succeeded(0, 1),
1176 BatchItem::failed(1, ErrorObject::new("Error", "msg")),
1177 ];
1178 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1179
1180 let failed = result.failed();
1181 assert_eq!(failed.len(), 1);
1182 }
1183
1184 #[test]
1185 fn test_get_results_success() {
1186 let items = vec![BatchItem::succeeded(0, 1), BatchItem::succeeded(1, 2)];
1187 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1188
1189 let results = result.get_results().unwrap();
1190 assert_eq!(results, vec![&1, &2]);
1191 }
1192
1193 #[test]
1194 fn test_get_results_failure_tolerance_exceeded() {
1195 let items = vec![
1196 BatchItem::succeeded(0, 1),
1197 BatchItem::failed(1, ErrorObject::new("TestError", "test")),
1198 ];
1199 let result = BatchResult::new(items, CompletionReason::FailureToleranceExceeded);
1200
1201 assert!(result.get_results().is_err());
1202 }
1203
1204 #[test]
1205 fn test_counts() {
1206 let items = vec![
1207 BatchItem::succeeded(0, 1),
1208 BatchItem::succeeded(1, 2),
1209 BatchItem::failed(2, ErrorObject::new("Error", "msg")),
1210 ];
1211 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1212
1213 assert_eq!(result.success_count(), 2);
1214 assert_eq!(result.failure_count(), 1);
1215 assert_eq!(result.total_count(), 3);
1216 }
1217
1218 #[test]
1219 fn test_all_succeeded() {
1220 let items = vec![BatchItem::succeeded(0, 1), BatchItem::succeeded(1, 2)];
1221 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1222 assert!(result.all_succeeded());
1223
1224 let items_with_failure = vec![
1225 BatchItem::succeeded(0, 1),
1226 BatchItem::failed(1, ErrorObject::new("Error", "msg")),
1227 ];
1228 let result_with_failure =
1229 BatchResult::new(items_with_failure, CompletionReason::AllCompleted);
1230 assert!(!result_with_failure.all_succeeded());
1231 }
1232
1233 #[test]
1234 fn test_is_success() {
1235 let result: BatchResult<i32> = BatchResult::new(vec![], CompletionReason::AllCompleted);
1236 assert!(result.is_success());
1237
1238 let result2: BatchResult<i32> =
1239 BatchResult::new(vec![], CompletionReason::MinSuccessfulReached);
1240 assert!(result2.is_success());
1241
1242 let result3: BatchResult<i32> =
1243 BatchResult::new(vec![], CompletionReason::FailureToleranceExceeded);
1244 assert!(!result3.is_success());
1245 }
1246 }
1247
1248 mod concurrent_executor_tests {
1249 use super::*;
1250
1251 type TaskFn<T> = Box<
1252 dyn FnOnce(
1253 usize,
1254 ) -> std::pin::Pin<
1255 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send>,
1256 > + Send,
1257 >;
1258
1259 #[tokio::test]
1260 async fn test_execute_empty() {
1261 let executor = ConcurrentExecutor::new(0, None, CompletionConfig::all_completed());
1262 let tasks: Vec<TaskFn<i32>> = vec![];
1263 let result = executor.execute(tasks).await;
1264
1265 assert!(result.items.is_empty());
1266 assert_eq!(result.completion_reason, CompletionReason::AllCompleted);
1267 }
1268
1269 #[tokio::test]
1270 async fn test_execute_all_success() {
1271 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_completed());
1272 let tasks: Vec<_> = (0..3)
1273 .map(|i| move |_idx: usize| async move { Ok(i * 10) })
1274 .collect();
1275
1276 let result = executor.execute(tasks).await;
1277
1278 assert_eq!(result.total_count(), 3);
1279 assert_eq!(result.success_count(), 3);
1280 assert!(result.all_succeeded());
1281 }
1282
1283 #[tokio::test]
1284 async fn test_execute_with_failures() {
1285 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_completed());
1286
1287 let tasks: Vec<TaskFn<i32>> = vec![
1289 Box::new(|_idx: usize| {
1290 Box::pin(async { Ok(1) })
1291 as std::pin::Pin<
1292 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1293 >
1294 }),
1295 Box::new(|_idx: usize| {
1296 Box::pin(async { Err(DurableError::execution("test error")) })
1297 as std::pin::Pin<
1298 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1299 >
1300 }),
1301 Box::new(|_idx: usize| {
1302 Box::pin(async { Ok(3) })
1303 as std::pin::Pin<
1304 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1305 >
1306 }),
1307 ];
1308
1309 let result = executor.execute(tasks).await;
1310
1311 assert_eq!(result.total_count(), 3);
1312 assert_eq!(result.success_count(), 2);
1313 assert_eq!(result.failure_count(), 1);
1314 }
1315
1316 #[tokio::test]
1317 async fn test_execute_min_successful() {
1318 let executor =
1319 ConcurrentExecutor::new(5, None, CompletionConfig::with_min_successful(2));
1320 let tasks: Vec<_> = (0..5)
1321 .map(|i| move |_idx: usize| async move { Ok(i) })
1322 .collect();
1323
1324 let result = executor.execute(tasks).await;
1325
1326 assert!(result.success_count() >= 2);
1328 }
1329
1330 #[tokio::test]
1331 async fn test_execute_with_concurrency_limit() {
1332 let executor = ConcurrentExecutor::new(5, Some(2), CompletionConfig::all_completed());
1333 let tasks: Vec<_> = (0..5)
1334 .map(|i| move |_idx: usize| async move { Ok(i) })
1335 .collect();
1336
1337 let result = executor.execute(tasks).await;
1338
1339 assert_eq!(result.total_count(), 5);
1340 assert!(result.all_succeeded());
1341 }
1342
1343 #[tokio::test]
1344 async fn test_record_success() {
1345 let executor =
1346 ConcurrentExecutor::new(3, None, CompletionConfig::with_min_successful(2));
1347
1348 assert!(executor.record_success().is_none());
1349 assert_eq!(
1350 executor.record_success(),
1351 Some(CompletionReason::MinSuccessfulReached)
1352 );
1353 }
1354
1355 #[tokio::test]
1356 async fn test_record_failure() {
1357 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_successful());
1358
1359 assert_eq!(
1360 executor.record_failure(),
1361 Some(CompletionReason::FailureToleranceExceeded)
1362 );
1363 }
1364 }
1365}
1366
1367#[cfg(test)]
1368mod property_tests {
1369 use super::*;
1370 use proptest::prelude::*;
1371
1372 mod completion_criteria_tests {
1379 use super::*;
1380
1381 proptest! {
1382 #![proptest_config(ProptestConfig::with_cases(100))]
1383
1384 #[test]
1388 fn prop_min_successful_triggers_completion(
1389 total_tasks in 1usize..=50,
1390 min_successful_ratio in 0.1f64..=1.0,
1391 ) {
1392 let min_successful = ((total_tasks as f64 * min_successful_ratio).ceil() as usize).max(1).min(total_tasks);
1393 let config = CompletionConfig::with_min_successful(min_successful);
1394 let counters = ExecutionCounters::new(total_tasks);
1395
1396 for i in 0..min_successful {
1398 if i < min_successful - 1 {
1399 counters.complete_task();
1400 prop_assert!(
1402 counters.should_complete(&config).is_none() ||
1403 counters.should_complete(&config) == Some(CompletionReason::MinSuccessfulReached),
1404 "Should not complete before reaching min_successful"
1405 );
1406 } else {
1407 counters.complete_task();
1408 prop_assert_eq!(
1410 counters.should_complete(&config),
1411 Some(CompletionReason::MinSuccessfulReached),
1412 "Should complete when min_successful is reached"
1413 );
1414 }
1415 }
1416 }
1417
1418 #[test]
1422 fn prop_failure_tolerance_exceeded_triggers_completion(
1423 total_tasks in 2usize..=50,
1424 tolerated_failures in 0usize..=10,
1425 ) {
1426 let config = CompletionConfig::with_failure_tolerance(tolerated_failures);
1427 let counters = ExecutionCounters::new(total_tasks);
1428
1429 for i in 0..=tolerated_failures {
1431 counters.fail_task();
1432 if i < tolerated_failures {
1433 let result = counters.should_complete(&config);
1435 prop_assert!(
1436 result.is_none() || result == Some(CompletionReason::AllCompleted),
1437 "Should not trigger failure tolerance until exceeded"
1438 );
1439 }
1440 }
1441
1442 prop_assert_eq!(
1444 counters.should_complete(&config),
1445 Some(CompletionReason::FailureToleranceExceeded),
1446 "Should complete when failure tolerance is exceeded"
1447 );
1448 }
1449
1450 #[test]
1454 fn prop_all_completed_triggers_when_all_done(
1455 total_tasks in 1usize..=50,
1456 success_count in 0usize..=50,
1457 ) {
1458 let success_count = success_count.min(total_tasks);
1459 let failure_count = total_tasks - success_count;
1460 let config = CompletionConfig::all_completed();
1461 let counters = ExecutionCounters::new(total_tasks);
1462
1463 for _ in 0..success_count {
1465 counters.complete_task();
1466 }
1467
1468 for _ in 0..failure_count {
1470 counters.fail_task();
1471 }
1472
1473 prop_assert_eq!(
1475 counters.should_complete(&config),
1476 Some(CompletionReason::AllCompleted),
1477 "Should complete when all tasks are done"
1478 );
1479 }
1480
1481 #[test]
1484 fn prop_suspended_triggers_when_tasks_suspend(
1485 total_tasks in 2usize..=50,
1486 completed_count in 1usize..=49,
1487 ) {
1488 let completed_count = completed_count.min(total_tasks - 1);
1489 let suspended_count = total_tasks - completed_count;
1490 let config = CompletionConfig::all_completed();
1491 let counters = ExecutionCounters::new(total_tasks);
1492
1493 for _ in 0..completed_count {
1495 counters.complete_task();
1496 }
1497
1498 for _ in 0..suspended_count {
1500 counters.suspend_task();
1501 }
1502
1503 prop_assert_eq!(
1505 counters.should_complete(&config),
1506 Some(CompletionReason::Suspended),
1507 "Should return Suspended when tasks are suspended"
1508 );
1509 }
1510
1511 #[test]
1514 fn prop_success_count_accurate(
1515 total_tasks in 1usize..=100,
1516 successes in 0usize..=100,
1517 ) {
1518 let successes = successes.min(total_tasks);
1519 let counters = ExecutionCounters::new(total_tasks);
1520
1521 for _ in 0..successes {
1522 counters.complete_task();
1523 }
1524
1525 prop_assert_eq!(
1526 counters.success_count(),
1527 successes,
1528 "Success count should match number of complete_task calls"
1529 );
1530 }
1531
1532 #[test]
1535 fn prop_failure_count_accurate(
1536 total_tasks in 1usize..=100,
1537 failures in 0usize..=100,
1538 ) {
1539 let failures = failures.min(total_tasks);
1540 let counters = ExecutionCounters::new(total_tasks);
1541
1542 for _ in 0..failures {
1543 counters.fail_task();
1544 }
1545
1546 prop_assert_eq!(
1547 counters.failure_count(),
1548 failures,
1549 "Failure count should match number of fail_task calls"
1550 );
1551 }
1552
1553 #[test]
1556 fn prop_completed_count_is_sum(
1557 total_tasks in 2usize..=100,
1558 successes in 0usize..=50,
1559 failures in 0usize..=50,
1560 ) {
1561 let successes = successes.min(total_tasks / 2);
1562 let failures = failures.min(total_tasks - successes);
1563 let counters = ExecutionCounters::new(total_tasks);
1564
1565 for _ in 0..successes {
1566 counters.complete_task();
1567 }
1568 for _ in 0..failures {
1569 counters.fail_task();
1570 }
1571
1572 prop_assert_eq!(
1573 counters.completed_count(),
1574 successes + failures,
1575 "Completed count should equal success + failure"
1576 );
1577 }
1578
1579 #[test]
1582 fn prop_pending_count_accurate(
1583 total_tasks in 3usize..=100,
1584 successes in 0usize..=33,
1585 failures in 0usize..=33,
1586 suspends in 0usize..=33,
1587 ) {
1588 let successes = successes.min(total_tasks / 3);
1589 let failures = failures.min((total_tasks - successes) / 2);
1590 let suspends = suspends.min(total_tasks - successes - failures);
1591 let counters = ExecutionCounters::new(total_tasks);
1592
1593 for _ in 0..successes {
1594 counters.complete_task();
1595 }
1596 for _ in 0..failures {
1597 counters.fail_task();
1598 }
1599 for _ in 0..suspends {
1600 counters.suspend_task();
1601 }
1602
1603 let expected_pending = total_tasks - successes - failures - suspends;
1604 prop_assert_eq!(
1605 counters.pending_count(),
1606 expected_pending,
1607 "Pending count should be total - completed - suspended"
1608 );
1609 }
1610
1611 #[test]
1614 fn prop_failure_percentage_calculation(
1615 total_tasks in 1usize..=100,
1616 failures in 0usize..=100,
1617 tolerance_percentage in 0.0f64..=1.0,
1618 ) {
1619 let failures = failures.min(total_tasks);
1620 let config = CompletionConfig {
1621 tolerated_failure_percentage: Some(tolerance_percentage),
1622 ..Default::default()
1623 };
1624 let counters = ExecutionCounters::new(total_tasks);
1625
1626 for _ in 0..failures {
1627 counters.fail_task();
1628 }
1629
1630 let actual_percentage = failures as f64 / total_tasks as f64;
1631 let exceeded = counters.is_failure_tolerance_exceeded(&config);
1632
1633 if actual_percentage > tolerance_percentage {
1634 prop_assert!(exceeded, "Should exceed tolerance when percentage is higher");
1635 } else {
1636 prop_assert!(!exceeded, "Should not exceed tolerance when percentage is lower or equal");
1637 }
1638 }
1639 }
1640 }
1641}