1pub struct ThreadPool {
2 inner: threadpool::ThreadPool,
3 pub use_num_threads: usize,
5}
6
7impl ThreadPool {
8 pub fn with_num_threads(num_threads: usize) -> ThreadPool {
9 let inner = threadpool::Builder::new().num_threads(num_threads).build();
10
11 ThreadPool {
12 inner,
13 use_num_threads: num_threads,
14 }
15 }
16
17 pub fn new() -> ThreadPool {
18 let inner = threadpool::Builder::new().build();
19
20 ThreadPool {
21 use_num_threads: inner.max_count(),
22 inner,
23 }
24 }
25
26 pub fn max_count(&self) -> usize {
28 self.inner.max_count()
29 }
30
31 pub fn broadcast<F>(&self, start: usize, end: usize, f: F)
34 where
35 F: FnOnce(usize) + Send + 'static + Copy,
36 {
37 let num_threads = self.use_num_threads;
38
39 let len = end - start;
40
41 let mut left = len % num_threads;
42 let main = len - left;
43
44 let job_size = main / num_threads;
45
46 let mut prev_end = 0;
47 for _ in 0..num_threads {
48 let mut now_end = prev_end + job_size;
49 if left > 0 {
50 now_end += 1;
51 left -= 1;
52 }
53
54 self.inner.execute(move || {
55 for j in prev_end..now_end {
56 f(start + j);
57 }
58 });
59
60 prev_end = now_end;
61 }
62
63 self.inner.join();
64 }
65}
66
67impl Default for ThreadPool {
68 fn default() -> ThreadPool {
69 ThreadPool::new()
70 }
71}
72
73#[derive(Clone, Copy)]
74pub struct UnsafePtr<T>(pub *const T);
75
76#[derive(Clone, Copy)]
77pub struct UnsafeMutPtr<T>(pub *mut T);
78
79unsafe impl<T> Send for UnsafePtr<T> {}
80unsafe impl<T> Send for UnsafeMutPtr<T> {}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn test_vec_add() {
88 const LEN: usize = 1131;
89
90 let mut a = Vec::with_capacity(LEN);
91 let mut b = Vec::with_capacity(LEN);
92 let mut c = vec![0.; LEN];
93
94 let mut val = 0.0;
95
96 for _ in 0..LEN {
97 a.push(val);
98 b.push(1131. - val);
99 val += 1.;
100 }
101
102 let threadpool = ThreadPool::default();
103
104 unsafe {
105 let aptr = UnsafePtr(a.as_ptr());
106 let bptr = UnsafePtr(b.as_ptr());
107 let cptr = UnsafeMutPtr(c.as_mut_ptr());
108
109 threadpool.broadcast(0, LEN, move |i| {
110 *cptr.0.add(i) = *aptr.0.add(i) + *bptr.0.add(i);
111 });
112 }
113
114 for i in 0..LEN {
115 assert_eq!(c[i], a[i] + b[i]);
116 }
117 }
118}