1use crate::{UtilsError, UtilsResult};
7use scirs2_core::numeric::Zero;
8use std::collections::VecDeque;
9use std::sync::{Arc, Mutex};
10use std::thread;
11#[allow(non_snake_case)]
12#[cfg(test)]
13use std::time::Duration;
14
15#[derive(Debug)]
17pub struct ThreadPool {
18 workers: Vec<Worker>,
19 sender: Option<std::sync::mpsc::Sender<Job>>,
20 num_threads: usize,
21}
22
23type Job = Box<dyn FnOnce() + Send + 'static>;
24
25impl ThreadPool {
26 pub fn new(num_threads: usize) -> UtilsResult<Self> {
28 if num_threads == 0 {
29 return Err(UtilsError::InvalidParameter(
30 "Thread pool size must be greater than 0".to_string(),
31 ));
32 }
33
34 let (sender, receiver) = std::sync::mpsc::channel();
35 let receiver = Arc::new(Mutex::new(receiver));
36 let mut workers = Vec::with_capacity(num_threads);
37
38 for id in 0..num_threads {
39 workers.push(Worker::new(id, Arc::clone(&receiver))?);
40 }
41
42 Ok(ThreadPool {
43 workers,
44 sender: Some(sender),
45 num_threads,
46 })
47 }
48
49 pub fn with_cpu_cores() -> UtilsResult<Self> {
51 let num_cores = num_cpus::get();
52 Self::new(num_cores)
53 }
54
55 pub fn execute<F>(&self, f: F) -> UtilsResult<()>
57 where
58 F: FnOnce() + Send + 'static,
59 {
60 let job = Box::new(f);
61 self.sender
62 .as_ref()
63 .ok_or_else(|| {
64 UtilsError::InvalidParameter("Thread pool is shutting down".to_string())
65 })?
66 .send(job)
67 .map_err(|_| {
68 UtilsError::InvalidParameter("Failed to send job to thread pool".to_string())
69 })?;
70 Ok(())
71 }
72
73 pub fn thread_count(&self) -> usize {
75 self.num_threads
76 }
77
78 pub fn join(&mut self) {
80 drop(self.sender.take());
81 for worker in &mut self.workers {
82 if let Some(thread) = worker.thread.take() {
83 thread.join().unwrap();
84 }
85 }
86 }
87}
88
89impl Drop for ThreadPool {
90 fn drop(&mut self) {
91 drop(self.sender.take());
92 for worker in &mut self.workers {
93 if let Some(thread) = worker.thread.take() {
94 thread.join().unwrap();
95 }
96 }
97 }
98}
99
100#[derive(Debug)]
101struct Worker {
102 #[allow(dead_code)]
103 id: usize,
104 thread: Option<thread::JoinHandle<()>>,
105}
106
107impl Worker {
108 fn new(id: usize, receiver: Arc<Mutex<std::sync::mpsc::Receiver<Job>>>) -> UtilsResult<Self> {
109 let thread = thread::spawn(move || loop {
110 let job = receiver.lock().unwrap().recv();
111 match job {
112 Ok(job) => {
113 job();
114 }
115 Err(_) => {
116 break;
117 }
118 }
119 });
120
121 Ok(Worker {
122 id,
123 thread: Some(thread),
124 })
125 }
126}
127
128#[derive(Debug)]
130pub struct WorkStealingQueue<T> {
131 local_queue: Arc<Mutex<VecDeque<T>>>,
132 global_queue: Arc<Mutex<VecDeque<T>>>,
133 workers: Vec<Arc<Mutex<VecDeque<T>>>>,
134 worker_id: usize,
135}
136
137impl<T> WorkStealingQueue<T>
138where
139 T: Send + 'static + Clone,
140{
141 pub fn new(num_workers: usize) -> Self {
143 let global_queue = Arc::new(Mutex::new(VecDeque::new()));
144 let mut workers = Vec::with_capacity(num_workers);
145
146 for _ in 0..num_workers {
147 workers.push(Arc::new(Mutex::new(VecDeque::new())));
148 }
149
150 Self {
151 local_queue: Arc::clone(&workers[0]),
152 global_queue,
153 workers,
154 worker_id: 0,
155 }
156 }
157
158 pub fn push_local(&self, task: T) -> UtilsResult<()> {
160 self.local_queue
161 .lock()
162 .map_err(|_| {
163 UtilsError::InvalidParameter("Failed to acquire local queue lock".to_string())
164 })?
165 .push_back(task);
166 Ok(())
167 }
168
169 pub fn push_global(&self, task: T) -> UtilsResult<()> {
171 self.global_queue
172 .lock()
173 .map_err(|_| {
174 UtilsError::InvalidParameter("Failed to acquire global queue lock".to_string())
175 })?
176 .push_back(task);
177 Ok(())
178 }
179
180 pub fn pop_local(&self) -> UtilsResult<Option<T>> {
182 Ok(self
183 .local_queue
184 .lock()
185 .map_err(|_| {
186 UtilsError::InvalidParameter("Failed to acquire local queue lock".to_string())
187 })?
188 .pop_front())
189 }
190
191 pub fn steal_work(&self) -> UtilsResult<Option<T>> {
193 for (i, worker_queue) in self.workers.iter().enumerate() {
195 if i != self.worker_id {
196 if let Ok(mut queue) = worker_queue.try_lock() {
197 if let Some(task) = queue.pop_back() {
198 return Ok(Some(task));
199 }
200 }
201 }
202 }
203
204 if let Ok(mut global) = self.global_queue.try_lock() {
206 return Ok(global.pop_front());
207 }
208
209 Ok(None)
210 }
211
212 pub fn get_task(&self) -> UtilsResult<Option<T>> {
214 if let Some(task) = self.pop_local()? {
216 return Ok(Some(task));
217 }
218
219 self.steal_work()
221 }
222}
223
224pub struct ParallelIterator<T> {
226 items: Vec<T>,
227 chunk_size: usize,
228}
229
230impl<T> ParallelIterator<T>
231where
232 T: Send + 'static + Clone,
233{
234 pub fn new(items: Vec<T>) -> Self {
236 let chunk_size = (items.len() / num_cpus::get()).max(1);
237 Self { items, chunk_size }
238 }
239
240 pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
242 self.chunk_size = chunk_size.max(1);
243 self
244 }
245
246 pub fn map<F, R>(self, f: F) -> UtilsResult<Vec<R>>
248 where
249 F: Fn(T) -> R + Send + Sync + 'static,
250 R: Send + 'static + Clone,
251 {
252 let f = Arc::new(f);
253 let results = Arc::new(Mutex::new(Vec::with_capacity(self.items.len())));
254 let _thread_pool = ThreadPool::with_cpu_cores()?;
255
256 let chunks: Vec<_> = self
258 .items
259 .into_iter()
260 .collect::<Vec<_>>()
261 .chunks(self.chunk_size)
262 .map(|chunk| chunk.to_vec())
263 .collect();
264
265 let mut handles = Vec::new();
266
267 for (chunk_idx, chunk) in chunks.into_iter().enumerate() {
268 let f_clone = Arc::clone(&f);
269 let results_clone = Arc::clone(&results);
270 let chunk_size = chunk.len();
271
272 let handle = thread::spawn(move || {
273 let mut chunk_results = Vec::with_capacity(chunk_size);
274 for item in chunk {
275 chunk_results.push(f_clone(item));
276 }
277
278 let mut results_lock = results_clone.lock().unwrap();
279 if results_lock.len() <= chunk_idx {
281 results_lock.resize_with(chunk_idx + 1, || Vec::new());
282 }
283 results_lock[chunk_idx] = chunk_results;
284 });
285
286 handles.push(handle);
287 }
288
289 for handle in handles {
291 handle.join().map_err(|_| {
292 UtilsError::InvalidParameter(
293 "Thread panicked during parallel execution".to_string(),
294 )
295 })?;
296 }
297
298 let results_lock = results.lock().unwrap();
300 let mut final_results = Vec::new();
301 for chunk_results in results_lock.iter() {
302 final_results.extend_from_slice(chunk_results);
303 }
304
305 Ok(final_results)
306 }
307
308 pub fn filter<F>(self, predicate: F) -> UtilsResult<Vec<T>>
310 where
311 F: Fn(&T) -> bool + Send + Sync + 'static,
312 T: Clone,
313 {
314 let predicate = Arc::new(predicate);
315 let results = Arc::new(Mutex::new(Vec::new()));
316 let _thread_pool = ThreadPool::with_cpu_cores()?;
317
318 let chunks: Vec<_> = self
320 .items
321 .into_iter()
322 .collect::<Vec<_>>()
323 .chunks(self.chunk_size)
324 .map(|chunk| chunk.to_vec())
325 .collect();
326
327 let mut handles = Vec::new();
328
329 for (chunk_idx, chunk) in chunks.into_iter().enumerate() {
330 let predicate_clone = Arc::clone(&predicate);
331 let results_clone = Arc::clone(&results);
332
333 let handle = thread::spawn(move || {
334 let filtered: Vec<T> = chunk
335 .into_iter()
336 .filter(|item| predicate_clone(item))
337 .collect();
338
339 let mut results_lock = results_clone.lock().unwrap();
340 if results_lock.len() <= chunk_idx {
342 results_lock.resize_with(chunk_idx + 1, || Vec::new());
343 }
344 results_lock[chunk_idx] = filtered;
345 });
346
347 handles.push(handle);
348 }
349
350 for handle in handles {
352 handle.join().map_err(|_| {
353 UtilsError::InvalidParameter(
354 "Thread panicked during parallel execution".to_string(),
355 )
356 })?;
357 }
358
359 let results_lock = results.lock().unwrap();
361 let mut final_results = Vec::new();
362 for chunk_results in results_lock.iter() {
363 final_results.extend_from_slice(chunk_results);
364 }
365
366 Ok(final_results)
367 }
368}
369
370pub struct ParallelReducer;
372
373impl ParallelReducer {
374 pub fn reduce<T, F>(items: Vec<T>, initial: T, op: F) -> UtilsResult<T>
376 where
377 T: Send + Sync + Clone + 'static,
378 F: Fn(T, T) -> T + Send + Sync + 'static,
379 {
380 if items.is_empty() {
381 return Ok(initial);
382 }
383
384 let op = Arc::new(op);
385 let chunk_size = (items.len() / num_cpus::get()).max(1);
386
387 let chunks: Vec<_> = items
389 .chunks(chunk_size)
390 .map(|chunk| chunk.to_vec())
391 .collect();
392
393 let mut handles = Vec::new();
394 let mut partial_results = Vec::new();
395
396 for chunk in chunks.into_iter() {
397 let op_clone = Arc::clone(&op);
398 let initial_clone = initial.clone();
399
400 let handle = thread::spawn(move || {
401 chunk
402 .into_iter()
403 .fold(initial_clone, |acc, item| op_clone(acc, item))
404 });
405
406 handles.push(handle);
407 }
408
409 for handle in handles {
411 let result = handle.join().map_err(|_| {
412 UtilsError::InvalidParameter(
413 "Thread panicked during parallel reduction".to_string(),
414 )
415 })?;
416 partial_results.push(result);
417 }
418
419 Ok(partial_results
421 .into_iter()
422 .fold(initial, |acc, partial| op(acc, partial)))
423 }
424
425 pub fn sum<T>(items: Vec<T>) -> UtilsResult<T>
427 where
428 T: Send + Sync + Clone + std::ops::Add<Output = T> + Zero + 'static,
429 {
430 Self::reduce(items, T::zero(), |a, b| a + b)
431 }
432
433 pub fn min<T>(items: Vec<T>) -> UtilsResult<Option<T>>
435 where
436 T: Send + Sync + Clone + Ord + 'static,
437 {
438 if items.is_empty() {
439 return Ok(None);
440 }
441
442 let first = items[0].clone();
443 let result = Self::reduce(items, first, |a, b| if a < b { a } else { b })?;
444 Ok(Some(result))
445 }
446
447 pub fn max<T>(items: Vec<T>) -> UtilsResult<Option<T>>
449 where
450 T: Send + Sync + Clone + Ord + 'static,
451 {
452 if items.is_empty() {
453 return Ok(None);
454 }
455
456 let first = items[0].clone();
457 let result = Self::reduce(items, first, |a, b| if a > b { a } else { b })?;
458 Ok(Some(result))
459 }
460}
461
462#[allow(non_snake_case)]
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use std::sync::atomic::{AtomicUsize, Ordering};
467
468 #[test]
469 fn test_thread_pool_creation() {
470 let pool = ThreadPool::new(4).unwrap();
471 assert_eq!(pool.thread_count(), 4);
472 }
473
474 #[test]
475 fn test_thread_pool_execution() {
476 let pool = ThreadPool::new(2).unwrap();
477 let counter = Arc::new(AtomicUsize::new(0));
478
479 for _ in 0..10 {
480 let counter_clone = Arc::clone(&counter);
481 pool.execute(move || {
482 counter_clone.fetch_add(1, Ordering::SeqCst);
483 })
484 .unwrap();
485 }
486
487 thread::sleep(Duration::from_millis(100));
489
490 assert_eq!(counter.load(Ordering::SeqCst), 10);
491 }
492
493 #[test]
494 fn test_work_stealing_queue() {
495 let queue = WorkStealingQueue::new(4);
496
497 queue.push_local(42).unwrap();
498 queue.push_global(24).unwrap();
499
500 assert_eq!(queue.get_task().unwrap(), Some(42));
501 assert_eq!(queue.get_task().unwrap(), Some(24));
502 assert_eq!(queue.get_task().unwrap(), None);
503 }
504
505 #[test]
506 fn test_parallel_iterator_map() {
507 let items = vec![1, 2, 3, 4, 5];
508 let iter = ParallelIterator::new(items);
509
510 let results = iter.map(|x| x * 2).unwrap();
511 assert_eq!(results, vec![2, 4, 6, 8, 10]);
512 }
513
514 #[test]
515 fn test_parallel_iterator_filter() {
516 let items = vec![1, 2, 3, 4, 5, 6];
517 let iter = ParallelIterator::new(items);
518
519 let results = iter.filter(|&x| x % 2 == 0).unwrap();
520 assert_eq!(results, vec![2, 4, 6]);
521 }
522
523 #[test]
524 fn test_parallel_reducer_sum() {
525 let items = vec![1, 2, 3, 4, 5];
526 let result = ParallelReducer::sum(items).unwrap();
527 assert_eq!(result, 15);
528 }
529
530 #[test]
531 fn test_parallel_reducer_min_max() {
532 let items = vec![5, 2, 8, 1, 9, 3];
533
534 let min_result = ParallelReducer::min(items.clone()).unwrap();
535 assert_eq!(min_result, Some(1));
536
537 let max_result = ParallelReducer::max(items).unwrap();
538 assert_eq!(max_result, Some(9));
539 }
540
541 #[test]
542 fn test_parallel_reducer_empty() {
543 let items: Vec<i32> = vec![];
544
545 let min_result = ParallelReducer::min(items.clone()).unwrap();
546 assert_eq!(min_result, None);
547
548 let max_result = ParallelReducer::max(items).unwrap();
549 assert_eq!(max_result, None);
550 }
551
552 #[test]
553 fn test_thread_pool_with_cpu_cores() {
554 let pool = ThreadPool::with_cpu_cores().unwrap();
555 assert!(pool.thread_count() > 0);
556 assert!(pool.thread_count() <= num_cpus::get());
557 }
558}