use std::cell::{Cell, UnsafeCell};
use std::collections::VecDeque;
use std::hint::unreachable_unchecked;
use std::iter::repeat_with;
use std::mem::MaybeUninit;
use std::ops::ControlFlow;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering, fence};
use std::thread::{self, JoinHandle, available_parallelism};
use paralight::iter::{Accumulator, ExactSizeAccumulator, GenericThreadPool, SourceCleanup};
use smallvec::SmallVec;
use crate::util::*;
use crate::{OwnedTask, SharedTask, TaskInner};
static GLOBAL_POOL: spin::Once<ThreadPool> = spin::Once::new();
thread_local! {
static LOCAL_ADVERTISE_MASK: Cell<*const AtomicBits> = const { Cell::new(std::ptr::null()) };
static LOCAL_POOL: Cell<*const ThreadPool> = const { Cell::new(std::ptr::null()) };
static ADVERTISE_MASK_STORE: UnsafeCell<AtomicBits> = UnsafeCell::new(AtomicBits::default());
}
type ThreadSpawnerFn = dyn FnMut(usize, Box<dyn FnOnce() + Send>) -> JoinHandle<()>;
pub struct ThreadPoolBuilder {
idle_spin_cycles: usize,
max_jobs: usize,
num_threads: usize,
spawn_handler: Box<ThreadSpawnerFn>,
}
impl ThreadPoolBuilder {
pub fn build(self) -> ThreadPool {
ThreadPool::new(self)
}
pub fn build_global(self) {
let mut run = false;
GLOBAL_POOL.call_once(|| {
run = true;
self.build()
});
if !run {
panic!("ThreadPoolBuilder::build_global called after global pool was already active");
}
}
pub fn idle_spin_cycles(self, idle_spin_cycles: usize) -> Self {
Self {
idle_spin_cycles,
..self
}
}
pub fn num_threads(self, num_threads: usize) -> Self {
Self {
num_threads,
..self
}
}
pub fn spawn_handler<F>(self, spawn: F) -> Self
where
F: FnMut(usize, Box<dyn FnOnce() + Send>) -> JoinHandle<()> + 'static,
{
Self {
spawn_handler: Box::new(spawn),
..self
}
}
}
impl Default for ThreadPoolBuilder {
fn default() -> Self {
Self {
idle_spin_cycles: 3000,
max_jobs: 8 * available_parallelism().map(usize::from).unwrap_or(1),
num_threads: available_parallelism()
.map(usize::from)
.unwrap_or_default()
.saturating_sub(1)
.max(1),
spawn_handler: Box::new(|_, x| thread::spawn(x)),
}
}
}
pub struct ThreadPool {
join_handles: Option<Vec<JoinHandle<()>>>,
state: Arc<ThreadPoolState>,
}
impl ThreadPool {
const OUTPUT_BUFFER_CAPACITY: usize = 256;
fn new(mut builder: ThreadPoolBuilder) -> Self {
let state = Arc::new(ThreadPoolState::new(&builder));
let mut join_handles = Vec::with_capacity(builder.num_threads);
for i in 0..builder.num_threads {
let state_cloned = state.clone();
join_handles.push((builder.spawn_handler)(
i,
Box::new(move || {
let thread_pool = Self {
join_handles: None,
state: state_cloned,
};
LOCAL_POOL.set(&thread_pool);
thread_pool.state.join()
}),
));
}
Self {
join_handles: Some(join_handles),
state,
}
}
pub fn install<R>(&self, f: impl FnOnce() -> R) -> R {
abort_on_panic(|| {
assert!(
LOCAL_POOL.get().is_null(),
"cannot call install recursively"
);
LOCAL_POOL.set(self);
let result = f();
LOCAL_POOL.set(std::ptr::null());
result
})
}
pub fn spawn_owned<T: 'static + Send>(
&self,
f: impl 'static + FnOnce() -> T + Send,
) -> OwnedTask<T> {
OwnedTask::spawn(&self.state, f)
}
pub fn spawn_shared<T: 'static + Send + Sync>(
&self,
f: impl 'static + FnOnce() -> T + Send,
) -> SharedTask<T> {
SharedTask::spawn(&self.state, f)
}
pub(crate) fn with_current<R>(f: impl FnOnce(&ThreadPool) -> R) -> R {
abort_on_panic(|| unsafe {
let mut pool_ptr = LOCAL_POOL.get();
if pool_ptr.is_null() {
pool_ptr = GLOBAL_POOL.call_once(|| ThreadPoolBuilder::default().build());
LOCAL_POOL.set(pool_ptr);
let result = f(&*pool_ptr);
LOCAL_POOL.set(std::ptr::null());
result
} else {
f(&*pool_ptr)
}
})
}
pub fn num_threads(&self) -> usize {
self.state.num_threads
}
pub fn join<A, B, RA, RB>(&self, oper_a: A, oper_b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
unsafe {
let oper_a_holder = MaybeUninit::new(oper_a);
let oper_b_holder = MaybeUninit::new(oper_b);
let result_a = UnsafeCell::new(MaybeUninit::uninit());
let result_b = UnsafeCell::new(MaybeUninit::uninit());
self.state.invoke_sync_unchecked(
|i| match i {
0 => {
(*result_a.get()).write(oper_a_holder.assume_init_read()());
}
1 => {
(*result_b.get()).write(oper_b_holder.assume_init_read()());
}
_ => unreachable_unchecked(),
},
2,
);
(
result_a.into_inner().assume_init(),
result_b.into_inner().assume_init(),
)
}
}
pub fn split_per_item(&self) -> impl '_ + GenericThreadPool {
SplitPerItem(self)
}
pub fn split_per(&self, chunk_size: usize) -> impl '_ + GenericThreadPool {
SplitPer {
chunk_units_calculator: move |x| (chunk_size.max(1), x.div_ceil(chunk_size.max(1))),
pool: self,
}
}
pub fn split_by(&self, chunks: usize) -> impl '_ + GenericThreadPool {
SplitPer {
chunk_units_calculator: move |x| (x.div_ceil(chunks.max(1)), chunks.max(1)),
pool: self,
}
}
pub fn split_by_threads(&self) -> impl '_ + GenericThreadPool {
self.split_by(self.num_threads() + 1)
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
if let Some(join_handles) = &mut self.join_handles {
self.state.should_stop.store(true, Ordering::Relaxed);
self.state.on_change.notify();
for handle in join_handles.drain(..) {
let _ = handle.join();
}
}
}
}
struct SplitPerItem<'a>(&'a ThreadPool);
unsafe impl<'a> GenericThreadPool for SplitPerItem<'a> {
fn upper_bounded_pipeline<Output: Send, Accum>(
self,
input_len: usize,
init: impl Fn() -> Accum + Sync,
process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
finalize: impl Fn(Accum) -> Output + Sync,
reduce: impl Fn(Output, Output) -> Output,
_: &(impl SourceCleanup + Sync),
) -> Output {
unsafe {
let mut output = SmallVec::<[Output; ThreadPool::OUTPUT_BUFFER_CAPACITY]>::new();
output.reserve_exact(input_len);
let output_buffer = output.as_mut_ptr();
self.0.state.invoke_sync_unchecked(
|i| {
output_buffer
.add(i)
.write(finalize(match process_item(init(), i) {
ControlFlow::Break(x) | ControlFlow::Continue(x) => x,
}));
},
input_len,
);
output.set_len(input_len);
output
.into_iter()
.reduce(reduce)
.expect("Iterator was empty")
}
}
fn iter_pipeline<Output, Accum: Send>(
self,
input_len: usize,
accum: impl Accumulator<usize, Accum> + Sync,
reduce: impl ExactSizeAccumulator<Accum, Output>,
_: &(impl SourceCleanup + Sync),
) -> Output {
unsafe {
let mut output = SmallVec::<[Accum; ThreadPool::OUTPUT_BUFFER_CAPACITY]>::new();
output.reserve_exact(input_len);
let output_buffer = output.as_mut_ptr();
self.0.state.invoke_sync_unchecked(
|i| {
output_buffer.add(i).write(accum.accumulate(i..i + 1));
},
input_len,
);
output.set_len(input_len);
reduce.accumulate_exact(output.into_iter())
}
}
}
struct SplitPer<'a, F: Fn(usize) -> (usize, usize)> {
chunk_units_calculator: F,
pool: &'a ThreadPool,
}
unsafe impl<'a, F: Fn(usize) -> (usize, usize)> GenericThreadPool for SplitPer<'a, F> {
fn upper_bounded_pipeline<Output: Send, Accum>(
self,
input_len: usize,
init: impl Fn() -> Accum + Sync,
process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
finalize: impl Fn(Accum) -> Output + Sync,
reduce: impl Fn(Output, Output) -> Output,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
unsafe {
let (chunk_size, work_units) = (self.chunk_units_calculator)(input_len);
let mut output = SmallVec::<[Output; ThreadPool::OUTPUT_BUFFER_CAPACITY]>::new();
output.reserve_exact(work_units);
let break_early = AtomicBool::new(false);
let output_buffer = output.as_mut_ptr();
self.pool.state.invoke_sync_unchecked(
|i| {
let start = chunk_size * i;
let end = (start + chunk_size).min(input_len);
let mut accumulator = init();
for j in start..end {
if break_early.load(Ordering::Relaxed) {
cleanup.cleanup_item_range(j..end);
break;
}
match process_item(accumulator, j) {
ControlFlow::Break(x) => {
accumulator = x;
break_early.store(true, Ordering::Release);
cleanup.cleanup_item_range(j + 1..end);
break;
}
ControlFlow::Continue(x) => accumulator = x,
};
}
output_buffer.add(i).write(finalize(accumulator));
},
work_units,
);
output.set_len(work_units);
output
.into_iter()
.reduce(reduce)
.expect("Iterator was empty")
}
}
fn iter_pipeline<Output, Accum: Send>(
self,
input_len: usize,
accum: impl Accumulator<usize, Accum> + Sync,
reduce: impl ExactSizeAccumulator<Accum, Output>,
_: &(impl SourceCleanup + Sync),
) -> Output {
unsafe {
let (chunk_size, work_units) = (self.chunk_units_calculator)(input_len);
let mut output = SmallVec::<[Accum; ThreadPool::OUTPUT_BUFFER_CAPACITY]>::new();
output.reserve_exact(work_units);
let output_buffer = output.as_mut_ptr();
self.pool.state.invoke_sync_unchecked(
|i| {
let start = chunk_size * i;
let end = (start + chunk_size).min(input_len);
output_buffer.add(i).write(accum.accumulate(start..end));
},
work_units,
);
output.set_len(work_units);
reduce.accumulate_exact(output.into_iter())
}
}
}
pub(crate) struct ThreadPoolState {
idle_spin_cycles: usize,
global_advertise_mask: AtomicBits,
jobs: Vec<JobSlot>,
num_threads: usize,
on_change: Event,
running_jobs: AtomicBits,
should_stop: AtomicBool,
tasks: spin::Mutex<VecDeque<Arc<dyn TaskInner>>>,
}
impl ThreadPoolState {
pub fn new(builder: &ThreadPoolBuilder) -> Self {
let global_advertise_mask = AtomicBits::new(builder.max_jobs);
let running_jobs = AtomicBits::new(builder.max_jobs);
Self {
global_advertise_mask,
idle_spin_cycles: builder.idle_spin_cycles,
jobs: repeat_with(JobSlot::default)
.take(running_jobs.len())
.collect(),
num_threads: builder.num_threads,
on_change: Event::new(),
running_jobs,
should_stop: AtomicBool::new(false),
tasks: spin::Mutex::new(VecDeque::new()),
}
}
pub fn cancel_task<U: ?Sized>(&self, task: &Arc<U>) {
let mut tasks = self.tasks.lock();
if let Some(index) = tasks
.iter()
.position(|x| Arc::as_ptr(x).cast::<()>() == Arc::as_ptr(task).cast::<()>())
{
tasks.swap_remove_back(index);
}
}
pub unsafe fn invoke_sync_unchecked(&self, f: impl Fn(usize), times: usize) {
struct AssertSync<F>(F);
impl<F> AssertSync<F> {
pub unsafe fn get(&self) -> &F {
&self.0
}
}
unsafe impl<F> Sync for AssertSync<F> {}
let f_sync = AssertSync(f);
self.invoke(move |i| unsafe { (f_sync.get())(i) }, times)
}
pub fn invoke(&self, f: impl Fn(usize) + Sync, times: usize) {
match times {
0 => {}
1 => f(0),
_ => self.invoke_parallel_job(f, times),
}
}
pub fn push_task(&self, task: Arc<dyn TaskInner>) {
self.tasks.lock().push_back(task);
self.on_change.notify();
}
fn join(&self) {
let mut listener = self.on_change.listen();
let mut spin_before_sleep = false;
loop {
if self.help_global_jobs() {
spin_before_sleep = true;
} else if self.should_stop.load(Ordering::Relaxed) {
return;
} else if let Some(task) = self.pop_task() {
spin_before_sleep |= task.run();
} else {
let spin_cycles = if spin_before_sleep {
self.idle_spin_cycles
} else {
0
};
spin_before_sleep = !listener.spin_wait(spin_cycles);
}
}
}
fn pop_task(&self) -> Option<Arc<dyn TaskInner>> {
self.tasks.lock().pop_front()
}
fn help_global_jobs(&self) -> bool {
let mut ran_item = false;
while self.help_one_job(&self.global_advertise_mask, true) {
ran_item = true;
}
ran_item
}
fn help_one_job(&self, search_mask: &AtomicBits, change_advertise_mask: bool) -> bool {
for index in search_mask.iter_ones() {
let slot = &self.jobs[index];
let reserved_unit = slot.available_units.fetch_sub(1, Ordering::Relaxed) - 1;
if reserved_unit < 0 {
continue;
} else {
fence(Ordering::Acquire);
let descriptor = unsafe { &*slot.descriptor.get().read().cast::<JobDescriptor>() };
let new_advertise_mask = if change_advertise_mask {
descriptor.search_mask
} else {
search_mask
};
let locally_completed = install_local_advertise_mask(new_advertise_mask, || {
(descriptor.func)(JobInvocation {
available_units: &slot.available_units,
clear_masks: descriptor.clear_masks,
reserved_unit,
slot: index,
})
});
let remaining = descriptor
.incomplete_units
.fetch_sub(locally_completed, Ordering::Release)
- locally_completed;
if remaining == 0 {
self.on_change.notify();
}
return true;
}
}
false
}
fn invoke_parallel_job(&self, f: impl Fn(usize) + Sync, times: usize) {
if let Some(index) = self.reserve_job_slot() {
unsafe {
self.invoke_parallel_job_at_slot(f, times, index);
}
unsafe {
self.release_job_slot(index);
}
} else {
for i in 0..times {
f(i);
}
}
}
unsafe fn invoke_parallel_job_at_slot(
&self,
f: impl Fn(usize) + Sync,
times: usize,
index: usize,
) {
with_local_advertise_mask(self.global_advertise_mask.len(), |search_mask| {
let func = Self::create_job_func(f);
let slot = &self.jobs[index];
let times = times as i64;
let advertise_masks = [&self.global_advertise_mask, search_mask];
let descriptor = JobDescriptor {
clear_masks: &advertise_masks,
func: &func,
incomplete_units: AtomicI64::new(times),
search_mask,
};
unsafe { *slot.descriptor.get() = &descriptor as *const _ as *const _ };
for mask in &advertise_masks {
mask.set(index, true, Ordering::Relaxed);
}
slot.available_units.store(times - 1, Ordering::Release);
self.on_change.notify();
let locally_completed = func(JobInvocation {
available_units: &slot.available_units,
clear_masks: &advertise_masks,
reserved_unit: times - 1,
slot: index,
});
if locally_completed < times {
let mut listener = self.on_change.listen();
let mut spin_before_sleep = true;
let remaining = descriptor
.incomplete_units
.fetch_sub(locally_completed, Ordering::Relaxed)
- locally_completed;
if 0 < remaining {
while 0 < descriptor.incomplete_units.load(Ordering::Relaxed) {
if self.help_one_job(search_mask, false) {
spin_before_sleep = true;
} else {
let spin_cycles = if spin_before_sleep {
self.idle_spin_cycles
} else {
0
};
spin_before_sleep = !listener.spin_wait(spin_cycles);
}
}
}
fence(Ordering::Acquire);
}
});
}
unsafe fn release_job_slot(&self, index: usize) {
self.running_jobs.set(index, false, Ordering::Release);
}
fn reserve_job_slot(&self) -> Option<usize> {
for index in self.running_jobs.iter_zeroes() {
if !self.running_jobs.set(index, true, Ordering::Relaxed) {
fence(Ordering::Acquire);
return Some(index);
}
}
None
}
fn create_job_func(f: impl Fn(usize) + Sync) -> impl Fn(JobInvocation) -> i64 {
move |invocation| {
let mut locally_completed = 0;
let mut next_unit = invocation.reserved_unit;
loop {
if next_unit < 0 {
break;
} else if next_unit == 0 {
for mask in invocation.clear_masks {
mask.set(invocation.slot, false, Ordering::Relaxed);
}
}
f(next_unit as usize);
locally_completed += 1;
next_unit = invocation.available_units.fetch_sub(1, Ordering::Relaxed) - 1;
}
locally_completed as i64
}
}
}
unsafe impl Send for ThreadPoolState {}
unsafe impl Sync for ThreadPoolState {}
#[derive(Debug, Default)]
struct JobSlot {
pub available_units: AtomicI64,
pub descriptor: UnsafeCell<*const ()>,
}
struct JobDescriptor<'a> {
pub clear_masks: &'a [&'a AtomicBits],
pub func: &'a dyn Fn(JobInvocation) -> i64,
pub incomplete_units: AtomicI64,
pub search_mask: &'a AtomicBits,
}
#[derive(Debug)]
struct JobInvocation<'a> {
pub available_units: &'a AtomicI64,
pub clear_masks: &'a [&'a AtomicBits],
pub reserved_unit: i64,
pub slot: usize,
}
fn install_local_advertise_mask<R>(mask: &AtomicBits, f: impl FnOnce() -> R) -> R {
abort_on_panic(|| {
let previous = LOCAL_ADVERTISE_MASK.get();
LOCAL_ADVERTISE_MASK.set(mask);
let result = f();
LOCAL_ADVERTISE_MASK.set(previous);
result
})
}
fn with_local_advertise_mask<R>(capacity: usize, f: impl FnOnce(&AtomicBits) -> R) -> R {
abort_on_panic(|| {
let previous_local_advertise_mask = LOCAL_ADVERTISE_MASK.get();
let mask = if previous_local_advertise_mask.is_null() {
unsafe {
&*ADVERTISE_MASK_STORE.with(|x| {
let array = &mut *x.get();
if array.len() < capacity {
*array = AtomicBits::new(capacity);
}
x.get()
})
}
} else {
unsafe { &*previous_local_advertise_mask }
};
assert!(
capacity <= mask.len(),
"attempted to recursively access local advertise mask with different capacity"
);
install_local_advertise_mask(mask, || f(mask))
})
}
#[derive(Clone, Debug, Default)]
struct AtomicBits(Arc<[AtomicU64]>);
impl AtomicBits {
const ELEMENT_BITS: usize = u64::BITS as usize;
pub fn new(capacity: usize) -> Self {
let elements = capacity.div_ceil(Self::ELEMENT_BITS);
Self(repeat_with(AtomicU64::default).take(elements).collect())
}
pub fn iter_ones(&self) -> impl Iterator<Item = usize> {
AtomicBitsIter {
base_index: 0,
bits: self,
current_value: 0,
invert_mask: 0,
next_element: 0,
}
}
pub fn len(&self) -> usize {
Self::ELEMENT_BITS * self.0.len()
}
pub fn iter_zeroes(&self) -> impl Iterator<Item = usize> {
AtomicBitsIter {
base_index: 0,
bits: self,
current_value: 0,
invert_mask: u64::MAX,
next_element: 0,
}
}
pub fn set(&self, index: usize, value: bool, order: Ordering) -> bool {
let (element, bit) = Self::element_bit(index);
let mask = 1 << bit;
let element = &self.0[element];
let previous = if value {
element.fetch_or(mask, order)
} else {
element.fetch_and(!mask, order)
};
(previous & mask) != 0
}
fn element_bit(index: usize) -> (usize, usize) {
(index / Self::ELEMENT_BITS, index % Self::ELEMENT_BITS)
}
}
struct AtomicBitsIter<'a> {
base_index: usize,
bits: &'a AtomicBits,
current_value: u64,
invert_mask: u64,
next_element: usize,
}
impl Iterator for AtomicBitsIter<'_> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while self.current_value == 0 {
if self.next_element < self.bits.0.len() {
let element = self.bits.0[self.next_element].load(Ordering::Relaxed);
self.current_value = element ^ self.invert_mask;
self.base_index = AtomicBits::ELEMENT_BITS * self.next_element;
self.next_element += 1;
} else {
return None;
}
}
let bit = self.current_value.trailing_zeros();
self.current_value &= (u64::MAX << 1) << bit;
Some(self.base_index + bit as usize)
}
}