1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
pub struct ThreadPool {
inner: threadpool::ThreadPool,
pub use_num_threads: usize,
}
impl ThreadPool {
pub fn with_num_threads(num_threads: usize) -> ThreadPool {
let inner = threadpool::Builder::new().num_threads(num_threads).build();
ThreadPool {
inner,
use_num_threads: num_threads,
}
}
pub fn new() -> ThreadPool {
let inner = threadpool::Builder::new().build();
ThreadPool {
use_num_threads: inner.max_count(),
inner,
}
}
pub fn max_count(&self) -> usize {
self.inner.max_count()
}
pub fn broadcast<F>(&self, start: usize, end: usize, f: F)
where
F: FnOnce(usize) + Send + 'static + Copy,
{
let num_threads = self.use_num_threads;
let len = end - start;
let mut left = len % num_threads;
let main = len - left;
let job_size = main / num_threads;
let mut prev_end = 0;
for _ in 0..num_threads {
let mut now_end = prev_end + job_size;
if left > 0 {
now_end += 1;
left -= 1;
}
self.inner.execute(move || {
for j in prev_end..now_end {
f(start + j);
}
});
prev_end = now_end;
}
self.inner.join();
}
}
impl Default for ThreadPool {
fn default() -> ThreadPool {
ThreadPool::new()
}
}
#[derive(Clone, Copy)]
pub struct UnsafePtr<T>(pub *const T);
#[derive(Clone, Copy)]
pub struct UnsafeMutPtr<T>(pub *mut T);
unsafe impl<T> Send for UnsafePtr<T> {}
unsafe impl<T> Send for UnsafeMutPtr<T> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vec_add() {
const LEN: usize = 1131;
let mut a = Vec::with_capacity(LEN);
let mut b = Vec::with_capacity(LEN);
let mut c = vec![0.; LEN];
let mut val = 0.0;
for _ in 0..LEN {
a.push(val);
b.push(1131. - val);
val += 1.;
}
let threadpool = ThreadPool::default();
unsafe {
let aptr = UnsafePtr(a.as_ptr());
let bptr = UnsafePtr(b.as_ptr());
let cptr = UnsafeMutPtr(c.as_mut_ptr());
threadpool.broadcast(0, LEN, move |i| {
*cptr.0.add(i) = *aptr.0.add(i) + *bptr.0.add(i);
});
}
for i in 0..LEN {
assert_eq!(c[i], a[i] + b[i]);
}
}
}