Skip to main content

poolio/
lib.rs

1//! poolio is a thread pool implementation using only channels for concurrency.
2//!
3//! ## Design
4//!
5//! A poolio thread pool is essentially made up of a 'supervisor' thread and a specified number of 'worker' threads.
6//! A worker's only purpose is executing jobs (in the form of closures), while the supervisor is responsible for everything else - most importantly, assigning jobs to workers that it receives from outside the pool via the public API.
7//! To this end, the thread pool is set up so that the supervisor can communicate with each worker separately and concurrently.
8//! This ensures that each worker remains equally busy.
9//! A single supervisor-worker communication cycle is roughly as follows:
10//! 1. The worker tells the supervisor its current status.
11//! 2. The supervisor decides what to tell the worker to do based on the current order-message from outside the pool and the worker's status.
12//! 3. The supervisor tells the worker what to do.
13//! 4. The worker attempts to perform the task assigned by the supervisor.
14//! 5. The worker tells the supervisor its current status.
15//!
16//! The following graphic illustrates the aforementioned communication model between a supervisor thread S and a worker thread W:
17//!
18//! <pre>
19//!    W
20//!    _
21//!    .
22//!    .
23//!    send-status
24//!    .   O
25//!    .     O
26//!    .       O                 send-message
27//!    .         O                   O
28//!    .           O               O
29//!    recv         recv         O
30//!   * .  O       O  . .      O
31//!  .   .   O   O   .   .   O
32//! .     e    O    m     recv . . | S
33//!  .   .   O   O   .   *
34//!   . .  O       O  . .
35//!    send-status  send-message
36//!
37//! X | . . * : arrow starting at | and ending at * representing the control-flow of thread X
38//! O O O O O : channel
39//! e : execute job
40//! m : manage workers
41//! </pre>
42//!
43//! ## Usage
44//!
45//! To use a poolio [`ThreadPool`], you simply set one up using the [`ThreadPool::new`] method and task the pool to run jobs using the [`ThreadPool::execute`] method.
46//!
47//! # Examples
48//!
49//! Setting up a pool to make a server multi-threaded:
50//!
51//! ```
52//! fn handle(req: usize) {
53//!     println!("Handled!")
54//! }
55//!
56//! let server_requests = [1, 2, 3, 4, 5, 6, 7, 8, 9];
57//!
58//! let pool = poolio::ThreadPool::new(3, poolio::PanicSwitch::Kill).unwrap();
59//!
60//! for req in server_requests {
61//!     pool.execute(move || {
62//!         handle(req);
63//!     });
64//! }
65//! ```
66
67mod thread {
68    //! This module is a wrapper for parts of the [`std::thread`] module to handle ownership issues when joining threads embedded in a larger data structure.
69    //! It allows you to spawn threads that return a handle, which you can join normally even if the handle is part of a larger data structure.
70
71    use std::thread;
72
73    /// Wraps [`std::thread::JoinHandle<T>`] to allow for "stealing" the handle for joining.
74    pub type JoinHandle = Option<thread::JoinHandle<()>>;
75
76    /// Wraps [`std::thread::spawn`] in an [`Option::Some`].
77    #[inline]
78    pub fn spawn<F>(f: F) -> JoinHandle
79    where
80        F: FnOnce() + Send + 'static,
81    {
82        Some(thread::spawn(f))
83    }
84
85    /// Takes the thread handle from the call site to pass it to [`std::thread::JoinHandle<T>::join`].
86    /// - `thread` is a reference to the handle this function intends to take.
87    ///
88    /// # Panics
89    ///
90    /// Panics if the `thread` is `None` or if joining the thread fails (which occurs if the thread panicked).
91    pub fn join(thread: &mut JoinHandle) {
92        let thread = thread.take();
93
94        match thread {
95            Some(thread) => {
96                if let Err(e) = thread.join() {
97                    panic!("{:?}", e);
98                }
99            }
100            None => panic!("Cannot join: no thread has been provided."),
101        }
102    }
103
104    #[cfg(test)]
105    mod tests {
106        use super::*;
107
108        #[test]
109        fn test_spawn() {
110            assert!(spawn(|| {}).is_some());
111        }
112
113        #[test]
114        fn test_join() {
115            let mut thread = spawn(|| {});
116            join(&mut thread);
117            assert!(thread.is_none());
118        }
119
120        #[test]
121        #[should_panic]
122        fn test_join_panic_some() {
123            join(&mut spawn(|| panic!("Oh no!")));
124        }
125
126        #[test]
127        #[should_panic]
128        fn test_join_panic_none() {
129            join(&mut None);
130        }
131    }
132}
133
134use thread::JoinHandle;
135
136use std::fmt;
137use std::panic::UnwindSafe;
138
139use crossbeam::channel::unbounded as channel;
140use crossbeam::channel::Sender;
141
142/// The type of jobs the [`ThreadPool`] can run.
143type Job = Box<dyn FnOnce() + UnwindSafe + Send + 'static>;
144
145/// Messages containing orders for the [`ThreadPool`].
146enum Message {
147    /// A message ordering the pool to execute a job.
148    NewJob(Job),
149    /// A message ordering the pool to finish its remaining jobs and then shut down.
150    Terminate,
151}
152
153impl fmt::Display for Message {
154    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
155        match *self {
156            Self::NewJob(_) => write!(f, "[NewJob]"),
157            Self::Terminate => write!(f, "[Terminate]"),
158        }
159    }
160}
161
162/// Configuration for how the [`ThreadPool`] handles panics in jobs.
163pub enum PanicSwitch {
164    /// The pool finishes parallel running jobs and then kills the entire process if a job panics.
165    Kill,
166    /// The pool ignores panicked jobs and simply respawns the affected threads.
167    Respawn,
168}
169
170/// Abstracts thread pools.
171pub struct ThreadPool {
172    /// Interface to the pool-controlling thread.
173    supervisor: Supervisor,
174}
175
176impl ThreadPool {
177    /// Sets up a new pool.
178    /// - `size` is the (non-zero) number of worker threads in the pool.
179    /// - `mode` is the setting for the panic switch.
180    ///
181    /// # Errors
182    ///
183    /// Returns an error if `size` is 0 (as a pool without worker threads is invalid).
184    ///
185    /// # Examples
186    ///
187    /// Setting up a pool with three worker threads in kill-mode:
188    ///
189    /// ```
190    /// let pool = poolio::ThreadPool::new(3, poolio::PanicSwitch::Kill).unwrap();
191    /// ```
192    pub fn new<'a>(size: usize, mode: PanicSwitch) -> Result<Self, &'a str> {
193        if size == 0 {
194            return Err("Setting up a pool with no workers is not allowed.");
195        }
196
197        let pool = Self {
198            supervisor: Supervisor::new(size, mode),
199        };
200        Ok(pool)
201    }
202
203    /// Runs a job in `self`.
204    /// - `f` is the job to be run, provided as a closure.
205    ///
206    /// # Panics
207    ///
208    /// Panics if the pool is unreachable.
209    ///
210    /// # Notes
211    ///
212    /// If `f` panics, the behavior is determined by the [`PanicSwitch`] setting of `self`.
213    ///
214    /// # Examples
215    ///
216    /// Setting up a pool and printing two strings concurrently:
217    ///
218    /// ```
219    /// let pool = poolio::ThreadPool::new(2, poolio::PanicSwitch::Kill).unwrap();
220    /// pool.execute(|| println!{"house"});
221    /// pool.execute(|| println!{"cat"});
222    /// ```
223    pub fn execute<F>(&self, f: F)
224    where
225        F: FnOnce() + UnwindSafe + Send + 'static,
226    {
227        let job = Box::new(f);
228
229        self.send(Message::NewJob(job));
230    }
231
232    /// Attempts to shut down `self` gracefully.
233    ///
234    /// # Panics
235    ///
236    /// Panics if:
237    /// 1. The pool is unreachable.
238    /// 2. Joining the threads causes a panic.
239    ///
240    /// # Notes
241    ///
242    /// Graceful shutdown ensures all remaining jobs are finished (except for panics in [`PanicSwitch::Kill`] mode).
243    fn terminate(&mut self) {
244        self.send(Message::Terminate);
245
246        thread::join(&mut self.supervisor.thread);
247    }
248
249    /// Wraps sending a [`Message`] to the pool.
250    ///
251    /// # Panics
252    ///
253    /// Panics if the receiver has already been deallocated.
254    fn send(&self, msg: Message) {
255        let panic_message = format!("Ordering {} failed. Pool is unreachable.", msg);
256
257        self.supervisor.orders_s.send(msg).expect(&panic_message);
258    }
259}
260
261impl Drop for ThreadPool {
262    /// Attempts to shut down `self` gracefully.
263    ///
264    /// # Panics
265    ///
266    /// Panics if:
267    /// 1. The pool is unreachable.
268    /// 2. Joining the threads causes a panic.
269    ///
270    /// Note: A panic during a drop will abort the entire process.
271    ///
272    /// # Notes
273    ///
274    /// Graceful shutdown ensures all remaining jobs are finished (except for panics in [`PanicSwitch::Kill`] mode).
275    fn drop(&mut self) {
276        self.terminate();
277    }
278}
279
280/// A numeric type used to identify workers.
281type StaffNumber = usize;
282
283/// States a worker can be in when not busy.
284enum Status {
285    Idle(StaffNumber),
286    Panic(StaffNumber),
287}
288
289impl fmt::Display for Status {
290    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
291        match *self {
292            Self::Idle(_) => write!(f, "[idle]"),
293            Self::Panic(_) => write!(f, "[panic]"),
294        }
295    }
296}
297
298/// Abstracts supervisors.
299struct Supervisor {
300    /// Channel for sending orders.
301    orders_s: Sender<Message>,
302    /// Handle to join the supervisor thread.
303    thread: JoinHandle,
304}
305
306impl Supervisor {
307    /// Sets up a supervisor.
308    /// - `number_of_workers` is the number of workers to employ.
309    /// - `mode` configures behavior when workers report panicked jobs.
310    fn new(mut number_of_workers: usize, mode: PanicSwitch) -> Self {
311        // This channel is used by the pool to contact the supervisor.
312        let (orders_s, orders_r) = channel();
313
314        let thread = thread::spawn(move || {
315            // This channel is used by the workers to contact the supervisor.
316            let (statuses_s, statuses_r) = channel();
317
318            // Construct `number_of_workers` worker threads.
319            let mut workers = Vec::with_capacity(number_of_workers);
320            for id in 0..number_of_workers {
321                workers.push(Worker::new(id, statuses_s.clone()));
322            }
323
324            // Track how many jobs have panicked.
325            let mut panicked_jobs = 0;
326
327            // Keep running to distribute jobs among idle workers.
328            'distribute_jobs: while let Message::NewJob(job) = orders_r.recv().unwrap() {
329                'query_status: loop {
330                    match statuses_r.recv().unwrap() {
331                        Status::Idle(id) => {
332                            workers[id]
333                                .instructions_s
334                                .send(Message::NewJob(job))
335                                .unwrap();
336                            break 'query_status;
337                        }
338                        Status::Panic(id) => {
339                            thread::join(&mut workers[id].thread);
340                            match mode {
341                                PanicSwitch::Kill => {
342                                    panicked_jobs += 1;
343                                    number_of_workers -= 1;
344                                    break 'distribute_jobs;
345                                }
346                                PanicSwitch::Respawn => {
347                                    workers[id] = Worker::new(id, statuses_s.clone());
348                                }
349                            }
350                        }
351                    }
352                }
353            }
354
355            // Destruct all remaining worker threads.
356            while number_of_workers != 0 {
357                match statuses_r.recv().unwrap() {
358                    Status::Idle(id) => {
359                        workers[id].instructions_s.send(Message::Terminate).unwrap();
360                        thread::join(&mut workers[id].thread);
361                    }
362                    Status::Panic(id) => {
363                        thread::join(&mut workers[id].thread);
364                        if matches!(mode, PanicSwitch::Kill) {
365                            panicked_jobs += 1;
366                        }
367                    }
368                }
369                number_of_workers -= 1;
370            }
371
372            if panicked_jobs > 0 {
373                eprintln!("Aborting process: {} panicked jobs.", panicked_jobs);
374                std::process::abort();
375            }
376
377            // Ensure that `orders_r` lives as long as the thread to prevent reachability errors.
378            drop(orders_r);
379        });
380
381        Self { orders_s, thread }
382    }
383}
384
385/// Abstracts workers.
386struct Worker {
387    /// Channel for sending instructions.
388    instructions_s: Sender<Message>,
389    /// Handle to join the worker thread.
390    thread: JoinHandle,
391}
392
393impl Worker {
394    /// Sets up a new worker.
395    /// - `id` is the worker's staff number.
396    /// - `statuses_s` is where the worker reports its current status.
397    fn new(id: StaffNumber, statuses_s: Sender<Status>) -> Self {
398        // This channel is used by the supervisor to contact this worker.
399        let (instructions_s, instructions_r) = channel();
400
401        let thread = thread::spawn(move || {
402            // Report for duty.
403            statuses_s.send(Status::Idle(id)).unwrap();
404
405            // Keep running to execute jobs.
406            loop {
407                let message = instructions_r.recv().unwrap();
408
409                match message {
410                    Message::NewJob(job) => match std::panic::catch_unwind(job) {
411                        Ok(()) => {
412                            statuses_s.send(Status::Idle(id)).unwrap();
413                        }
414                        Err(_) => {
415                            statuses_s.send(Status::Panic(id)).unwrap();
416                            break;
417                        }
418                    },
419                    Message::Terminate => break,
420                }
421            }
422        });
423
424        Self {
425            instructions_s,
426            thread,
427        }
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
435    use std::sync::Arc;
436
437    // settings
438    const SIZE: usize = 2; //= 6; && = 12; && = 36;
439    const MODE: PanicSwitch = PanicSwitch::Respawn; //= PanicSwitch::Kill;
440    const ID: StaffNumber = 0;
441
442    #[test]
443    fn test_threadpool_new_ok() {
444        let pool = ThreadPool::new(SIZE, MODE);
445        assert!(pool.is_ok());
446    }
447
448    #[test]
449    fn test_threadpool_new_err() {
450        let pool = ThreadPool::new(0, MODE);
451        assert!(pool.is_err());
452    }
453
454    #[test]
455    fn test_threadpool_execute() {
456        const N: usize = 5;
457
458        let pool = ThreadPool::new(SIZE, MODE).unwrap();
459
460        let counter = Arc::new(AtomicUsize::new(0));
461
462        let count_to = |n: usize| {
463            for _ in 0..n {
464                let counter = Arc::clone(&counter);
465                pool.execute(move || {
466                    counter.fetch_add(1, Ordering::SeqCst);
467                });
468            }
469        };
470
471        for _ in 0..N {
472            count_to(SIZE);
473            if matches!(MODE, PanicSwitch::Respawn) {
474                pool.execute(|| panic!("Oh no!"));
475            }
476        }
477
478        drop(pool);
479
480        assert_eq!(N * SIZE, counter.load(Ordering::SeqCst));
481    }
482
483    #[test]
484    fn test_worker_thread_newjob() {
485        let (statuses_s, statuses_r) = channel();
486        let mut worker = Worker::new(ID, statuses_s);
487
488        assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
489
490        let flag = Arc::new(AtomicBool::new(false));
491        let flag_ref = Arc::clone(&flag);
492        let job = Box::new(move || {
493            flag_ref.store(true, Ordering::SeqCst);
494        });
495        worker.instructions_s.send(Message::NewJob(job)).unwrap();
496        assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
497        assert!(flag.load(Ordering::SeqCst));
498
499        let job = Box::new(|| panic!("Oh no!"));
500        worker.instructions_s.send(Message::NewJob(job)).unwrap();
501        assert!(matches!(statuses_r.recv().unwrap(), Status::Panic(ID)));
502
503        thread::join(&mut worker.thread);
504    }
505
506    #[test]
507    fn test_worker_thread_terminate() {
508        let (statuses_s, statuses_r) = channel();
509        let mut worker = Worker::new(ID, statuses_s);
510
511        assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
512
513        worker.instructions_s.send(Message::Terminate).unwrap();
514
515        thread::join(&mut worker.thread);
516    }
517}