poolio/lib.rs
1//! poolio is a thread pool implementation using only channels for concurrency.
2//!
3//! ## Design
4//!
5//! A poolio thread pool is essentially made up of a 'supervisor' thread and a specified number of 'worker' threads.
6//! A worker's only purpose is executing jobs (in the form of closures), while the supervisor is responsible for everything else - most importantly, assigning jobs to workers that it receives from outside the pool via the public API.
7//! To this end, the thread pool is set up so that the supervisor can communicate with each worker separately and concurrently.
8//! This ensures that each worker remains equally busy.
9//! A single supervisor-worker communication cycle is roughly as follows:
10//! 1. The worker tells the supervisor its current status.
11//! 2. The supervisor decides what to tell the worker to do based on the current order-message from outside the pool and the worker's status.
12//! 3. The supervisor tells the worker what to do.
13//! 4. The worker attempts to perform the task assigned by the supervisor.
14//! 5. The worker tells the supervisor its current status.
15//!
16//! The following graphic illustrates the aforementioned communication model between a supervisor thread S and a worker thread W:
17//!
18//! <pre>
19//! W
20//! _
21//! .
22//! .
23//! send-status
24//! . O
25//! . O
26//! . O send-message
27//! . O O
28//! . O O
29//! recv recv O
30//! * . O O . . O
31//! . . O O . . O
32//! . e O m recv . . | S
33//! . . O O . *
34//! . . O O . .
35//! send-status send-message
36//!
37//! X | . . * : arrow starting at | and ending at * representing the control-flow of thread X
38//! O O O O O : channel
39//! e : execute job
40//! m : manage workers
41//! </pre>
42//!
43//! ## Usage
44//!
45//! To use a poolio [`ThreadPool`], you simply set one up using the [`ThreadPool::new`] method and task the pool to run jobs using the [`ThreadPool::execute`] method.
46//!
47//! # Examples
48//!
49//! Setting up a pool to make a server multi-threaded:
50//!
51//! ```
52//! fn handle(req: usize) {
53//! println!("Handled!")
54//! }
55//!
56//! let server_requests = [1, 2, 3, 4, 5, 6, 7, 8, 9];
57//!
58//! let pool = poolio::ThreadPool::new(3, poolio::PanicSwitch::Kill).unwrap();
59//!
60//! for req in server_requests {
61//! pool.execute(move || {
62//! handle(req);
63//! });
64//! }
65//! ```
66
67mod thread {
68 //! This module is a wrapper for parts of the [`std::thread`] module to handle ownership issues when joining threads embedded in a larger data structure.
69 //! It allows you to spawn threads that return a handle, which you can join normally even if the handle is part of a larger data structure.
70
71 use std::thread;
72
73 /// Wraps [`std::thread::JoinHandle<T>`] to allow for "stealing" the handle for joining.
74 pub type JoinHandle = Option<thread::JoinHandle<()>>;
75
76 /// Wraps [`std::thread::spawn`] in an [`Option::Some`].
77 #[inline]
78 pub fn spawn<F>(f: F) -> JoinHandle
79 where
80 F: FnOnce() + Send + 'static,
81 {
82 Some(thread::spawn(f))
83 }
84
85 /// Takes the thread handle from the call site to pass it to [`std::thread::JoinHandle<T>::join`].
86 /// - `thread` is a reference to the handle this function intends to take.
87 ///
88 /// # Panics
89 ///
90 /// Panics if the `thread` is `None` or if joining the thread fails (which occurs if the thread panicked).
91 pub fn join(thread: &mut JoinHandle) {
92 let thread = thread.take();
93
94 match thread {
95 Some(thread) => {
96 if let Err(e) = thread.join() {
97 panic!("{:?}", e);
98 }
99 }
100 None => panic!("Cannot join: no thread has been provided."),
101 }
102 }
103
104 #[cfg(test)]
105 mod tests {
106 use super::*;
107
108 #[test]
109 fn test_spawn() {
110 assert!(spawn(|| {}).is_some());
111 }
112
113 #[test]
114 fn test_join() {
115 let mut thread = spawn(|| {});
116 join(&mut thread);
117 assert!(thread.is_none());
118 }
119
120 #[test]
121 #[should_panic]
122 fn test_join_panic_some() {
123 join(&mut spawn(|| panic!("Oh no!")));
124 }
125
126 #[test]
127 #[should_panic]
128 fn test_join_panic_none() {
129 join(&mut None);
130 }
131 }
132}
133
134use thread::JoinHandle;
135
136use std::fmt;
137use std::panic::UnwindSafe;
138
139use crossbeam::channel::unbounded as channel;
140use crossbeam::channel::Sender;
141
142/// The type of jobs the [`ThreadPool`] can run.
143type Job = Box<dyn FnOnce() + UnwindSafe + Send + 'static>;
144
145/// Messages containing orders for the [`ThreadPool`].
146enum Message {
147 /// A message ordering the pool to execute a job.
148 NewJob(Job),
149 /// A message ordering the pool to finish its remaining jobs and then shut down.
150 Terminate,
151}
152
153impl fmt::Display for Message {
154 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
155 match *self {
156 Self::NewJob(_) => write!(f, "[NewJob]"),
157 Self::Terminate => write!(f, "[Terminate]"),
158 }
159 }
160}
161
162/// Configuration for how the [`ThreadPool`] handles panics in jobs.
163pub enum PanicSwitch {
164 /// The pool finishes parallel running jobs and then kills the entire process if a job panics.
165 Kill,
166 /// The pool ignores panicked jobs and simply respawns the affected threads.
167 Respawn,
168}
169
170/// Abstracts thread pools.
171pub struct ThreadPool {
172 /// Interface to the pool-controlling thread.
173 supervisor: Supervisor,
174}
175
176impl ThreadPool {
177 /// Sets up a new pool.
178 /// - `size` is the (non-zero) number of worker threads in the pool.
179 /// - `mode` is the setting for the panic switch.
180 ///
181 /// # Errors
182 ///
183 /// Returns an error if `size` is 0 (as a pool without worker threads is invalid).
184 ///
185 /// # Examples
186 ///
187 /// Setting up a pool with three worker threads in kill-mode:
188 ///
189 /// ```
190 /// let pool = poolio::ThreadPool::new(3, poolio::PanicSwitch::Kill).unwrap();
191 /// ```
192 pub fn new<'a>(size: usize, mode: PanicSwitch) -> Result<Self, &'a str> {
193 if size == 0 {
194 return Err("Setting up a pool with no workers is not allowed.");
195 }
196
197 let pool = Self {
198 supervisor: Supervisor::new(size, mode),
199 };
200 Ok(pool)
201 }
202
203 /// Runs a job in `self`.
204 /// - `f` is the job to be run, provided as a closure.
205 ///
206 /// # Panics
207 ///
208 /// Panics if the pool is unreachable.
209 ///
210 /// # Notes
211 ///
212 /// If `f` panics, the behavior is determined by the [`PanicSwitch`] setting of `self`.
213 ///
214 /// # Examples
215 ///
216 /// Setting up a pool and printing two strings concurrently:
217 ///
218 /// ```
219 /// let pool = poolio::ThreadPool::new(2, poolio::PanicSwitch::Kill).unwrap();
220 /// pool.execute(|| println!{"house"});
221 /// pool.execute(|| println!{"cat"});
222 /// ```
223 pub fn execute<F>(&self, f: F)
224 where
225 F: FnOnce() + UnwindSafe + Send + 'static,
226 {
227 let job = Box::new(f);
228
229 self.send(Message::NewJob(job));
230 }
231
232 /// Attempts to shut down `self` gracefully.
233 ///
234 /// # Panics
235 ///
236 /// Panics if:
237 /// 1. The pool is unreachable.
238 /// 2. Joining the threads causes a panic.
239 ///
240 /// # Notes
241 ///
242 /// Graceful shutdown ensures all remaining jobs are finished (except for panics in [`PanicSwitch::Kill`] mode).
243 fn terminate(&mut self) {
244 self.send(Message::Terminate);
245
246 thread::join(&mut self.supervisor.thread);
247 }
248
249 /// Wraps sending a [`Message`] to the pool.
250 ///
251 /// # Panics
252 ///
253 /// Panics if the receiver has already been deallocated.
254 fn send(&self, msg: Message) {
255 let panic_message = format!("Ordering {} failed. Pool is unreachable.", msg);
256
257 self.supervisor.orders_s.send(msg).expect(&panic_message);
258 }
259}
260
261impl Drop for ThreadPool {
262 /// Attempts to shut down `self` gracefully.
263 ///
264 /// # Panics
265 ///
266 /// Panics if:
267 /// 1. The pool is unreachable.
268 /// 2. Joining the threads causes a panic.
269 ///
270 /// Note: A panic during a drop will abort the entire process.
271 ///
272 /// # Notes
273 ///
274 /// Graceful shutdown ensures all remaining jobs are finished (except for panics in [`PanicSwitch::Kill`] mode).
275 fn drop(&mut self) {
276 self.terminate();
277 }
278}
279
280/// A numeric type used to identify workers.
281type StaffNumber = usize;
282
283/// States a worker can be in when not busy.
284enum Status {
285 Idle(StaffNumber),
286 Panic(StaffNumber),
287}
288
289impl fmt::Display for Status {
290 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
291 match *self {
292 Self::Idle(_) => write!(f, "[idle]"),
293 Self::Panic(_) => write!(f, "[panic]"),
294 }
295 }
296}
297
298/// Abstracts supervisors.
299struct Supervisor {
300 /// Channel for sending orders.
301 orders_s: Sender<Message>,
302 /// Handle to join the supervisor thread.
303 thread: JoinHandle,
304}
305
306impl Supervisor {
307 /// Sets up a supervisor.
308 /// - `number_of_workers` is the number of workers to employ.
309 /// - `mode` configures behavior when workers report panicked jobs.
310 fn new(mut number_of_workers: usize, mode: PanicSwitch) -> Self {
311 // This channel is used by the pool to contact the supervisor.
312 let (orders_s, orders_r) = channel();
313
314 let thread = thread::spawn(move || {
315 // This channel is used by the workers to contact the supervisor.
316 let (statuses_s, statuses_r) = channel();
317
318 // Construct `number_of_workers` worker threads.
319 let mut workers = Vec::with_capacity(number_of_workers);
320 for id in 0..number_of_workers {
321 workers.push(Worker::new(id, statuses_s.clone()));
322 }
323
324 // Track how many jobs have panicked.
325 let mut panicked_jobs = 0;
326
327 // Keep running to distribute jobs among idle workers.
328 'distribute_jobs: while let Message::NewJob(job) = orders_r.recv().unwrap() {
329 'query_status: loop {
330 match statuses_r.recv().unwrap() {
331 Status::Idle(id) => {
332 workers[id]
333 .instructions_s
334 .send(Message::NewJob(job))
335 .unwrap();
336 break 'query_status;
337 }
338 Status::Panic(id) => {
339 thread::join(&mut workers[id].thread);
340 match mode {
341 PanicSwitch::Kill => {
342 panicked_jobs += 1;
343 number_of_workers -= 1;
344 break 'distribute_jobs;
345 }
346 PanicSwitch::Respawn => {
347 workers[id] = Worker::new(id, statuses_s.clone());
348 }
349 }
350 }
351 }
352 }
353 }
354
355 // Destruct all remaining worker threads.
356 while number_of_workers != 0 {
357 match statuses_r.recv().unwrap() {
358 Status::Idle(id) => {
359 workers[id].instructions_s.send(Message::Terminate).unwrap();
360 thread::join(&mut workers[id].thread);
361 }
362 Status::Panic(id) => {
363 thread::join(&mut workers[id].thread);
364 if matches!(mode, PanicSwitch::Kill) {
365 panicked_jobs += 1;
366 }
367 }
368 }
369 number_of_workers -= 1;
370 }
371
372 if panicked_jobs > 0 {
373 eprintln!("Aborting process: {} panicked jobs.", panicked_jobs);
374 std::process::abort();
375 }
376
377 // Ensure that `orders_r` lives as long as the thread to prevent reachability errors.
378 drop(orders_r);
379 });
380
381 Self { orders_s, thread }
382 }
383}
384
385/// Abstracts workers.
386struct Worker {
387 /// Channel for sending instructions.
388 instructions_s: Sender<Message>,
389 /// Handle to join the worker thread.
390 thread: JoinHandle,
391}
392
393impl Worker {
394 /// Sets up a new worker.
395 /// - `id` is the worker's staff number.
396 /// - `statuses_s` is where the worker reports its current status.
397 fn new(id: StaffNumber, statuses_s: Sender<Status>) -> Self {
398 // This channel is used by the supervisor to contact this worker.
399 let (instructions_s, instructions_r) = channel();
400
401 let thread = thread::spawn(move || {
402 // Report for duty.
403 statuses_s.send(Status::Idle(id)).unwrap();
404
405 // Keep running to execute jobs.
406 loop {
407 let message = instructions_r.recv().unwrap();
408
409 match message {
410 Message::NewJob(job) => match std::panic::catch_unwind(job) {
411 Ok(()) => {
412 statuses_s.send(Status::Idle(id)).unwrap();
413 }
414 Err(_) => {
415 statuses_s.send(Status::Panic(id)).unwrap();
416 break;
417 }
418 },
419 Message::Terminate => break,
420 }
421 }
422 });
423
424 Self {
425 instructions_s,
426 thread,
427 }
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
435 use std::sync::Arc;
436
437 // settings
438 const SIZE: usize = 2; //= 6; && = 12; && = 36;
439 const MODE: PanicSwitch = PanicSwitch::Respawn; //= PanicSwitch::Kill;
440 const ID: StaffNumber = 0;
441
442 #[test]
443 fn test_threadpool_new_ok() {
444 let pool = ThreadPool::new(SIZE, MODE);
445 assert!(pool.is_ok());
446 }
447
448 #[test]
449 fn test_threadpool_new_err() {
450 let pool = ThreadPool::new(0, MODE);
451 assert!(pool.is_err());
452 }
453
454 #[test]
455 fn test_threadpool_execute() {
456 const N: usize = 5;
457
458 let pool = ThreadPool::new(SIZE, MODE).unwrap();
459
460 let counter = Arc::new(AtomicUsize::new(0));
461
462 let count_to = |n: usize| {
463 for _ in 0..n {
464 let counter = Arc::clone(&counter);
465 pool.execute(move || {
466 counter.fetch_add(1, Ordering::SeqCst);
467 });
468 }
469 };
470
471 for _ in 0..N {
472 count_to(SIZE);
473 if matches!(MODE, PanicSwitch::Respawn) {
474 pool.execute(|| panic!("Oh no!"));
475 }
476 }
477
478 drop(pool);
479
480 assert_eq!(N * SIZE, counter.load(Ordering::SeqCst));
481 }
482
483 #[test]
484 fn test_worker_thread_newjob() {
485 let (statuses_s, statuses_r) = channel();
486 let mut worker = Worker::new(ID, statuses_s);
487
488 assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
489
490 let flag = Arc::new(AtomicBool::new(false));
491 let flag_ref = Arc::clone(&flag);
492 let job = Box::new(move || {
493 flag_ref.store(true, Ordering::SeqCst);
494 });
495 worker.instructions_s.send(Message::NewJob(job)).unwrap();
496 assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
497 assert!(flag.load(Ordering::SeqCst));
498
499 let job = Box::new(|| panic!("Oh no!"));
500 worker.instructions_s.send(Message::NewJob(job)).unwrap();
501 assert!(matches!(statuses_r.recv().unwrap(), Status::Panic(ID)));
502
503 thread::join(&mut worker.thread);
504 }
505
506 #[test]
507 fn test_worker_thread_terminate() {
508 let (statuses_s, statuses_r) = channel();
509 let mut worker = Worker::new(ID, statuses_s);
510
511 assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
512
513 worker.instructions_s.send(Message::Terminate).unwrap();
514
515 thread::join(&mut worker.thread);
516 }
517}