blasoxide_mt/
lib.rs

1pub struct ThreadPool {
2    inner: threadpool::ThreadPool,
3    /// number of threads to use. Initialized to number of available threads in the threadpool
4    pub use_num_threads: usize,
5}
6
7impl ThreadPool {
8    pub fn with_num_threads(num_threads: usize) -> ThreadPool {
9        let inner = threadpool::Builder::new().num_threads(num_threads).build();
10
11        ThreadPool {
12            inner,
13            use_num_threads: num_threads,
14        }
15    }
16
17    pub fn new() -> ThreadPool {
18        let inner = threadpool::Builder::new().build();
19
20        ThreadPool {
21            use_num_threads: inner.max_count(),
22            inner,
23        }
24    }
25
26    /// returns the number of threads in the threadpool
27    pub fn max_count(&self) -> usize {
28        self.inner.max_count()
29    }
30
31    /// this function broadcasts consecutive parts of given range to all threads
32    /// and waits for all threads to finish before returning.
33    pub fn broadcast<F>(&self, start: usize, end: usize, f: F)
34    where
35        F: FnOnce(usize) + Send + 'static + Copy,
36    {
37        let num_threads = self.use_num_threads;
38
39        let len = end - start;
40
41        let mut left = len % num_threads;
42        let main = len - left;
43
44        let job_size = main / num_threads;
45
46        let mut prev_end = 0;
47        for _ in 0..num_threads {
48            let mut now_end = prev_end + job_size;
49            if left > 0 {
50                now_end += 1;
51                left -= 1;
52            }
53
54            self.inner.execute(move || {
55                for j in prev_end..now_end {
56                    f(start + j);
57                }
58            });
59
60            prev_end = now_end;
61        }
62
63        self.inner.join();
64    }
65}
66
67impl Default for ThreadPool {
68    fn default() -> ThreadPool {
69        ThreadPool::new()
70    }
71}
72
73#[derive(Clone, Copy)]
74pub struct UnsafePtr<T>(pub *const T);
75
76#[derive(Clone, Copy)]
77pub struct UnsafeMutPtr<T>(pub *mut T);
78
79unsafe impl<T> Send for UnsafePtr<T> {}
80unsafe impl<T> Send for UnsafeMutPtr<T> {}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn test_vec_add() {
88        const LEN: usize = 1131;
89
90        let mut a = Vec::with_capacity(LEN);
91        let mut b = Vec::with_capacity(LEN);
92        let mut c = vec![0.; LEN];
93
94        let mut val = 0.0;
95
96        for _ in 0..LEN {
97            a.push(val);
98            b.push(1131. - val);
99            val += 1.;
100        }
101
102        let threadpool = ThreadPool::default();
103
104        unsafe {
105            let aptr = UnsafePtr(a.as_ptr());
106            let bptr = UnsafePtr(b.as_ptr());
107            let cptr = UnsafeMutPtr(c.as_mut_ptr());
108
109            threadpool.broadcast(0, LEN, move |i| {
110                *cptr.0.add(i) = *aptr.0.add(i) + *bptr.0.add(i);
111            });
112        }
113
114        for i in 0..LEN {
115        	assert_eq!(c[i], a[i] + b[i]);
116        }
117    }
118}