1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
//! Thread Pool
//! The implementation is taken from the [book](https://doc.rust-lang.org/book/ch20-02-multithreaded.html)

use std::{
    sync::{mpsc, Arc, Mutex},
    thread,
};

type Job = Box<dyn FnOnce() + Send + 'static>;

pub struct ThreadPool {
    workers: Vec<Worker>,
    sender: Option<mpsc::Sender<Job>>,
}

struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        drop(self.sender.take());

        for worker in &mut self.workers {
            if log::log_enabled!(log::Level::Trace) {
                log::trace!("Shutting down worker {}", worker.id);
            }
            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

impl ThreadPool {
    /// Create a new ThreadPool.
    ///
    /// The size is the number of threads in the pool.
    ///
    /// # Panics
    ///
    /// The `new` function will panic if the size is zero.
    pub fn new(size: usize) -> ThreadPool {
        assert!(size > 0);

        let (sender, receiver) = mpsc::channel();

        let receiver = Arc::new(Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);

        for id in 0..size {
            workers.push(Worker::new(id, receiver.clone()));
        }

        ThreadPool {
            workers,
            sender: Some(sender),
        }
    }

    pub fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);

        self.sender.as_ref().unwrap().send(job).unwrap();
    }

    /// blocks the executor and waits for the completion of active jobs
    pub fn wait_for_completion(&self) {
        todo!()
    }
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || loop {
            let message = receiver.lock().unwrap().recv();

            match message {
                Ok(job) => {
                    if log::log_enabled!(log::Level::Trace) {
                        log::trace!("Worker {id} got a job; executing.");
                    }

                    job();
                }
                Err(_) => {
                    if log::log_enabled!(log::Level::Trace) {
                        log::trace!("Worker {id} disconnected; shutting down.");
                    }
                    break;
                }
            }
        });

        Worker {
            id,
            thread: Some(thread),
        }
    }
}

#[cfg(test)]
pub mod test {

    use std::time::Duration;

    use super::*;

    #[test]
    fn thread_pool_test() {
        use std::sync::atomic::{AtomicU64, Ordering};

        let total = Arc::new(AtomicU64::new(0));

        // need a scope to drop the pool and join threads
        {
            let pool = ThreadPool::new(4);
            let task = |n: u64| {
                thread::sleep(Duration::from_millis(20));
                n * n
            };

            for n in 0..100 {
                let total_clone = total.clone();
                pool.execute(move || {
                    let product = task(n);
                    total_clone.fetch_add(product, Ordering::SeqCst);
                });
            }
        }
        
        // wait for executed threads complete
        // pool.wait_for_completion();
        // drop(pool);

        assert_eq!(total.load(Ordering::SeqCst), 328350);
    }
}