use std::cell::UnsafeCell;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
#[must_use]
pub fn partition_threads(n: usize) -> (usize, usize) {
let reader = n.div_ceil(3); let matcher = (2 * n) / 3; (reader.max(1), matcher.max(1))
}
type Job = Box<dyn FnOnce() + Send + 'static>;
struct SharedState {
queue: Mutex<QueueInner>,
job_available: Condvar,
}
struct QueueInner {
jobs: VecDeque<Job>,
shutdown: bool,
}
pub struct ThreadPool {
shared: Arc<SharedState>,
workers: Vec<thread::JoinHandle<()>>,
num_threads: usize,
}
impl ThreadPool {
#[must_use]
pub fn new(num_threads: usize) -> Self {
assert!(num_threads > 0, "ThreadPool requires at least 1 thread");
let shared = Arc::new(SharedState {
queue: Mutex::new(QueueInner {
jobs: VecDeque::new(),
shutdown: false,
}),
job_available: Condvar::new(),
});
let mut workers = Vec::with_capacity(num_threads);
for _ in 0..num_threads {
let worker_shared = Arc::clone(&shared);
workers.push(thread::spawn(move || worker_loop(&worker_shared)));
}
Self {
shared,
workers,
num_threads,
}
}
#[inline]
#[must_use]
pub fn num_threads(&self) -> usize {
self.num_threads
}
pub fn spawn<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
{
let mut queue = self.shared.queue.lock().unwrap();
queue.jobs.push_back(Box::new(f));
} self.shared.job_available.notify_one();
}
pub fn spawn_batch<I>(&self, jobs: I)
where
I: IntoIterator<Item = Box<dyn FnOnce() + Send + 'static>>,
{
let count = {
let mut queue = self.shared.queue.lock().unwrap();
let before = queue.jobs.len();
for job in jobs {
queue.jobs.push_back(job);
}
queue.jobs.len() - before
}; if count >= self.num_threads {
self.shared.job_available.notify_all();
} else {
for _ in 0..count {
self.shared.job_available.notify_one();
}
}
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
{
let mut queue = self.shared.queue.lock().unwrap();
queue.shutdown = true;
}
self.shared.job_available.notify_all();
for handle in self.workers.drain(..) {
let _ = handle.join();
}
}
}
fn worker_loop(shared: &SharedState) {
loop {
let next_job = {
let mut queue = shared.queue.lock().unwrap();
loop {
if let Some(ready) = queue.jobs.pop_front() {
break Some(ready);
}
if queue.shutdown {
break None;
}
queue = shared.job_available.wait(queue).unwrap();
}
};
match next_job {
Some(runnable) => runnable(),
None => return, }
}
}
#[cfg_attr(
any(
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "arm64ec",
target_arch = "powerpc64",
),
repr(align(128))
)]
#[cfg_attr(
any(
target_arch = "arm",
target_arch = "mips",
target_arch = "mips32r6",
target_arch = "mips64",
target_arch = "mips64r6",
target_arch = "sparc",
target_arch = "hexagon",
),
repr(align(32))
)]
#[cfg_attr(target_arch = "m68k", repr(align(16)))]
#[cfg_attr(target_arch = "s390x", repr(align(256)))]
#[cfg_attr(
not(any(
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "arm64ec",
target_arch = "powerpc64",
target_arch = "arm",
target_arch = "mips",
target_arch = "mips32r6",
target_arch = "mips64",
target_arch = "mips64r6",
target_arch = "sparc",
target_arch = "hexagon",
target_arch = "m68k",
target_arch = "s390x",
)),
repr(align(64))
)]
struct Slot<R> {
value: UnsafeCell<Option<R>>,
}
unsafe impl<R: Send> Send for Slot<R> {}
unsafe impl<R: Send> Sync for Slot<R> {}
impl<R> Slot<R> {
fn new() -> Self {
Self {
value: UnsafeCell::new(None),
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn parallel_work_queue<S, T, R, P, M, I, W, G>(
pool: &ThreadPool,
num_workers: usize,
items: &Arc<S>,
chunk_size: usize,
identity: I,
process_chunk: P,
reduce: M,
prepare: W,
merge: G,
) where
S: AsRef<[T]> + Send + Sync + ?Sized + 'static,
T: Send + Sync + 'static,
R: Send + 'static,
P: Fn(usize, &[T]) -> R + Send + Sync + 'static,
M: Fn(&mut R, R) + Send + Sync + 'static,
I: Fn() -> R + Send + Sync + 'static,
W: Fn(&mut R) + Send + Sync + 'static,
G: FnOnce(Vec<R>),
{
let items_slice: &[T] = AsRef::<[T]>::as_ref(&**items);
let total = items_slice.len();
if total == 0 {
merge(Vec::new());
return;
}
let num_chunks = total.div_ceil(chunk_size);
let next_chunk = Arc::new(AtomicUsize::new(0));
let slots: Arc<Vec<Slot<R>>> = Arc::new((0..num_workers).map(|_| Slot::new()).collect());
let remaining = Arc::new(AtomicCounter::new(num_workers));
remaining.set_waiter();
let process_chunk = Arc::new(process_chunk);
let reduce = Arc::new(reduce);
let prepare = Arc::new(prepare);
let identity = Arc::new(identity);
let jobs: Vec<Box<dyn FnOnce() + Send + 'static>> = (0..num_workers)
.map(|worker_id| {
let w_items = Arc::clone(items);
let w_next_chunk = Arc::clone(&next_chunk);
let w_slots: Arc<Vec<Slot<R>>> = Arc::clone(&slots);
let w_remaining = Arc::clone(&remaining);
let w_process_chunk = Arc::clone(&process_chunk);
let w_reduce = Arc::clone(&reduce);
let w_prepare = Arc::clone(&prepare);
let w_identity = Arc::clone(&identity);
let job: Box<dyn FnOnce() + Send + 'static> = Box::new(move || {
let local_acc = {
let mut local_acc = w_identity();
loop {
let chunk_idx = w_next_chunk.fetch_add(1, Ordering::Relaxed);
if chunk_idx >= num_chunks {
break;
}
let start = chunk_idx * chunk_size;
let end = total.min(start + chunk_size);
let slice: &[T] = AsRef::<[T]>::as_ref(&*w_items);
let partial = w_process_chunk(start, &slice[start..end]);
w_reduce(&mut local_acc, partial);
}
w_prepare(&mut local_acc);
local_acc
};
unsafe { *w_slots[worker_id].value.get() = Some(local_acc) };
drop(w_slots);
w_remaining.dec_and_notify();
});
job
})
.collect();
pool.spawn_batch(jobs);
remaining.wait_for_zero();
if let Some(slots) = Arc::into_inner(slots) {
let results: Vec<R> = slots.into_iter().filter_map(|slot| slot.value.into_inner()).collect();
merge(results);
} else {
log::error!("More than one ref to the slots remaining after workers exit. This SHOULD NOT happen.");
}
}
struct AtomicCounter {
count: AtomicUsize,
waiter: UnsafeCell<Option<thread::Thread>>,
}
unsafe impl Send for AtomicCounter {}
unsafe impl Sync for AtomicCounter {}
impl AtomicCounter {
fn new(n: usize) -> Self {
Self {
count: AtomicUsize::new(n),
waiter: UnsafeCell::new(None),
}
}
fn set_waiter(&self) {
unsafe { *self.waiter.get() = Some(thread::current()) };
}
fn dec_and_notify(&self) {
let prev = self.count.fetch_sub(1, Ordering::AcqRel);
debug_assert!(prev > 0, "AtomicCounter decremented below zero — double-decrement bug?");
if prev == 1 {
if let Some(t) = unsafe { &*self.waiter.get() } {
t.unpark();
}
}
}
fn wait_for_zero(&self) {
while self.count.load(Ordering::Acquire) > 0 {
thread::park();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn partition_threads_split() {
assert_eq!(partition_threads(1), (1, 1));
assert_eq!(partition_threads(2), (1, 1));
assert_eq!(partition_threads(3), (1, 2));
assert_eq!(partition_threads(6), (2, 4));
assert_eq!(partition_threads(8), (3, 5));
assert_eq!(partition_threads(9), (3, 6));
for n in 3..=64 {
let (r, m) = partition_threads(n);
assert_eq!(r + m, n, "partition_threads({n}) = ({r}, {m}) does not sum to {n}");
assert!(r >= 1);
assert!(m >= 1);
}
}
#[test]
fn spawn_runs_closure() {
let pool = ThreadPool::new(2);
let flag = Arc::new(AtomicUsize::new(0));
let flag2 = Arc::clone(&flag);
pool.spawn(move || {
flag2.store(42, Ordering::SeqCst);
});
std::thread::sleep(std::time::Duration::from_millis(50));
assert_eq!(flag.load(Ordering::SeqCst), 42);
}
#[test]
fn spawn_batch_runs_all() {
let pool = ThreadPool::new(4);
let counter = Arc::new(AtomicUsize::new(0));
let jobs: Vec<Box<dyn FnOnce() + Send + 'static>> = (0..10)
.map(|_| {
let c = Arc::clone(&counter);
let job: Box<dyn FnOnce() + Send + 'static> = Box::new(move || {
c.fetch_add(1, Ordering::SeqCst);
});
job
})
.collect();
pool.spawn_batch(jobs);
std::thread::sleep(std::time::Duration::from_millis(100));
assert_eq!(counter.load(Ordering::SeqCst), 10);
}
#[test]
fn parallel_work_queue_sums() {
let pool = ThreadPool::new(4);
let items: Arc<[u64]> = (1..=1000u64).collect::<Vec<_>>().into();
let mut result = 0u64;
parallel_work_queue(
&pool,
4,
&items,
64,
|| 0u64,
|_start, chunk| chunk.iter().sum::<u64>(),
|acc, partial| *acc += partial,
|_| {},
|worker_results| {
for partial in worker_results {
result += partial;
}
},
);
assert_eq!(result, 500_500);
}
#[test]
fn parallel_work_queue_empty() {
let pool = ThreadPool::new(2);
let items: Arc<[u64]> = Arc::from(Vec::<u64>::new().into_boxed_slice());
let mut result = Vec::<u64>::new();
parallel_work_queue(
&pool,
2,
&items,
64,
Vec::<u64>::new,
|_start, chunk| chunk.to_vec(),
|acc, mut partial| acc.append(&mut partial),
|_| {},
|worker_results| {
for partial in worker_results {
result.extend(partial);
}
},
);
assert!(result.is_empty());
}
#[test]
fn parallel_work_queue_single_thread() {
let pool = ThreadPool::new(1);
let items: Arc<[i32]> = (0..100i32).collect::<Vec<_>>().into();
let mut result = 0i32;
parallel_work_queue(
&pool,
1,
&items,
10,
|| 0i32,
|_start, chunk| chunk.iter().sum::<i32>(),
|acc, partial| *acc += partial,
|_| {},
|worker_results| {
for partial in worker_results {
result += partial;
}
},
);
assert_eq!(result, (0..100).sum::<i32>());
}
#[test]
fn pool_drop_joins_threads() {
let flag = Arc::new(AtomicUsize::new(0));
{
let pool = ThreadPool::new(2);
let f = Arc::clone(&flag);
pool.spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(30));
f.store(1, Ordering::SeqCst);
});
} assert_eq!(flag.load(Ordering::SeqCst), 1);
}
#[test]
fn parallel_work_queue_many_workers_few_chunks() {
let pool = ThreadPool::new(8);
let items: Arc<[u64]> = (1..=10u64).collect::<Vec<_>>().into();
let mut result = 0u64;
parallel_work_queue(
&pool,
8,
&items,
5,
|| 0u64,
|_start, chunk| chunk.iter().sum::<u64>(),
|acc, partial| *acc += partial,
|_| {},
|worker_results| {
for partial in worker_results {
result += partial;
}
},
);
assert_eq!(result, 55);
}
#[test]
fn parallel_work_queue_single_thread_pool_no_deadlock() {
let (tx, rx) = std::sync::mpsc::channel();
let pool = Arc::new(ThreadPool::new(1));
let items: Arc<[u64]> = (1..=100u64).collect::<Vec<_>>().into();
let pool_coord = Arc::clone(&pool);
std::thread::spawn(move || {
parallel_work_queue(
&pool_coord,
1,
&items,
10,
|| 0u64,
|_start, chunk| chunk.iter().sum::<u64>(),
|acc, partial| *acc += partial,
|_| {},
|worker_results| {
let _ = tx.send(worker_results.into_iter().sum::<u64>());
},
);
});
let result = rx
.recv_timeout(std::time::Duration::from_secs(5))
.expect("deadlock or timeout");
assert_eq!(result, 5050);
}
}