use std::sync::mpsc::{channel, Sender, Receiver};
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
trait FnBox {
fn call_box(self: Box<Self>);
}
impl<F: FnOnce()> FnBox for F {
fn call_box(self: Box<F>) {
(*self)()
}
}
type Thunk<'a> = Box<FnBox + Send + 'a>;
struct Sentinel<'a> {
jobs: &'a Arc<Mutex<Receiver<Thunk<'static>>>>,
thread_counter: &'a Arc<AtomicUsize>,
thread_count_max: &'a Arc<AtomicUsize>,
active: bool
}
impl<'a> Sentinel<'a> {
fn new(jobs: &'a Arc<Mutex<Receiver<Thunk<'static>>>>,
thread_counter: &'a Arc<AtomicUsize>,
thread_count_max: &'a Arc<AtomicUsize>) -> Sentinel<'a> {
Sentinel {
jobs: jobs,
thread_counter: thread_counter,
thread_count_max: thread_count_max,
active: true
}
}
fn cancel(mut self) {
self.active = false;
}
}
impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
if self.active {
self.thread_counter.fetch_sub(1, Ordering::SeqCst);
spawn_in_pool(self.jobs.clone(), self.thread_counter.clone(), self.thread_count_max.clone())
}
}
}
#[derive(Clone)]
pub struct ThreadPool {
jobs: Sender<Thunk<'static>>,
job_receiver: Arc<Mutex<Receiver<Thunk<'static>>>>,
active_count: Arc<AtomicUsize>,
max_count: Arc<AtomicUsize>,
}
impl ThreadPool {
pub fn new(threads: usize) -> ThreadPool {
assert!(threads >= 1);
let (tx, rx) = channel::<Thunk<'static>>();
let rx = Arc::new(Mutex::new(rx));
let active_count = Arc::new(AtomicUsize::new(0));
let max_count = Arc::new(AtomicUsize::new(threads));
for _ in 0..threads {
spawn_in_pool(rx.clone(), active_count.clone(), max_count.clone());
}
ThreadPool {
jobs: tx,
job_receiver: rx.clone(),
active_count: active_count,
max_count: max_count
}
}
pub fn execute<F>(&self, job: F)
where F : FnOnce() + Send + 'static
{
self.jobs.send(Box::new(move || job())).unwrap();
}
pub fn active_count(&self) -> usize {
self.active_count.load(Ordering::Relaxed)
}
pub fn max_count(&self) -> usize {
self.max_count.load(Ordering::Relaxed)
}
pub fn set_threads(&mut self, threads: usize) {
assert!(threads >= 1);
let current_max = (*self.max_count).swap(threads, Ordering::Release);
if threads > current_max {
for _ in 0..(threads - current_max) {
spawn_in_pool(self.job_receiver.clone(), self.active_count.clone(), self.max_count.clone());
}
}
}
}
fn spawn_in_pool(jobs: Arc<Mutex<Receiver<Thunk<'static>>>>,
thread_counter: Arc<AtomicUsize>,
thread_count_max: Arc<AtomicUsize>) {
thread::spawn(move || {
let sentinel = Sentinel::new(&jobs, &thread_counter, &thread_count_max);
loop {
let thread_counter_val = thread_counter.load(Ordering::Acquire);
let thread_count_max_val = thread_count_max.load(Ordering::Relaxed);
if thread_counter_val < thread_count_max_val {
let message = {
let lock = jobs.lock().unwrap();
lock.recv()
};
match message {
Ok(job) => {
thread_counter.fetch_add(1, Ordering::SeqCst);
job.call_box();
thread_counter.fetch_sub(1, Ordering::SeqCst);
},
Err(..) => break
}
} else {
break;
}
}
sentinel.cancel();
});
}
#[cfg(test)]
mod test {
#![allow(deprecated)]
use super::ThreadPool;
use std::sync::mpsc::channel;
use std::sync::{Arc, Barrier};
use std::thread::sleep_ms;
const TEST_TASKS: usize = 4;
#[test]
fn test_set_threads_increasing() {
let new_thread_amount = TEST_TASKS + 8;
let mut pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move || {
loop {
sleep_ms(10000)
}
});
}
pool.set_threads(new_thread_amount);
for _ in 0..(new_thread_amount - TEST_TASKS) {
pool.execute(move || {
loop {
sleep_ms(10000)
}
});
}
sleep_ms(1024);
assert_eq!(pool.active_count(), new_thread_amount);
}
#[test]
fn test_set_threads_decreasing() {
let new_thread_amount = 2;
let mut pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move || {
1 + 1;
});
}
pool.set_threads(new_thread_amount);
for _ in 0..new_thread_amount {
pool.execute(move || {
loop {
sleep_ms(10000)
}
});
}
sleep_ms(1024);
assert_eq!(pool.active_count(), new_thread_amount);
}
#[test]
fn test_active_count() {
let pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move|| {
loop {
sleep_ms(10000);
}
});
}
sleep_ms(1024);
let active_count = pool.active_count();
assert_eq!(active_count, TEST_TASKS);
let initialized_count = pool.max_count();
assert_eq!(initialized_count, TEST_TASKS);
}
#[test]
fn test_works() {
let pool = ThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move|| {
tx.send(1).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
#[should_panic]
fn test_zero_tasks_panic() {
ThreadPool::new(0);
}
#[test]
fn test_recovery_from_subtask_panic() {
let pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move|| -> () { panic!() });
}
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move|| {
tx.send(1).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
let pool = ThreadPool::new(TEST_TASKS);
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
for _ in 0..TEST_TASKS {
let waiter = waiter.clone();
pool.execute(move|| {
waiter.wait();
panic!("Ignore this panic, it should!");
});
}
drop(pool);
waiter.wait();
}
#[test]
fn test_massive_task_creation() {
let test_tasks = 4_200_000;
let pool = ThreadPool::new(TEST_TASKS);
let b0 = Arc::new(Barrier::new(TEST_TASKS +1));
let b1 = Arc::new(Barrier::new(TEST_TASKS +1));
let (tx, rx) = channel();
for i in 0..test_tasks {
let tx = tx.clone();
let (b0, b1) = (b0.clone(), b1.clone());
pool.execute(move|| {
if i < TEST_TASKS {
b0.wait();
b1.wait();
}
tx.send(1).is_ok();
});
}
b0.wait();
assert_eq!(pool.active_count(), TEST_TASKS);
b1.wait();
assert_eq!(rx.iter().take(test_tasks).fold(0, |a, b| a + b), test_tasks);
assert!(pool.active_count() <= 1);
}
#[test]
fn test_shrink() {
let test_tasks_begin = TEST_TASKS + 2;
let mut pool = ThreadPool::new(test_tasks_begin);
let b0 = Arc::new(Barrier::new(test_tasks_begin +1));
let b1 = Arc::new(Barrier::new(test_tasks_begin +1));
for _ in 0..test_tasks_begin {
let (b0, b1) = (b0.clone(), b1.clone());
pool.execute(move|| {
b0.wait();
b1.wait();
});
}
let b2 = Arc::new(Barrier::new(TEST_TASKS +1));
let b3 = Arc::new(Barrier::new(TEST_TASKS +1));
for _ in 0..TEST_TASKS {
let (b2, b3) = (b2.clone(), b3.clone());
pool.execute(move|| {
b2.wait();
b3.wait();
});
}
b0.wait();
pool.set_threads(TEST_TASKS);
assert_eq!(pool.active_count(), test_tasks_begin);
b1.wait();
b2.wait();
assert_eq!(pool.active_count(), TEST_TASKS);
b3.wait();
}
}