use std::sync::mpsc::{channel, Sender, Receiver};
use std::sync::{Arc, Mutex};
use std::thread;
trait FnBox {
fn call_box(self: Box<Self>);
}
impl<F: FnOnce()> FnBox for F {
fn call_box(self: Box<F>) {
(*self)()
}
}
type Thunk<'a> = Box<FnBox + Send + 'a>;
struct Sentinel<'a> {
jobs: &'a Arc<Mutex<Receiver<Thunk<'static>>>>,
thread_counter: &'a Arc<Mutex<usize>>,
thread_count_max: &'a Arc<Mutex<usize>>,
active: bool
}
impl<'a> Sentinel<'a> {
fn new(jobs: &'a Arc<Mutex<Receiver<Thunk<'static>>>>,
thread_counter: &'a Arc<Mutex<usize>>,
thread_count_max: &'a Arc<Mutex<usize>>) -> Sentinel<'a> {
Sentinel {
jobs: jobs,
thread_counter: thread_counter,
thread_count_max: thread_count_max,
active: true
}
}
fn cancel(mut self) {
self.active = false;
}
}
impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
if self.active {
*self.thread_counter.lock().unwrap() -= 1;
spawn_in_pool(self.jobs.clone(), self.thread_counter.clone(), self.thread_count_max.clone())
}
}
}
#[derive(Clone)]
pub struct ThreadPool {
jobs: Sender<Thunk<'static>>,
job_receiver: Arc<Mutex<Receiver<Thunk<'static>>>>,
active_count: Arc<Mutex<usize>>,
max_count: Arc<Mutex<usize>>,
}
impl ThreadPool {
pub fn new(threads: usize) -> ThreadPool {
assert!(threads >= 1);
let (tx, rx) = channel::<Thunk<'static>>();
let rx = Arc::new(Mutex::new(rx));
let active_count = Arc::new(Mutex::new(0));
let max_count = Arc::new(Mutex::new(threads));
for _ in 0..threads {
spawn_in_pool(rx.clone(), active_count.clone(), max_count.clone());
}
ThreadPool {
jobs: tx,
job_receiver: rx.clone(),
active_count: active_count,
max_count: max_count
}
}
pub fn execute<F>(&self, job: F)
where F : FnOnce() + Send + 'static
{
self.jobs.send(Box::new(move || job())).unwrap();
}
pub fn active_count(&self) -> usize {
*self.active_count.lock().unwrap()
}
pub fn max_count(&self) -> usize {
*self.max_count.lock().unwrap()
}
pub fn set_threads(&mut self, threads: usize) {
assert!(threads >= 1);
let current_max = self.max_count.lock().unwrap().clone();
*self.max_count.lock().unwrap() = threads;
if threads > current_max {
for _ in 0..(threads - current_max) {
spawn_in_pool(self.job_receiver.clone(), self.active_count.clone(), self.max_count.clone());
}
}
}
}
fn spawn_in_pool(jobs: Arc<Mutex<Receiver<Thunk<'static>>>>,
thread_counter: Arc<Mutex<usize>>,
thread_count_max: Arc<Mutex<usize>>) {
thread::spawn(move || {
let sentinel = Sentinel::new(&jobs, &thread_counter, &thread_count_max);
loop {
let thread_counter_val = thread_counter.lock().unwrap().clone();
let thread_count_max_val = thread_count_max.lock().unwrap().clone();
if thread_counter_val < thread_count_max_val {
let message = {
let lock = jobs.lock().unwrap();
lock.recv()
};
match message {
Ok(job) => {
*thread_counter.lock().unwrap() += 1;
job.call_box();
*thread_counter.lock().unwrap() -= 1;
},
Err(..) => break
}
} else {
break;
}
}
sentinel.cancel();
});
}
#[cfg(test)]
mod test {
#![allow(deprecated)]
use super::ThreadPool;
use std::sync::mpsc::channel;
use std::sync::{Arc, Barrier};
use std::thread::sleep_ms;
const TEST_TASKS: usize = 4;
#[test]
fn test_set_threads_increasing() {
let new_thread_amount = 6;
let mut pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move || {
loop {
sleep_ms(10000)
}
});
}
pool.set_threads(new_thread_amount);
for _ in 0..(new_thread_amount - TEST_TASKS) {
pool.execute(move || {
loop {
sleep_ms(10000)
}
});
}
sleep_ms(1000);
assert_eq!(pool.active_count(), new_thread_amount);
}
#[test]
fn test_set_threads_decreasing() {
let new_thread_amount = 2;
let mut pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move || {
1 + 1;
});
}
pool.set_threads(new_thread_amount);
for _ in 0..new_thread_amount {
pool.execute(move || {
loop {
sleep_ms(10000)
}
});
}
sleep_ms(1000);
assert_eq!(pool.active_count(), new_thread_amount);
}
#[test]
fn test_active_count() {
let pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move|| {
loop {
sleep_ms(10000);
}
});
}
sleep_ms(1000);
let active_count = pool.active_count();
assert_eq!(active_count, TEST_TASKS);
let initialized_count = pool.max_count();
assert_eq!(initialized_count, TEST_TASKS);
}
#[test]
fn test_works() {
let pool = ThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move|| {
tx.send(1).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
#[should_panic]
fn test_zero_tasks_panic() {
ThreadPool::new(0);
}
#[test]
fn test_recovery_from_subtask_panic() {
let pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move|| -> () { panic!() });
}
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move|| {
tx.send(1).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
let pool = ThreadPool::new(TEST_TASKS);
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
for _ in 0..TEST_TASKS {
let waiter = waiter.clone();
pool.execute(move|| {
waiter.wait();
panic!();
});
}
drop(pool);
waiter.wait();
}
}