use std::fmt;
use std::sync::mpsc::{channel, Sender, Receiver};
use std::sync::{Arc, Mutex, Condvar};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread::{Builder, panicking};
trait FnBox {
fn call_box(self: Box<Self>);
}
impl<F: FnOnce()> FnBox for F {
fn call_box(self: Box<F>) {
(*self)()
}
}
type Thunk<'a> = Box<FnBox + Send + 'a>;
struct Sentinel<'a> {
shared_data: &'a Arc<ThreadPoolSharedData>,
active: bool,
}
impl<'a> Sentinel<'a> {
fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
Sentinel {
shared_data: shared_data,
active: true,
}
}
fn cancel(mut self) {
self.active = false;
}
}
impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
if self.active {
self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
if panicking() {
self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
}
self.shared_data.no_work_notify_all();
spawn_in_pool(self.shared_data.clone())
}
}
}
struct ThreadPoolSharedData {
name: Option<String>,
job_receiver: Mutex<Receiver<Thunk<'static>>>,
empty_trigger: Mutex<()>,
empty_condvar: Condvar,
queued_count: AtomicUsize,
active_count: AtomicUsize,
max_thread_count: AtomicUsize,
panic_count: AtomicUsize,
}
impl ThreadPoolSharedData {
fn has_work(&self) -> bool {
self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
}
fn no_work_notify_all(&self) {
if !self.has_work() {
*self.empty_trigger.lock().unwrap();
self.empty_condvar.notify_all();
}
}
}
#[derive(Clone)]
pub struct ThreadPool {
jobs: Sender<Thunk<'static>>,
shared_data: Arc<ThreadPoolSharedData>,
}
impl ThreadPool {
pub fn new(num_threads: usize) -> ThreadPool {
ThreadPool::new_pool(None, num_threads)
}
pub fn with_name(name: String, num_threads: usize) -> ThreadPool {
ThreadPool::new_pool(Some(name), num_threads)
}
#[inline(always)]
#[deprecated]
pub fn new_with_name(name: String, num_threads: usize) -> ThreadPool {
ThreadPool::with_name(name, num_threads)
}
#[inline]
fn new_pool(name: Option<String>, num_threads: usize) -> ThreadPool {
assert!(num_threads >= 1);
let (tx, rx) = channel::<Thunk<'static>>();
let shared_data = Arc::new(ThreadPoolSharedData {
name: name,
job_receiver: Mutex::new(rx),
empty_condvar: Condvar::new(),
empty_trigger: Mutex::new(()),
queued_count: AtomicUsize::new(0),
active_count: AtomicUsize::new(0),
max_thread_count: AtomicUsize::new(num_threads),
panic_count: AtomicUsize::new(0),
});
for _ in 0..num_threads {
spawn_in_pool(shared_data.clone());
}
ThreadPool {
jobs: tx,
shared_data: shared_data,
}
}
pub fn execute<F>(&self, job: F)
where
F: FnOnce() + Send + 'static,
{
self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
self.jobs
.send(Box::new(job))
.expect("ThreadPool::execute unable to send job into queue.");
}
pub fn queued_count(&self) -> usize {
self.shared_data.queued_count.load(Ordering::Relaxed)
}
pub fn active_count(&self) -> usize {
self.shared_data.active_count.load(Ordering::SeqCst)
}
pub fn max_count(&self) -> usize {
self.shared_data.max_thread_count.load(Ordering::Relaxed)
}
pub fn panic_count(&self) -> usize {
self.shared_data.panic_count.load(Ordering::Relaxed)
}
#[deprecated(since = "1.3.0", note = "use ThreadPool::set_num_threads")]
pub fn set_threads(&mut self, num_threads: usize) {
self.set_num_threads(num_threads)
}
pub fn set_num_threads(&mut self, num_threads: usize) {
assert!(num_threads >= 1);
let prev_num_threads = self.shared_data.max_thread_count.swap(
num_threads,
Ordering::Release,
);
if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) {
for _ in 0..num_spawn {
spawn_in_pool(self.shared_data.clone());
}
}
}
pub fn join(&self) {
while self.shared_data.has_work() {
let mut lock = self.shared_data.empty_trigger.lock().unwrap();
while self.shared_data.has_work() {
lock = self.shared_data.empty_condvar.wait(lock).unwrap();
}
}
}
}
impl fmt::Debug for ThreadPool {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ThreadPool")
.field("name", &self.shared_data.name)
.field("queued_count", &self.queued_count())
.field("active_count", &self.active_count())
.field("max_count", &self.max_count())
.finish()
}
}
fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) {
let mut builder = Builder::new();
if let Some(ref name) = shared_data.name {
builder = builder.name(name.clone());
}
builder
.spawn(move || {
let sentinel = Sentinel::new(&shared_data);
loop {
let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
if thread_counter_val >= max_thread_count_val {
break;
}
let message = {
let lock = shared_data
.job_receiver
.lock()
.expect("Worker thread unable to lock job_receiver");
lock.recv()
};
let job = match message {
Ok(job) => job,
Err(..) => break,
};
shared_data.active_count.fetch_add(1, Ordering::SeqCst);
shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
job.call_box();
shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
shared_data.no_work_notify_all();
}
sentinel.cancel();
})
.unwrap();
}
#[cfg(test)]
mod test {
use super::ThreadPool;
use std::sync::{Arc, Barrier};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::{sync_channel, channel};
use std::thread::{self, sleep};
use std::time::Duration;
const TEST_TASKS: usize = 4;
#[test]
fn test_set_num_threads_increasing() {
let new_thread_amount = TEST_TASKS + 8;
let mut pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move || sleep(Duration::from_secs(23)));
}
sleep(Duration::from_secs(1));
assert_eq!(pool.active_count(), TEST_TASKS);
pool.set_num_threads(new_thread_amount);
for _ in 0..(new_thread_amount - TEST_TASKS) {
pool.execute(move || sleep(Duration::from_secs(23)));
}
sleep(Duration::from_secs(1));
assert_eq!(pool.active_count(), new_thread_amount);
pool.join();
}
#[test]
fn test_set_num_threads_decreasing() {
let new_thread_amount = 2;
let mut pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move || { 1 + 1; });
}
pool.set_num_threads(new_thread_amount);
for _ in 0..new_thread_amount {
pool.execute(move || sleep(Duration::from_secs(23)));
}
sleep(Duration::from_secs(1));
assert_eq!(pool.active_count(), new_thread_amount);
pool.join();
}
#[test]
fn test_active_count() {
let pool = ThreadPool::new(TEST_TASKS);
for _ in 0..2 * TEST_TASKS {
pool.execute(move || loop {
sleep(Duration::from_secs(10))
});
}
sleep(Duration::from_secs(1));
let active_count = pool.active_count();
assert_eq!(active_count, TEST_TASKS);
let initialized_count = pool.max_count();
assert_eq!(initialized_count, TEST_TASKS);
}
#[test]
fn test_works() {
let pool = ThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move || { tx.send(1).unwrap(); });
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
#[should_panic]
fn test_zero_tasks_panic() {
ThreadPool::new(0);
}
#[test]
fn test_recovery_from_subtask_panic() {
let pool = ThreadPool::new(TEST_TASKS);
for _ in 0..TEST_TASKS {
pool.execute(move || panic!("Ignore this panic, it must!"));
}
pool.join();
assert_eq!(pool.panic_count(), TEST_TASKS);
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move || { tx.send(1).unwrap(); });
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
let pool = ThreadPool::new(TEST_TASKS);
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
for _ in 0..TEST_TASKS {
let waiter = waiter.clone();
pool.execute(move || {
waiter.wait();
panic!("Ignore this panic, it should!");
});
}
drop(pool);
waiter.wait();
}
#[test]
fn test_massive_task_creation() {
let test_tasks = 4_200_000;
let pool = ThreadPool::new(TEST_TASKS);
let b0 = Arc::new(Barrier::new(TEST_TASKS + 1));
let b1 = Arc::new(Barrier::new(TEST_TASKS + 1));
let (tx, rx) = channel();
for i in 0..test_tasks {
let tx = tx.clone();
let (b0, b1) = (b0.clone(), b1.clone());
pool.execute(move || {
if i < TEST_TASKS {
b0.wait();
b1.wait();
}
tx.send(1).is_ok();
});
}
b0.wait();
assert_eq!(pool.active_count(), TEST_TASKS);
b1.wait();
assert_eq!(rx.iter().take(test_tasks).fold(0, |a, b| a + b), test_tasks);
pool.join();
let atomic_active_count = pool.active_count();
assert!(
atomic_active_count == 0,
"atomic_active_count: {}",
atomic_active_count
);
}
#[test]
fn test_shrink() {
let test_tasks_begin = TEST_TASKS + 2;
let mut pool = ThreadPool::new(test_tasks_begin);
let b0 = Arc::new(Barrier::new(test_tasks_begin + 1));
let b1 = Arc::new(Barrier::new(test_tasks_begin + 1));
for _ in 0..test_tasks_begin {
let (b0, b1) = (b0.clone(), b1.clone());
pool.execute(move || {
b0.wait();
b1.wait();
});
}
let b2 = Arc::new(Barrier::new(TEST_TASKS + 1));
let b3 = Arc::new(Barrier::new(TEST_TASKS + 1));
for _ in 0..TEST_TASKS {
let (b2, b3) = (b2.clone(), b3.clone());
pool.execute(move || {
b2.wait();
b3.wait();
});
}
b0.wait();
pool.set_num_threads(TEST_TASKS);
assert_eq!(pool.active_count(), test_tasks_begin);
b1.wait();
b2.wait();
assert_eq!(pool.active_count(), TEST_TASKS);
b3.wait();
}
#[test]
fn test_name() {
let name = "test";
let mut pool = ThreadPool::with_name(name.to_owned(), 2);
let (tx, rx) = sync_channel(0);
for _ in 0..2 {
let tx = tx.clone();
pool.execute(move || {
let name = thread::current().name().unwrap().to_owned();
tx.send(name).unwrap();
});
}
pool.set_num_threads(3);
let tx_clone = tx.clone();
pool.execute(move || {
let name = thread::current().name().unwrap().to_owned();
tx_clone.send(name).unwrap();
panic!();
});
pool.execute(move || {
let name = thread::current().name().unwrap().to_owned();
tx.send(name).unwrap();
});
for thread_name in rx.iter().take(4) {
assert_eq!(name, thread_name);
}
}
#[test]
fn test_debug() {
let pool = ThreadPool::new(4);
let debug = format!("{:?}", pool);
assert_eq!(
debug,
"ThreadPool { name: None, queued_count: 0, active_count: 0, max_count: 4 }"
);
let pool = ThreadPool::with_name("hello".into(), 4);
let debug = format!("{:?}", pool);
assert_eq!(
debug,
"ThreadPool { name: Some(\"hello\"), queued_count: 0, active_count: 0, max_count: 4 }"
);
let pool = ThreadPool::new(4);
pool.execute(move || sleep(Duration::from_secs(5)));
sleep(Duration::from_secs(1));
let debug = format!("{:?}", pool);
assert_eq!(
debug,
"ThreadPool { name: None, queued_count: 0, active_count: 1, max_count: 4 }"
);
}
#[test]
fn test_repeate_join() {
let pool = ThreadPool::with_name("repeate join test".into(), 8);
let test_count = Arc::new(AtomicUsize::new(0));
for _ in 0..42 {
let test_count = test_count.clone();
pool.execute(move || {
sleep(Duration::from_secs(2));
test_count.fetch_add(1, Ordering::Release);
});
}
println!("{:?}", pool);
pool.join();
assert_eq!(42, test_count.load(Ordering::Acquire));
for _ in 0..42 {
let test_count = test_count.clone();
pool.execute(move || {
sleep(Duration::from_secs(2));
test_count.fetch_add(1, Ordering::Relaxed);
});
}
pool.join();
assert_eq!(84, test_count.load(Ordering::Relaxed));
}
#[test]
fn test_multi_join() {
use std::sync::mpsc::TryRecvError::*;
fn error(_s: String) {
}
let pool0 = ThreadPool::with_name("multi join pool0".into(), 4);
let pool1 = ThreadPool::with_name("multi join pool1".into(), 4);
let (tx, rx) = channel();
for i in 0..8 {
let pool1 = pool1.clone();
let pool0_ = pool0.clone();
let tx = tx.clone();
pool0.execute(move || {
pool1.execute(move || {
error(format!("p1: {} -=- {:?}\n", i, pool0_));
pool0_.join();
error(format!("p1: send({})\n", i));
tx.send(i).expect("send i from pool1 -> main");
});
error(format!("p0: {}\n", i));
});
}
drop(tx);
assert_eq!(rx.try_recv(), Err(Empty));
error(format!("{:?}\n{:?}\n", pool0, pool1));
pool0.join();
error(format!("pool0.join() complete =-= {:?}", pool1));
pool1.join();
error("pool1.join() complete\n".into());
assert_eq!(
rx.iter().fold(0, |acc, i| acc + i),
0 + 1 + 2 + 3 + 4 + 5 + 6 + 7
);
}
#[test]
fn test_empty_pool() {
let pool = ThreadPool::new(4);
pool.join();
assert!(true);
}
#[test]
fn test_no_fun_or_joy() {
fn sleepy_function() {
sleep(Duration::from_secs(6));
}
let pool = ThreadPool::with_name("no fun or joy".into(), 8);
pool.execute(sleepy_function);
let p_t = pool.clone();
thread::spawn(move || {
(0..23).map(|_| p_t.execute(sleepy_function)).count();
});
pool.join();
}
}