#[cfg(feature = "rayon")]
use std::sync::Arc;
use std::sync::{
atomic::{AtomicUsize, Ordering},
OnceLock,
};
#[cfg(feature = "rayon")]
use parking_lot::RwLock;
#[cfg(feature = "rayon")]
use crate::LadduError;
use crate::LadduResult;
static GLOBAL_THREAD_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Clone, Default)]
pub(crate) enum ThreadExecutor {
#[default]
Ambient,
#[cfg(feature = "rayon")]
Dedicated(Arc<rayon::ThreadPool>),
}
impl ThreadExecutor {
#[cfg(feature = "rayon")]
pub(crate) fn dedicated(n_threads: usize) -> LadduResult<Self> {
if n_threads == 0 {
return Err(LadduError::ExecutionContextError {
reason: "Dedicated thread pool size must be >= 1".into(),
});
}
Ok(Self::Dedicated(Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(n_threads)
.build()?,
)))
}
#[cfg(feature = "rayon")]
pub(crate) fn install<R: Send>(&self, op: impl FnOnce() -> R + Send) -> R {
match self {
Self::Ambient => op(),
Self::Dedicated(pool) => pool.install(op),
}
}
#[allow(dead_code)]
#[cfg(not(feature = "rayon"))]
pub(crate) fn install<R>(&self, op: impl FnOnce() -> R) -> R {
op()
}
}
#[derive(Debug, Default)]
pub struct ThreadPoolManager {
#[cfg(feature = "rayon")]
pub(crate) dedicated_pool: RwLock<Option<(usize, ThreadExecutor)>>,
}
impl ThreadPoolManager {
pub fn shared() -> &'static Self {
static THREAD_POOL_MANAGER: OnceLock<ThreadPoolManager> = OnceLock::new();
THREAD_POOL_MANAGER.get_or_init(Self::default)
}
pub fn set_global_thread_count(n_threads: usize) {
GLOBAL_THREAD_COUNT.store(n_threads, Ordering::Relaxed);
}
pub fn global_thread_count() -> Option<usize> {
Self::normalize_thread_request(Some(GLOBAL_THREAD_COUNT.load(Ordering::Relaxed)))
}
pub fn resolve_thread_request(requested_threads: Option<usize>) -> Option<usize> {
match requested_threads {
None | Some(0) => Self::global_thread_count(),
Some(n_threads) => Some(n_threads),
}
}
#[cfg(feature = "rayon")]
pub fn install<R: Send>(
&self,
requested_threads: Option<usize>,
op: impl FnOnce() -> R + Send,
) -> LadduResult<R> {
match Self::resolve_thread_request(requested_threads) {
Some(n_threads) => Ok(self.executor_for_threads(n_threads)?.install(op)),
None => Ok(ThreadExecutor::default().install(op)),
}
}
#[cfg(not(feature = "rayon"))]
pub fn install<R>(
&self,
_requested_threads: Option<usize>,
op: impl FnOnce() -> R,
) -> LadduResult<R> {
Ok(op())
}
fn normalize_thread_request(requested_threads: Option<usize>) -> Option<usize> {
requested_threads.filter(|&n_threads| n_threads > 0)
}
#[cfg(feature = "rayon")]
pub(crate) fn executor_for_threads(&self, n_threads: usize) -> LadduResult<ThreadExecutor> {
if let Some((cached_threads, executor)) = &*self.dedicated_pool.read() {
if *cached_threads == n_threads {
return Ok(executor.clone());
}
}
let executor = ThreadExecutor::dedicated(n_threads)?;
let mut dedicated_pool = self.dedicated_pool.write();
*dedicated_pool = Some((n_threads, executor.clone()));
Ok(executor)
}
}