slave_pool/lib.rs
1//! Simple thread pool
2//!
3//! # Usage
4//!
5//! ```rust
6//! use slave_pool::ThreadPool;
7//! const SECOND: core::time::Duration = core::time::Duration::from_secs(1);
8//!
9//! static POOL: ThreadPool = ThreadPool::new();
10//!
11//! POOL.set_threads(8); //Tell how many threads you want
12//!
13//! let mut handles = Vec::new();
14//! for idx in 0..8 {
15//! handles.push(POOL.spawn_handle(move || {
16//! std::thread::sleep(SECOND);
17//! idx
18//! }));
19//! }
20//!
21//! POOL.set_threads(0); //Tells to shut down threads
22//!
23//! for (idx, handle) in handles.drain(..).enumerate() {
24//! assert_eq!(handle.wait().unwrap(), idx) //Even though we told it to shutdown all threads, it is going to finish queued job first
25//! }
26//!
27//! let handle = POOL.spawn_handle(|| {});
28//! assert!(handle.wait_timeout(SECOND).is_err()); // All are shutdown now
29//!
30//! POOL.set_threads(1); //But let's add one more
31//!
32//! assert!(handle.wait().is_ok());
33//!
34//! let handle = POOL.spawn_handle(|| panic!("Oh no!")); // We can panic, if we want
35//!
36//! assert!(handle.wait().is_err()); // In that case we'll get error, but thread will be ok
37//!
38//! let handle = POOL.spawn_handle(|| {});
39//!
40//! POOL.set_threads(0);
41//!
42//! assert!(handle.wait().is_ok());
43//! ```
44
45#![warn(missing_docs)]
46#![cfg_attr(feature = "cargo-clippy", allow(clippy::style))]
47
48use std::{thread, io};
49
50use core::{time, fmt};
51use core::sync::atomic::{Ordering, AtomicUsize, AtomicU16};
52
53mod utils;
54mod spin;
55mod oneshot;
56
57#[derive(Debug)]
58///Describes possible reasons for join to fail
59pub enum JoinError {
60 ///Job wasn't finished and aborted.
61 Disconnect,
62 ///Timeout expired, job continues.
63 Timeout,
64 ///Job was already consumed.
65 ///
66 ///Only possible if handle successfully finished with `wait_timeout`
67 ///or via reference future.
68 AlreadyConsumed,
69}
70
71///Handle to the job, allowing to await for it to finish
72///
73///It provides methods to block current thread to wait for job to finish.
74///Alternatively the handle implements `Future` allowing it to be used in async context.
75///
76///Note that it is undesirable for it to be awaited from multiple threads,
77///therefore `Clone` is not implemented, even though it is possible
78pub struct JobHandle<T> {
79 inner: oneshot::Receiver<T>
80}
81
82impl<T> fmt::Debug for JobHandle<T> {
83 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84 write!(f, "JobHandle")
85 }
86}
87
88impl<T> JobHandle<T> {
89 #[inline]
90 ///Awaits for job to finish indefinitely.
91 pub fn wait(self) -> Result<T, JoinError> {
92 self.inner.recv()
93 }
94
95 #[inline]
96 ///Awaits for job to finish for limited time.
97 pub fn wait_timeout(&self, timeout: time::Duration) -> Result<T, JoinError> {
98 self.inner.recv_timeout(timeout)
99 }
100}
101
102impl<T> core::future::Future for JobHandle<T> {
103 type Output = Result<T, JoinError>;
104
105 #[inline]
106 fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> {
107 let inner = unsafe {
108 self.map_unchecked_mut(|this| &mut this.inner)
109 };
110
111 core::future::Future::poll(inner, cx)
112 }
113}
114
115enum Message {
116 Execute(Box<dyn FnOnce() + Send + 'static>),
117 Shutdown,
118}
119
120struct State {
121 send: crossbeam_channel::Sender<Message>,
122 recv: crossbeam_channel::Receiver<Message>,
123}
124
125unsafe impl Sync for ThreadPool {}
126
127///Thread pool that allows to change number of threads at runtime.
128///
129///On `Drop` it instructs threads to shutdown, but doesn't await for them to finish
130///
131///# Note
132///
133///The pool doesn't implement any sort of flow control.
134///If workers are busy, message will remain in queue until any other thread can take it.
135///
136///# Clone
137///
138///Thread pool intentionally doesn't implement `Clone`
139///If you want to share it, then share it by using global variable or on heap.
140///It is thread safe, so concurrent access is allowed.
141///
142///# Panic
143///
144///Each thread wraps execution of job into `catch_unwind` to ensure that thread is not aborted
145///on panic
146pub struct ThreadPool {
147 stack_size: AtomicUsize,
148 thread_num: AtomicU16,
149 thread_num_lock: spin::Lock,
150 name: &'static str,
151 init_lock: std::sync::Once,
152 //Option is fine as extra size goes from padding, so it
153 //doesn't increase overall size, but when changing layout
154 //consider to switch to MaybeUninit
155 state: core::cell::Cell<Option<State>>,
156}
157
158impl ThreadPool {
159 ///Creates new thread pool with default params
160 pub const fn new() -> Self {
161 Self::with_defaults("", 0)
162 }
163
164 ///Creates new instance by specifying all params
165 pub const fn with_defaults(name: &'static str, stack_size: usize) -> Self {
166 Self {
167 stack_size: AtomicUsize::new(stack_size),
168 thread_num: AtomicU16::new(0),
169 thread_num_lock: spin::Lock::new(),
170 name,
171 init_lock: std::sync::Once::new(),
172 state: core::cell::Cell::new(None),
173 }
174 }
175
176 fn get_state(&self) -> &State {
177 self.init_lock.call_once(|| {
178 let (send, recv) = crossbeam_channel::unbounded();
179 self.state.set(Some(State {
180 send,
181 recv,
182 }))
183 });
184
185 match unsafe { &*self.state.as_ptr() } {
186 Some(state) => state,
187 None => unreach!(),
188 }
189 }
190
191 #[inline]
192 ///Sets stack size to use.
193 ///
194 ///By default it uses default value, used by Rust's stdlib.
195 ///But setting this variable overrides it, allowing to customize it.
196 ///
197 ///This setting takes effect only when creating new threads
198 pub fn set_stack_size(&self, stack_size: usize) -> usize {
199 self.stack_size.swap(stack_size, Ordering::AcqRel)
200 }
201
202 ///Sets worker number, starting new threads if it is greater than previous
203 ///
204 ///In case if it is less, extra threads are shut down.
205 ///Returns previous number of threads.
206 ///
207 ///By default when pool is created no threads are started.
208 ///
209 ///If any thread fails to start, function returns immediately with error.
210 ///
211 ///# Note
212 ///
213 ///Any calls to this method are serialized, which means under hood it locks out
214 ///any attempt to change number of threads, until it is done
215 pub fn set_threads(&self, thread_num: u16) -> io::Result<u16> {
216 let mut _guard = self.thread_num_lock.lock();
217 let old_thread_num = self.thread_num.load(Ordering::Relaxed);
218 self.thread_num.store(thread_num, Ordering::Relaxed);
219
220 if old_thread_num > thread_num {
221 let state = self.get_state();
222
223 let shutdown_num = old_thread_num - thread_num;
224 for _ in 0..shutdown_num {
225 if state.send.send(Message::Shutdown).is_err() {
226 break;
227 }
228 }
229
230 } else if thread_num > old_thread_num {
231 let create_num = thread_num - old_thread_num;
232 let stack_size = self.stack_size.load(Ordering::Acquire);
233 let state = self.get_state();
234
235 for num in 0..create_num {
236 let recv = state.recv.clone();
237
238 let builder = match self.name {
239 "" => thread::Builder::new(),
240 name => thread::Builder::new().name(name.to_owned()),
241 };
242
243 let builder = match stack_size {
244 0 => builder,
245 stack_size => builder.stack_size(stack_size),
246 };
247
248 let result = builder.spawn(move || loop { match recv.recv() {
249 Ok(Message::Execute(job)) => {
250 //TODO: for some reason closures has no impl, wonder why?
251 let job = std::panic::AssertUnwindSafe(job);
252 let _ = std::panic::catch_unwind(|| (job.0)());
253 },
254 Ok(Message::Shutdown) | Err(_) => break,
255 }});
256
257 match result {
258 Ok(_) => (),
259 Err(error) => {
260 self.thread_num.store(old_thread_num + num, Ordering::Relaxed);
261 return Err(error);
262 }
263 }
264 }
265 }
266
267 Ok(old_thread_num)
268 }
269
270 ///Schedules new execution, sending it over to one of the workers.
271 pub fn spawn<F: FnOnce() + Send + 'static>(&self, job: F) {
272 let state = self.get_state();
273 let _ = state.send.send(Message::Execute(Box::new(job)));
274 }
275
276 ///Schedules execution, that allows to await and receive it's result.
277 pub fn spawn_handle<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(&self, job: F) -> JobHandle<R> {
278 let (send, recv) = oneshot::oneshot();
279 let job = move || {
280 let _ = send.send(job());
281 };
282 let _ = self.get_state().send.send(Message::Execute(Box::new(job)));
283
284 JobHandle {
285 inner: recv
286 }
287 }
288}
289
290impl fmt::Debug for ThreadPool {
291 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292 write!(f, "ThreadPool {{ threads: {} }}", self.thread_num.load(Ordering::Relaxed))
293 }
294}