use std::{
panic::{
AssertUnwindSafe,
catch_unwind,
},
sync::Arc,
time::{
Duration,
Instant,
},
};
use qubit_function::Runnable;
use crate::{
BatchExecutionError,
BatchExecutionResult,
BatchTaskError,
BatchTaskFailure,
error::panic_payload_to_error,
progress::{
NoOpProgressReporter,
ProgressReporter,
},
};
use super::BatchExecutor;
#[derive(Clone)]
pub struct SequentialBatchExecutor {
report_interval: Duration,
reporter: Arc<dyn ProgressReporter>,
}
impl SequentialBatchExecutor {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn with_reporter<R>(self, reporter: R) -> Self
where
R: ProgressReporter + 'static,
{
self.with_reporter_arc(Arc::new(reporter))
}
#[inline]
pub fn with_reporter_arc(self, reporter: Arc<dyn ProgressReporter>) -> Self {
Self { reporter, ..self }
}
#[inline]
pub fn with_report_interval(self, report_interval: Duration) -> Self {
Self {
report_interval,
..self
}
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
}
impl Default for SequentialBatchExecutor {
fn default() -> Self {
Self {
report_interval: Self::DEFAULT_REPORT_INTERVAL,
reporter: Arc::new(NoOpProgressReporter),
}
}
}
impl BatchExecutor for SequentialBatchExecutor {
fn execute<T, E, I>(
&self,
tasks: I,
count: usize,
) -> Result<BatchExecutionResult<E>, BatchExecutionError<E>>
where
I: IntoIterator<Item = T>,
T: Runnable<E> + Send,
E: Send,
{
self.reporter.start(count);
let start = Instant::now();
let mut next_progress = start + self.report_interval;
let mut completed_count = 0;
let mut succeeded_count = 0;
let mut failed_count = 0;
let mut panicked_count = 0;
let mut failures = Vec::new();
let mut actual_count = 0;
for task in tasks {
if actual_count == count {
let result = build_result(
count,
completed_count,
succeeded_count,
failed_count,
panicked_count,
start.elapsed(),
failures,
);
self.reporter.finish(count, result.elapsed());
return Err(BatchExecutionError::CountExceeded {
expected: count,
observed_at_least: count + 1,
result,
});
}
execute_one_task(
task,
actual_count,
&mut completed_count,
&mut succeeded_count,
&mut failed_count,
&mut panicked_count,
&mut failures,
);
actual_count += 1;
maybe_report_progress(
self.reporter.as_ref(),
count,
completed_count,
start,
self.report_interval,
&mut next_progress,
);
}
let result = build_result(
count,
completed_count,
succeeded_count,
failed_count,
panicked_count,
start.elapsed(),
failures,
);
self.reporter.finish(count, result.elapsed());
if actual_count < count {
Err(BatchExecutionError::CountShortfall {
expected: count,
actual: actual_count,
result,
})
} else {
Ok(result)
}
}
}
fn execute_one_task<T, E>(
mut task: T,
index: usize,
completed_count: &mut usize,
succeeded_count: &mut usize,
failed_count: &mut usize,
panicked_count: &mut usize,
failures: &mut Vec<BatchTaskFailure<E>>,
) where
T: Runnable<E>,
{
match catch_unwind(AssertUnwindSafe(|| task.run())) {
Ok(Ok(())) => {
*completed_count += 1;
*succeeded_count += 1;
}
Ok(Err(error)) => {
*completed_count += 1;
*failed_count += 1;
failures.push(BatchTaskFailure::new(index, BatchTaskError::Failed(error)));
}
Err(payload) => {
*completed_count += 1;
*panicked_count += 1;
failures.push(BatchTaskFailure::new(
index,
panic_payload_to_error(payload.as_ref()),
));
}
}
}
fn maybe_report_progress(
reporter: &dyn ProgressReporter,
total_count: usize,
completed_count: usize,
start: Instant,
report_interval: Duration,
next_progress: &mut Instant,
) {
let now = Instant::now();
if now < *next_progress {
return;
}
reporter.process(total_count, 0, completed_count, now.duration_since(start));
*next_progress = now + report_interval;
}
fn build_result<E>(
task_count: usize,
completed_count: usize,
succeeded_count: usize,
failed_count: usize,
panicked_count: usize,
elapsed: Duration,
failures: Vec<BatchTaskFailure<E>>,
) -> BatchExecutionResult<E> {
BatchExecutionResult::from_validated_parts(
task_count,
completed_count,
succeeded_count,
failed_count,
panicked_count,
elapsed,
failures,
)
}