1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use super::*;
use crossbeam::sync::WaitGroup;

/// Thin wrapper of rayon::ThreadPool which allows the use of [`ThreadPool::shutdown`] methpd
pub struct RayonThreadPool {
    inner: rayon::ThreadPool,
    wg: WaitGroup,
}

impl ThreadPool for RayonThreadPool {
    /// Create a new thread pool with the given number of threads
    ///
    /// # Examples
    ///
    /// ```
    /// use lib_wc::executors::{RayonThreadPool, ThreadPool};
    /// let tp = RayonThreadPool::new(4).unwrap();
    /// ```
    fn new(threads: usize) -> Result<Self> {
        let pool = rayon::ThreadPoolBuilder::new()
            .num_threads(threads)
            .build()
            .map_err(|e| ThreadPoolError::Message(format!("{e}")))?;
        Ok(RayonThreadPool {
            inner: pool,
            wg: WaitGroup::new(),
        })
    }

    /// Spawn a new task on the thread pool
    ///
    /// # Examples
    ///
    /// ```
    ///   use lib_wc::executors::{RayonThreadPool, ThreadPool};
    ///
    ///   let tp = RayonThreadPool::new(4).unwrap();
    ///
    ///   tp.spawn(|| {
    ///     println!("Hello from a thread!");
    ///   });
    /// ```
    fn spawn<F>(&self, job: F)
    where
        F: FnOnce() + Send + 'static,
    {
        // wg is used to signal that pending work has been finished
        let wg = self.wg.clone();
        self.inner.spawn(move || {
            job();
            drop(wg);
        });
    }

    /// Wait for all currently running tasks to complete
    ///
    /// # Examples
    ///
    /// ```
    ///   use lib_wc::executors::{RayonThreadPool, ThreadPool};    
    ///   use std::sync::atomic::{AtomicUsize, Ordering};
    ///
    ///   static ATOMIC_COUNTER: AtomicUsize = AtomicUsize::new(0);
    ///
    ///   let tp = RayonThreadPool::new(4).unwrap();
    ///
    ///   for _ in 0..100 {
    ///     tp.spawn(|| {
    ///       ATOMIC_COUNTER.fetch_add(1, Ordering::Acquire);
    ///     });
    ///   }
    ///
    ///   tp.shutdown();
    ///
    ///   assert_eq!(ATOMIC_COUNTER.load(Ordering::Relaxed), 100);
    ///
    /// ```
    fn shutdown(self) {
        self.wg.wait();
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::atomic::{AtomicUsize, Ordering};
    use std::sync::Arc;

    #[test]
    fn test_rayon_thread_pool() {
        let pool = RayonThreadPool::new(4).unwrap();
        let counter = Arc::new(AtomicUsize::new(0));
        for _ in 0..100 {
            let counter = counter.clone();
            pool.spawn(move || {
                counter.fetch_add(1, Ordering::Acquire);
            });
        }
        pool.shutdown();
        assert_eq!(counter.load(Ordering::Relaxed), 100);
    }

    #[test]
    fn test_wait_no_jobs() {
        let pool = RayonThreadPool::new(4).unwrap();
        pool.shutdown();
    }
}