1use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use tokio::sync::{Mutex, Notify, Semaphore};
12
13use crate::config::CompletionConfig;
14use crate::error::{DurableError, ErrorObject};
15
16#[derive(Debug)]
54pub struct ExecutionCounters {
55 total_tasks: AtomicUsize,
57 success_count: AtomicUsize,
59 failure_count: AtomicUsize,
61 completed_count: AtomicUsize,
63 suspended_count: AtomicUsize,
65}
66
67impl ExecutionCounters {
68 pub fn new(total_tasks: usize) -> Self {
70 Self {
71 total_tasks: AtomicUsize::new(total_tasks),
72 success_count: AtomicUsize::new(0),
73 failure_count: AtomicUsize::new(0),
74 completed_count: AtomicUsize::new(0),
75 suspended_count: AtomicUsize::new(0),
76 }
77 }
78
79 pub fn complete_task(&self) -> usize {
90 self.completed_count.fetch_add(1, Ordering::Relaxed);
92 self.success_count.fetch_add(1, Ordering::Relaxed) + 1
93 }
94
95 pub fn fail_task(&self) -> usize {
105 self.completed_count.fetch_add(1, Ordering::Relaxed);
107 self.failure_count.fetch_add(1, Ordering::Relaxed) + 1
108 }
109
110 pub fn suspend_task(&self) -> usize {
120 self.suspended_count.fetch_add(1, Ordering::Relaxed) + 1
122 }
123
124 pub fn total_tasks(&self) -> usize {
133 self.total_tasks.load(Ordering::Acquire)
135 }
136
137 pub fn success_count(&self) -> usize {
145 self.success_count.load(Ordering::Acquire)
147 }
148
149 pub fn failure_count(&self) -> usize {
157 self.failure_count.load(Ordering::Acquire)
159 }
160
161 pub fn completed_count(&self) -> usize {
169 self.completed_count.load(Ordering::Acquire)
171 }
172
173 pub fn suspended_count(&self) -> usize {
181 self.suspended_count.load(Ordering::Acquire)
183 }
184
185 pub fn pending_count(&self) -> usize {
187 let total = self.total_tasks();
188 let completed = self.completed_count();
189 let suspended = self.suspended_count();
190 total.saturating_sub(completed + suspended)
191 }
192
193 pub fn is_min_successful_reached(&self, min_successful: usize) -> bool {
199 self.success_count() >= min_successful
200 }
201
202 pub fn is_failure_tolerance_exceeded(&self, config: &CompletionConfig) -> bool {
208 let failures = self.failure_count();
209 let total = self.total_tasks();
210
211 if let Some(max_failures) = config.tolerated_failure_count {
213 if failures > max_failures {
214 return true;
215 }
216 }
217
218 if let Some(max_percentage) = config.tolerated_failure_percentage {
220 if total > 0 {
221 let failure_percentage = failures as f64 / total as f64;
222 if failure_percentage > max_percentage {
223 return true;
224 }
225 }
226 }
227
228 false
229 }
230
231 pub fn should_complete(&self, config: &CompletionConfig) -> Option<CompletionReason> {
240 let total = self.total_tasks();
241 let completed = self.completed_count();
242 let suspended = self.suspended_count();
243 let successes = self.success_count();
244
245 if let Some(min_successful) = config.min_successful {
247 if successes >= min_successful {
248 return Some(CompletionReason::MinSuccessfulReached);
249 }
250 }
251
252 if self.is_failure_tolerance_exceeded(config) {
254 return Some(CompletionReason::FailureToleranceExceeded);
255 }
256
257 if completed + suspended >= total {
259 if suspended > 0 && completed < total {
260 return Some(CompletionReason::Suspended);
261 }
262 return Some(CompletionReason::AllCompleted);
263 }
264
265 None
266 }
267
268 pub fn all_completed(&self) -> bool {
270 self.completed_count() >= self.total_tasks()
271 }
272
273 pub fn has_pending(&self) -> bool {
275 self.pending_count() > 0
276 }
277}
278
279impl Default for ExecutionCounters {
280 fn default() -> Self {
281 Self::new(0)
282 }
283}
284
285#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
287pub enum CompletionReason {
288 AllCompleted,
290 MinSuccessfulReached,
292 FailureToleranceExceeded,
294 Suspended,
296}
297
298impl std::fmt::Display for CompletionReason {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 match self {
301 Self::AllCompleted => write!(f, "AllCompleted"),
302 Self::MinSuccessfulReached => write!(f, "MinSuccessfulReached"),
303 Self::FailureToleranceExceeded => write!(f, "FailureToleranceExceeded"),
304 Self::Suspended => write!(f, "Suspended"),
305 }
306 }
307}
308
309#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
311pub enum BatchItemStatus {
312 Succeeded,
314 Failed,
316 Cancelled,
318 Pending,
320 Suspended,
322}
323
324impl BatchItemStatus {
325 pub fn is_success(&self) -> bool {
327 matches!(self, Self::Succeeded)
328 }
329
330 pub fn is_failure(&self) -> bool {
332 matches!(self, Self::Failed)
333 }
334
335 pub fn is_terminal(&self) -> bool {
337 matches!(self, Self::Succeeded | Self::Failed | Self::Cancelled)
338 }
339
340 pub fn is_pending(&self) -> bool {
342 matches!(self, Self::Pending | Self::Suspended)
343 }
344}
345
346impl std::fmt::Display for BatchItemStatus {
347 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 match self {
349 Self::Succeeded => write!(f, "Succeeded"),
350 Self::Failed => write!(f, "Failed"),
351 Self::Cancelled => write!(f, "Cancelled"),
352 Self::Pending => write!(f, "Pending"),
353 Self::Suspended => write!(f, "Suspended"),
354 }
355 }
356}
357
358#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct BatchItem<T> {
386 pub index: usize,
388 pub status: BatchItemStatus,
390 pub result: Option<T>,
392 pub error: Option<ErrorObject>,
394}
395
396impl<T> BatchItem<T> {
397 pub fn succeeded(index: usize, result: T) -> Self {
399 Self {
400 index,
401 status: BatchItemStatus::Succeeded,
402 result: Some(result),
403 error: None,
404 }
405 }
406
407 pub fn failed(index: usize, error: ErrorObject) -> Self {
409 Self {
410 index,
411 status: BatchItemStatus::Failed,
412 result: None,
413 error: Some(error),
414 }
415 }
416
417 pub fn cancelled(index: usize) -> Self {
419 Self {
420 index,
421 status: BatchItemStatus::Cancelled,
422 result: None,
423 error: None,
424 }
425 }
426
427 pub fn pending(index: usize) -> Self {
429 Self {
430 index,
431 status: BatchItemStatus::Pending,
432 result: None,
433 error: None,
434 }
435 }
436
437 pub fn suspended(index: usize) -> Self {
439 Self {
440 index,
441 status: BatchItemStatus::Suspended,
442 result: None,
443 error: None,
444 }
445 }
446
447 pub fn is_succeeded(&self) -> bool {
449 self.status.is_success()
450 }
451
452 pub fn is_failed(&self) -> bool {
454 self.status.is_failure()
455 }
456
457 pub fn get_result(&self) -> Option<&T> {
459 self.result.as_ref()
460 }
461
462 pub fn get_error(&self) -> Option<&ErrorObject> {
464 self.error.as_ref()
465 }
466}
467
468#[derive(Debug, Clone, Serialize, Deserialize)]
497pub struct BatchResult<T> {
498 pub items: Vec<BatchItem<T>>,
500 pub completion_reason: CompletionReason,
502}
503
504impl<T> BatchResult<T> {
505 pub fn new(items: Vec<BatchItem<T>>, completion_reason: CompletionReason) -> Self {
507 Self {
508 items,
509 completion_reason,
510 }
511 }
512
513 pub fn empty() -> Self {
515 Self {
516 items: Vec::new(),
517 completion_reason: CompletionReason::AllCompleted,
518 }
519 }
520
521 pub fn succeeded(&self) -> Vec<&BatchItem<T>> {
523 self.items
524 .iter()
525 .filter(|item| item.is_succeeded())
526 .collect()
527 }
528
529 pub fn failed(&self) -> Vec<&BatchItem<T>> {
531 self.items.iter().filter(|item| item.is_failed()).collect()
532 }
533
534 pub fn get_results(&self) -> Result<Vec<&T>, DurableError> {
539 if self.completion_reason == CompletionReason::FailureToleranceExceeded {
540 if let Some(failed_item) = self.failed().first() {
542 if let Some(ref error) = failed_item.error {
543 return Err(DurableError::UserCode {
544 message: error.error_message.clone(),
545 error_type: error.error_type.clone(),
546 stack_trace: error.stack_trace.clone(),
547 });
548 }
549 }
550 return Err(DurableError::execution("Batch operation failed"));
551 }
552
553 Ok(self
554 .items
555 .iter()
556 .filter_map(|item| item.result.as_ref())
557 .collect())
558 }
559
560 pub fn success_count(&self) -> usize {
562 self.succeeded().len()
563 }
564
565 pub fn failure_count(&self) -> usize {
567 self.failed().len()
568 }
569
570 pub fn total_count(&self) -> usize {
572 self.items.len()
573 }
574
575 pub fn all_succeeded(&self) -> bool {
577 self.items.iter().all(|item| item.is_succeeded())
578 }
579
580 pub fn has_failures(&self) -> bool {
582 self.items.iter().any(|item| item.is_failed())
583 }
584
585 pub fn is_failure(&self) -> bool {
587 self.completion_reason == CompletionReason::FailureToleranceExceeded
588 }
589
590 pub fn is_success(&self) -> bool {
592 matches!(
593 self.completion_reason,
594 CompletionReason::AllCompleted | CompletionReason::MinSuccessfulReached
595 )
596 }
597}
598
599impl<T> Default for BatchResult<T> {
600 fn default() -> Self {
601 Self::empty()
602 }
603}
604
605pub struct ConcurrentExecutor {
616 max_concurrency: Option<usize>,
618 completion_config: CompletionConfig,
620 counters: Arc<ExecutionCounters>,
622 completion_notify: Arc<Notify>,
624 semaphore: Option<Arc<Semaphore>>,
626}
627
628impl ConcurrentExecutor {
629 pub fn new(
637 total_tasks: usize,
638 max_concurrency: Option<usize>,
639 completion_config: CompletionConfig,
640 ) -> Self {
641 let semaphore = max_concurrency.map(|n| Arc::new(Semaphore::new(n)));
642
643 Self {
644 max_concurrency,
645 completion_config,
646 counters: Arc::new(ExecutionCounters::new(total_tasks)),
647 completion_notify: Arc::new(Notify::new()),
648 semaphore,
649 }
650 }
651
652 pub fn counters(&self) -> &Arc<ExecutionCounters> {
654 &self.counters
655 }
656
657 pub fn completion_notify(&self) -> &Arc<Notify> {
659 &self.completion_notify
660 }
661
662 pub fn should_complete(&self) -> Option<CompletionReason> {
664 self.counters.should_complete(&self.completion_config)
665 }
666
667 pub fn record_success(&self) -> Option<CompletionReason> {
671 self.counters.complete_task();
672 let reason = self.should_complete();
673 if reason.is_some() {
674 self.completion_notify.notify_waiters();
675 }
676 reason
677 }
678
679 pub fn record_failure(&self) -> Option<CompletionReason> {
683 self.counters.fail_task();
684 let reason = self.should_complete();
685 if reason.is_some() {
686 self.completion_notify.notify_waiters();
687 }
688 reason
689 }
690
691 pub fn record_suspend(&self) -> Option<CompletionReason> {
695 self.counters.suspend_task();
696 let reason = self.should_complete();
697 if reason.is_some() {
698 self.completion_notify.notify_waiters();
699 }
700 reason
701 }
702
703 pub async fn execute<T, F, Fut>(self, tasks: Vec<F>) -> BatchResult<T>
713 where
714 T: Send + 'static,
715 F: FnOnce(usize) -> Fut + Send + 'static,
716 Fut: std::future::Future<Output = Result<T, DurableError>> + Send + 'static,
717 {
718 let total = tasks.len();
719 if total == 0 {
720 return BatchResult::empty();
721 }
722
723 let results: Arc<Mutex<Vec<BatchItem<T>>>> =
725 Arc::new(Mutex::new((0..total).map(BatchItem::pending).collect()));
726
727 let mut handles = Vec::with_capacity(total);
729
730 for (index, task) in tasks.into_iter().enumerate() {
731 let counters = self.counters.clone();
732 let completion_notify = self.completion_notify.clone();
733 let completion_config = self.completion_config.clone();
734 let results = results.clone();
735 let semaphore = self.semaphore.clone();
736
737 let handle = tokio::spawn(async move {
738 let _permit = if let Some(ref sem) = semaphore {
740 Some(sem.acquire().await.expect("Semaphore closed"))
741 } else {
742 None
743 };
744
745 if counters.should_complete(&completion_config).is_some() {
747 let mut results_guard = results.lock().await;
748 results_guard[index] = BatchItem::cancelled(index);
749 return;
750 }
751
752 let result = task(index).await;
754
755 let mut results_guard = results.lock().await;
757 match result {
758 Ok(value) => {
759 results_guard[index] = BatchItem::succeeded(index, value);
760 counters.complete_task();
761 }
762 Err(DurableError::Suspend { .. }) => {
763 results_guard[index] = BatchItem::suspended(index);
764 counters.suspend_task();
765 }
766 Err(error) => {
767 let error_obj = ErrorObject::from(&error);
768 results_guard[index] = BatchItem::failed(index, error_obj);
769 counters.fail_task();
770 }
771 }
772 drop(results_guard);
773
774 if counters.should_complete(&completion_config).is_some() {
776 completion_notify.notify_waiters();
777 }
778 });
779
780 handles.push(handle);
781 }
782
783 for handle in handles {
785 let _ = handle.await;
786 }
787
788 let final_results = Arc::try_unwrap(results)
790 .map_err(|_| "All handles should be done")
791 .unwrap()
792 .into_inner();
793
794 let completion_reason = self
795 .counters
796 .should_complete(&self.completion_config)
797 .unwrap_or(CompletionReason::AllCompleted);
798
799 BatchResult::new(final_results, completion_reason)
800 }
801}
802
803impl std::fmt::Debug for ConcurrentExecutor {
804 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
805 f.debug_struct("ConcurrentExecutor")
806 .field("max_concurrency", &self.max_concurrency)
807 .field("completion_config", &self.completion_config)
808 .field("counters", &self.counters)
809 .finish_non_exhaustive()
810 }
811}
812
813#[cfg(test)]
814mod tests {
815 use super::*;
816
817 mod execution_counters_tests {
818 use super::*;
819
820 #[test]
821 fn test_new() {
822 let counters = ExecutionCounters::new(10);
823 assert_eq!(counters.total_tasks(), 10);
824 assert_eq!(counters.success_count(), 0);
825 assert_eq!(counters.failure_count(), 0);
826 assert_eq!(counters.completed_count(), 0);
827 assert_eq!(counters.suspended_count(), 0);
828 assert_eq!(counters.pending_count(), 10);
829 }
830
831 #[test]
832 fn test_complete_task() {
833 let counters = ExecutionCounters::new(5);
834
835 assert_eq!(counters.complete_task(), 1);
836 assert_eq!(counters.success_count(), 1);
837 assert_eq!(counters.completed_count(), 1);
838 assert_eq!(counters.pending_count(), 4);
839
840 assert_eq!(counters.complete_task(), 2);
841 assert_eq!(counters.success_count(), 2);
842 assert_eq!(counters.completed_count(), 2);
843 }
844
845 #[test]
846 fn test_fail_task() {
847 let counters = ExecutionCounters::new(5);
848
849 assert_eq!(counters.fail_task(), 1);
850 assert_eq!(counters.failure_count(), 1);
851 assert_eq!(counters.completed_count(), 1);
852 assert_eq!(counters.pending_count(), 4);
853 }
854
855 #[test]
856 fn test_suspend_task() {
857 let counters = ExecutionCounters::new(5);
858
859 assert_eq!(counters.suspend_task(), 1);
860 assert_eq!(counters.suspended_count(), 1);
861 assert_eq!(counters.completed_count(), 0);
862 assert_eq!(counters.pending_count(), 4);
863 }
864
865 #[test]
866 fn test_is_min_successful_reached() {
867 let counters = ExecutionCounters::new(10);
868
869 assert!(!counters.is_min_successful_reached(3));
870
871 counters.complete_task();
872 counters.complete_task();
873 assert!(!counters.is_min_successful_reached(3));
874
875 counters.complete_task();
876 assert!(counters.is_min_successful_reached(3));
877 }
878
879 #[test]
880 fn test_is_failure_tolerance_exceeded_count() {
881 let counters = ExecutionCounters::new(10);
882 let config = CompletionConfig {
883 tolerated_failure_count: Some(2),
884 ..Default::default()
885 };
886
887 counters.fail_task();
888 counters.fail_task();
889 assert!(!counters.is_failure_tolerance_exceeded(&config));
890
891 counters.fail_task();
892 assert!(counters.is_failure_tolerance_exceeded(&config));
893 }
894
895 #[test]
896 fn test_is_failure_tolerance_exceeded_percentage() {
897 let counters = ExecutionCounters::new(10);
898 let config = CompletionConfig {
899 tolerated_failure_percentage: Some(0.2),
900 ..Default::default()
901 };
902
903 counters.fail_task();
904 counters.fail_task();
905 assert!(!counters.is_failure_tolerance_exceeded(&config));
906
907 counters.fail_task();
908 assert!(counters.is_failure_tolerance_exceeded(&config));
909 }
910
911 #[test]
912 fn test_should_complete_min_successful() {
913 let counters = ExecutionCounters::new(10);
914 let config = CompletionConfig::with_min_successful(3);
915
916 assert!(counters.should_complete(&config).is_none());
917
918 counters.complete_task();
919 counters.complete_task();
920 assert!(counters.should_complete(&config).is_none());
921
922 counters.complete_task();
923 assert_eq!(
924 counters.should_complete(&config),
925 Some(CompletionReason::MinSuccessfulReached)
926 );
927 }
928
929 #[test]
930 fn test_should_complete_failure_tolerance() {
931 let counters = ExecutionCounters::new(10);
932 let config = CompletionConfig::all_successful();
933
934 assert!(counters.should_complete(&config).is_none());
935
936 counters.fail_task();
937 assert_eq!(
938 counters.should_complete(&config),
939 Some(CompletionReason::FailureToleranceExceeded)
940 );
941 }
942
943 #[test]
944 fn test_should_complete_all_completed() {
945 let counters = ExecutionCounters::new(3);
946 let config = CompletionConfig::all_completed();
947
948 counters.complete_task();
949 counters.complete_task();
950 assert!(counters.should_complete(&config).is_none());
951
952 counters.complete_task();
953 assert_eq!(
954 counters.should_complete(&config),
955 Some(CompletionReason::AllCompleted)
956 );
957 }
958
959 #[test]
960 fn test_should_complete_suspended() {
961 let counters = ExecutionCounters::new(3);
962 let config = CompletionConfig::all_completed();
963
964 counters.complete_task();
965 counters.complete_task();
966 counters.suspend_task();
967
968 assert_eq!(
969 counters.should_complete(&config),
970 Some(CompletionReason::Suspended)
971 );
972 }
973
974 #[test]
975 fn test_all_completed() {
976 let counters = ExecutionCounters::new(3);
977
978 assert!(!counters.all_completed());
979
980 counters.complete_task();
981 counters.complete_task();
982 counters.complete_task();
983
984 assert!(counters.all_completed());
985 }
986
987 #[test]
988 fn test_has_pending() {
989 let counters = ExecutionCounters::new(2);
990
991 assert!(counters.has_pending());
992
993 counters.complete_task();
994 assert!(counters.has_pending());
995
996 counters.complete_task();
997 assert!(!counters.has_pending());
998 }
999
1000 #[test]
1005 fn test_concurrent_counter_updates() {
1006 use std::sync::Arc;
1007 use std::thread;
1008
1009 let counters = Arc::new(ExecutionCounters::new(1000));
1010 let mut handles = vec![];
1011
1012 for _ in 0..10 {
1014 let counters_clone = counters.clone();
1015 handles.push(thread::spawn(move || {
1016 for _ in 0..50 {
1017 counters_clone.complete_task();
1018 }
1019 }));
1020 }
1021
1022 for _ in 0..5 {
1024 let counters_clone = counters.clone();
1025 handles.push(thread::spawn(move || {
1026 for _ in 0..50 {
1027 counters_clone.fail_task();
1028 }
1029 }));
1030 }
1031
1032 for _ in 0..5 {
1034 let counters_clone = counters.clone();
1035 handles.push(thread::spawn(move || {
1036 for _ in 0..50 {
1037 counters_clone.suspend_task();
1038 }
1039 }));
1040 }
1041
1042 for handle in handles {
1044 handle.join().unwrap();
1045 }
1046
1047 assert_eq!(counters.success_count(), 500);
1052 assert_eq!(counters.failure_count(), 250);
1053 assert_eq!(counters.suspended_count(), 250);
1054 assert_eq!(counters.completed_count(), 750);
1056 }
1057
1058 #[test]
1063 fn test_concurrent_read_write_stress() {
1064 use std::sync::Arc;
1065 use std::thread;
1066
1067 let counters = Arc::new(ExecutionCounters::new(10000));
1068 let mut handles = vec![];
1069
1070 for _ in 0..5 {
1072 let counters_clone = counters.clone();
1073 handles.push(thread::spawn(move || {
1074 for _ in 0..200 {
1075 counters_clone.complete_task();
1076 }
1077 }));
1078 }
1079
1080 for _ in 0..5 {
1082 let counters_clone = counters.clone();
1083 handles.push(thread::spawn(move || {
1084 let mut last_success = 0;
1085 for _ in 0..1000 {
1086 let current_success = counters_clone.success_count();
1087 assert!(
1089 current_success >= last_success,
1090 "Success count decreased from {} to {}",
1091 last_success,
1092 current_success
1093 );
1094 last_success = current_success;
1095
1096 let completed = counters_clone.completed_count();
1098 assert!(
1099 completed >= current_success,
1100 "Completed {} should be >= success {}",
1101 completed,
1102 current_success
1103 );
1104 }
1105 }));
1106 }
1107
1108 for handle in handles {
1110 handle.join().unwrap();
1111 }
1112
1113 assert_eq!(counters.success_count(), 1000); assert_eq!(counters.completed_count(), 1000);
1116 }
1117 }
1118
1119 mod batch_item_tests {
1120 use super::*;
1121
1122 #[test]
1123 fn test_succeeded() {
1124 let item = BatchItem::succeeded(0, 42);
1125 assert_eq!(item.index, 0);
1126 assert!(item.is_succeeded());
1127 assert!(!item.is_failed());
1128 assert_eq!(item.get_result(), Some(&42));
1129 assert!(item.get_error().is_none());
1130 }
1131
1132 #[test]
1133 fn test_failed() {
1134 let error = ErrorObject::new("TestError", "test message");
1135 let item: BatchItem<i32> = BatchItem::failed(1, error);
1136 assert_eq!(item.index, 1);
1137 assert!(!item.is_succeeded());
1138 assert!(item.is_failed());
1139 assert!(item.get_result().is_none());
1140 assert!(item.get_error().is_some());
1141 }
1142
1143 #[test]
1144 fn test_cancelled() {
1145 let item: BatchItem<i32> = BatchItem::cancelled(2);
1146 assert_eq!(item.index, 2);
1147 assert_eq!(item.status, BatchItemStatus::Cancelled);
1148 }
1149
1150 #[test]
1151 fn test_pending() {
1152 let item: BatchItem<i32> = BatchItem::pending(3);
1153 assert_eq!(item.index, 3);
1154 assert_eq!(item.status, BatchItemStatus::Pending);
1155 }
1156
1157 #[test]
1158 fn test_suspended() {
1159 let item: BatchItem<i32> = BatchItem::suspended(4);
1160 assert_eq!(item.index, 4);
1161 assert_eq!(item.status, BatchItemStatus::Suspended);
1162 }
1163 }
1164
1165 mod batch_result_tests {
1166 use super::*;
1167
1168 #[test]
1169 fn test_empty() {
1170 let result: BatchResult<i32> = BatchResult::empty();
1171 assert!(result.items.is_empty());
1172 assert_eq!(result.completion_reason, CompletionReason::AllCompleted);
1173 }
1174
1175 #[test]
1176 fn test_succeeded() {
1177 let items = vec![
1178 BatchItem::succeeded(0, 1),
1179 BatchItem::succeeded(1, 2),
1180 BatchItem::failed(2, ErrorObject::new("Error", "msg")),
1181 ];
1182 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1183
1184 let succeeded = result.succeeded();
1185 assert_eq!(succeeded.len(), 2);
1186 }
1187
1188 #[test]
1189 fn test_failed() {
1190 let items = vec![
1191 BatchItem::succeeded(0, 1),
1192 BatchItem::failed(1, ErrorObject::new("Error", "msg")),
1193 ];
1194 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1195
1196 let failed = result.failed();
1197 assert_eq!(failed.len(), 1);
1198 }
1199
1200 #[test]
1201 fn test_get_results_success() {
1202 let items = vec![BatchItem::succeeded(0, 1), BatchItem::succeeded(1, 2)];
1203 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1204
1205 let results = result.get_results().unwrap();
1206 assert_eq!(results, vec![&1, &2]);
1207 }
1208
1209 #[test]
1210 fn test_get_results_failure_tolerance_exceeded() {
1211 let items = vec![
1212 BatchItem::succeeded(0, 1),
1213 BatchItem::failed(1, ErrorObject::new("TestError", "test")),
1214 ];
1215 let result = BatchResult::new(items, CompletionReason::FailureToleranceExceeded);
1216
1217 assert!(result.get_results().is_err());
1218 }
1219
1220 #[test]
1221 fn test_counts() {
1222 let items = vec![
1223 BatchItem::succeeded(0, 1),
1224 BatchItem::succeeded(1, 2),
1225 BatchItem::failed(2, ErrorObject::new("Error", "msg")),
1226 ];
1227 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1228
1229 assert_eq!(result.success_count(), 2);
1230 assert_eq!(result.failure_count(), 1);
1231 assert_eq!(result.total_count(), 3);
1232 }
1233
1234 #[test]
1235 fn test_all_succeeded() {
1236 let items = vec![BatchItem::succeeded(0, 1), BatchItem::succeeded(1, 2)];
1237 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1238 assert!(result.all_succeeded());
1239
1240 let items_with_failure = vec![
1241 BatchItem::succeeded(0, 1),
1242 BatchItem::failed(1, ErrorObject::new("Error", "msg")),
1243 ];
1244 let result_with_failure =
1245 BatchResult::new(items_with_failure, CompletionReason::AllCompleted);
1246 assert!(!result_with_failure.all_succeeded());
1247 }
1248
1249 #[test]
1250 fn test_is_success() {
1251 let result: BatchResult<i32> = BatchResult::new(vec![], CompletionReason::AllCompleted);
1252 assert!(result.is_success());
1253
1254 let result2: BatchResult<i32> =
1255 BatchResult::new(vec![], CompletionReason::MinSuccessfulReached);
1256 assert!(result2.is_success());
1257
1258 let result3: BatchResult<i32> =
1259 BatchResult::new(vec![], CompletionReason::FailureToleranceExceeded);
1260 assert!(!result3.is_success());
1261 }
1262 }
1263
1264 mod concurrent_executor_tests {
1265 use super::*;
1266
1267 #[tokio::test]
1268 async fn test_execute_empty() {
1269 let executor = ConcurrentExecutor::new(0, None, CompletionConfig::all_completed());
1270 let tasks: Vec<
1271 Box<
1272 dyn FnOnce(
1273 usize,
1274 ) -> std::pin::Pin<
1275 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1276 > + Send,
1277 >,
1278 > = vec![];
1279 let result = executor.execute(tasks).await;
1280
1281 assert!(result.items.is_empty());
1282 assert_eq!(result.completion_reason, CompletionReason::AllCompleted);
1283 }
1284
1285 #[tokio::test]
1286 async fn test_execute_all_success() {
1287 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_completed());
1288 let tasks: Vec<_> = (0..3)
1289 .map(|i| move |_idx: usize| async move { Ok(i * 10) })
1290 .collect();
1291
1292 let result = executor.execute(tasks).await;
1293
1294 assert_eq!(result.total_count(), 3);
1295 assert_eq!(result.success_count(), 3);
1296 assert!(result.all_succeeded());
1297 }
1298
1299 #[tokio::test]
1300 async fn test_execute_with_failures() {
1301 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_completed());
1302
1303 let tasks: Vec<
1305 Box<
1306 dyn FnOnce(
1307 usize,
1308 ) -> std::pin::Pin<
1309 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1310 > + Send,
1311 >,
1312 > = vec![
1313 Box::new(|_idx: usize| {
1314 Box::pin(async { Ok(1) })
1315 as std::pin::Pin<
1316 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1317 >
1318 }),
1319 Box::new(|_idx: usize| {
1320 Box::pin(async { Err(DurableError::execution("test error")) })
1321 as std::pin::Pin<
1322 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1323 >
1324 }),
1325 Box::new(|_idx: usize| {
1326 Box::pin(async { Ok(3) })
1327 as std::pin::Pin<
1328 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1329 >
1330 }),
1331 ];
1332
1333 let result = executor.execute(tasks).await;
1334
1335 assert_eq!(result.total_count(), 3);
1336 assert_eq!(result.success_count(), 2);
1337 assert_eq!(result.failure_count(), 1);
1338 }
1339
1340 #[tokio::test]
1341 async fn test_execute_min_successful() {
1342 let executor =
1343 ConcurrentExecutor::new(5, None, CompletionConfig::with_min_successful(2));
1344 let tasks: Vec<_> = (0..5)
1345 .map(|i| move |_idx: usize| async move { Ok(i) })
1346 .collect();
1347
1348 let result = executor.execute(tasks).await;
1349
1350 assert!(result.success_count() >= 2);
1352 }
1353
1354 #[tokio::test]
1355 async fn test_execute_with_concurrency_limit() {
1356 let executor = ConcurrentExecutor::new(5, Some(2), CompletionConfig::all_completed());
1357 let tasks: Vec<_> = (0..5)
1358 .map(|i| move |_idx: usize| async move { Ok(i) })
1359 .collect();
1360
1361 let result = executor.execute(tasks).await;
1362
1363 assert_eq!(result.total_count(), 5);
1364 assert!(result.all_succeeded());
1365 }
1366
1367 #[tokio::test]
1368 async fn test_record_success() {
1369 let executor =
1370 ConcurrentExecutor::new(3, None, CompletionConfig::with_min_successful(2));
1371
1372 assert!(executor.record_success().is_none());
1373 assert_eq!(
1374 executor.record_success(),
1375 Some(CompletionReason::MinSuccessfulReached)
1376 );
1377 }
1378
1379 #[tokio::test]
1380 async fn test_record_failure() {
1381 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_successful());
1382
1383 assert_eq!(
1384 executor.record_failure(),
1385 Some(CompletionReason::FailureToleranceExceeded)
1386 );
1387 }
1388 }
1389}
1390
1391#[cfg(test)]
1392mod property_tests {
1393 use super::*;
1394 use proptest::prelude::*;
1395
1396 mod completion_criteria_tests {
1403 use super::*;
1404
1405 proptest! {
1406 #![proptest_config(ProptestConfig::with_cases(100))]
1407
1408 #[test]
1412 fn prop_min_successful_triggers_completion(
1413 total_tasks in 1usize..=50,
1414 min_successful_ratio in 0.1f64..=1.0,
1415 ) {
1416 let min_successful = ((total_tasks as f64 * min_successful_ratio).ceil() as usize).max(1).min(total_tasks);
1417 let config = CompletionConfig::with_min_successful(min_successful);
1418 let counters = ExecutionCounters::new(total_tasks);
1419
1420 for i in 0..min_successful {
1422 if i < min_successful - 1 {
1423 counters.complete_task();
1424 prop_assert!(
1426 counters.should_complete(&config).is_none() ||
1427 counters.should_complete(&config) == Some(CompletionReason::MinSuccessfulReached),
1428 "Should not complete before reaching min_successful"
1429 );
1430 } else {
1431 counters.complete_task();
1432 prop_assert_eq!(
1434 counters.should_complete(&config),
1435 Some(CompletionReason::MinSuccessfulReached),
1436 "Should complete when min_successful is reached"
1437 );
1438 }
1439 }
1440 }
1441
1442 #[test]
1446 fn prop_failure_tolerance_exceeded_triggers_completion(
1447 total_tasks in 2usize..=50,
1448 tolerated_failures in 0usize..=10,
1449 ) {
1450 let config = CompletionConfig::with_failure_tolerance(tolerated_failures);
1451 let counters = ExecutionCounters::new(total_tasks);
1452
1453 for i in 0..=tolerated_failures {
1455 counters.fail_task();
1456 if i < tolerated_failures {
1457 let result = counters.should_complete(&config);
1459 prop_assert!(
1460 result.is_none() || result == Some(CompletionReason::AllCompleted),
1461 "Should not trigger failure tolerance until exceeded"
1462 );
1463 }
1464 }
1465
1466 prop_assert_eq!(
1468 counters.should_complete(&config),
1469 Some(CompletionReason::FailureToleranceExceeded),
1470 "Should complete when failure tolerance is exceeded"
1471 );
1472 }
1473
1474 #[test]
1478 fn prop_all_completed_triggers_when_all_done(
1479 total_tasks in 1usize..=50,
1480 success_count in 0usize..=50,
1481 ) {
1482 let success_count = success_count.min(total_tasks);
1483 let failure_count = total_tasks - success_count;
1484 let config = CompletionConfig::all_completed();
1485 let counters = ExecutionCounters::new(total_tasks);
1486
1487 for _ in 0..success_count {
1489 counters.complete_task();
1490 }
1491
1492 for _ in 0..failure_count {
1494 counters.fail_task();
1495 }
1496
1497 prop_assert_eq!(
1499 counters.should_complete(&config),
1500 Some(CompletionReason::AllCompleted),
1501 "Should complete when all tasks are done"
1502 );
1503 }
1504
1505 #[test]
1508 fn prop_suspended_triggers_when_tasks_suspend(
1509 total_tasks in 2usize..=50,
1510 completed_count in 1usize..=49,
1511 ) {
1512 let completed_count = completed_count.min(total_tasks - 1);
1513 let suspended_count = total_tasks - completed_count;
1514 let config = CompletionConfig::all_completed();
1515 let counters = ExecutionCounters::new(total_tasks);
1516
1517 for _ in 0..completed_count {
1519 counters.complete_task();
1520 }
1521
1522 for _ in 0..suspended_count {
1524 counters.suspend_task();
1525 }
1526
1527 prop_assert_eq!(
1529 counters.should_complete(&config),
1530 Some(CompletionReason::Suspended),
1531 "Should return Suspended when tasks are suspended"
1532 );
1533 }
1534
1535 #[test]
1538 fn prop_success_count_accurate(
1539 total_tasks in 1usize..=100,
1540 successes in 0usize..=100,
1541 ) {
1542 let successes = successes.min(total_tasks);
1543 let counters = ExecutionCounters::new(total_tasks);
1544
1545 for _ in 0..successes {
1546 counters.complete_task();
1547 }
1548
1549 prop_assert_eq!(
1550 counters.success_count(),
1551 successes,
1552 "Success count should match number of complete_task calls"
1553 );
1554 }
1555
1556 #[test]
1559 fn prop_failure_count_accurate(
1560 total_tasks in 1usize..=100,
1561 failures in 0usize..=100,
1562 ) {
1563 let failures = failures.min(total_tasks);
1564 let counters = ExecutionCounters::new(total_tasks);
1565
1566 for _ in 0..failures {
1567 counters.fail_task();
1568 }
1569
1570 prop_assert_eq!(
1571 counters.failure_count(),
1572 failures,
1573 "Failure count should match number of fail_task calls"
1574 );
1575 }
1576
1577 #[test]
1580 fn prop_completed_count_is_sum(
1581 total_tasks in 2usize..=100,
1582 successes in 0usize..=50,
1583 failures in 0usize..=50,
1584 ) {
1585 let successes = successes.min(total_tasks / 2);
1586 let failures = failures.min(total_tasks - successes);
1587 let counters = ExecutionCounters::new(total_tasks);
1588
1589 for _ in 0..successes {
1590 counters.complete_task();
1591 }
1592 for _ in 0..failures {
1593 counters.fail_task();
1594 }
1595
1596 prop_assert_eq!(
1597 counters.completed_count(),
1598 successes + failures,
1599 "Completed count should equal success + failure"
1600 );
1601 }
1602
1603 #[test]
1606 fn prop_pending_count_accurate(
1607 total_tasks in 3usize..=100,
1608 successes in 0usize..=33,
1609 failures in 0usize..=33,
1610 suspends in 0usize..=33,
1611 ) {
1612 let successes = successes.min(total_tasks / 3);
1613 let failures = failures.min((total_tasks - successes) / 2);
1614 let suspends = suspends.min(total_tasks - successes - failures);
1615 let counters = ExecutionCounters::new(total_tasks);
1616
1617 for _ in 0..successes {
1618 counters.complete_task();
1619 }
1620 for _ in 0..failures {
1621 counters.fail_task();
1622 }
1623 for _ in 0..suspends {
1624 counters.suspend_task();
1625 }
1626
1627 let expected_pending = total_tasks - successes - failures - suspends;
1628 prop_assert_eq!(
1629 counters.pending_count(),
1630 expected_pending,
1631 "Pending count should be total - completed - suspended"
1632 );
1633 }
1634
1635 #[test]
1638 fn prop_failure_percentage_calculation(
1639 total_tasks in 1usize..=100,
1640 failures in 0usize..=100,
1641 tolerance_percentage in 0.0f64..=1.0,
1642 ) {
1643 let failures = failures.min(total_tasks);
1644 let config = CompletionConfig {
1645 tolerated_failure_percentage: Some(tolerance_percentage),
1646 ..Default::default()
1647 };
1648 let counters = ExecutionCounters::new(total_tasks);
1649
1650 for _ in 0..failures {
1651 counters.fail_task();
1652 }
1653
1654 let actual_percentage = failures as f64 / total_tasks as f64;
1655 let exceeded = counters.is_failure_tolerance_exceeded(&config);
1656
1657 if actual_percentage > tolerance_percentage {
1658 prop_assert!(exceeded, "Should exceed tolerance when percentage is higher");
1659 } else {
1660 prop_assert!(!exceeded, "Should not exceed tolerance when percentage is lower or equal");
1661 }
1662 }
1663 }
1664 }
1665}