use std::{
sync::Arc,
time::Duration,
};
use qubit_progress::{
NoOpProgressReporter,
ProgressReporter,
};
use crate::{
RayonBatchExecutor,
RayonBatchExecutorBuildError,
};
const DEFAULT_THREAD_NAME_PREFIX: &str = "qubit-rayon-batch";
pub struct RayonBatchExecutorBuilder {
pub(crate) thread_count: usize,
pub(crate) sequential_threshold: usize,
pub(crate) report_interval: Duration,
pub(crate) reporter: Arc<dyn ProgressReporter>,
pub(crate) thread_name_prefix: String,
pub(crate) stack_size: Option<usize>,
}
impl RayonBatchExecutorBuilder {
#[inline]
pub fn thread_count(mut self, thread_count: usize) -> Self {
self.thread_count = thread_count;
self
}
#[inline]
pub fn sequential_threshold(mut self, sequential_threshold: usize) -> Self {
self.sequential_threshold = sequential_threshold;
self
}
#[inline]
pub fn report_interval(mut self, report_interval: Duration) -> Self {
self.report_interval = report_interval;
self
}
#[inline]
pub fn reporter<R>(mut self, reporter: R) -> Self
where
R: ProgressReporter + 'static,
{
self.reporter = Arc::new(reporter);
self
}
#[inline]
pub fn reporter_arc(mut self, reporter: Arc<dyn ProgressReporter>) -> Self {
self.reporter = reporter;
self
}
#[inline]
pub fn no_reporter(mut self) -> Self {
self.reporter = Arc::new(NoOpProgressReporter);
self
}
#[inline]
pub fn thread_name_prefix(mut self, thread_name_prefix: impl Into<String>) -> Self {
self.thread_name_prefix = thread_name_prefix.into();
self
}
#[inline]
pub fn stack_size(mut self, stack_size: usize) -> Self {
self.stack_size = Some(stack_size);
self
}
#[inline]
pub fn build(self) -> Result<RayonBatchExecutor, RayonBatchExecutorBuildError> {
if self.thread_count == 0 {
return Err(RayonBatchExecutorBuildError::ZeroThreadCount);
}
if self.stack_size == Some(0) {
return Err(RayonBatchExecutorBuildError::ZeroStackSize);
}
let prefix = self.thread_name_prefix.clone();
let mut builder = rayon::ThreadPoolBuilder::new()
.num_threads(self.thread_count)
.thread_name(move |index| format!("{prefix}-{index}"));
if let Some(stack_size) = self.stack_size {
builder = builder.stack_size(stack_size);
}
let pool = builder.build()?;
Ok(RayonBatchExecutor::new_with_rayon(pool, self))
}
}
impl Default for RayonBatchExecutorBuilder {
#[inline]
fn default() -> Self {
Self {
thread_count: RayonBatchExecutor::default_thread_count(),
sequential_threshold: RayonBatchExecutor::DEFAULT_SEQUENTIAL_THRESHOLD,
report_interval: RayonBatchExecutor::DEFAULT_REPORT_INTERVAL,
reporter: Arc::new(NoOpProgressReporter),
thread_name_prefix: DEFAULT_THREAD_NAME_PREFIX.to_owned(),
stack_size: None,
}
}
}