use std::{
num::NonZeroUsize,
panic::AssertUnwindSafe,
ptr::NonNull,
sync::{
atomic::{AtomicUsize, Ordering},
mpsc, Mutex, PoisonError,
},
thread::Thread,
};
use crate::util::{defer, sync::SyncWrap};
pub(crate) static BENCH_POOL: ThreadPool = ThreadPool::new();
pub(crate) struct ThreadPool {
threads: Mutex<Vec<mpsc::SyncSender<Task>>>,
}
impl ThreadPool {
const fn new() -> Self {
Self { threads: Mutex::new(Vec::new()) }
}
#[inline]
pub fn par_extend<T, F>(&self, vec: &mut Vec<Option<T>>, aux_threads: usize, task: F)
where
F: Sync + Fn(usize) -> T,
T: Sync + Send,
{
unsafe {
let old_len = vec.len();
let additional = aux_threads + 1;
vec.reserve_exact(additional);
vec.spare_capacity_mut().iter_mut().for_each(|val| {
val.write(None);
});
vec.set_len(old_len + additional);
let ptr = SyncWrap::new(vec.as_mut_ptr().add(old_len));
self.broadcast(aux_threads, move |index| {
ptr.add(index).write(Some(task(index)));
});
}
}
#[inline]
pub fn broadcast<F>(&self, aux_threads: usize, task: F)
where
F: Sync + Fn(usize),
{
unsafe {
let task = TaskShared::new(aux_threads, task);
let task = Task { shared: NonNull::from(&task).cast() };
self.broadcast_task(aux_threads, task);
}
}
unsafe fn broadcast_task(&self, aux_threads: usize, task: Task) {
if aux_threads > 0 {
let threads = &mut *self.threads.lock().unwrap_or_else(PoisonError::into_inner);
if let Some(additional) = NonZeroUsize::new(aux_threads.saturating_sub(threads.len())) {
spawn(additional, threads);
}
for thread in &threads[..aux_threads] {
thread.send(task).unwrap();
}
}
let main_result = std::panic::catch_unwind(AssertUnwindSafe(|| task.run(0)));
while task.shared.as_ref().ref_count.load(Ordering::Acquire) > 0 {
std::thread::park();
}
drop(main_result);
}
pub fn drop_threads(&self) {
*self.threads.lock().unwrap_or_else(PoisonError::into_inner) = Default::default();
}
#[cfg(test)]
fn aux_thread_count(&self) -> usize {
self.threads.lock().unwrap_or_else(PoisonError::into_inner).len()
}
}
#[derive(Clone, Copy)]
struct Task {
shared: NonNull<TaskShared<()>>,
}
unsafe impl Send for Task {}
unsafe impl Sync for Task {}
impl Task {
#[inline]
unsafe fn run(&self, thread_id: usize) {
let shared_ptr = self.shared.as_ptr();
let shared = &*shared_ptr;
(shared.task_fn_ptr)(shared_ptr.cast(), thread_id);
}
}
#[repr(C)]
struct TaskShared<F> {
main_thread: Thread,
ref_count: AtomicUsize,
task_fn_ptr: unsafe fn(task: *const TaskShared<()>, thread: usize),
task_fn: F,
}
impl<F> TaskShared<F> {
#[inline]
fn new(aux_threads: usize, task_fn: F) -> Self
where
F: Sync + Fn(usize),
{
unsafe fn call<F>(task: *const TaskShared<()>, thread: usize)
where
F: Fn(usize),
{
let task_fn = &(*task.cast::<TaskShared<F>>()).task_fn;
task_fn(thread);
}
Self {
main_thread: std::thread::current(),
ref_count: AtomicUsize::new(aux_threads),
task_fn_ptr: call::<F>,
task_fn,
}
}
}
#[cold]
fn spawn(additional: NonZeroUsize, threads: &mut Vec<mpsc::SyncSender<Task>>) {
let next_thread_id = threads.len() + 1;
threads.extend((next_thread_id..(next_thread_id + additional.get())).map(|thread_id| {
let (sender, receiver) = mpsc::sync_channel::<Task>(0);
let work = move || {
let panic_guard = defer(|| std::process::abort());
while let Ok(task) = receiver.recv() {
let result =
std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { task.run(thread_id) }));
unsafe {
let main_thread = task.shared.as_ref().main_thread.clone();
if task.shared.as_ref().ref_count.fetch_sub(1, Ordering::Release) == 1 {
main_thread.unpark();
}
}
drop(result);
}
std::mem::forget(panic_guard);
};
std::thread::Builder::new()
.name(format!("divan-{thread_id}"))
.spawn(work)
.expect("failed to spawn thread");
sender
}));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extend() {
static TEST_POOL: ThreadPool = ThreadPool::new();
fn test(aux_threads: usize, final_aux_threads: usize) {
let total_threads = aux_threads + 1;
let mut results = Vec::new();
let expected = (0..total_threads).map(Some).collect::<Vec<_>>();
TEST_POOL.par_extend(&mut results, aux_threads, |index| index);
assert_eq!(results, expected);
assert_eq!(TEST_POOL.aux_thread_count(), final_aux_threads);
}
test(0, 0);
test(1, 1);
test(2, 2);
test(3, 3);
test(4, 4);
test(8, 8);
test(4, 8);
test(0, 8);
TEST_POOL.drop_threads();
}
#[test]
fn broadcast_sleep() {
use std::time::Duration;
static TEST_POOL: ThreadPool = ThreadPool::new();
TEST_POOL.broadcast(10, |thread_id| {
if thread_id > 0 {
std::thread::sleep(Duration::from_millis(10));
}
});
TEST_POOL.drop_threads();
}
#[test]
fn broadcast_thread_id() {
static TEST_POOL: ThreadPool = ThreadPool::new();
let main_thread = std::thread::current().id();
TEST_POOL.broadcast(10, |thread_id| {
let is_main = main_thread == std::thread::current().id();
assert_eq!(is_main, thread_id == 0);
});
TEST_POOL.drop_threads();
}
}
#[cfg(feature = "internal_benches")]
mod benches {
use super::*;
fn aux_thread_counts() -> impl Iterator<Item = usize> {
let mut available_parallelism = std::thread::available_parallelism().ok().map(|n| n.get());
let range = 0..=16;
if let Some(n) = available_parallelism {
if range.contains(&n) {
available_parallelism = None;
}
}
range.chain(available_parallelism)
}
#[crate::bench(crate = crate, args = aux_thread_counts())]
fn broadcast(bencher: crate::Bencher, aux_threads: usize) {
let pool = ThreadPool::new();
let benched = move || pool.broadcast(aux_threads, crate::black_box_drop);
benched();
bencher.bench(benched);
}
#[crate::bench(crate = crate, args = aux_thread_counts(), sample_size = 1)]
fn broadcast_once(bencher: crate::Bencher, aux_threads: usize) {
bencher
.with_inputs(ThreadPool::new)
.bench_refs(|pool| pool.broadcast(aux_threads, crate::black_box_drop));
}
}