1use std::{
5 sync::{mpsc, Arc, Mutex},
6 thread,
7};
8
9type Job = Box<dyn FnOnce() + Send + 'static>;
10
11pub struct ThreadPool {
12 workers: Vec<Worker>,
13 sender: Option<mpsc::Sender<Job>>,
14}
15
16struct Worker {
17 id: usize,
18 thread: Option<thread::JoinHandle<()>>,
19}
20
21impl Drop for ThreadPool {
22 fn drop(&mut self) {
23 drop(self.sender.take());
24
25 for worker in &mut self.workers {
26 if log::log_enabled!(log::Level::Trace) {
27 log::trace!("Shutting down worker {}", worker.id);
28 }
29 if let Some(thread) = worker.thread.take() {
30 thread.join().unwrap();
31 }
32 }
33 }
34}
35
36impl ThreadPool {
37 pub fn new(size: usize) -> ThreadPool {
45 assert!(size > 0);
46
47 let (sender, receiver) = mpsc::channel();
48
49 let receiver = Arc::new(Mutex::new(receiver));
50
51 let mut workers = Vec::with_capacity(size);
52
53 for id in 0..size {
54 workers.push(Worker::new(id, receiver.clone()));
55 }
56
57 ThreadPool {
58 workers,
59 sender: Some(sender),
60 }
61 }
62
63 pub fn execute<F>(&self, f: F)
64 where
65 F: FnOnce() + Send + 'static,
66 {
67 let job = Box::new(f);
68
69 self.sender.as_ref().unwrap().send(job).unwrap();
70 }
71
72 pub fn wait_for_completion(&self) {
74 todo!()
75 }
76}
77
78impl Worker {
79 fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
80 let thread = thread::spawn(move || loop {
81 let message = receiver.lock().unwrap().recv();
82
83 match message {
84 Ok(job) => {
85 if log::log_enabled!(log::Level::Trace) {
86 log::trace!("Worker {id} got a job; executing.");
87 }
88
89 job();
90 }
91 Err(_) => {
92 if log::log_enabled!(log::Level::Trace) {
93 log::trace!("Worker {id} disconnected; shutting down.");
94 }
95 break;
96 }
97 }
98 });
99
100 Worker {
101 id,
102 thread: Some(thread),
103 }
104 }
105}
106
107#[cfg(test)]
108pub mod test {
109
110 use std::time::Duration;
111
112 use super::*;
113
114 #[test]
115 fn thread_pool_test() {
116 use std::sync::atomic::{AtomicU64, Ordering};
117
118 let total = Arc::new(AtomicU64::new(0));
119
120 {
122 let pool = ThreadPool::new(4);
123 let task = |n: u64| {
124 thread::sleep(Duration::from_millis(20));
125 n * n
126 };
127
128 for n in 0..100 {
129 let total_clone = total.clone();
130 pool.execute(move || {
131 let product = task(n);
132 total_clone.fetch_add(product, Ordering::SeqCst);
133 });
134 }
135 }
136
137 assert_eq!(total.load(Ordering::SeqCst), 328350);
142 }
143}