#![warn(missing_debug_implementations, missing_docs)]
use std::borrow::Cow;
use std::collections::VecDeque;
use std::mem;
use std::panic;
use std::sync::Arc;
use std::sync::{Condvar, Mutex};
use std::thread;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ThreadPool {
inner: Arc<Inner>,
}
#[derive(Debug)]
struct Inner {
locked: Mutex<Locked>,
thread_condvar: Condvar,
all_complete: Condvar,
thread_name: Cow<'static, str>,
thread_stack_size: Option<usize>,
idle_timeout: Duration,
}
#[derive(Debug)]
struct Locked {
work: VecDeque<RawFunction>,
to_prune: usize,
sleeping_threads: usize,
workers: usize,
spawnable: usize,
#[cfg(miri)]
join_handles: Vec<thread::JoinHandle<()>>,
}
impl Inner {
fn thread_loop(&self) {
let mut locked = self.locked.lock().unwrap();
loop {
if let Some(f) = locked.work.pop_front() {
locked.workers += 1;
drop(locked);
unsafe { (f.run)(f.data) };
locked = self.locked.lock().unwrap();
locked.workers -= 1;
if locked.workers == 0 && locked.work.is_empty() {
self.all_complete.notify_all();
}
} else if locked.to_prune > 0 {
locked.to_prune -= 1;
break;
} else {
locked.sleeping_threads += 1;
let timed_out = if cfg!(miri) {
locked = self.thread_condvar.wait(locked).unwrap();
false
} else {
let (new_locked, wait_res) = self
.thread_condvar
.wait_timeout(locked, self.idle_timeout)
.unwrap();
locked = new_locked;
wait_res.timed_out()
};
locked.sleeping_threads -= 1;
if timed_out {
break;
}
}
}
locked.spawnable += 1;
}
fn start_thread(self: Arc<Self>) {
#[cfg(miri)]
let self_2 = Arc::clone(&self);
let mut builder = thread::Builder::new();
if cfg!(not(miri)) {
builder = builder.name(self.thread_name.clone().into_owned());
}
if let Some(stack_size) = self.thread_stack_size {
builder = builder.stack_size(stack_size);
}
let handle = builder
.spawn(move || self.thread_loop())
.expect("failed to spawn worker thread");
#[cfg(miri)]
{
self_2.locked.lock().unwrap().join_handles.push(handle);
}
#[cfg(not(miri))]
drop(handle);
}
}
impl ThreadPool {
#[must_use]
pub fn new() -> Self {
Builder::new().build()
}
#[must_use]
pub fn builder() -> Builder {
Builder::new()
}
pub unsafe fn spawn_raw<T>(&self, data: *const T, run: unsafe fn(*const T)) {
let raw = RawFunction {
data: data.cast(),
run: mem::transmute::<unsafe fn(*const T), unsafe fn(*const ())>(run),
};
let mut locked = self.inner.locked.lock().unwrap();
locked.work.push_back(raw);
if locked.sleeping_threads == 0 {
if let Some(new_spawnable) = locked.spawnable.checked_sub(1) {
locked.spawnable = new_spawnable;
drop(locked);
self.inner.clone().start_thread();
}
} else {
self.inner.thread_condvar.notify_one();
}
}
pub fn spawn_boxed<F: FnOnce() + Send + 'static>(&self, f: Box<F>) {
unsafe {
self.spawn_raw(Box::into_raw(f), |f| {
struct AbortOnDrop;
impl Drop for AbortOnDrop {
fn drop(&mut self) {
std::process::abort();
}
}
let guard = AbortOnDrop;
let f = Box::from_raw(f as *mut F);
let _ = panic::catch_unwind(panic::AssertUnwindSafe(f));
mem::forget(guard);
})
};
}
pub fn wait_all_complete(&self) {
let locked = self.inner.locked.lock().unwrap();
let locked = self
.inner
.all_complete
.wait_while(locked, |locked| {
locked.workers != 0 || !locked.work.is_empty()
})
.unwrap();
drop(locked);
}
pub fn prune(&self) {
let mut locked = self.inner.locked.lock().unwrap();
locked.to_prune = locked.sleeping_threads;
self.inner.thread_condvar.notify_all();
}
#[cfg(test)]
fn miri_shutdown(&self) {
#[cfg(miri)]
{
let mut locked = self.inner.locked.lock().unwrap();
locked.to_prune = locked.sleeping_threads + locked.workers;
self.inner.thread_condvar.notify_all();
let join_handles = std::mem::take(&mut locked.join_handles);
drop(locked);
for handle in join_handles {
handle.join().unwrap();
}
}
}
}
impl Default for ThreadPool {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct Builder {
max_threads: usize,
thread_name: Cow<'static, str>,
thread_stack_size: Option<usize>,
idle_timeout: Duration,
}
impl Builder {
#[must_use]
pub fn new() -> Self {
Self {
max_threads: 512,
thread_name: Cow::Borrowed("blocking-worker"),
thread_stack_size: None,
idle_timeout: Duration::from_secs(10),
}
}
#[must_use]
pub fn max_threads(mut self, max_threads: usize) -> Self {
self.max_threads = max_threads;
self
}
#[must_use]
pub fn thread_name<N: Into<Cow<'static, str>>>(mut self, name: N) -> Self {
self.thread_name = name.into();
self
}
#[must_use]
pub fn thread_stack_size(mut self, stack_size: usize) -> Self {
self.thread_stack_size = Some(stack_size);
self
}
#[must_use]
pub fn idle_timeout(mut self, idle_timeout: Duration) -> Self {
self.idle_timeout = idle_timeout;
self
}
#[must_use]
pub fn build(self) -> ThreadPool {
ThreadPool {
inner: Arc::new(Inner {
locked: Mutex::new(Locked {
work: VecDeque::new(),
to_prune: 0,
sleeping_threads: 0,
workers: 0,
spawnable: self.max_threads,
#[cfg(miri)]
join_handles: Vec::new(),
}),
thread_condvar: Condvar::new(),
all_complete: Condvar::new(),
thread_name: self.thread_name,
thread_stack_size: self.thread_stack_size,
idle_timeout: self.idle_timeout,
}),
}
}
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct RawFunction {
data: *const (),
run: unsafe fn(*const ()),
}
unsafe impl Send for RawFunction {}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
#[test]
fn more_work_than_threads() {
let thread_pool = Builder::new().max_threads(2).build();
let value: AtomicUsize = AtomicUsize::new(0);
let value: &'static AtomicUsize = unsafe { &*(&value as *const AtomicUsize) };
thread_pool.spawn_boxed(Box::new(move || {
assert_eq!(value.load(SeqCst), 0);
wait();
assert!(value.fetch_add(1, SeqCst) < 2);
}));
thread_pool.spawn_boxed(Box::new(move || {
assert_eq!(value.load(SeqCst), 0);
wait();
assert!(value.fetch_add(1, SeqCst) < 2);
}));
thread_pool.spawn_boxed(Box::new(move || {
assert!(matches!(value.load(SeqCst), 1 | 2));
wait();
assert_eq!(value.load(SeqCst), 2);
}));
thread_pool.wait_all_complete();
assert_eq!(value.load(SeqCst), 2);
thread_pool.miri_shutdown();
}
}
#[cfg(test)]
fn wait() {
if cfg!(miri) {
for _ in 0..3_000 {
thread::yield_now();
}
} else {
thread::sleep(Duration::from_secs(1));
}
}