use std::{
num::NonZeroUsize,
panic::resume_unwind,
sync::{
Arc,
mpsc,
},
thread,
time::{
Duration,
Instant,
},
};
use qubit_function::{
ArcConsumer,
Consumer,
};
use qubit_progress::{
Progress,
model::ProgressPhase,
reporter::{
NoOpProgressReporter,
ProgressReporter,
},
};
use crate::process::{
BatchProcessError,
BatchProcessResult,
BatchProcessState,
BatchProcessor,
};
use crate::utils::run_scoped_parallel;
enum ProgressLoopSignal {
RunningPoint,
Stop,
}
pub struct ParallelBatchProcessor<Item> {
consumer: ArcConsumer<Item>,
thread_count: NonZeroUsize,
report_interval: Duration,
reporter: Arc<dyn ProgressReporter>,
}
impl<Item> ParallelBatchProcessor<Item> {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
#[inline]
pub fn new<C>(consumer: C) -> Self
where
C: Consumer<Item> + Send + Sync + 'static,
{
Self {
consumer: consumer.into_arc(),
thread_count: NonZeroUsize::new(Self::default_thread_count())
.expect("default parallel processor thread count should be non-zero"),
report_interval: Self::DEFAULT_REPORT_INTERVAL,
reporter: Arc::new(NoOpProgressReporter),
}
}
#[inline]
pub fn default_thread_count() -> usize {
thread::available_parallelism()
.map(usize::from)
.unwrap_or(1)
}
#[inline]
pub const fn with_thread_count(mut self, thread_count: NonZeroUsize) -> Self {
self.thread_count = thread_count;
self
}
#[inline]
pub fn with_reporter<R>(self, reporter: R) -> Self
where
R: ProgressReporter + 'static,
{
self.with_reporter_arc(Arc::new(reporter))
}
#[inline]
pub fn with_reporter_arc(self, reporter: Arc<dyn ProgressReporter>) -> Self {
Self { reporter, ..self }
}
#[inline]
pub fn with_report_interval(self, report_interval: Duration) -> Self {
Self {
report_interval,
..self
}
}
#[inline]
pub const fn thread_count(&self) -> usize {
self.thread_count.get()
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
#[inline]
pub const fn consumer(&self) -> &ArcConsumer<Item> {
&self.consumer
}
#[inline]
pub fn into_consumer(self) -> ArcConsumer<Item> {
self.consumer
}
}
impl<Item> BatchProcessor<Item> for ParallelBatchProcessor<Item>
where
Item: Send,
{
type Error = BatchProcessError;
fn process<I>(&mut self, items: I, count: usize) -> Result<BatchProcessResult, Self::Error>
where
I: IntoIterator<Item = Item>,
{
let state = Arc::new(BatchProcessState::new(count));
let progress = Progress::new(self.reporter.as_ref(), self.report_interval);
progress.report_with_elapsed(
ProgressPhase::Started,
state.progress_counters(),
Duration::ZERO,
);
let start = progress.started_at();
if count > 0 {
self.process_non_empty(items, count, Arc::clone(&state), start);
} else if items.into_iter().next().is_some() {
state.record_item_observed();
}
let result = state.to_direct_result(start.elapsed());
if state.observed_count() < count {
progress.report_with_elapsed(
ProgressPhase::Failed,
state.progress_counters(),
result.elapsed(),
);
Err(BatchProcessError::CountShortfall {
expected: count,
actual: state.observed_count(),
result,
})
} else if state.observed_count() > count {
progress.report_with_elapsed(
ProgressPhase::Failed,
state.progress_counters(),
result.elapsed(),
);
Err(BatchProcessError::CountExceeded {
expected: count,
observed_at_least: state.observed_count(),
result,
})
} else {
progress.report_with_elapsed(
ProgressPhase::Finished,
state.progress_counters(),
result.elapsed(),
);
Ok(result)
}
}
}
impl<Item> ParallelBatchProcessor<Item>
where
Item: Send,
{
fn process_non_empty<I>(
&self,
items: I,
count: usize,
state: Arc<BatchProcessState>,
start: Instant,
) where
I: IntoIterator<Item = Item>,
{
thread::scope(|scope| {
let (progress_sender, progress_receiver) = mpsc::channel();
let progress_handle = {
let progress_reporter = Arc::clone(&self.reporter);
let reporter_state = Arc::clone(&state);
let report_interval = self.report_interval;
scope.spawn(move || {
run_progress_loop(
progress_reporter,
reporter_state,
start,
report_interval,
progress_receiver,
);
})
};
let worker_count = self.thread_count.get().min(count);
let observer_state = Arc::clone(&state);
let worker_state = Arc::clone(&state);
let consumer = self.consumer.clone();
let worker_progress_sender = progress_sender.clone();
let report_on_worker_completion = self.report_interval.is_zero();
run_scoped_parallel(
items,
count,
worker_count,
move || observer_state.record_item_observed(),
move |_index, item| {
worker_state.record_item_started();
consumer.accept(&item);
worker_state.record_item_processed();
if report_on_worker_completion {
let _ = worker_progress_sender.send(ProgressLoopSignal::RunningPoint);
}
},
);
let _ = progress_sender.send(ProgressLoopSignal::Stop);
if let Err(payload) = progress_handle.join() {
resume_unwind(payload);
}
});
}
}
fn run_progress_loop(
reporter: Arc<dyn ProgressReporter>,
state: Arc<BatchProcessState>,
start: Instant,
report_interval: Duration,
signal_receiver: mpsc::Receiver<ProgressLoopSignal>,
) {
let mut progress = Progress::from_start(reporter.as_ref(), report_interval, start);
loop {
match receive_progress_signal(&signal_receiver, report_interval) {
ProgressLoopWait::Signal(ProgressLoopSignal::RunningPoint) => {
progress.report_running_if_due(state.progress_counters());
}
ProgressLoopWait::Signal(ProgressLoopSignal::Stop) | ProgressLoopWait::Disconnected => {
break;
}
ProgressLoopWait::Timeout => {
progress.report_running_if_due(state.progress_counters());
}
}
}
}
fn receive_progress_signal(
signal_receiver: &mpsc::Receiver<ProgressLoopSignal>,
report_interval: Duration,
) -> ProgressLoopWait {
if report_interval.is_zero() {
return match signal_receiver.recv() {
Ok(signal) => ProgressLoopWait::Signal(signal),
Err(_) => ProgressLoopWait::Disconnected,
};
}
match signal_receiver.recv_timeout(report_interval) {
Ok(signal) => ProgressLoopWait::Signal(signal),
Err(mpsc::RecvTimeoutError::Timeout) => ProgressLoopWait::Timeout,
Err(mpsc::RecvTimeoutError::Disconnected) => ProgressLoopWait::Disconnected,
}
}
enum ProgressLoopWait {
Signal(ProgressLoopSignal),
Timeout,
Disconnected,
}