use std::{
num::NonZeroUsize,
sync::Arc,
thread,
time::Duration,
};
use qubit_function::{
ArcConsumer,
Consumer,
};
use qubit_progress::{
Progress,
reporter::ProgressReporter,
};
use crate::process::{
BatchProcessError,
BatchProcessResult,
BatchProcessState,
BatchProcessor,
};
use crate::utils::run_scoped_parallel;
use super::parallel_batch_processor_builder::ParallelBatchProcessorBuilder;
pub struct ParallelBatchProcessor<Item> {
pub(crate) consumer: ArcConsumer<Item>,
pub(crate) thread_count: NonZeroUsize,
pub(crate) sequential_threshold: usize,
pub(crate) report_interval: Duration,
pub(crate) reporter: Arc<dyn ProgressReporter>,
}
impl<Item> ParallelBatchProcessor<Item> {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
pub const DEFAULT_SEQUENTIAL_THRESHOLD: usize = 100;
#[inline]
pub fn new<C>(consumer: C) -> Self
where
C: Consumer<Item> + Send + Sync + 'static,
{
Self::builder(consumer)
.build()
.expect("default parallel batch processor should build")
}
#[inline]
pub fn builder<C>(consumer: C) -> ParallelBatchProcessorBuilder<Item>
where
C: Consumer<Item> + Send + Sync + 'static,
{
ParallelBatchProcessorBuilder::new(consumer)
}
#[inline]
pub fn default_thread_count() -> usize {
thread::available_parallelism()
.map(usize::from)
.unwrap_or(1)
}
#[inline]
pub const fn thread_count(&self) -> usize {
self.thread_count.get()
}
#[inline]
pub const fn sequential_threshold(&self) -> usize {
self.sequential_threshold
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
#[inline]
pub const fn consumer(&self) -> &ArcConsumer<Item> {
&self.consumer
}
#[inline]
pub fn into_consumer(self) -> ArcConsumer<Item> {
self.consumer
}
}
impl<Item> BatchProcessor<Item> for ParallelBatchProcessor<Item>
where
Item: Send,
{
type Error = BatchProcessError;
fn process_with_count<I>(
&mut self,
items: I,
count: usize,
) -> Result<BatchProcessResult, Self::Error>
where
I: IntoIterator<Item = Item>,
{
let state = Arc::new(BatchProcessState::new(count));
let mut progress = Progress::new(self.reporter.as_ref(), self.report_interval);
progress.report_started(state.progress_counters());
if count > 0 {
if count <= self.sequential_threshold {
self.process_sequential(items, count, state.as_ref(), &mut progress);
} else {
self.process_parallel_non_empty(items, count, Arc::clone(&state), &progress);
}
} else if items.into_iter().next().is_some() {
state.record_item_observed();
}
if state.observed_count() < count {
let failed = progress.report_failed(state.progress_counters());
let result = state.to_direct_result(failed.elapsed());
Err(BatchProcessError::CountShortfall {
expected: count,
actual: state.observed_count(),
result,
})
} else if state.observed_count() > count {
let failed = progress.report_failed(state.progress_counters());
let result = state.to_direct_result(failed.elapsed());
Err(BatchProcessError::CountExceeded {
expected: count,
observed_at_least: state.observed_count(),
result,
})
} else {
let finished = progress.report_finished(state.progress_counters());
let result = state.to_direct_result(finished.elapsed());
Ok(result)
}
}
}
impl<Item> ParallelBatchProcessor<Item>
where
Item: Send,
{
fn process_sequential<I>(
&self,
items: I,
count: usize,
state: &BatchProcessState,
progress: &mut Progress<'_>,
) where
I: IntoIterator<Item = Item>,
{
for item in items {
let observed_count = state.record_item_observed();
if observed_count > count {
break;
}
state.record_item_started();
self.consumer.accept(&item);
state.record_item_processed();
let _ = progress.report_running_if_due(state.progress_counters());
}
}
fn process_parallel_non_empty<I>(
&self,
items: I,
count: usize,
state: Arc<BatchProcessState>,
progress: &Progress<'_>,
) where
I: IntoIterator<Item = Item>,
{
thread::scope(|scope| {
let reporter_state = Arc::clone(&state);
let running_progress =
progress.spawn_running_reporter(scope, move || reporter_state.progress_counters());
let running_point_handle = running_progress.point_handle();
let worker_count = self.thread_count.get().min(count);
let observer_state = Arc::clone(&state);
let worker_state = Arc::clone(&state);
let consumer = self.consumer.clone();
run_scoped_parallel(
items,
count,
worker_count,
move || observer_state.record_item_observed(),
move |_index, item| {
worker_state.record_item_started();
consumer.accept(&item);
worker_state.record_item_processed();
running_point_handle.report();
},
);
running_progress.stop_and_join();
});
}
}