laddu_core/execution/
thread_pool.rs1#[cfg(feature = "rayon")]
4use std::sync::Arc;
5use std::sync::{
6 atomic::{AtomicUsize, Ordering},
7 OnceLock,
8};
9
10#[cfg(feature = "rayon")]
11use parking_lot::RwLock;
12
13#[cfg(feature = "rayon")]
14use crate::LadduError;
15use crate::LadduResult;
16
17static GLOBAL_THREAD_COUNT: AtomicUsize = AtomicUsize::new(0);
18
19#[derive(Debug, Clone, Default)]
22pub(crate) enum ThreadExecutor {
23 #[default]
25 Ambient,
26 #[cfg(feature = "rayon")]
28 Dedicated(Arc<rayon::ThreadPool>),
29}
30
31impl ThreadExecutor {
32 #[cfg(feature = "rayon")]
34 pub(crate) fn dedicated(n_threads: usize) -> LadduResult<Self> {
35 if n_threads == 0 {
36 return Err(LadduError::ExecutionContextError {
37 reason: "Dedicated thread pool size must be >= 1".into(),
38 });
39 }
40
41 Ok(Self::Dedicated(Arc::new(
42 rayon::ThreadPoolBuilder::new()
43 .num_threads(n_threads)
44 .build()?,
45 )))
46 }
47
48 #[cfg(feature = "rayon")]
50 pub(crate) fn install<R: Send>(&self, op: impl FnOnce() -> R + Send) -> R {
51 match self {
52 Self::Ambient => op(),
53 Self::Dedicated(pool) => pool.install(op),
54 }
55 }
56
57 #[allow(dead_code)]
59 #[cfg(not(feature = "rayon"))]
60 pub(crate) fn install<R>(&self, op: impl FnOnce() -> R) -> R {
61 op()
62 }
63}
64
65#[derive(Debug, Default)]
72pub struct ThreadPoolManager {
73 #[cfg(feature = "rayon")]
74 pub(crate) dedicated_pool: RwLock<Option<(usize, ThreadExecutor)>>,
75}
76
77impl ThreadPoolManager {
78 pub fn shared() -> &'static Self {
80 static THREAD_POOL_MANAGER: OnceLock<ThreadPoolManager> = OnceLock::new();
81 THREAD_POOL_MANAGER.get_or_init(Self::default)
82 }
83
84 pub fn set_global_thread_count(n_threads: usize) {
88 GLOBAL_THREAD_COUNT.store(n_threads, Ordering::Relaxed);
89 }
90
91 pub fn global_thread_count() -> Option<usize> {
95 Self::normalize_thread_request(Some(GLOBAL_THREAD_COUNT.load(Ordering::Relaxed)))
96 }
97
98 pub fn resolve_thread_request(requested_threads: Option<usize>) -> Option<usize> {
103 match requested_threads {
104 None | Some(0) => Self::global_thread_count(),
105 Some(n_threads) => Some(n_threads),
106 }
107 }
108
109 #[cfg(feature = "rayon")]
114 pub fn install<R: Send>(
115 &self,
116 requested_threads: Option<usize>,
117 op: impl FnOnce() -> R + Send,
118 ) -> LadduResult<R> {
119 match Self::resolve_thread_request(requested_threads) {
120 Some(n_threads) => Ok(self.executor_for_threads(n_threads)?.install(op)),
121 None => Ok(ThreadExecutor::default().install(op)),
122 }
123 }
124
125 #[cfg(not(feature = "rayon"))]
130 pub fn install<R>(
131 &self,
132 _requested_threads: Option<usize>,
133 op: impl FnOnce() -> R,
134 ) -> LadduResult<R> {
135 Ok(op())
136 }
137
138 fn normalize_thread_request(requested_threads: Option<usize>) -> Option<usize> {
139 requested_threads.filter(|&n_threads| n_threads > 0)
140 }
141
142 #[cfg(feature = "rayon")]
143 pub(crate) fn executor_for_threads(&self, n_threads: usize) -> LadduResult<ThreadExecutor> {
144 if let Some((cached_threads, executor)) = &*self.dedicated_pool.read() {
145 if *cached_threads == n_threads {
146 return Ok(executor.clone());
147 }
148 }
149
150 let executor = ThreadExecutor::dedicated(n_threads)?;
151 let mut dedicated_pool = self.dedicated_pool.write();
152 *dedicated_pool = Some((n_threads, executor.clone()));
153 Ok(executor)
154 }
155}