1use std::{
2 num::NonZeroUsize,
3 thread::{self, JoinHandle},
4};
5
6use crossbeam::channel::{self as mpmc, Receiver, Sender};
7use once_cell::sync::Lazy;
8
9use crate::{
10 error::Error,
11 sink::{OverflowPolicy, Task},
12 sync::*,
13 Result,
14};
15
16pub struct ThreadPool(ArcSwapOption<ThreadPoolInner>);
38
39struct ThreadPoolInner {
40 threads: Vec<Option<JoinHandle<()>>>,
41 sender: Option<Sender<Task>>,
42}
43
44type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
45
46#[allow(missing_docs)]
47pub struct ThreadPoolBuilder {
48 capacity: NonZeroUsize,
49 threads: NonZeroUsize,
50 on_thread_spawn: Option<Callback>,
51 on_thread_finish: Option<Callback>,
52}
53
54struct Worker {
55 receiver: Receiver<Task>,
56}
57
58impl ThreadPool {
59 #[must_use]
71 pub fn builder() -> ThreadPoolBuilder {
72 ThreadPoolBuilder {
73 capacity: NonZeroUsize::new(8192).unwrap(),
74 threads: NonZeroUsize::new(1).unwrap(),
75 on_thread_spawn: None,
76 on_thread_finish: None,
77 }
78 }
79
80 pub fn new() -> Result<Self> {
83 Self::builder().build()
84 }
85
86 pub(super) fn assign_task(&self, task: Task, overflow_policy: OverflowPolicy) -> Result<()> {
87 let inner = self.0.load();
88 if let Some(inner) = inner.as_ref() {
89 let sender = inner.sender.as_ref().unwrap();
90
91 match overflow_policy {
92 OverflowPolicy::Block => sender.send(task).map_err(Error::from_crossbeam_send),
93 OverflowPolicy::DropIncoming => sender
94 .try_send(task)
95 .map_err(Error::from_crossbeam_try_send),
96 }
97 } else {
98 Ok(())
104 }
105 }
106
107 pub(super) fn destroy(&self) {
108 if let Some(inner) = self.0.swap(None) {
109 if let Some(mut inner) = Arc::into_inner(inner) {
118 inner.sender.take();
121
122 for thread in &mut inner.threads {
123 if let Some(thread) = thread.take() {
124 thread.join().expect("failed to join a thread from pool");
125 }
126 }
127 }
128 }
129 }
130}
131
132impl Drop for ThreadPool {
133 fn drop(&mut self) {
134 self.destroy();
135 }
136}
137
138impl ThreadPoolBuilder {
139 #[must_use]
147 pub fn capacity(&mut self, capacity: NonZeroUsize) -> &mut Self {
148 self.capacity = capacity;
149 self
150 }
151
152 #[must_use]
155 #[allow(dead_code)]
156 fn threads(&mut self, threads: NonZeroUsize) -> &mut Self {
157 self.threads = threads;
158 self
159 }
160
161 #[must_use]
165 pub fn on_thread_spawn<F>(&mut self, f: F) -> &mut Self
166 where
167 F: Fn() + Send + Sync + 'static,
168 {
169 self.on_thread_spawn = Some(Arc::new(f));
170 self
171 }
172
173 #[must_use]
176 pub fn on_thread_finish<F>(&mut self, f: F) -> &mut Self
177 where
178 F: Fn() + Send + Sync + 'static,
179 {
180 self.on_thread_finish = Some(Arc::new(f));
181 self
182 }
183
184 pub fn build(&self) -> Result<ThreadPool> {
186 let (sender, receiver) = mpmc::bounded(self.capacity.get());
187
188 let mut threads = Vec::new();
189 threads.resize_with(self.threads.get(), || {
190 let receiver = receiver.clone();
191 let on_thread_spawn = self.on_thread_spawn.clone();
192 let on_thread_finish = self.on_thread_finish.clone();
193
194 Some(thread::spawn(move || {
195 if let Some(f) = on_thread_spawn {
196 f();
197 }
198
199 Worker { receiver }.run();
200
201 if let Some(f) = on_thread_finish {
202 f();
203 }
204 }))
205 });
206
207 Ok(ThreadPool(ArcSwapOption::new(Some(Arc::new(
208 ThreadPoolInner {
209 threads,
210 sender: Some(sender),
211 },
212 )))))
213 }
214
215 pub fn build_arc(&self) -> Result<Arc<ThreadPool>> {
219 self.build().map(Arc::new)
220 }
221}
222
223impl Worker {
224 fn run(&self) {
225 while let Ok(task) = self.receiver.recv() {
226 task.exec();
227 }
228 }
229}
230
231#[must_use]
232pub(crate) fn default_thread_pool() -> Arc<ThreadPool> {
233 static POOL_WEAK: Lazy<Mutex<Weak<ThreadPool>>> = Lazy::new(|| Mutex::new(Weak::new()));
234
235 let mut pool_weak = POOL_WEAK.lock_expect();
236
237 match pool_weak.upgrade() {
238 Some(pool) => pool,
239 None => {
240 let pool = ThreadPool::builder().build_arc().unwrap();
241 *pool_weak = Arc::downgrade(&pool);
242 pool
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use std::{thread::sleep, time::Duration};
250
251 use super::*;
252
253 #[test]
255 fn inner_arc_multiple_strong_refs() {
256 let thread_pool = ThreadPool::builder()
257 .capacity(1.try_into().unwrap())
258 .build_arc()
259 .unwrap();
260
261 let task = || Task::__ForTestUse {
262 sleep: Some(Duration::from_secs(1)),
263 };
264
265 thread_pool
266 .assign_task(task(), OverflowPolicy::Block)
267 .unwrap();
268
269 let (first_blocked_assign, second_blocked_assign, destroy, third_assign) =
270 std::thread::scope(|s| {
271 let first_blocked_assign = s.spawn({
272 let thread_pool = thread_pool.clone();
273 move || {
274 thread_pool
275 .assign_task(task(), OverflowPolicy::Block)
276 .unwrap();
277 }
278 });
279 let second_blocked_assign = s.spawn({
280 let thread_pool = thread_pool.clone();
281 move || {
282 thread_pool
283 .assign_task(task(), OverflowPolicy::Block)
284 .unwrap();
285 }
286 });
287 sleep(Duration::from_millis(200));
288 let destroy = s.spawn({
289 let thread_pool = thread_pool.clone();
290 move || {
291 thread_pool.destroy();
292 }
293 });
294 let third_assign = s.spawn({
295 let thread_pool = thread_pool.clone();
296 move || {
297 thread_pool
298 .assign_task(task(), OverflowPolicy::Block)
299 .unwrap();
300 }
301 });
302 (
303 first_blocked_assign.join(),
304 second_blocked_assign.join(),
305 destroy.join(),
306 third_assign.join(),
307 )
308 });
309 first_blocked_assign.unwrap();
310 second_blocked_assign.unwrap();
311 destroy.unwrap();
312 third_assign.unwrap();
313 }
314}