1#![warn(missing_docs)]
47#![allow(clippy::style)]
48
49use std::{thread, io, sync};
50use core::{time, fmt, ops, future, pin, task};
51use core::sync::atomic::{Ordering, AtomicUsize, AtomicU16};
52
53mod utils;
54mod spin;
55pub mod oneshot;
56
57#[derive(PartialEq, Eq, Debug)]
58pub enum JoinError {
60 Disconnect,
62 AlreadyConsumed,
66}
67
68#[repr(transparent)]
69pub struct JobHandle<T> {
77 inner: oneshot::Receiver<T>
78}
79
80impl<T> fmt::Debug for JobHandle<T> {
81 #[inline(always)]
82 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83 write!(f, "JobHandle")
84 }
85}
86
87impl<T> JobHandle<T> {
88 #[inline(always)]
89 pub fn is_finished(&self) -> bool {
91 self.inner.is_ready()
92 }
93
94 #[inline]
95 pub fn try_wait(&self) -> Result<Option<T>, JoinError> {
97 self.inner.try_recv()
98 }
99
100 #[inline]
101 pub fn wait(self) -> Result<T, JoinError> {
103 self.inner.recv()
104 }
105
106 #[inline]
107 pub fn wait_timeout(&self, timeout: time::Duration) -> Result<Option<T>, JoinError> {
109 self.inner.recv_timeout(timeout)
110 }
111}
112
113impl<T> future::Future for JobHandle<T> {
114 type Output = Result<T, JoinError>;
115
116 #[inline]
117 fn poll(self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
118 let inner = pin::Pin::new(&mut self.get_mut().inner);
119
120 future::Future::poll(inner, cx)
121 }
122}
123
124enum Message {
125 Execute(Box<dyn FnOnce() + Send + 'static>),
126 Shutdown(Option<oneshot::Sender<()>>),
127}
128
129const _: () = {
131 assert!(core::mem::size_of::<Box<dyn FnOnce() + Send + 'static>>() == 16);
132 assert!(core::mem::size_of::<oneshot::Sender<()>>() == 8);
133 assert!(core::mem::size_of::<Message>() == 16);
134};
135
136struct Receiver<T>(pub sync::mpsc::Receiver<T>);
140
141unsafe impl<T: Send> Sync for Receiver<T> {
142}
143unsafe impl<T: Send> Send for Receiver<T> {
144}
145
146impl<T> ops::Deref for Receiver<T> {
147 type Target = sync::mpsc::Receiver<T>;
148 #[inline(always)]
149 fn deref(&self) -> &Self::Target {
150 &self.0
151 }
152}
153
154struct State {
155 send: sync::mpsc::Sender<Message>,
156 recv: sync::Arc<Receiver<Message>>,
157}
158
159#[derive(Clone)]
160struct ThreadBuilder {
161 idx: u16,
162 name: &'static str,
163 stack_size: usize,
164 receiver: sync::Arc<Receiver<Message>>,
165}
166
167impl ThreadBuilder {
168 pub fn spawn(self) -> Result<thread::JoinHandle<()>, io::Error> {
169 let mut result = thread::Builder::new();
170 if !self.name.is_empty() {
171 result = result.name(format!("{}-{}", self.name, self.idx))
172 }
173 if self.stack_size != 0 {
174 result = result.stack_size(self.stack_size)
175 }
176 let recv = self.receiver.clone();
177
178 let mut guard = ThreadGuard(Some(self));
180 let worker_fn = move || loop {
181 match recv.recv() {
182 Ok(Message::Execute(job)) => {
183 job();
184 },
185 Ok(Message::Shutdown(Some(notifier))) => {
186 guard.0.take();
187 let _ = notifier.send(());
188 break;
189 }
190 Ok(Message::Shutdown(None)) | Err(_) => {
191 guard.0.take();
192 break;
193 },
194 }
195 };
196
197 result.spawn(worker_fn)
198 }
199}
200
201#[repr(transparent)]
202struct ThreadGuard(Option<ThreadBuilder>);
203
204impl Drop for ThreadGuard {
205 fn drop(&mut self) {
206 if thread::panicking() {
207 if let Some(builder) = self.0.take() {
208 let _ = builder.spawn();
210 }
211 }
212 }
213}
214
215pub struct ThreadPool {
235 stack_size: AtomicUsize,
236 thread_num: AtomicU16,
237 thread_num_lock: spin::Lock,
238 name: &'static str,
239 once_state: std::sync::OnceLock<State>,
240}
241
242impl ThreadPool {
243 #[inline(always)]
244 pub const fn new() -> Self {
246 Self::with_defaults("", 0)
247 }
248
249 #[inline(always)]
250 pub const fn with_defaults(name: &'static str, stack_size: usize) -> Self {
252 Self {
253 stack_size: AtomicUsize::new(stack_size),
254 thread_num: AtomicU16::new(0),
255 thread_num_lock: spin::Lock::new(),
256 name,
257 once_state: std::sync::OnceLock::new(),
258 }
259 }
260
261 fn get_state(&self) -> &State {
262 self.once_state.get_or_init(|| {
263 let (send, recv) = sync::mpsc::channel();
264 State {
265 send,
266 recv: sync::Arc::new(Receiver(recv)),
267 }
268 })
269 }
270
271 #[inline]
272 pub fn set_stack_size(&self, stack_size: usize) -> usize {
279 self.stack_size.swap(stack_size, Ordering::AcqRel)
280 }
281
282 pub fn set_threads(&self, thread_num: u16) -> io::Result<u16> {
296 let _guard = self.thread_num_lock.lock();
297 let old_thread_num = self.thread_num.swap(thread_num, Ordering::Relaxed);
298
299 if old_thread_num > thread_num {
300 let state = self.get_state();
301
302 let shutdown_num = old_thread_num.saturating_sub(thread_num);
303 for _ in 0..shutdown_num {
304 if state.send.send(Message::Shutdown(None)).is_err() {
305 break;
306 }
307 }
308
309 } else if thread_num > old_thread_num {
310 let create_num = thread_num.saturating_sub(old_thread_num);
311 let state = self.get_state();
312
313 for num in 0..create_num {
314 let builder = ThreadBuilder {
315 idx: num,
316 stack_size: self.stack_size.load(Ordering::Relaxed),
317 name: self.name,
318 receiver: state.recv.clone(),
319 };
320
321 match builder.spawn() {
322 Ok(_) => (),
323 Err(error) => {
324 self.thread_num.store(old_thread_num.saturating_add(num), Ordering::Relaxed);
325 return Err(error);
326 }
327 }
328 }
329 }
330
331 Ok(old_thread_num)
332 }
333
334 pub fn shutdown(&mut self) {
338 let _guard = self.thread_num_lock.lock();
339 let old_thread_num = self.thread_num.swap(0, Ordering::Relaxed);
340
341 {
342 let state = self.get_state();
343
344 for _ in 0..old_thread_num {
345 if state.send.send(Message::Shutdown(None)).is_err() {
346 break;
347 }
348 }
349 }
350
351 let _ = self.once_state.take();
353 }
354
355 pub fn shutdown_and_join(&mut self) {
359 let _guard = self.thread_num_lock.lock();
360 let old_thread_num = self.thread_num.swap(0, Ordering::Relaxed);
361
362 let mut joiners = Vec::new();
363 {
364 let state = self.get_state();
365
366 for _ in 0..old_thread_num {
367 let (sender, receiver) = oneshot::oneshot();
368 if state.send.send(Message::Shutdown(Some(sender))).is_err() {
369 break;
370 }
371 joiners.push(receiver);
372 }
373 }
374
375 for receiver in joiners {
376 let _ = receiver.recv();
377 }
378 let _ = self.once_state.take();
380 }
381
382 pub fn spawn<F: FnOnce() + Send + 'static>(&self, job: F) {
384 let state = self.get_state();
385 let _ = state.send.send(Message::Execute(Box::new(job)));
386 }
387
388 pub fn spawn_handle<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(&self, job: F) -> JobHandle<R> {
390 let (send, recv) = oneshot::oneshot();
391 let job = move || {
392 let _ = send.send((job)());
393 };
394 let _ = self.get_state().send.send(Message::Execute(Box::new(job)));
395
396 JobHandle {
397 inner: recv
398 }
399 }
400}
401
402impl Drop for ThreadPool {
403 #[inline(always)]
404 fn drop(&mut self) {
405 self.shutdown();
406 }
407}
408
409impl fmt::Debug for ThreadPool {
410 #[inline(always)]
411 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
412 fmt.write_fmt(format_args!("ThreadPool {{ threads: {} }}", self.thread_num.load(Ordering::Relaxed)))
413 }
414}
415
416unsafe impl Sync for ThreadPool {}
417