scirs2_ndimage/
threading.rs1use scirs2_core::parallel_ops::*;
9use std::sync::{Arc, Mutex, OnceLock};
10
11static THREAD_POOL_CONFIG: OnceLock<Arc<Mutex<ThreadPoolConfig>>> = OnceLock::new();
13
14#[derive(Debug, Clone)]
16pub struct ThreadPoolConfig {
17 pub num_threads: Option<usize>,
19 pub stack_size: Option<usize>,
21 pub thread_name_prefix: String,
23 pub pin_threads: bool,
25}
26
27impl Default for ThreadPoolConfig {
28 fn default() -> Self {
29 Self {
30 num_threads: None, stack_size: Some(8 * 1024 * 1024), thread_name_prefix: "scirs2-worker".to_string(),
33 pin_threads: false,
34 }
35 }
36}
37
38#[allow(dead_code)]
40pub fn init_thread_pool(config: ThreadPoolConfig) -> Result<(), String> {
41 THREAD_POOL_CONFIG
42 .set(Arc::new(Mutex::new(config)))
43 .map_err(|_| "Thread pool already initialized".to_string())
44}
45
46#[allow(dead_code)]
48pub fn get_thread_pool_config() -> ThreadPoolConfig {
49 THREAD_POOL_CONFIG
50 .get()
51 .map(|config| config.lock().unwrap().clone())
52 .unwrap_or_default()
53}
54
55#[allow(dead_code)]
57pub fn update_thread_pool_config<F>(_updatefn: F) -> Result<(), String>
58where
59 F: FnOnce(&mut ThreadPoolConfig),
60{
61 if let Some(config) = THREAD_POOL_CONFIG.get() {
62 let mut config = config.lock().unwrap();
63 _updatefn(&mut *config);
64 Ok(())
65 } else {
66 Err("Thread pool not initialized".to_string())
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct WorkerInfo {
73 pub thread_id: usize,
75 pub num_workers: usize,
77 pub cpu_affinity: Option<usize>,
79}
80
81thread_local! {
82 static WORKER_INFO: std::cell::RefCell<Option<WorkerInfo>> = const { std::cell::RefCell::new(None) };
83}
84
85#[allow(dead_code)]
87pub fn current_worker_info() -> Option<WorkerInfo> {
88 WORKER_INFO.with(|info| info.borrow().clone())
89}
90
91#[allow(dead_code)]
93pub fn set_worker_info(info: WorkerInfo) {
94 WORKER_INFO.with(|cell| {
95 *cell.borrow_mut() = Some(info);
96 });
97}
98
99#[allow(dead_code)]
101pub trait ParallelIteratorExt: ParallelIterator {
102 fn with_threads(self, numthreads: usize) -> Self;
104
105 fn with_thread_init<F>(self, init: F) -> Self
107 where
108 F: Fn() + Send + Sync + 'static;
109}
110
111#[allow(dead_code)]
113pub trait ThreadPoolArrayExt<T, D> {
114 fn par_map_inplace<F>(&mut self, f: F)
116 where
117 F: Fn(&mut T) + Send + Sync;
118
119 fn par_chunks_mut<F>(&mut self, chunksize: usize, f: F)
121 where
122 F: Fn(&mut [T]) + Send + Sync;
123}
124
125#[allow(dead_code)]
127pub struct ThreadPoolContext {
128 config: ThreadPoolConfig,
129}
130
131impl ThreadPoolContext {
132 pub fn new() -> Self {
133 Self {
134 config: get_thread_pool_config(),
135 }
136 }
137
138 pub fn execute_parallel<F, R>(&self, operation: F) -> R
140 where
141 F: FnOnce() -> R + Send,
142 R: Send,
143 {
144 operation()
147 }
148
149 pub fn execute_with_threads<F, R>(&self, numthreads: usize, operation: F) -> R
151 where
152 F: FnOnce() -> R + Send,
153 R: Send,
154 {
155 let _prev_threads = num_threads();
157 let result = operation();
159 result
161 }
162}
163
164#[allow(dead_code)]
166pub struct AdaptiveThreadPool {
167 min_threads: usize,
168 max_threads: usize,
169 current_threads: Arc<Mutex<usize>>,
170 load_threshold: f64,
171}
172
173impl AdaptiveThreadPool {
174 pub fn new(_min_threads: usize, maxthreads: usize) -> Self {
175 Self {
176 min_threads: _min_threads,
177 max_threads: maxthreads,
178 current_threads: Arc::new(Mutex::new(_min_threads)),
179 load_threshold: 0.8,
180 }
181 }
182
183 pub fn adjust_threads(&self, currentload: f64) {
185 let mut threads = self.current_threads.lock().unwrap();
186
187 if currentload > self.load_threshold && *threads < self.max_threads {
188 *threads = (*threads + 1).min(self.max_threads);
189 } else if currentload < self.load_threshold * 0.5 && *threads > self.min_threads {
190 *threads = (*threads - 1).max(self.min_threads);
191 }
192 }
193
194 pub fn current_thread_count(&self) -> usize {
196 *self.current_threads.lock().unwrap()
197 }
198}
199
200#[allow(dead_code)]
202pub struct WorkStealingQueue<T> {
203 queues: Vec<Arc<Mutex<Vec<T>>>>,
204}
205
206impl<T: Send> WorkStealingQueue<T> {
207 pub fn new(_numqueues: usize) -> Self {
208 let _queues = (0.._numqueues)
209 .map(|_| Arc::new(Mutex::new(Vec::new())))
210 .collect();
211
212 Self { queues: _queues }
213 }
214
215 pub fn push(&self, queueid: usize, item: T) {
217 if let Some(queue) = self.queues.get(queueid) {
218 queue.lock().unwrap().push(item);
219 }
220 }
221
222 pub fn pop(&self, queueid: usize) -> Option<T> {
224 if let Some(queue) = self.queues.get(queueid) {
226 if let Some(item) = queue.lock().unwrap().pop() {
227 return Some(item);
228 }
229 }
230
231 for (i, queue) in self.queues.iter().enumerate() {
233 if i != queueid {
234 if let Some(item) = queue.lock().unwrap().pop() {
235 return Some(item);
236 }
237 }
238 }
239
240 None
241 }
242}
243
244#[allow(dead_code)]
246pub fn configure_parallel_ops() {
247 let config = get_thread_pool_config();
248
249 if let Some(num_threads) = config.num_threads {
251 let _ = num_threads;
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_thread_pool_config() {
262 let config = ThreadPoolConfig {
263 num_threads: Some(4),
264 stack_size: Some(4 * 1024 * 1024),
265 thread_name_prefix: "test-worker".to_string(),
266 pin_threads: true,
267 };
268
269 assert_eq!(config.num_threads, Some(4));
270 assert_eq!(config.thread_name_prefix, "test-worker");
271 }
272
273 #[test]
274 fn test_adaptive_thread_pool() {
275 let pool = AdaptiveThreadPool::new(2, 8);
276
277 assert_eq!(pool.current_thread_count(), 2);
278
279 pool.adjust_threads(0.9);
281 assert_eq!(pool.current_thread_count(), 3);
282
283 pool.adjust_threads(0.3);
285 assert_eq!(pool.current_thread_count(), 2);
286 }
287
288 #[test]
289 fn test_work_stealing_queue() {
290 let queue: WorkStealingQueue<i32> = WorkStealingQueue::new(2);
291
292 queue.push(0, 1);
294 queue.push(0, 2);
295
296 assert_eq!(queue.pop(0), Some(2));
298
299 assert_eq!(queue.pop(1), Some(1));
301
302 assert_eq!(queue.pop(0), None);
304 assert_eq!(queue.pop(1), None);
305 }
306}