use std::{
panic::{
AssertUnwindSafe,
catch_unwind,
},
sync::{
Arc,
Mutex,
PoisonError,
mpsc::{
self,
Receiver,
},
},
thread,
time::Duration,
};
use qubit_batch::{
BatchExecutionError,
BatchExecutionState,
BatchExecutor,
BatchOutcome,
BatchTaskError,
SequentialBatchExecutor,
};
use qubit_function::Runnable;
use qubit_progress::{
Progress,
ProgressReporter,
RunningProgressPointHandle,
};
use rayon::ThreadPool as RayonThreadPool;
use crate::{
RayonBatchExecutorBuildError,
RayonBatchExecutorBuilder,
};
struct RayonWorkItem<T> {
index: usize,
task: T,
}
#[derive(Clone)]
pub struct RayonBatchExecutor {
pool: Arc<RayonThreadPool>,
thread_count: usize,
sequential_threshold: usize,
report_interval: Duration,
reporter: Arc<dyn ProgressReporter>,
}
impl RayonBatchExecutor {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
pub const DEFAULT_SEQUENTIAL_THRESHOLD: usize = 100;
#[inline]
pub fn default_thread_count() -> usize {
thread::available_parallelism()
.map(usize::from)
.unwrap_or(1)
}
#[inline]
pub fn builder() -> RayonBatchExecutorBuilder {
RayonBatchExecutorBuilder::default()
}
#[inline]
pub fn new(thread_count: usize) -> Result<Self, RayonBatchExecutorBuildError> {
Self::builder().thread_count(thread_count).build()
}
#[inline]
pub(crate) fn new_with_rayon(
pool: RayonThreadPool,
builder: RayonBatchExecutorBuilder,
) -> Self {
Self {
pool: Arc::new(pool),
thread_count: builder.thread_count,
sequential_threshold: builder.sequential_threshold,
report_interval: builder.report_interval,
reporter: builder.reporter,
}
}
#[inline]
pub const fn thread_count(&self) -> usize {
self.thread_count
}
#[inline]
pub const fn sequential_threshold(&self) -> usize {
self.sequential_threshold
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
}
impl Default for RayonBatchExecutor {
#[inline]
fn default() -> Self {
Self::builder()
.build()
.expect("default rayon batch executor should build")
}
}
impl BatchExecutor for RayonBatchExecutor {
fn execute<T, E, I>(
&self,
tasks: I,
count: usize,
) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
where
I: IntoIterator<Item = T>,
T: Runnable<E> + Send,
E: Send,
{
if count <= self.sequential_threshold || self.thread_count <= 1 {
let sequential = SequentialBatchExecutor::new()
.with_report_interval(self.report_interval)
.with_reporter_arc(Arc::clone(&self.reporter));
return sequential.execute(tasks, count);
}
let state = Arc::new(BatchExecutionState::new(count));
let progress = Progress::new(self.reporter.as_ref(), self.report_interval);
progress.report_started(state.progress_counters());
let mut observed_count = 0usize;
let worker_count = self.thread_count.min(count);
thread::scope(|thread_scope| {
let reporter_state = Arc::clone(&state);
let running_progress = progress
.spawn_running_reporter(thread_scope, move || reporter_state.progress_counters());
let running_point_handle = running_progress.point_handle();
self.pool.in_place_scope_fifo(|scope| {
let (work_sender, work_receiver) = mpsc::sync_channel(worker_count);
let work_receiver = Arc::new(Mutex::new(work_receiver));
for _ in 0..worker_count {
let worker_receiver = Arc::clone(&work_receiver);
let worker_state = Arc::clone(&state);
let running_point_handle = running_point_handle.clone();
scope.spawn_fifo(move |_| {
run_rayon_worker(worker_receiver, worker_state, running_point_handle);
});
}
drop(work_receiver);
for task in tasks {
observed_count = state.record_task_observed();
if observed_count > count {
break;
}
if work_sender
.send(RayonWorkItem {
index: observed_count - 1,
task,
})
.is_err()
{
break;
}
}
drop(work_sender);
});
running_progress.stop_and_join();
});
let state =
Arc::into_inner(state).expect("rayon batch execution state should have a single owner");
if observed_count < count {
let failed = progress.report_failed(state.progress_counters());
let result = state.into_outcome(failed.elapsed());
Err(BatchExecutionError::CountShortfall {
expected: count,
actual: observed_count,
outcome: result,
})
} else if observed_count > count {
let failed = progress.report_failed(state.progress_counters());
let result = state.into_outcome(failed.elapsed());
Err(BatchExecutionError::CountExceeded {
expected: count,
observed_at_least: observed_count,
outcome: result,
})
} else {
let finished = progress.report_finished(state.progress_counters());
let result = state.into_outcome(finished.elapsed());
Ok(result)
}
}
}
fn run_rayon_worker<T, E>(
work_receiver: Arc<Mutex<Receiver<RayonWorkItem<T>>>>,
state: Arc<BatchExecutionState<E>>,
progress_point_handle: RunningProgressPointHandle,
) where
T: Runnable<E> + Send,
E: Send,
{
loop {
let received = work_receiver
.lock()
.unwrap_or_else(PoisonError::into_inner)
.recv();
let Ok(RayonWorkItem { index, task }) = received else {
break;
};
run_rayon_task(&state, index, task);
progress_point_handle.report();
}
}
fn run_rayon_task<T, E>(state: &BatchExecutionState<E>, index: usize, mut task: T)
where
T: Runnable<E>,
E: Send,
{
state.record_task_started();
let outcome = catch_unwind(AssertUnwindSafe(|| task.run()));
match outcome {
Ok(Ok(())) => state.record_task_succeeded(),
Ok(Err(error)) => state.record_task_failed(index, error),
Err(payload) => {
state.record_task_panicked(index, BatchTaskError::from_panic_payload(payload.as_ref()));
}
}
}