sklears_utils/
parallel.rs

1//! Parallel computing utilities for machine learning workloads
2//!
3//! This module provides utilities for parallel processing, including thread pool
4//! management, work-stealing algorithms, and parallel iterator utilities.
5
6use crate::{UtilsError, UtilsResult};
7use scirs2_core::numeric::Zero;
8use std::collections::VecDeque;
9use std::sync::{Arc, Mutex};
10use std::thread;
11#[allow(non_snake_case)]
12#[cfg(test)]
13use std::time::Duration;
14
15/// Thread pool for parallel task execution
16#[derive(Debug)]
17pub struct ThreadPool {
18    workers: Vec<Worker>,
19    sender: Option<std::sync::mpsc::Sender<Job>>,
20    num_threads: usize,
21}
22
23type Job = Box<dyn FnOnce() + Send + 'static>;
24
25impl ThreadPool {
26    /// Create a new thread pool with the specified number of threads
27    pub fn new(num_threads: usize) -> UtilsResult<Self> {
28        if num_threads == 0 {
29            return Err(UtilsError::InvalidParameter(
30                "Thread pool size must be greater than 0".to_string(),
31            ));
32        }
33
34        let (sender, receiver) = std::sync::mpsc::channel();
35        let receiver = Arc::new(Mutex::new(receiver));
36        let mut workers = Vec::with_capacity(num_threads);
37
38        for id in 0..num_threads {
39            workers.push(Worker::new(id, Arc::clone(&receiver))?);
40        }
41
42        Ok(ThreadPool {
43            workers,
44            sender: Some(sender),
45            num_threads,
46        })
47    }
48
49    /// Create a thread pool with number of threads equal to CPU cores
50    pub fn with_cpu_cores() -> UtilsResult<Self> {
51        let num_cores = num_cpus::get();
52        Self::new(num_cores)
53    }
54
55    /// Submit a job to the thread pool
56    pub fn execute<F>(&self, f: F) -> UtilsResult<()>
57    where
58        F: FnOnce() + Send + 'static,
59    {
60        let job = Box::new(f);
61        self.sender
62            .as_ref()
63            .ok_or_else(|| {
64                UtilsError::InvalidParameter("Thread pool is shutting down".to_string())
65            })?
66            .send(job)
67            .map_err(|_| {
68                UtilsError::InvalidParameter("Failed to send job to thread pool".to_string())
69            })?;
70        Ok(())
71    }
72
73    /// Get the number of worker threads
74    pub fn thread_count(&self) -> usize {
75        self.num_threads
76    }
77
78    /// Wait for all current jobs to complete
79    pub fn join(&mut self) {
80        drop(self.sender.take());
81        for worker in &mut self.workers {
82            if let Some(thread) = worker.thread.take() {
83                thread.join().unwrap();
84            }
85        }
86    }
87}
88
89impl Drop for ThreadPool {
90    fn drop(&mut self) {
91        drop(self.sender.take());
92        for worker in &mut self.workers {
93            if let Some(thread) = worker.thread.take() {
94                thread.join().unwrap();
95            }
96        }
97    }
98}
99
100#[derive(Debug)]
101struct Worker {
102    #[allow(dead_code)]
103    id: usize,
104    thread: Option<thread::JoinHandle<()>>,
105}
106
107impl Worker {
108    fn new(id: usize, receiver: Arc<Mutex<std::sync::mpsc::Receiver<Job>>>) -> UtilsResult<Self> {
109        let thread = thread::spawn(move || loop {
110            let job = receiver.lock().unwrap().recv();
111            match job {
112                Ok(job) => {
113                    job();
114                }
115                Err(_) => {
116                    break;
117                }
118            }
119        });
120
121        Ok(Worker {
122            id,
123            thread: Some(thread),
124        })
125    }
126}
127
128/// Work-stealing deque for load balancing
129#[derive(Debug)]
130pub struct WorkStealingQueue<T> {
131    local_queue: Arc<Mutex<VecDeque<T>>>,
132    global_queue: Arc<Mutex<VecDeque<T>>>,
133    workers: Vec<Arc<Mutex<VecDeque<T>>>>,
134    worker_id: usize,
135}
136
137impl<T> WorkStealingQueue<T>
138where
139    T: Send + 'static + Clone,
140{
141    /// Create a new work-stealing queue system
142    pub fn new(num_workers: usize) -> Self {
143        let global_queue = Arc::new(Mutex::new(VecDeque::new()));
144        let mut workers = Vec::with_capacity(num_workers);
145
146        for _ in 0..num_workers {
147            workers.push(Arc::new(Mutex::new(VecDeque::new())));
148        }
149
150        Self {
151            local_queue: Arc::clone(&workers[0]),
152            global_queue,
153            workers,
154            worker_id: 0,
155        }
156    }
157
158    /// Push a task to the local queue
159    pub fn push_local(&self, task: T) -> UtilsResult<()> {
160        self.local_queue
161            .lock()
162            .map_err(|_| {
163                UtilsError::InvalidParameter("Failed to acquire local queue lock".to_string())
164            })?
165            .push_back(task);
166        Ok(())
167    }
168
169    /// Push a task to the global queue
170    pub fn push_global(&self, task: T) -> UtilsResult<()> {
171        self.global_queue
172            .lock()
173            .map_err(|_| {
174                UtilsError::InvalidParameter("Failed to acquire global queue lock".to_string())
175            })?
176            .push_back(task);
177        Ok(())
178    }
179
180    /// Pop a task from the local queue
181    pub fn pop_local(&self) -> UtilsResult<Option<T>> {
182        Ok(self
183            .local_queue
184            .lock()
185            .map_err(|_| {
186                UtilsError::InvalidParameter("Failed to acquire local queue lock".to_string())
187            })?
188            .pop_front())
189    }
190
191    /// Steal work from other workers' queues
192    pub fn steal_work(&self) -> UtilsResult<Option<T>> {
193        // Try to steal from other workers' queues
194        for (i, worker_queue) in self.workers.iter().enumerate() {
195            if i != self.worker_id {
196                if let Ok(mut queue) = worker_queue.try_lock() {
197                    if let Some(task) = queue.pop_back() {
198                        return Ok(Some(task));
199                    }
200                }
201            }
202        }
203
204        // If no work was stolen, try the global queue
205        if let Ok(mut global) = self.global_queue.try_lock() {
206            return Ok(global.pop_front());
207        }
208
209        Ok(None)
210    }
211
212    /// Get the next task, trying local queue first, then stealing
213    pub fn get_task(&self) -> UtilsResult<Option<T>> {
214        // Try local queue first
215        if let Some(task) = self.pop_local()? {
216            return Ok(Some(task));
217        }
218
219        // If local queue is empty, try stealing
220        self.steal_work()
221    }
222}
223
224/// Parallel iterator utilities
225pub struct ParallelIterator<T> {
226    items: Vec<T>,
227    chunk_size: usize,
228}
229
230impl<T> ParallelIterator<T>
231where
232    T: Send + 'static + Clone,
233{
234    /// Create a new parallel iterator
235    pub fn new(items: Vec<T>) -> Self {
236        let chunk_size = (items.len() / num_cpus::get()).max(1);
237        Self { items, chunk_size }
238    }
239
240    /// Set the chunk size for parallel processing
241    pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
242        self.chunk_size = chunk_size.max(1);
243        self
244    }
245
246    /// Map function over items in parallel
247    pub fn map<F, R>(self, f: F) -> UtilsResult<Vec<R>>
248    where
249        F: Fn(T) -> R + Send + Sync + 'static,
250        R: Send + 'static + Clone,
251    {
252        let f = Arc::new(f);
253        let results = Arc::new(Mutex::new(Vec::with_capacity(self.items.len())));
254        let _thread_pool = ThreadPool::with_cpu_cores()?;
255
256        // Split items into chunks
257        let chunks: Vec<_> = self
258            .items
259            .into_iter()
260            .collect::<Vec<_>>()
261            .chunks(self.chunk_size)
262            .map(|chunk| chunk.to_vec())
263            .collect();
264
265        let mut handles = Vec::new();
266
267        for (chunk_idx, chunk) in chunks.into_iter().enumerate() {
268            let f_clone = Arc::clone(&f);
269            let results_clone = Arc::clone(&results);
270            let chunk_size = chunk.len();
271
272            let handle = thread::spawn(move || {
273                let mut chunk_results = Vec::with_capacity(chunk_size);
274                for item in chunk {
275                    chunk_results.push(f_clone(item));
276                }
277
278                let mut results_lock = results_clone.lock().unwrap();
279                // Ensure we have enough space
280                if results_lock.len() <= chunk_idx {
281                    results_lock.resize_with(chunk_idx + 1, || Vec::new());
282                }
283                results_lock[chunk_idx] = chunk_results;
284            });
285
286            handles.push(handle);
287        }
288
289        // Wait for all threads to complete
290        for handle in handles {
291            handle.join().map_err(|_| {
292                UtilsError::InvalidParameter(
293                    "Thread panicked during parallel execution".to_string(),
294                )
295            })?;
296        }
297
298        // Collect results in order
299        let results_lock = results.lock().unwrap();
300        let mut final_results = Vec::new();
301        for chunk_results in results_lock.iter() {
302            final_results.extend_from_slice(chunk_results);
303        }
304
305        Ok(final_results)
306    }
307
308    /// Filter items in parallel
309    pub fn filter<F>(self, predicate: F) -> UtilsResult<Vec<T>>
310    where
311        F: Fn(&T) -> bool + Send + Sync + 'static,
312        T: Clone,
313    {
314        let predicate = Arc::new(predicate);
315        let results = Arc::new(Mutex::new(Vec::new()));
316        let _thread_pool = ThreadPool::with_cpu_cores()?;
317
318        // Split items into chunks
319        let chunks: Vec<_> = self
320            .items
321            .into_iter()
322            .collect::<Vec<_>>()
323            .chunks(self.chunk_size)
324            .map(|chunk| chunk.to_vec())
325            .collect();
326
327        let mut handles = Vec::new();
328
329        for (chunk_idx, chunk) in chunks.into_iter().enumerate() {
330            let predicate_clone = Arc::clone(&predicate);
331            let results_clone = Arc::clone(&results);
332
333            let handle = thread::spawn(move || {
334                let filtered: Vec<T> = chunk
335                    .into_iter()
336                    .filter(|item| predicate_clone(item))
337                    .collect();
338
339                let mut results_lock = results_clone.lock().unwrap();
340                // Ensure we have enough space
341                if results_lock.len() <= chunk_idx {
342                    results_lock.resize_with(chunk_idx + 1, || Vec::new());
343                }
344                results_lock[chunk_idx] = filtered;
345            });
346
347            handles.push(handle);
348        }
349
350        // Wait for all threads to complete
351        for handle in handles {
352            handle.join().map_err(|_| {
353                UtilsError::InvalidParameter(
354                    "Thread panicked during parallel execution".to_string(),
355                )
356            })?;
357        }
358
359        // Collect results in order
360        let results_lock = results.lock().unwrap();
361        let mut final_results = Vec::new();
362        for chunk_results in results_lock.iter() {
363            final_results.extend_from_slice(chunk_results);
364        }
365
366        Ok(final_results)
367    }
368}
369
370/// Parallel reduction operations
371pub struct ParallelReducer;
372
373impl ParallelReducer {
374    /// Reduce a vector in parallel using the given operation
375    pub fn reduce<T, F>(items: Vec<T>, initial: T, op: F) -> UtilsResult<T>
376    where
377        T: Send + Sync + Clone + 'static,
378        F: Fn(T, T) -> T + Send + Sync + 'static,
379    {
380        if items.is_empty() {
381            return Ok(initial);
382        }
383
384        let op = Arc::new(op);
385        let chunk_size = (items.len() / num_cpus::get()).max(1);
386
387        // Split items into chunks
388        let chunks: Vec<_> = items
389            .chunks(chunk_size)
390            .map(|chunk| chunk.to_vec())
391            .collect();
392
393        let mut handles = Vec::new();
394        let mut partial_results = Vec::new();
395
396        for chunk in chunks.into_iter() {
397            let op_clone = Arc::clone(&op);
398            let initial_clone = initial.clone();
399
400            let handle = thread::spawn(move || {
401                chunk
402                    .into_iter()
403                    .fold(initial_clone, |acc, item| op_clone(acc, item))
404            });
405
406            handles.push(handle);
407        }
408
409        // Collect partial results
410        for handle in handles {
411            let result = handle.join().map_err(|_| {
412                UtilsError::InvalidParameter(
413                    "Thread panicked during parallel reduction".to_string(),
414                )
415            })?;
416            partial_results.push(result);
417        }
418
419        // Reduce partial results
420        Ok(partial_results
421            .into_iter()
422            .fold(initial, |acc, partial| op(acc, partial)))
423    }
424
425    /// Sum elements in parallel
426    pub fn sum<T>(items: Vec<T>) -> UtilsResult<T>
427    where
428        T: Send + Sync + Clone + std::ops::Add<Output = T> + Zero + 'static,
429    {
430        Self::reduce(items, T::zero(), |a, b| a + b)
431    }
432
433    /// Find minimum element in parallel
434    pub fn min<T>(items: Vec<T>) -> UtilsResult<Option<T>>
435    where
436        T: Send + Sync + Clone + Ord + 'static,
437    {
438        if items.is_empty() {
439            return Ok(None);
440        }
441
442        let first = items[0].clone();
443        let result = Self::reduce(items, first, |a, b| if a < b { a } else { b })?;
444        Ok(Some(result))
445    }
446
447    /// Find maximum element in parallel
448    pub fn max<T>(items: Vec<T>) -> UtilsResult<Option<T>>
449    where
450        T: Send + Sync + Clone + Ord + 'static,
451    {
452        if items.is_empty() {
453            return Ok(None);
454        }
455
456        let first = items[0].clone();
457        let result = Self::reduce(items, first, |a, b| if a > b { a } else { b })?;
458        Ok(Some(result))
459    }
460}
461
462#[allow(non_snake_case)]
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use std::sync::atomic::{AtomicUsize, Ordering};
467
468    #[test]
469    fn test_thread_pool_creation() {
470        let pool = ThreadPool::new(4).unwrap();
471        assert_eq!(pool.thread_count(), 4);
472    }
473
474    #[test]
475    fn test_thread_pool_execution() {
476        let pool = ThreadPool::new(2).unwrap();
477        let counter = Arc::new(AtomicUsize::new(0));
478
479        for _ in 0..10 {
480            let counter_clone = Arc::clone(&counter);
481            pool.execute(move || {
482                counter_clone.fetch_add(1, Ordering::SeqCst);
483            })
484            .unwrap();
485        }
486
487        // Give threads time to complete
488        thread::sleep(Duration::from_millis(100));
489
490        assert_eq!(counter.load(Ordering::SeqCst), 10);
491    }
492
493    #[test]
494    fn test_work_stealing_queue() {
495        let queue = WorkStealingQueue::new(4);
496
497        queue.push_local(42).unwrap();
498        queue.push_global(24).unwrap();
499
500        assert_eq!(queue.get_task().unwrap(), Some(42));
501        assert_eq!(queue.get_task().unwrap(), Some(24));
502        assert_eq!(queue.get_task().unwrap(), None);
503    }
504
505    #[test]
506    fn test_parallel_iterator_map() {
507        let items = vec![1, 2, 3, 4, 5];
508        let iter = ParallelIterator::new(items);
509
510        let results = iter.map(|x| x * 2).unwrap();
511        assert_eq!(results, vec![2, 4, 6, 8, 10]);
512    }
513
514    #[test]
515    fn test_parallel_iterator_filter() {
516        let items = vec![1, 2, 3, 4, 5, 6];
517        let iter = ParallelIterator::new(items);
518
519        let results = iter.filter(|&x| x % 2 == 0).unwrap();
520        assert_eq!(results, vec![2, 4, 6]);
521    }
522
523    #[test]
524    fn test_parallel_reducer_sum() {
525        let items = vec![1, 2, 3, 4, 5];
526        let result = ParallelReducer::sum(items).unwrap();
527        assert_eq!(result, 15);
528    }
529
530    #[test]
531    fn test_parallel_reducer_min_max() {
532        let items = vec![5, 2, 8, 1, 9, 3];
533
534        let min_result = ParallelReducer::min(items.clone()).unwrap();
535        assert_eq!(min_result, Some(1));
536
537        let max_result = ParallelReducer::max(items).unwrap();
538        assert_eq!(max_result, Some(9));
539    }
540
541    #[test]
542    fn test_parallel_reducer_empty() {
543        let items: Vec<i32> = vec![];
544
545        let min_result = ParallelReducer::min(items.clone()).unwrap();
546        assert_eq!(min_result, None);
547
548        let max_result = ParallelReducer::max(items).unwrap();
549        assert_eq!(max_result, None);
550    }
551
552    #[test]
553    fn test_thread_pool_with_cpu_cores() {
554        let pool = ThreadPool::with_cpu_cores().unwrap();
555        assert!(pool.thread_count() > 0);
556        assert!(pool.thread_count() <= num_cpus::get());
557    }
558}