1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
pub struct ThreadPool {
    inner: threadpool::ThreadPool,
    /// number of threads to use. Initialized to number of available threads in the threadpool
    pub use_num_threads: usize,
}

impl ThreadPool {
    pub fn with_num_threads(num_threads: usize) -> ThreadPool {
        let inner = threadpool::Builder::new().num_threads(num_threads).build();

        ThreadPool {
            inner,
            use_num_threads: num_threads,
        }
    }

    pub fn new() -> ThreadPool {
        let inner = threadpool::Builder::new().build();

        ThreadPool {
            use_num_threads: inner.max_count(),
            inner,
        }
    }

    /// returns the number of threads in the threadpool
    pub fn max_count(&self) -> usize {
        self.inner.max_count()
    }

    /// this function broadcasts consecutive parts of given range to all threads
    /// and waits for all threads to finish before returning.
    pub fn broadcast<F>(&self, start: usize, end: usize, f: F)
    where
        F: FnOnce(usize) + Send + 'static + Copy,
    {
        let num_threads = self.use_num_threads;

        let len = end - start;

        let mut left = len % num_threads;
        let main = len - left;

        let job_size = main / num_threads;

        let mut prev_end = 0;
        for _ in 0..num_threads {
            let mut now_end = prev_end + job_size;
            if left > 0 {
                now_end += 1;
                left -= 1;
            }

            self.inner.execute(move || {
                for j in prev_end..now_end {
                    f(start + j);
                }
            });

            prev_end = now_end;
        }

        self.inner.join();
    }
}

impl Default for ThreadPool {
    fn default() -> ThreadPool {
        ThreadPool::new()
    }
}

#[derive(Clone, Copy)]
pub struct UnsafePtr<T>(pub *const T);

#[derive(Clone, Copy)]
pub struct UnsafeMutPtr<T>(pub *mut T);

unsafe impl<T> Send for UnsafePtr<T> {}
unsafe impl<T> Send for UnsafeMutPtr<T> {}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_vec_add() {
        const LEN: usize = 1131;

        let mut a = Vec::with_capacity(LEN);
        let mut b = Vec::with_capacity(LEN);
        let mut c = vec![0.; LEN];

        let mut val = 0.0;

        for _ in 0..LEN {
            a.push(val);
            b.push(1131. - val);
            val += 1.;
        }

        let threadpool = ThreadPool::default();

        unsafe {
            let aptr = UnsafePtr(a.as_ptr());
            let bptr = UnsafePtr(b.as_ptr());
            let cptr = UnsafeMutPtr(c.as_mut_ptr());

            threadpool.broadcast(0, LEN, move |i| {
                *cptr.0.add(i) = *aptr.0.add(i) + *bptr.0.add(i);
            });
        }

        for i in 0..LEN {
        	assert_eq!(c[i], a[i] + b[i]);
        }
    }
}