use std::panic::resume_unwind;
use std::sync::Arc;
use std::sync::mpsc::{
self,
RecvTimeoutError,
};
use std::thread;
use std::time::{
Duration,
Instant,
};
use qubit_function::Runnable;
use qubit_progress::{
Progress,
model::{
ProgressCounters,
ProgressPhase,
},
reporter::ProgressReporter,
};
use crate::BatchExecutionError;
use crate::BatchOutcome;
use crate::execute::{
BatchExecutionState,
BatchExecutor,
SequentialBatchExecutor,
};
use crate::utils::run_scoped_parallel;
use super::ParallelBatchExecutorBuildError;
use super::ParallelBatchExecutorBuilder;
use super::indexed_task::run_parallel_task;
enum ProgressLoopSignal {
RunningPoint,
Stop,
}
#[derive(Clone)]
pub struct ParallelBatchExecutor {
pub(crate) thread_count: usize,
pub(crate) sequential_threshold: usize,
pub(crate) report_interval: Duration,
pub(crate) reporter: Arc<dyn ProgressReporter>,
}
impl ParallelBatchExecutor {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
pub const DEFAULT_SEQUENTIAL_THRESHOLD: usize = 100;
#[inline]
pub fn default_thread_count() -> usize {
thread::available_parallelism()
.map(usize::from)
.unwrap_or(1)
}
#[inline]
pub fn builder() -> ParallelBatchExecutorBuilder {
ParallelBatchExecutorBuilder::default()
}
#[inline]
pub fn new(thread_count: usize) -> Result<Self, ParallelBatchExecutorBuildError> {
Self::builder().thread_count(thread_count).build()
}
#[inline]
pub const fn thread_count(&self) -> usize {
self.thread_count
}
#[inline]
pub const fn sequential_threshold(&self) -> usize {
self.sequential_threshold
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
fn sequential_executor(&self) -> SequentialBatchExecutor {
SequentialBatchExecutor::new()
.with_report_interval(self.report_interval)
.with_reporter_arc(Arc::clone(&self.reporter))
}
}
impl Default for ParallelBatchExecutor {
fn default() -> Self {
Self::builder()
.build()
.expect("default parallel batch executor should build")
}
}
impl BatchExecutor for ParallelBatchExecutor {
fn execute<T, E, I>(
&self,
tasks: I,
count: usize,
) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
where
I: IntoIterator<Item = T>,
T: Runnable<E> + Send,
E: Send,
{
if count <= self.sequential_threshold || self.thread_count <= 1 {
return self.sequential_executor().execute(tasks, count);
}
let state = Arc::new(BatchExecutionState::new(count));
let progress = Progress::new(self.reporter.as_ref(), self.report_interval);
progress.report_with_elapsed(
ProgressPhase::Started,
state.progress_counters(),
Duration::ZERO,
);
let start = progress.started_at();
let mut actual_count = 0usize;
let worker_count = self.thread_count.min(count);
thread::scope(|scope| {
let (progress_sender, progress_receiver) = mpsc::channel();
let progress_handle = {
let progress_reporter = Arc::clone(&self.reporter);
let reporter_state = Arc::clone(&state);
let report_interval = self.report_interval;
scope.spawn(move || {
run_progress_loop(
progress_reporter,
reporter_state,
start,
report_interval,
progress_receiver,
);
})
};
let observer_state = Arc::clone(&state);
let worker_state = Arc::clone(&state);
let worker_progress_sender = progress_sender.clone();
let report_on_worker_completion = self.report_interval.is_zero();
actual_count = run_scoped_parallel(
tasks,
count,
worker_count,
move || observer_state.record_task_observed(),
move |index, task| {
run_parallel_task(&worker_state, index, task);
if report_on_worker_completion {
let _ = worker_progress_sender.send(ProgressLoopSignal::RunningPoint);
}
},
);
let _ = progress_sender.send(ProgressLoopSignal::Stop);
if let Err(payload) = progress_handle.join() {
resume_unwind(payload);
}
});
let elapsed = progress.elapsed();
let result = Arc::into_inner(state)
.expect("parallel batch execution state should have a single owner")
.into_outcome(elapsed);
if actual_count < count {
progress.report_with_elapsed(
ProgressPhase::Failed,
outcome_progress_counters(&result),
result.elapsed(),
);
Err(BatchExecutionError::CountShortfall {
expected: count,
actual: actual_count,
outcome: result,
})
} else if actual_count > count {
progress.report_with_elapsed(
ProgressPhase::Failed,
outcome_progress_counters(&result),
result.elapsed(),
);
Err(BatchExecutionError::CountExceeded {
expected: count,
observed_at_least: actual_count,
outcome: result,
})
} else {
progress.report_with_elapsed(
ProgressPhase::Finished,
outcome_progress_counters(&result),
result.elapsed(),
);
Ok(result)
}
}
}
fn outcome_progress_counters<E>(outcome: &BatchOutcome<E>) -> ProgressCounters {
ProgressCounters::new(Some(outcome.task_count()))
.with_completed_count(outcome.completed_count())
.with_succeeded_count(outcome.succeeded_count())
.with_failed_count(outcome.failure_count())
}
fn run_progress_loop<E>(
reporter: Arc<dyn ProgressReporter>,
state: Arc<BatchExecutionState<E>>,
start: Instant,
report_interval: Duration,
signal_receiver: mpsc::Receiver<ProgressLoopSignal>,
) {
let mut progress = Progress::from_start(reporter.as_ref(), report_interval, start);
loop {
match receive_progress_signal(&signal_receiver, report_interval) {
ProgressLoopWait::Signal(ProgressLoopSignal::RunningPoint) => {
progress.report_running_if_due(state.progress_counters());
}
ProgressLoopWait::Signal(ProgressLoopSignal::Stop) | ProgressLoopWait::Disconnected => {
break;
}
ProgressLoopWait::Timeout => {
progress.report_running_if_due(state.progress_counters());
}
}
}
}
fn receive_progress_signal(
signal_receiver: &mpsc::Receiver<ProgressLoopSignal>,
report_interval: Duration,
) -> ProgressLoopWait {
if report_interval.is_zero() {
return match signal_receiver.recv() {
Ok(signal) => ProgressLoopWait::Signal(signal),
Err(_) => ProgressLoopWait::Disconnected,
};
}
match signal_receiver.recv_timeout(report_interval) {
Ok(signal) => ProgressLoopWait::Signal(signal),
Err(RecvTimeoutError::Timeout) => ProgressLoopWait::Timeout,
Err(RecvTimeoutError::Disconnected) => ProgressLoopWait::Disconnected,
}
}
enum ProgressLoopWait {
Signal(ProgressLoopSignal),
Timeout,
Disconnected,
}