#![feature(unsafe_destructor)]
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{channel, Sender, Receiver};
use std::thread;
trait FnBox<A, R> {
fn call_box(self: Box<Self>, a: A) -> R;
}
impl<A, R, F: FnOnce(A) -> R> FnBox<A, R> for F {
fn call_box(self: Box<F>, a: A) -> R {
(*self)(a)
}
}
type Thunk = Box<FnBox<(), ()> + Send + 'static>;
struct Sentinel<'a> {
jobs: &'a Arc<Mutex<Receiver<Thunk>>>,
active: bool
}
impl<'a> Sentinel<'a> {
fn new(jobs: &'a Arc<Mutex<Receiver<Thunk>>>) -> Sentinel<'a> {
Sentinel {
jobs: jobs,
active: true
}
}
fn cancel(mut self) {
self.active = false;
}
}
#[unsafe_destructor]
impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
if self.active {
spawn_in_pool(self.jobs.clone())
}
}
}
pub struct ThreadPool {
jobs: Sender<Thunk>
}
impl ThreadPool {
pub fn new(threads: usize) -> ThreadPool {
assert!(threads >= 1);
let (tx, rx) = channel::<Thunk>();
let rx = Arc::new(Mutex::new(rx));
for _ in 0..threads {
spawn_in_pool(rx.clone());
}
ThreadPool { jobs: tx }
}
pub fn execute<F>(&self, job: F)
where F : FnOnce() + Send + 'static
{
self.jobs.send(Box::new(move |()| job())).unwrap();
}
}
fn spawn_in_pool(jobs: Arc<Mutex<Receiver<Thunk>>>) {
thread::spawn(move || {
let sentinel = Sentinel::new(&jobs);
loop {
let message = {
let lock = jobs.lock().unwrap();
lock.recv()
};
match message {
Ok(job) => job.call_box(()),
Err(..) => break
}
}
sentinel.cancel();
});
}
#[cfg(test)]
mod test {
use super::*;
use std::sync::mpsc::channel;
use std::sync::{Arc, Barrier};
const TEST_TASKS: usize = 4;
#[test]
fn test_works() {
let pool = ThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move|| {
tx.send(1).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
#[should_fail]
fn test_zero_tasks_panic() {
ThreadPool::new(0);
}
#[test]
fn test_recovery_from_subtask_panic() {
let pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move|| -> () { panic!() });
}
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move|| {
tx.send(1).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
let pool = ThreadPool::new(TEST_TASKS);
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
for _ in 0..TEST_TASKS {
let waiter = waiter.clone();
pool.execute(move|| {
waiter.wait();
panic!();
});
}
drop(pool);
waiter.wait();
}
}