use std::{
panic::{
AssertUnwindSafe,
catch_unwind,
},
sync::Arc,
time::Duration,
};
use qubit_function::Runnable;
use qubit_progress::{
Progress,
reporter::{
NoOpProgressReporter,
ProgressReporter,
},
};
use crate::{
BatchExecutionError,
BatchOutcome,
execute::{
BatchExecutionState,
BatchExecutor,
panic_payload_to_error,
},
};
#[derive(Clone)]
pub struct SequentialBatchExecutor {
report_interval: Duration,
reporter: Arc<dyn ProgressReporter>,
}
impl SequentialBatchExecutor {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn with_reporter<R>(self, reporter: R) -> Self
where
R: ProgressReporter + 'static,
{
self.with_reporter_arc(Arc::new(reporter))
}
#[inline]
pub fn with_reporter_arc(self, reporter: Arc<dyn ProgressReporter>) -> Self {
Self { reporter, ..self }
}
#[inline]
pub fn with_report_interval(self, report_interval: Duration) -> Self {
Self {
report_interval,
..self
}
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
}
impl Default for SequentialBatchExecutor {
#[inline]
fn default() -> Self {
Self {
report_interval: Self::DEFAULT_REPORT_INTERVAL,
reporter: Arc::new(NoOpProgressReporter),
}
}
}
impl BatchExecutor for SequentialBatchExecutor {
fn execute<T, E, I>(
&self,
tasks: I,
count: usize,
) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
where
I: IntoIterator<Item = T>,
T: Runnable<E> + Send,
E: Send,
{
let state = BatchExecutionState::new(count);
let mut progress = Progress::new(self.reporter.as_ref(), self.report_interval);
progress.report_started(state.progress_counters());
let mut actual_count = 0;
for task in tasks {
actual_count = state.record_task_observed();
if actual_count > count {
let failed = progress.report_failed(state.progress_counters());
let outcome = state.into_outcome(failed.elapsed());
return Err(BatchExecutionError::CountExceeded {
expected: count,
observed_at_least: actual_count,
outcome,
});
}
let mut task = task;
state.record_task_started();
match catch_unwind(AssertUnwindSafe(|| task.run())) {
Ok(Ok(())) => state.record_task_succeeded(),
Ok(Err(error)) => state.record_task_failed(actual_count - 1, error),
Err(payload) => state.record_task_panicked(
actual_count - 1,
panic_payload_to_error(payload.as_ref()),
),
}
let _ = progress.report_running_if_due(state.progress_counters());
}
if actual_count < count {
let failed = progress.report_failed(state.progress_counters());
Err(BatchExecutionError::CountShortfall {
expected: count,
actual: actual_count,
outcome: state.into_outcome(failed.elapsed()),
})
} else {
let finished = progress.report_finished(state.progress_counters());
Ok(state.into_outcome(finished.elapsed()))
}
}
}