mod thread {
use std::thread;
pub type JoinHandle = Option<thread::JoinHandle<()>>;
#[inline]
pub fn spawn<F>(f: F) -> JoinHandle
where
F: FnOnce() + Send + 'static,
{
Some(thread::spawn(f))
}
pub fn join(thread: &mut JoinHandle) {
let thread = thread.take();
match thread {
Some(thread) => {
if let Err(e) = thread.join() {
panic!("{:?}", e);
}
}
None => panic!("Cannot join: no thread has been provided."),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spawn() {
assert!(spawn(|| {}).is_some());
}
#[test]
fn test_join() {
let mut thread = spawn(|| {});
join(&mut thread);
assert!(thread.is_none());
}
#[test]
#[should_panic]
fn test_join_panic_some() {
join(&mut spawn(|| panic!("Oh no!")));
}
#[test]
#[should_panic]
fn test_join_panic_none() {
join(&mut None);
}
}
}
use thread::JoinHandle;
use std::fmt;
use std::panic::UnwindSafe;
use crossbeam::channel::unbounded as channel;
use crossbeam::channel::Sender;
type Job = Box<dyn FnOnce() + UnwindSafe + Send + 'static>;
enum Message {
NewJob(Job),
Terminate,
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::NewJob(_) => write!(f, "[NewJob]"),
Self::Terminate => write!(f, "[Terminate]"),
}
}
}
pub enum PanicSwitch {
Kill,
Respawn,
}
pub struct ThreadPool {
supervisor: Supervisor,
}
impl ThreadPool {
pub fn new<'a>(size: usize, mode: PanicSwitch) -> Result<Self, &'a str> {
if size == 0 {
return Err("Setting up a pool with no workers is not allowed.");
}
let pool = Self {
supervisor: Supervisor::new(size, mode),
};
Ok(pool)
}
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + UnwindSafe + Send + 'static,
{
let job = Box::new(f);
self.send(Message::NewJob(job));
}
fn terminate(&mut self) {
self.send(Message::Terminate);
thread::join(&mut self.supervisor.thread);
}
fn send(&self, msg: Message) {
let panic_message = format!("Ordering {} failed. Pool is unreachable.", msg);
self.supervisor.orders_s.send(msg).expect(&panic_message);
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
self.terminate();
}
}
type StaffNumber = usize;
enum Status {
Idle(StaffNumber),
Panic(StaffNumber),
}
impl fmt::Display for Status {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Idle(_) => write!(f, "[idle]"),
Self::Panic(_) => write!(f, "[panic]"),
}
}
}
struct Supervisor {
orders_s: Sender<Message>,
thread: JoinHandle,
}
impl Supervisor {
fn new(mut number_of_workers: usize, mode: PanicSwitch) -> Self {
let (orders_s, orders_r) = channel();
let thread = thread::spawn(move || {
let (statuses_s, statuses_r) = channel();
let mut workers = Vec::with_capacity(number_of_workers);
for id in 0..number_of_workers {
workers.push(Worker::new(id, statuses_s.clone()));
}
let mut panicked_jobs = 0;
'distribute_jobs: while let Message::NewJob(job) = orders_r.recv().unwrap() {
'query_status: loop {
match statuses_r.recv().unwrap() {
Status::Idle(id) => {
workers[id]
.instructions_s
.send(Message::NewJob(job))
.unwrap();
break 'query_status;
}
Status::Panic(id) => {
thread::join(&mut workers[id].thread);
match mode {
PanicSwitch::Kill => {
panicked_jobs += 1;
number_of_workers -= 1;
break 'distribute_jobs;
}
PanicSwitch::Respawn => {
workers[id] = Worker::new(id, statuses_s.clone());
}
}
}
}
}
}
while number_of_workers != 0 {
match statuses_r.recv().unwrap() {
Status::Idle(id) => {
workers[id].instructions_s.send(Message::Terminate).unwrap();
thread::join(&mut workers[id].thread);
}
Status::Panic(id) => {
thread::join(&mut workers[id].thread);
if matches!(mode, PanicSwitch::Kill) {
panicked_jobs += 1;
}
}
}
number_of_workers -= 1;
}
if panicked_jobs > 0 {
eprintln!("Aborting process: {} panicked jobs.", panicked_jobs);
std::process::abort();
}
drop(orders_r);
});
Self { orders_s, thread }
}
}
struct Worker {
instructions_s: Sender<Message>,
thread: JoinHandle,
}
impl Worker {
fn new(id: StaffNumber, statuses_s: Sender<Status>) -> Self {
let (instructions_s, instructions_r) = channel();
let thread = thread::spawn(move || {
statuses_s.send(Status::Idle(id)).unwrap();
loop {
let message = instructions_r.recv().unwrap();
match message {
Message::NewJob(job) => match std::panic::catch_unwind(job) {
Ok(()) => {
statuses_s.send(Status::Idle(id)).unwrap();
}
Err(_) => {
statuses_s.send(Status::Panic(id)).unwrap();
break;
}
},
Message::Terminate => break,
}
}
});
Self {
instructions_s,
thread,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
const SIZE: usize = 2; const MODE: PanicSwitch = PanicSwitch::Respawn; const ID: StaffNumber = 0;
#[test]
fn test_threadpool_new_ok() {
let pool = ThreadPool::new(SIZE, MODE);
assert!(pool.is_ok());
}
#[test]
fn test_threadpool_new_err() {
let pool = ThreadPool::new(0, MODE);
assert!(pool.is_err());
}
#[test]
fn test_threadpool_execute() {
const N: usize = 5;
let pool = ThreadPool::new(SIZE, MODE).unwrap();
let counter = Arc::new(AtomicUsize::new(0));
let count_to = |n: usize| {
for _ in 0..n {
let counter = Arc::clone(&counter);
pool.execute(move || {
counter.fetch_add(1, Ordering::SeqCst);
});
}
};
for _ in 0..N {
count_to(SIZE);
if matches!(MODE, PanicSwitch::Respawn) {
pool.execute(|| panic!("Oh no!"));
}
}
drop(pool);
assert_eq!(N * SIZE, counter.load(Ordering::SeqCst));
}
#[test]
fn test_worker_thread_newjob() {
let (statuses_s, statuses_r) = channel();
let mut worker = Worker::new(ID, statuses_s);
assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
let flag = Arc::new(AtomicBool::new(false));
let flag_ref = Arc::clone(&flag);
let job = Box::new(move || {
flag_ref.store(true, Ordering::SeqCst);
});
worker.instructions_s.send(Message::NewJob(job)).unwrap();
assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
assert!(flag.load(Ordering::SeqCst));
let job = Box::new(|| panic!("Oh no!"));
worker.instructions_s.send(Message::NewJob(job)).unwrap();
assert!(matches!(statuses_r.recv().unwrap(), Status::Panic(ID)));
thread::join(&mut worker.thread);
}
#[test]
fn test_worker_thread_terminate() {
let (statuses_s, statuses_r) = channel();
let mut worker = Worker::new(ID, statuses_s);
assert!(matches!(statuses_r.recv().unwrap(), Status::Idle(ID)));
worker.instructions_s.send(Message::Terminate).unwrap();
thread::join(&mut worker.thread);
}
}