1use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use tokio::sync::{Mutex, Notify, Semaphore};
12
13use crate::config::CompletionConfig;
14use crate::error::{DurableError, ErrorObject};
15
16#[derive(Debug)]
49pub struct ExecutionCounters {
50 total_tasks: AtomicUsize,
52 success_count: AtomicUsize,
54 failure_count: AtomicUsize,
56 completed_count: AtomicUsize,
58 suspended_count: AtomicUsize,
60}
61
62impl ExecutionCounters {
63 pub fn new(total_tasks: usize) -> Self {
65 Self {
66 total_tasks: AtomicUsize::new(total_tasks),
67 success_count: AtomicUsize::new(0),
68 failure_count: AtomicUsize::new(0),
69 completed_count: AtomicUsize::new(0),
70 suspended_count: AtomicUsize::new(0),
71 }
72 }
73
74 pub fn complete_task(&self) -> usize {
85 self.completed_count.fetch_add(1, Ordering::Relaxed);
87 self.success_count.fetch_add(1, Ordering::Relaxed) + 1
88 }
89
90 pub fn fail_task(&self) -> usize {
100 self.completed_count.fetch_add(1, Ordering::Relaxed);
102 self.failure_count.fetch_add(1, Ordering::Relaxed) + 1
103 }
104
105 pub fn suspend_task(&self) -> usize {
115 self.suspended_count.fetch_add(1, Ordering::Relaxed) + 1
117 }
118
119 pub fn total_tasks(&self) -> usize {
128 self.total_tasks.load(Ordering::Acquire)
130 }
131
132 pub fn success_count(&self) -> usize {
140 self.success_count.load(Ordering::Acquire)
142 }
143
144 pub fn failure_count(&self) -> usize {
152 self.failure_count.load(Ordering::Acquire)
154 }
155
156 pub fn completed_count(&self) -> usize {
164 self.completed_count.load(Ordering::Acquire)
166 }
167
168 pub fn suspended_count(&self) -> usize {
176 self.suspended_count.load(Ordering::Acquire)
178 }
179
180 pub fn pending_count(&self) -> usize {
182 let total = self.total_tasks();
183 let completed = self.completed_count();
184 let suspended = self.suspended_count();
185 total.saturating_sub(completed + suspended)
186 }
187
188 pub fn is_min_successful_reached(&self, min_successful: usize) -> bool {
194 self.success_count() >= min_successful
195 }
196
197 pub fn is_failure_tolerance_exceeded(&self, config: &CompletionConfig) -> bool {
203 let failures = self.failure_count();
204 let total = self.total_tasks();
205
206 if let Some(max_failures) = config.tolerated_failure_count {
208 if failures > max_failures {
209 return true;
210 }
211 }
212
213 if let Some(max_percentage) = config.tolerated_failure_percentage {
215 if total > 0 {
216 let failure_percentage = failures as f64 / total as f64;
217 if failure_percentage > max_percentage {
218 return true;
219 }
220 }
221 }
222
223 false
224 }
225
226 pub fn should_complete(&self, config: &CompletionConfig) -> Option<CompletionReason> {
235 let total = self.total_tasks();
236 let completed = self.completed_count();
237 let suspended = self.suspended_count();
238 let successes = self.success_count();
239
240 if let Some(min_successful) = config.min_successful {
242 if successes >= min_successful {
243 return Some(CompletionReason::MinSuccessfulReached);
244 }
245 }
246
247 if self.is_failure_tolerance_exceeded(config) {
249 return Some(CompletionReason::FailureToleranceExceeded);
250 }
251
252 if completed + suspended >= total {
254 if suspended > 0 && completed < total {
255 return Some(CompletionReason::Suspended);
256 }
257 return Some(CompletionReason::AllCompleted);
258 }
259
260 None
261 }
262
263 pub fn all_completed(&self) -> bool {
265 self.completed_count() >= self.total_tasks()
266 }
267
268 pub fn has_pending(&self) -> bool {
270 self.pending_count() > 0
271 }
272}
273
274impl Default for ExecutionCounters {
275 fn default() -> Self {
276 Self::new(0)
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
282pub enum CompletionReason {
283 AllCompleted,
285 MinSuccessfulReached,
287 FailureToleranceExceeded,
289 Suspended,
291}
292
293impl std::fmt::Display for CompletionReason {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 match self {
296 Self::AllCompleted => write!(f, "AllCompleted"),
297 Self::MinSuccessfulReached => write!(f, "MinSuccessfulReached"),
298 Self::FailureToleranceExceeded => write!(f, "FailureToleranceExceeded"),
299 Self::Suspended => write!(f, "Suspended"),
300 }
301 }
302}
303
304#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
306pub enum BatchItemStatus {
307 Succeeded,
309 Failed,
311 Cancelled,
313 Pending,
315 Suspended,
317}
318
319impl BatchItemStatus {
320 pub fn is_success(&self) -> bool {
322 matches!(self, Self::Succeeded)
323 }
324
325 pub fn is_failure(&self) -> bool {
327 matches!(self, Self::Failed)
328 }
329
330 pub fn is_terminal(&self) -> bool {
332 matches!(self, Self::Succeeded | Self::Failed | Self::Cancelled)
333 }
334
335 pub fn is_pending(&self) -> bool {
337 matches!(self, Self::Pending | Self::Suspended)
338 }
339}
340
341impl std::fmt::Display for BatchItemStatus {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 match self {
344 Self::Succeeded => write!(f, "Succeeded"),
345 Self::Failed => write!(f, "Failed"),
346 Self::Cancelled => write!(f, "Cancelled"),
347 Self::Pending => write!(f, "Pending"),
348 Self::Suspended => write!(f, "Suspended"),
349 }
350 }
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct BatchItem<T> {
381 pub index: usize,
383 pub status: BatchItemStatus,
385 pub result: Option<T>,
387 pub error: Option<ErrorObject>,
389}
390
391impl<T> BatchItem<T> {
392 pub fn succeeded(index: usize, result: T) -> Self {
394 Self {
395 index,
396 status: BatchItemStatus::Succeeded,
397 result: Some(result),
398 error: None,
399 }
400 }
401
402 pub fn failed(index: usize, error: ErrorObject) -> Self {
404 Self {
405 index,
406 status: BatchItemStatus::Failed,
407 result: None,
408 error: Some(error),
409 }
410 }
411
412 pub fn cancelled(index: usize) -> Self {
414 Self {
415 index,
416 status: BatchItemStatus::Cancelled,
417 result: None,
418 error: None,
419 }
420 }
421
422 pub fn pending(index: usize) -> Self {
424 Self {
425 index,
426 status: BatchItemStatus::Pending,
427 result: None,
428 error: None,
429 }
430 }
431
432 pub fn suspended(index: usize) -> Self {
434 Self {
435 index,
436 status: BatchItemStatus::Suspended,
437 result: None,
438 error: None,
439 }
440 }
441
442 pub fn is_succeeded(&self) -> bool {
444 self.status.is_success()
445 }
446
447 pub fn is_failed(&self) -> bool {
449 self.status.is_failure()
450 }
451
452 pub fn get_result(&self) -> Option<&T> {
454 self.result.as_ref()
455 }
456
457 pub fn get_error(&self) -> Option<&ErrorObject> {
459 self.error.as_ref()
460 }
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct BatchResult<T> {
488 pub items: Vec<BatchItem<T>>,
490 pub completion_reason: CompletionReason,
492}
493
494impl<T> BatchResult<T> {
495 pub fn new(items: Vec<BatchItem<T>>, completion_reason: CompletionReason) -> Self {
497 Self {
498 items,
499 completion_reason,
500 }
501 }
502
503 pub fn empty() -> Self {
505 Self {
506 items: Vec::new(),
507 completion_reason: CompletionReason::AllCompleted,
508 }
509 }
510
511 pub fn succeeded(&self) -> Vec<&BatchItem<T>> {
513 self.items
514 .iter()
515 .filter(|item| item.is_succeeded())
516 .collect()
517 }
518
519 pub fn failed(&self) -> Vec<&BatchItem<T>> {
521 self.items.iter().filter(|item| item.is_failed()).collect()
522 }
523
524 pub fn get_results(&self) -> Result<Vec<&T>, DurableError> {
529 if self.completion_reason == CompletionReason::FailureToleranceExceeded {
530 if let Some(failed_item) = self.failed().first() {
532 if let Some(ref error) = failed_item.error {
533 return Err(DurableError::UserCode {
534 message: error.error_message.clone(),
535 error_type: error.error_type.clone(),
536 stack_trace: error.stack_trace.clone(),
537 });
538 }
539 }
540 return Err(DurableError::execution("Batch operation failed"));
541 }
542
543 Ok(self
544 .items
545 .iter()
546 .filter_map(|item| item.result.as_ref())
547 .collect())
548 }
549
550 pub fn success_count(&self) -> usize {
552 self.succeeded().len()
553 }
554
555 pub fn failure_count(&self) -> usize {
557 self.failed().len()
558 }
559
560 pub fn total_count(&self) -> usize {
562 self.items.len()
563 }
564
565 pub fn all_succeeded(&self) -> bool {
567 self.items.iter().all(|item| item.is_succeeded())
568 }
569
570 pub fn has_failures(&self) -> bool {
572 self.items.iter().any(|item| item.is_failed())
573 }
574
575 pub fn is_failure(&self) -> bool {
577 self.completion_reason == CompletionReason::FailureToleranceExceeded
578 }
579
580 pub fn is_success(&self) -> bool {
582 matches!(
583 self.completion_reason,
584 CompletionReason::AllCompleted | CompletionReason::MinSuccessfulReached
585 )
586 }
587}
588
589impl<T> Default for BatchResult<T> {
590 fn default() -> Self {
591 Self::empty()
592 }
593}
594
595pub struct ConcurrentExecutor {
600 max_concurrency: Option<usize>,
602 completion_config: CompletionConfig,
604 counters: Arc<ExecutionCounters>,
606 completion_notify: Arc<Notify>,
608 semaphore: Option<Arc<Semaphore>>,
610}
611
612impl ConcurrentExecutor {
613 pub fn new(
621 total_tasks: usize,
622 max_concurrency: Option<usize>,
623 completion_config: CompletionConfig,
624 ) -> Self {
625 let semaphore = max_concurrency.map(|n| Arc::new(Semaphore::new(n)));
626
627 Self {
628 max_concurrency,
629 completion_config,
630 counters: Arc::new(ExecutionCounters::new(total_tasks)),
631 completion_notify: Arc::new(Notify::new()),
632 semaphore,
633 }
634 }
635
636 pub fn counters(&self) -> &Arc<ExecutionCounters> {
638 &self.counters
639 }
640
641 pub fn completion_notify(&self) -> &Arc<Notify> {
643 &self.completion_notify
644 }
645
646 pub fn should_complete(&self) -> Option<CompletionReason> {
648 self.counters.should_complete(&self.completion_config)
649 }
650
651 pub fn record_success(&self) -> Option<CompletionReason> {
655 self.counters.complete_task();
656 let reason = self.should_complete();
657 if reason.is_some() {
658 self.completion_notify.notify_waiters();
659 }
660 reason
661 }
662
663 pub fn record_failure(&self) -> Option<CompletionReason> {
667 self.counters.fail_task();
668 let reason = self.should_complete();
669 if reason.is_some() {
670 self.completion_notify.notify_waiters();
671 }
672 reason
673 }
674
675 pub fn record_suspend(&self) -> Option<CompletionReason> {
679 self.counters.suspend_task();
680 let reason = self.should_complete();
681 if reason.is_some() {
682 self.completion_notify.notify_waiters();
683 }
684 reason
685 }
686
687 pub async fn execute<T, F, Fut>(self, tasks: Vec<F>) -> BatchResult<T>
697 where
698 T: Send + 'static,
699 F: FnOnce(usize) -> Fut + Send + 'static,
700 Fut: std::future::Future<Output = Result<T, DurableError>> + Send + 'static,
701 {
702 let total = tasks.len();
703 if total == 0 {
704 return BatchResult::empty();
705 }
706
707 let results: Arc<Mutex<Vec<BatchItem<T>>>> =
709 Arc::new(Mutex::new((0..total).map(BatchItem::pending).collect()));
710
711 let mut handles = Vec::with_capacity(total);
713
714 for (index, task) in tasks.into_iter().enumerate() {
715 let counters = self.counters.clone();
716 let completion_notify = self.completion_notify.clone();
717 let completion_config = self.completion_config.clone();
718 let results = results.clone();
719 let semaphore = self.semaphore.clone();
720
721 let handle = tokio::spawn(async move {
722 let _permit = if let Some(ref sem) = semaphore {
724 Some(sem.acquire().await.expect("Semaphore closed"))
725 } else {
726 None
727 };
728
729 if counters.should_complete(&completion_config).is_some() {
731 let mut results_guard = results.lock().await;
732 results_guard[index] = BatchItem::cancelled(index);
733 return;
734 }
735
736 let result = task(index).await;
738
739 let mut results_guard = results.lock().await;
741 match result {
742 Ok(value) => {
743 results_guard[index] = BatchItem::succeeded(index, value);
744 counters.complete_task();
745 }
746 Err(DurableError::Suspend { .. }) => {
747 results_guard[index] = BatchItem::suspended(index);
748 counters.suspend_task();
749 }
750 Err(error) => {
751 let error_obj = ErrorObject::from(&error);
752 results_guard[index] = BatchItem::failed(index, error_obj);
753 counters.fail_task();
754 }
755 }
756 drop(results_guard);
757
758 if counters.should_complete(&completion_config).is_some() {
760 completion_notify.notify_waiters();
761 }
762 });
763
764 handles.push(handle);
765 }
766
767 for handle in handles {
769 let _ = handle.await;
770 }
771
772 let final_results = Arc::try_unwrap(results)
774 .map_err(|_| "All handles should be done")
775 .unwrap()
776 .into_inner();
777
778 let completion_reason = self
779 .counters
780 .should_complete(&self.completion_config)
781 .unwrap_or(CompletionReason::AllCompleted);
782
783 BatchResult::new(final_results, completion_reason)
784 }
785}
786
787impl std::fmt::Debug for ConcurrentExecutor {
788 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
789 f.debug_struct("ConcurrentExecutor")
790 .field("max_concurrency", &self.max_concurrency)
791 .field("completion_config", &self.completion_config)
792 .field("counters", &self.counters)
793 .finish_non_exhaustive()
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 mod execution_counters_tests {
802 use super::*;
803
804 #[test]
805 fn test_new() {
806 let counters = ExecutionCounters::new(10);
807 assert_eq!(counters.total_tasks(), 10);
808 assert_eq!(counters.success_count(), 0);
809 assert_eq!(counters.failure_count(), 0);
810 assert_eq!(counters.completed_count(), 0);
811 assert_eq!(counters.suspended_count(), 0);
812 assert_eq!(counters.pending_count(), 10);
813 }
814
815 #[test]
816 fn test_complete_task() {
817 let counters = ExecutionCounters::new(5);
818
819 assert_eq!(counters.complete_task(), 1);
820 assert_eq!(counters.success_count(), 1);
821 assert_eq!(counters.completed_count(), 1);
822 assert_eq!(counters.pending_count(), 4);
823
824 assert_eq!(counters.complete_task(), 2);
825 assert_eq!(counters.success_count(), 2);
826 assert_eq!(counters.completed_count(), 2);
827 }
828
829 #[test]
830 fn test_fail_task() {
831 let counters = ExecutionCounters::new(5);
832
833 assert_eq!(counters.fail_task(), 1);
834 assert_eq!(counters.failure_count(), 1);
835 assert_eq!(counters.completed_count(), 1);
836 assert_eq!(counters.pending_count(), 4);
837 }
838
839 #[test]
840 fn test_suspend_task() {
841 let counters = ExecutionCounters::new(5);
842
843 assert_eq!(counters.suspend_task(), 1);
844 assert_eq!(counters.suspended_count(), 1);
845 assert_eq!(counters.completed_count(), 0);
846 assert_eq!(counters.pending_count(), 4);
847 }
848
849 #[test]
850 fn test_is_min_successful_reached() {
851 let counters = ExecutionCounters::new(10);
852
853 assert!(!counters.is_min_successful_reached(3));
854
855 counters.complete_task();
856 counters.complete_task();
857 assert!(!counters.is_min_successful_reached(3));
858
859 counters.complete_task();
860 assert!(counters.is_min_successful_reached(3));
861 }
862
863 #[test]
864 fn test_is_failure_tolerance_exceeded_count() {
865 let counters = ExecutionCounters::new(10);
866 let config = CompletionConfig {
867 tolerated_failure_count: Some(2),
868 ..Default::default()
869 };
870
871 counters.fail_task();
872 counters.fail_task();
873 assert!(!counters.is_failure_tolerance_exceeded(&config));
874
875 counters.fail_task();
876 assert!(counters.is_failure_tolerance_exceeded(&config));
877 }
878
879 #[test]
880 fn test_is_failure_tolerance_exceeded_percentage() {
881 let counters = ExecutionCounters::new(10);
882 let config = CompletionConfig {
883 tolerated_failure_percentage: Some(0.2),
884 ..Default::default()
885 };
886
887 counters.fail_task();
888 counters.fail_task();
889 assert!(!counters.is_failure_tolerance_exceeded(&config));
890
891 counters.fail_task();
892 assert!(counters.is_failure_tolerance_exceeded(&config));
893 }
894
895 #[test]
896 fn test_should_complete_min_successful() {
897 let counters = ExecutionCounters::new(10);
898 let config = CompletionConfig::with_min_successful(3);
899
900 assert!(counters.should_complete(&config).is_none());
901
902 counters.complete_task();
903 counters.complete_task();
904 assert!(counters.should_complete(&config).is_none());
905
906 counters.complete_task();
907 assert_eq!(
908 counters.should_complete(&config),
909 Some(CompletionReason::MinSuccessfulReached)
910 );
911 }
912
913 #[test]
914 fn test_should_complete_failure_tolerance() {
915 let counters = ExecutionCounters::new(10);
916 let config = CompletionConfig::all_successful();
917
918 assert!(counters.should_complete(&config).is_none());
919
920 counters.fail_task();
921 assert_eq!(
922 counters.should_complete(&config),
923 Some(CompletionReason::FailureToleranceExceeded)
924 );
925 }
926
927 #[test]
928 fn test_should_complete_all_completed() {
929 let counters = ExecutionCounters::new(3);
930 let config = CompletionConfig::all_completed();
931
932 counters.complete_task();
933 counters.complete_task();
934 assert!(counters.should_complete(&config).is_none());
935
936 counters.complete_task();
937 assert_eq!(
938 counters.should_complete(&config),
939 Some(CompletionReason::AllCompleted)
940 );
941 }
942
943 #[test]
944 fn test_should_complete_suspended() {
945 let counters = ExecutionCounters::new(3);
946 let config = CompletionConfig::all_completed();
947
948 counters.complete_task();
949 counters.complete_task();
950 counters.suspend_task();
951
952 assert_eq!(
953 counters.should_complete(&config),
954 Some(CompletionReason::Suspended)
955 );
956 }
957
958 #[test]
959 fn test_all_completed() {
960 let counters = ExecutionCounters::new(3);
961
962 assert!(!counters.all_completed());
963
964 counters.complete_task();
965 counters.complete_task();
966 counters.complete_task();
967
968 assert!(counters.all_completed());
969 }
970
971 #[test]
972 fn test_has_pending() {
973 let counters = ExecutionCounters::new(2);
974
975 assert!(counters.has_pending());
976
977 counters.complete_task();
978 assert!(counters.has_pending());
979
980 counters.complete_task();
981 assert!(!counters.has_pending());
982 }
983
984 #[test]
989 fn test_concurrent_counter_updates() {
990 use std::sync::Arc;
991 use std::thread;
992
993 let counters = Arc::new(ExecutionCounters::new(1000));
994 let mut handles = vec![];
995
996 for _ in 0..10 {
998 let counters_clone = counters.clone();
999 handles.push(thread::spawn(move || {
1000 for _ in 0..50 {
1001 counters_clone.complete_task();
1002 }
1003 }));
1004 }
1005
1006 for _ in 0..5 {
1008 let counters_clone = counters.clone();
1009 handles.push(thread::spawn(move || {
1010 for _ in 0..50 {
1011 counters_clone.fail_task();
1012 }
1013 }));
1014 }
1015
1016 for _ in 0..5 {
1018 let counters_clone = counters.clone();
1019 handles.push(thread::spawn(move || {
1020 for _ in 0..50 {
1021 counters_clone.suspend_task();
1022 }
1023 }));
1024 }
1025
1026 for handle in handles {
1028 handle.join().unwrap();
1029 }
1030
1031 assert_eq!(counters.success_count(), 500);
1036 assert_eq!(counters.failure_count(), 250);
1037 assert_eq!(counters.suspended_count(), 250);
1038 assert_eq!(counters.completed_count(), 750);
1040 }
1041
1042 #[test]
1047 fn test_concurrent_read_write_stress() {
1048 use std::sync::Arc;
1049 use std::thread;
1050
1051 let counters = Arc::new(ExecutionCounters::new(10000));
1052 let mut handles = vec![];
1053
1054 for _ in 0..5 {
1056 let counters_clone = counters.clone();
1057 handles.push(thread::spawn(move || {
1058 for _ in 0..200 {
1059 counters_clone.complete_task();
1060 }
1061 }));
1062 }
1063
1064 for _ in 0..5 {
1066 let counters_clone = counters.clone();
1067 handles.push(thread::spawn(move || {
1068 let mut last_success = 0;
1069 for _ in 0..1000 {
1070 let current_success = counters_clone.success_count();
1071 assert!(
1073 current_success >= last_success,
1074 "Success count decreased from {} to {}",
1075 last_success,
1076 current_success
1077 );
1078 last_success = current_success;
1079
1080 let completed = counters_clone.completed_count();
1082 assert!(
1083 completed >= current_success,
1084 "Completed {} should be >= success {}",
1085 completed,
1086 current_success
1087 );
1088 }
1089 }));
1090 }
1091
1092 for handle in handles {
1094 handle.join().unwrap();
1095 }
1096
1097 assert_eq!(counters.success_count(), 1000); assert_eq!(counters.completed_count(), 1000);
1100 }
1101 }
1102
1103 mod batch_item_tests {
1104 use super::*;
1105
1106 #[test]
1107 fn test_succeeded() {
1108 let item = BatchItem::succeeded(0, 42);
1109 assert_eq!(item.index, 0);
1110 assert!(item.is_succeeded());
1111 assert!(!item.is_failed());
1112 assert_eq!(item.get_result(), Some(&42));
1113 assert!(item.get_error().is_none());
1114 }
1115
1116 #[test]
1117 fn test_failed() {
1118 let error = ErrorObject::new("TestError", "test message");
1119 let item: BatchItem<i32> = BatchItem::failed(1, error);
1120 assert_eq!(item.index, 1);
1121 assert!(!item.is_succeeded());
1122 assert!(item.is_failed());
1123 assert!(item.get_result().is_none());
1124 assert!(item.get_error().is_some());
1125 }
1126
1127 #[test]
1128 fn test_cancelled() {
1129 let item: BatchItem<i32> = BatchItem::cancelled(2);
1130 assert_eq!(item.index, 2);
1131 assert_eq!(item.status, BatchItemStatus::Cancelled);
1132 }
1133
1134 #[test]
1135 fn test_pending() {
1136 let item: BatchItem<i32> = BatchItem::pending(3);
1137 assert_eq!(item.index, 3);
1138 assert_eq!(item.status, BatchItemStatus::Pending);
1139 }
1140
1141 #[test]
1142 fn test_suspended() {
1143 let item: BatchItem<i32> = BatchItem::suspended(4);
1144 assert_eq!(item.index, 4);
1145 assert_eq!(item.status, BatchItemStatus::Suspended);
1146 }
1147 }
1148
1149 mod batch_result_tests {
1150 use super::*;
1151
1152 #[test]
1153 fn test_empty() {
1154 let result: BatchResult<i32> = BatchResult::empty();
1155 assert!(result.items.is_empty());
1156 assert_eq!(result.completion_reason, CompletionReason::AllCompleted);
1157 }
1158
1159 #[test]
1160 fn test_succeeded() {
1161 let items = vec![
1162 BatchItem::succeeded(0, 1),
1163 BatchItem::succeeded(1, 2),
1164 BatchItem::failed(2, ErrorObject::new("Error", "msg")),
1165 ];
1166 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1167
1168 let succeeded = result.succeeded();
1169 assert_eq!(succeeded.len(), 2);
1170 }
1171
1172 #[test]
1173 fn test_failed() {
1174 let items = vec![
1175 BatchItem::succeeded(0, 1),
1176 BatchItem::failed(1, ErrorObject::new("Error", "msg")),
1177 ];
1178 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1179
1180 let failed = result.failed();
1181 assert_eq!(failed.len(), 1);
1182 }
1183
1184 #[test]
1185 fn test_get_results_success() {
1186 let items = vec![BatchItem::succeeded(0, 1), BatchItem::succeeded(1, 2)];
1187 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1188
1189 let results = result.get_results().unwrap();
1190 assert_eq!(results, vec![&1, &2]);
1191 }
1192
1193 #[test]
1194 fn test_get_results_failure_tolerance_exceeded() {
1195 let items = vec![
1196 BatchItem::succeeded(0, 1),
1197 BatchItem::failed(1, ErrorObject::new("TestError", "test")),
1198 ];
1199 let result = BatchResult::new(items, CompletionReason::FailureToleranceExceeded);
1200
1201 assert!(result.get_results().is_err());
1202 }
1203
1204 #[test]
1205 fn test_counts() {
1206 let items = vec![
1207 BatchItem::succeeded(0, 1),
1208 BatchItem::succeeded(1, 2),
1209 BatchItem::failed(2, ErrorObject::new("Error", "msg")),
1210 ];
1211 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1212
1213 assert_eq!(result.success_count(), 2);
1214 assert_eq!(result.failure_count(), 1);
1215 assert_eq!(result.total_count(), 3);
1216 }
1217
1218 #[test]
1219 fn test_all_succeeded() {
1220 let items = vec![BatchItem::succeeded(0, 1), BatchItem::succeeded(1, 2)];
1221 let result = BatchResult::new(items, CompletionReason::AllCompleted);
1222 assert!(result.all_succeeded());
1223
1224 let items_with_failure = vec![
1225 BatchItem::succeeded(0, 1),
1226 BatchItem::failed(1, ErrorObject::new("Error", "msg")),
1227 ];
1228 let result_with_failure =
1229 BatchResult::new(items_with_failure, CompletionReason::AllCompleted);
1230 assert!(!result_with_failure.all_succeeded());
1231 }
1232
1233 #[test]
1234 fn test_is_success() {
1235 let result: BatchResult<i32> = BatchResult::new(vec![], CompletionReason::AllCompleted);
1236 assert!(result.is_success());
1237
1238 let result2: BatchResult<i32> =
1239 BatchResult::new(vec![], CompletionReason::MinSuccessfulReached);
1240 assert!(result2.is_success());
1241
1242 let result3: BatchResult<i32> =
1243 BatchResult::new(vec![], CompletionReason::FailureToleranceExceeded);
1244 assert!(!result3.is_success());
1245 }
1246 }
1247
1248 mod concurrent_executor_tests {
1249 use super::*;
1250
1251 #[tokio::test]
1252 async fn test_execute_empty() {
1253 let executor = ConcurrentExecutor::new(0, None, CompletionConfig::all_completed());
1254 let tasks: Vec<
1255 Box<
1256 dyn FnOnce(
1257 usize,
1258 ) -> std::pin::Pin<
1259 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1260 > + Send,
1261 >,
1262 > = vec![];
1263 let result = executor.execute(tasks).await;
1264
1265 assert!(result.items.is_empty());
1266 assert_eq!(result.completion_reason, CompletionReason::AllCompleted);
1267 }
1268
1269 #[tokio::test]
1270 async fn test_execute_all_success() {
1271 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_completed());
1272 let tasks: Vec<_> = (0..3)
1273 .map(|i| move |_idx: usize| async move { Ok(i * 10) })
1274 .collect();
1275
1276 let result = executor.execute(tasks).await;
1277
1278 assert_eq!(result.total_count(), 3);
1279 assert_eq!(result.success_count(), 3);
1280 assert!(result.all_succeeded());
1281 }
1282
1283 #[tokio::test]
1284 async fn test_execute_with_failures() {
1285 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_completed());
1286
1287 let tasks: Vec<
1289 Box<
1290 dyn FnOnce(
1291 usize,
1292 ) -> std::pin::Pin<
1293 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1294 > + Send,
1295 >,
1296 > = vec![
1297 Box::new(|_idx: usize| {
1298 Box::pin(async { Ok(1) })
1299 as std::pin::Pin<
1300 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1301 >
1302 }),
1303 Box::new(|_idx: usize| {
1304 Box::pin(async { Err(DurableError::execution("test error")) })
1305 as std::pin::Pin<
1306 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1307 >
1308 }),
1309 Box::new(|_idx: usize| {
1310 Box::pin(async { Ok(3) })
1311 as std::pin::Pin<
1312 Box<dyn std::future::Future<Output = Result<i32, DurableError>> + Send>,
1313 >
1314 }),
1315 ];
1316
1317 let result = executor.execute(tasks).await;
1318
1319 assert_eq!(result.total_count(), 3);
1320 assert_eq!(result.success_count(), 2);
1321 assert_eq!(result.failure_count(), 1);
1322 }
1323
1324 #[tokio::test]
1325 async fn test_execute_min_successful() {
1326 let executor =
1327 ConcurrentExecutor::new(5, None, CompletionConfig::with_min_successful(2));
1328 let tasks: Vec<_> = (0..5)
1329 .map(|i| move |_idx: usize| async move { Ok(i) })
1330 .collect();
1331
1332 let result = executor.execute(tasks).await;
1333
1334 assert!(result.success_count() >= 2);
1336 }
1337
1338 #[tokio::test]
1339 async fn test_execute_with_concurrency_limit() {
1340 let executor = ConcurrentExecutor::new(5, Some(2), CompletionConfig::all_completed());
1341 let tasks: Vec<_> = (0..5)
1342 .map(|i| move |_idx: usize| async move { Ok(i) })
1343 .collect();
1344
1345 let result = executor.execute(tasks).await;
1346
1347 assert_eq!(result.total_count(), 5);
1348 assert!(result.all_succeeded());
1349 }
1350
1351 #[tokio::test]
1352 async fn test_record_success() {
1353 let executor =
1354 ConcurrentExecutor::new(3, None, CompletionConfig::with_min_successful(2));
1355
1356 assert!(executor.record_success().is_none());
1357 assert_eq!(
1358 executor.record_success(),
1359 Some(CompletionReason::MinSuccessfulReached)
1360 );
1361 }
1362
1363 #[tokio::test]
1364 async fn test_record_failure() {
1365 let executor = ConcurrentExecutor::new(3, None, CompletionConfig::all_successful());
1366
1367 assert_eq!(
1368 executor.record_failure(),
1369 Some(CompletionReason::FailureToleranceExceeded)
1370 );
1371 }
1372 }
1373}
1374
1375#[cfg(test)]
1376mod property_tests {
1377 use super::*;
1378 use proptest::prelude::*;
1379
1380 mod completion_criteria_tests {
1387 use super::*;
1388
1389 proptest! {
1390 #![proptest_config(ProptestConfig::with_cases(100))]
1391
1392 #[test]
1396 fn prop_min_successful_triggers_completion(
1397 total_tasks in 1usize..=50,
1398 min_successful_ratio in 0.1f64..=1.0,
1399 ) {
1400 let min_successful = ((total_tasks as f64 * min_successful_ratio).ceil() as usize).max(1).min(total_tasks);
1401 let config = CompletionConfig::with_min_successful(min_successful);
1402 let counters = ExecutionCounters::new(total_tasks);
1403
1404 for i in 0..min_successful {
1406 if i < min_successful - 1 {
1407 counters.complete_task();
1408 prop_assert!(
1410 counters.should_complete(&config).is_none() ||
1411 counters.should_complete(&config) == Some(CompletionReason::MinSuccessfulReached),
1412 "Should not complete before reaching min_successful"
1413 );
1414 } else {
1415 counters.complete_task();
1416 prop_assert_eq!(
1418 counters.should_complete(&config),
1419 Some(CompletionReason::MinSuccessfulReached),
1420 "Should complete when min_successful is reached"
1421 );
1422 }
1423 }
1424 }
1425
1426 #[test]
1430 fn prop_failure_tolerance_exceeded_triggers_completion(
1431 total_tasks in 2usize..=50,
1432 tolerated_failures in 0usize..=10,
1433 ) {
1434 let config = CompletionConfig::with_failure_tolerance(tolerated_failures);
1435 let counters = ExecutionCounters::new(total_tasks);
1436
1437 for i in 0..=tolerated_failures {
1439 counters.fail_task();
1440 if i < tolerated_failures {
1441 let result = counters.should_complete(&config);
1443 prop_assert!(
1444 result.is_none() || result == Some(CompletionReason::AllCompleted),
1445 "Should not trigger failure tolerance until exceeded"
1446 );
1447 }
1448 }
1449
1450 prop_assert_eq!(
1452 counters.should_complete(&config),
1453 Some(CompletionReason::FailureToleranceExceeded),
1454 "Should complete when failure tolerance is exceeded"
1455 );
1456 }
1457
1458 #[test]
1462 fn prop_all_completed_triggers_when_all_done(
1463 total_tasks in 1usize..=50,
1464 success_count in 0usize..=50,
1465 ) {
1466 let success_count = success_count.min(total_tasks);
1467 let failure_count = total_tasks - success_count;
1468 let config = CompletionConfig::all_completed();
1469 let counters = ExecutionCounters::new(total_tasks);
1470
1471 for _ in 0..success_count {
1473 counters.complete_task();
1474 }
1475
1476 for _ in 0..failure_count {
1478 counters.fail_task();
1479 }
1480
1481 prop_assert_eq!(
1483 counters.should_complete(&config),
1484 Some(CompletionReason::AllCompleted),
1485 "Should complete when all tasks are done"
1486 );
1487 }
1488
1489 #[test]
1492 fn prop_suspended_triggers_when_tasks_suspend(
1493 total_tasks in 2usize..=50,
1494 completed_count in 1usize..=49,
1495 ) {
1496 let completed_count = completed_count.min(total_tasks - 1);
1497 let suspended_count = total_tasks - completed_count;
1498 let config = CompletionConfig::all_completed();
1499 let counters = ExecutionCounters::new(total_tasks);
1500
1501 for _ in 0..completed_count {
1503 counters.complete_task();
1504 }
1505
1506 for _ in 0..suspended_count {
1508 counters.suspend_task();
1509 }
1510
1511 prop_assert_eq!(
1513 counters.should_complete(&config),
1514 Some(CompletionReason::Suspended),
1515 "Should return Suspended when tasks are suspended"
1516 );
1517 }
1518
1519 #[test]
1522 fn prop_success_count_accurate(
1523 total_tasks in 1usize..=100,
1524 successes in 0usize..=100,
1525 ) {
1526 let successes = successes.min(total_tasks);
1527 let counters = ExecutionCounters::new(total_tasks);
1528
1529 for _ in 0..successes {
1530 counters.complete_task();
1531 }
1532
1533 prop_assert_eq!(
1534 counters.success_count(),
1535 successes,
1536 "Success count should match number of complete_task calls"
1537 );
1538 }
1539
1540 #[test]
1543 fn prop_failure_count_accurate(
1544 total_tasks in 1usize..=100,
1545 failures in 0usize..=100,
1546 ) {
1547 let failures = failures.min(total_tasks);
1548 let counters = ExecutionCounters::new(total_tasks);
1549
1550 for _ in 0..failures {
1551 counters.fail_task();
1552 }
1553
1554 prop_assert_eq!(
1555 counters.failure_count(),
1556 failures,
1557 "Failure count should match number of fail_task calls"
1558 );
1559 }
1560
1561 #[test]
1564 fn prop_completed_count_is_sum(
1565 total_tasks in 2usize..=100,
1566 successes in 0usize..=50,
1567 failures in 0usize..=50,
1568 ) {
1569 let successes = successes.min(total_tasks / 2);
1570 let failures = failures.min(total_tasks - successes);
1571 let counters = ExecutionCounters::new(total_tasks);
1572
1573 for _ in 0..successes {
1574 counters.complete_task();
1575 }
1576 for _ in 0..failures {
1577 counters.fail_task();
1578 }
1579
1580 prop_assert_eq!(
1581 counters.completed_count(),
1582 successes + failures,
1583 "Completed count should equal success + failure"
1584 );
1585 }
1586
1587 #[test]
1590 fn prop_pending_count_accurate(
1591 total_tasks in 3usize..=100,
1592 successes in 0usize..=33,
1593 failures in 0usize..=33,
1594 suspends in 0usize..=33,
1595 ) {
1596 let successes = successes.min(total_tasks / 3);
1597 let failures = failures.min((total_tasks - successes) / 2);
1598 let suspends = suspends.min(total_tasks - successes - failures);
1599 let counters = ExecutionCounters::new(total_tasks);
1600
1601 for _ in 0..successes {
1602 counters.complete_task();
1603 }
1604 for _ in 0..failures {
1605 counters.fail_task();
1606 }
1607 for _ in 0..suspends {
1608 counters.suspend_task();
1609 }
1610
1611 let expected_pending = total_tasks - successes - failures - suspends;
1612 prop_assert_eq!(
1613 counters.pending_count(),
1614 expected_pending,
1615 "Pending count should be total - completed - suspended"
1616 );
1617 }
1618
1619 #[test]
1622 fn prop_failure_percentage_calculation(
1623 total_tasks in 1usize..=100,
1624 failures in 0usize..=100,
1625 tolerance_percentage in 0.0f64..=1.0,
1626 ) {
1627 let failures = failures.min(total_tasks);
1628 let config = CompletionConfig {
1629 tolerated_failure_percentage: Some(tolerance_percentage),
1630 ..Default::default()
1631 };
1632 let counters = ExecutionCounters::new(total_tasks);
1633
1634 for _ in 0..failures {
1635 counters.fail_task();
1636 }
1637
1638 let actual_percentage = failures as f64 / total_tasks as f64;
1639 let exceeded = counters.is_failure_tolerance_exceeded(&config);
1640
1641 if actual_percentage > tolerance_percentage {
1642 prop_assert!(exceeded, "Should exceed tolerance when percentage is higher");
1643 } else {
1644 prop_assert!(!exceeded, "Should not exceed tolerance when percentage is lower or equal");
1645 }
1646 }
1647 }
1648 }
1649}