use std::{
collections::HashSet,
time::Duration,
};
use crate::{
BatchOutcomeBuildError,
BatchTaskError,
BatchTaskFailure,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BatchOutcomeBuilder<E> {
pub(crate) task_count: usize,
pub(crate) completed_count: usize,
pub(crate) succeeded_count: usize,
pub(crate) failed_count: usize,
pub(crate) panicked_count: usize,
pub(crate) elapsed: Duration,
pub(crate) failures: Vec<BatchTaskFailure<E>>,
}
impl<E> BatchOutcomeBuilder<E> {
#[inline]
pub fn builder(task_count: usize) -> Self {
Self {
task_count,
completed_count: 0,
succeeded_count: 0,
failed_count: 0,
panicked_count: 0,
elapsed: Duration::ZERO,
failures: Vec::new(),
}
}
#[inline]
pub const fn completed_count(mut self, completed_count: usize) -> Self {
self.completed_count = completed_count;
self
}
#[inline]
pub const fn succeeded_count(mut self, succeeded_count: usize) -> Self {
self.succeeded_count = succeeded_count;
self
}
#[inline]
pub const fn failed_count(mut self, failed_count: usize) -> Self {
self.failed_count = failed_count;
self
}
#[inline]
pub const fn panicked_count(mut self, panicked_count: usize) -> Self {
self.panicked_count = panicked_count;
self
}
#[inline]
pub const fn elapsed(mut self, elapsed: Duration) -> Self {
self.elapsed = elapsed;
self
}
#[inline]
pub fn failures(mut self, failures: Vec<BatchTaskFailure<E>>) -> Self {
self.failures = failures;
self
}
#[inline]
pub fn validate(mut self) -> Result<Self, BatchOutcomeBuildError> {
validate_outcome_invariants(
self.task_count,
self.completed_count,
self.succeeded_count,
self.failed_count,
self.panicked_count,
&self.failures,
)?;
self.failures.sort_by_key(|failure| failure.index());
Ok(self)
}
#[inline]
pub fn build(self) -> Result<crate::BatchOutcome<E>, BatchOutcomeBuildError> {
self.validate().map(crate::BatchOutcome::new)
}
}
fn validate_outcome_invariants<E>(
task_count: usize,
completed_count: usize,
succeeded_count: usize,
failed_count: usize,
panicked_count: usize,
failures: &[BatchTaskFailure<E>],
) -> Result<(), BatchOutcomeBuildError> {
let failure_count = failed_count.checked_add(panicked_count).ok_or(
BatchOutcomeBuildError::FailureCountOverflow {
failed_count,
panicked_count,
},
)?;
let terminal_count = succeeded_count.checked_add(failure_count).ok_or(
BatchOutcomeBuildError::TerminalCountOverflow {
succeeded_count,
failure_count,
},
)?;
if completed_count > task_count {
return Err(BatchOutcomeBuildError::CompletedCountExceeded {
task_count,
completed_count,
});
}
if terminal_count != completed_count {
return Err(BatchOutcomeBuildError::TerminalCountMismatch {
completed_count,
terminal_count,
succeeded_count,
failed_count,
panicked_count,
});
}
if failures.len() != failure_count {
return Err(BatchOutcomeBuildError::FailureDetailCountMismatch {
expected: failure_count,
actual: failures.len(),
});
}
validate_failure_details(task_count, failed_count, panicked_count, failures)
}
fn validate_failure_details<E>(
task_count: usize,
failed_count: usize,
panicked_count: usize,
failures: &[BatchTaskFailure<E>],
) -> Result<(), BatchOutcomeBuildError> {
let mut observed_failed_count = 0usize;
let mut observed_panicked_count = 0usize;
let mut observed_indexes = HashSet::with_capacity(failures.len());
for failure in failures {
if failure.index() >= task_count {
return Err(BatchOutcomeBuildError::FailureIndexOutOfRange {
index: failure.index(),
task_count,
});
}
if !observed_indexes.insert(failure.index()) {
return Err(BatchOutcomeBuildError::DuplicateFailureIndex {
index: failure.index(),
});
}
match failure.error() {
BatchTaskError::Failed(_) => observed_failed_count += 1,
BatchTaskError::Panicked { .. } => observed_panicked_count += 1,
}
}
if observed_failed_count != failed_count || observed_panicked_count != panicked_count {
return Err(BatchOutcomeBuildError::FailureVariantCountMismatch {
expected_failed: failed_count,
actual_failed: observed_failed_count,
expected_panicked: panicked_count,
actual_panicked: observed_panicked_count,
});
}
Ok(())
}