use alloc::{collections::VecDeque, sync::Arc};
use core::{
cell::{OnceCell, UnsafeCell},
cmp,
future::Future,
num::NonZero,
pin::Pin,
ptr,
ptr::NonNull,
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
task::{Context, Poll},
time::Duration,
};
use std::{thread, thread_local};
use async_task::{Runnable, Task};
use crossbeam_queue::SegQueue;
use crossbeam_utils::CachePadded;
use parking_lot::{Condvar, Mutex};
use crate::{
job::{HeapJob, JobRef, StackJob},
latch::{AtomicLatch, Latch, LockLatch, Probe, SetOnWake, WakeLatch},
scope::*,
util::{CallOnDrop, Slot, XorShift64Star},
};
pub const MAX_THREADS: usize = 32;
pub struct ThreadPool {
threads: [CachePadded<ThreadInfo>; MAX_THREADS],
queue: SegQueue<JobRef>,
state: CachePadded<ThreadPoolState>,
active_tally: CachePadded<AtomicUsize>,
}
struct ThreadPoolState {
running_threads: AtomicUsize,
is_resizing: Mutex<bool>,
heartbeat_control: ThreadControl,
is_active: Mutex<bool>,
activity_changed: Condvar,
}
struct ThreadInfo {
heartbeat: AtomicBool,
shared_job: Slot<JobRef>,
control: ThreadControl,
}
struct ThreadControl {
is_sleeping: Mutex<bool>,
awakened: Condvar,
is_running: Mutex<bool>,
synchronized: Condvar,
should_terminate: AtomicLatch,
}
#[allow(clippy::declare_interior_mutable_const)]
const THREAD_CONTROL: ThreadControl = ThreadControl {
is_sleeping: Mutex::new(false),
awakened: Condvar::new(),
is_running: Mutex::new(false),
synchronized: Condvar::new(),
should_terminate: AtomicLatch::new(),
};
#[allow(clippy::declare_interior_mutable_const)]
const THREAD_INFO: CachePadded<ThreadInfo> = CachePadded::new(ThreadInfo {
heartbeat: AtomicBool::new(false),
shared_job: Slot::empty(),
control: THREAD_CONTROL,
});
#[allow(clippy::new_without_default)]
impl ThreadPool {
pub const fn new() -> ThreadPool {
ThreadPool {
threads: [THREAD_INFO; MAX_THREADS],
queue: SegQueue::new(),
state: CachePadded::new(ThreadPoolState {
running_threads: AtomicUsize::new(0),
heartbeat_control: THREAD_CONTROL,
is_resizing: Mutex::new(false),
is_active: Mutex::new(false),
activity_changed: Condvar::new(),
}),
active_tally: CachePadded::new(AtomicUsize::new(0)),
}
}
pub fn resize_to_available(&'static self) -> usize {
let available = thread::available_parallelism()
.map(NonZero::get)
.unwrap_or(1);
self.resize_to(available)
}
pub fn resize_to(&'static self, new_size: usize) -> usize {
self.resize(|_| new_size)
}
pub fn grow(&'static self, added_threads: usize) -> usize {
self.resize(|current_size| current_size + added_threads)
}
pub fn shrink(&'static self, terminated_threads: usize) -> usize {
self.resize(|current_size| current_size - terminated_threads)
}
pub fn populate(&'static self) -> usize {
self.resize(
|current_size| {
if current_size == 0 {
1
} else {
current_size
}
},
)
}
pub fn depopulate(&'static self) -> usize {
self.resize_to(0)
}
pub fn resize<F>(&'static self, get_size: F) -> usize
where
F: Fn(usize) -> usize,
{
if WorkerThread::with(|worker_thread| worker_thread.is_some()) {
return self.state.running_threads.load(Ordering::Acquire);
}
let mut is_resizing = self.state.is_resizing.lock();
*is_resizing = true;
let current_size = self.state.running_threads.load(Ordering::Acquire);
let new_size = usize::min(get_size(current_size), MAX_THREADS);
if new_size == current_size {
*is_resizing = false;
return current_size;
}
self.state
.running_threads
.store(new_size, Ordering::Release);
match new_size.cmp(¤t_size) {
cmp::Ordering::Equal => {}
cmp::Ordering::Greater => {
for i in current_size..new_size {
self.threads[i].control.run(move || {
unsafe { main_loop(self, i) }
});
}
for i in new_size..current_size {
self.threads[i].control.await_ready();
}
if current_size == 0 {
self.state
.heartbeat_control
.run(move || heartbeat_loop(self));
}
}
cmp::Ordering::Less => {
for i in new_size..current_size {
self.threads[i].control.halt();
}
for i in new_size..current_size {
self.threads[i].control.await_termination();
}
if new_size == 0 {
self.state.heartbeat_control.halt();
self.state.heartbeat_control.await_termination();
}
}
}
*is_resizing = false;
new_size
}
pub fn id(&'static self) -> usize {
ptr::from_ref(self) as usize
}
pub fn inject_or_push(&'static self, job_ref: JobRef) {
WorkerThread::with(|worker_thread| match worker_thread {
Some(worker_thread) if worker_thread.thread_pool().id() == self.id() => {
worker_thread.push(job_ref);
}
_ => self.inject(job_ref),
});
}
pub fn inject(&'static self, job_ref: JobRef) {
self.queue.push(job_ref);
if self.active_tally.load(Ordering::Relaxed) == 0 {
self.wake_any(1);
}
}
pub fn pop(&'static self) -> Option<JobRef> {
self.queue.pop()
}
pub fn in_worker<F, T>(&'static self, f: F) -> T
where
F: FnOnce(&WorkerThread, bool) -> T + Send,
T: Send,
{
WorkerThread::with(|worker_thread| match worker_thread {
None => self.in_worker_cold(f),
Some(worker_thread) => {
if worker_thread.thread_pool.id() != self.id() {
self.in_worker_cross(worker_thread, f)
} else {
f(worker_thread, false)
}
}
})
}
#[cold]
fn in_worker_cold<F, T>(&'static self, f: F) -> T
where
F: FnOnce(&WorkerThread, bool) -> T + Send,
T: Send,
{
thread_local!(static LOCK_LATCH: LockLatch = const { LockLatch::new() });
let _ = self.populate();
LOCK_LATCH.with(|latch| {
let mut result = None;
let job = StackJob::new(|| {
WorkerThread::with(|worker_thread| {
let worker_thread = worker_thread.unwrap();
result = Some(f(worker_thread, true));
unsafe { Latch::set(latch) };
});
});
let job_ref = unsafe { job.as_job_ref() };
self.inject(job_ref);
latch.wait_and_reset();
result.unwrap()
})
}
fn in_worker_cross<F, T>(&'static self, current_thread: &WorkerThread, f: F) -> T
where
F: FnOnce(&WorkerThread, bool) -> T + Send,
T: Send,
{
let _ = self.populate();
let latch = WakeLatch::new(current_thread);
let mut result = None;
let job = StackJob::new(|| {
WorkerThread::with(|worker_thread| {
let worker_thread = worker_thread.unwrap();
result = Some(f(worker_thread, true));
unsafe { Latch::set(&latch) };
});
});
let job_ref = unsafe { job.as_job_ref() };
self.inject(job_ref);
current_thread.run_until(&latch);
result.unwrap()
}
pub fn wake_any(&'static self, num_to_wake: usize) -> usize {
if num_to_wake > 0 {
let mut num_woken = 0;
let num_threads = self.state.running_threads.load(Ordering::Relaxed);
for index in 0..num_threads {
if self.wake_thread(index) {
num_woken += 1;
if num_to_wake == num_woken {
return num_woken;
}
}
}
num_woken
} else {
0
}
}
pub fn wake_thread(&'static self, index: usize) -> bool {
self.threads[index].control.wake()
}
pub fn mark_active(&'static self) {
if self.active_tally.fetch_add(1, Ordering::AcqRel) == 0 {
let mut is_active = self.state.is_active.lock();
*is_active = true;
self.state.activity_changed.notify_all();
};
}
pub fn mark_inactive(&'static self) {
if self.active_tally.fetch_sub(1, Ordering::AcqRel) == 1 {
let mut is_active = self.state.is_active.lock();
*is_active = false;
self.state.activity_changed.notify_all();
}
}
pub fn wait_until_inactive(&'static self) {
let mut is_active = self.state.is_active.lock();
while *is_active {
self.state.activity_changed.wait(&mut is_active);
}
}
}
impl ThreadControl {
fn run<F>(&'static self, f: F)
where
F: FnOnce() + Send + 'static,
{
thread::spawn(f);
}
fn await_ready(&'static self) {
let mut is_running = self.is_running.lock();
while !*is_running {
self.synchronized.wait(&mut is_running);
}
}
fn post_ready_status(&'static self) {
let mut is_running = self.is_running.lock();
*is_running = true;
self.synchronized.notify_all();
}
fn wake(&'static self) -> bool {
let mut is_sleeping = self.is_sleeping.lock();
if *is_sleeping {
*is_sleeping = false;
self.awakened.notify_one();
true
} else {
false
}
}
fn halt(&'static self) {
unsafe { Latch::set(&self.should_terminate) }
self.wake();
}
fn await_termination(&'static self) {
let mut is_running = self.is_running.lock();
while *is_running {
self.synchronized.wait(&mut is_running);
}
}
fn post_termination_status(&'static self) {
self.should_terminate.reset();
let mut is_running = self.is_running.lock();
*is_running = false;
self.synchronized.notify_all();
}
}
impl ThreadPool {
pub fn spawn<F>(&'static self, f: F)
where
F: FnOnce() + Send + 'static,
{
self.mark_active();
let job = HeapJob::new(|| {
f();
self.mark_inactive(); });
let job_ref = unsafe { job.into_static_job_ref() };
self.inject_or_push(job_ref);
}
pub fn spawn_future<F, T>(&'static self, future: F) -> Task<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.mark_active();
let future = async move {
let _guard = CallOnDrop(|| self.mark_inactive());
future.await
};
let schedule = move |runnable: Runnable| {
let job_ref = unsafe {
JobRef::new_raw(runnable.into_raw().as_ptr(), |this| {
let this = NonNull::new_unchecked(this.cast_mut());
let runnable = Runnable::<()>::from_raw(this);
runnable.run();
})
};
self.inject_or_push(job_ref);
};
let (runnable, task) = async_task::spawn(future, schedule);
runnable.schedule();
task
}
pub fn spawn_async<Fn, Fut, T>(&'static self, f: Fn) -> Task<T>
where
Fn: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let future = async move { f().await };
self.spawn_future(future)
}
pub fn block_on<F, T>(&'static self, mut future: F) -> T
where
F: Future<Output = T> + Send,
T: Send,
{
let mut future = unsafe { Pin::new_unchecked(&mut future) };
self.in_worker(|worker_thread, _| {
let wake = SetOnWake::new(WakeLatch::new(worker_thread));
let ctx_waker = Arc::clone(&wake).into();
let mut ctx = Context::from_waker(&ctx_waker);
loop {
match future.as_mut().poll(&mut ctx) {
Poll::Ready(res) => return res,
Poll::Pending => {
worker_thread.run_until(wake.latch());
wake.latch().reset();
}
}
}
})
}
pub fn join<A, B, RA, RB>(&'static self, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
self.in_worker(|worker_thread, _| {
let mut status_b = None;
let latch_b = WakeLatch::new(worker_thread);
let job_b = StackJob::new(|| {
status_b = Some(b());
unsafe { Latch::set(&latch_b) };
});
let job_b_ref = unsafe { job_b.as_job_ref() };
let job_b_ref_id = job_b_ref.id();
worker_thread.push(job_b_ref);
let status_a = a();
while !latch_b.probe() {
if let Some(job) = worker_thread.pop() {
if job.id() == job_b_ref_id {
worker_thread.tick();
job_b.run_inline();
break;
}
worker_thread.execute(job);
} else {
worker_thread.run_until(&latch_b);
}
}
(status_a, status_b.unwrap())
})
}
pub fn scope<'scope, F, T>(&'static self, f: F) -> T
where
F: FnOnce(&Scope<'scope>) -> T + Send,
T: Send,
{
self.in_worker(|owner_thread, _| {
unsafe {
let scope = Scope::<'scope>::new(owner_thread);
let outcome = f(&scope);
scope.complete(owner_thread);
outcome
}
})
}
}
pub struct WorkerThread {
queue: UnsafeCell<VecDeque<JobRef>>,
thread_pool: &'static ThreadPool,
index: usize,
rng: XorShift64Star,
}
thread_local! {
static WORKER_THREAD_STATE: CachePadded<OnceCell<WorkerThread>> = const { CachePadded::new(OnceCell::new()) };
}
impl WorkerThread {
#[inline]
fn thread_info(&self) -> &ThreadInfo {
&self.thread_pool.threads[self.index]
}
#[inline]
#[allow(clippy::mut_from_ref)]
unsafe fn get_queue(&self) -> &mut VecDeque<JobRef> {
unsafe { &mut *self.queue.get() }
}
pub fn with<R>(f: impl FnOnce(Option<&Self>) -> R) -> R {
WORKER_THREAD_STATE.with(|worker_thread| f(worker_thread.get()))
}
#[inline]
pub fn thread_pool(&self) -> &'static ThreadPool {
self.thread_pool
}
#[inline]
pub fn index(&self) -> usize {
self.index
}
#[inline]
pub fn push(&self, job: JobRef) {
let local_queue = unsafe { self.get_queue() };
local_queue.push_front(job);
}
#[inline]
pub fn pop(&self) -> Option<JobRef> {
let local_queue = unsafe { self.get_queue() };
local_queue.pop_front()
}
#[inline]
pub fn claim_shared(&self) -> Option<JobRef> {
if let Some(job) = self.thread_info().shared_job.take() {
return Some(job);
}
let threads = self.thread_pool.threads.as_slice();
let num_threads = self
.thread_pool
.state
.running_threads
.load(Ordering::Relaxed);
let start = self.rng.next_usize(num_threads);
(start..num_threads)
.chain(0..start)
.filter(move |&i| i != self.index())
.find_map(|i| threads[i].shared_job.take())
}
#[cold]
fn promote(&self) {
let local_queue = unsafe { self.get_queue() };
if let Some(job) = local_queue.pop_back() {
if let Some(job) = self.thread_info().shared_job.put(job) {
local_queue.push_back(job);
} else {
self.thread_pool.wake_any(1);
}
}
}
#[inline]
pub fn tick(&self) {
if self
.thread_info()
.heartbeat
.compare_exchange(true, false, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
self.promote();
}
}
#[inline]
pub fn execute(&self, job: JobRef) {
self.tick();
job.execute();
}
#[inline]
pub fn run_until<L: Probe>(&self, latch: &L) {
if !latch.probe() {
self.run_until_cold(latch);
}
}
#[cold]
fn run_until_cold<L: Probe>(&self, latch: &L) {
while !latch.probe() {
if let Some(job) = self.find_work() {
self.execute(job);
continue;
}
let control = &self.thread_info().control;
let mut is_sleeping = control.is_sleeping.lock();
if latch.probe() {
return;
}
*is_sleeping = true;
while *is_sleeping {
control.awakened.wait(&mut is_sleeping);
}
}
}
#[inline]
pub fn find_work(&self) -> Option<JobRef> {
self.pop()
.or_else(|| self.claim_shared())
.or_else(|| self.thread_pool().pop())
}
}
unsafe fn main_loop(thread_pool: &'static ThreadPool, index: usize) {
let control = &thread_pool.threads[index].control;
WORKER_THREAD_STATE.with(|worker_thread| {
let worker_thread = worker_thread.get_or_init(|| WorkerThread {
index,
thread_pool,
queue: UnsafeCell::new(VecDeque::with_capacity(32)),
rng: XorShift64Star::new(index as u64 + 1),
});
control.post_ready_status();
worker_thread.run_until(&control.should_terminate);
let local_queue = unsafe { worker_thread.get_queue() };
for job in local_queue.drain(..) {
thread_pool.inject(job);
}
if let Some(job) = worker_thread.thread_info().shared_job.take() {
thread_pool.inject(job);
}
control.post_termination_status();
});
}
fn heartbeat_loop(thread_pool: &'static ThreadPool) {
let interval = Duration::from_micros(50);
let control = &thread_pool.state.heartbeat_control;
control.post_ready_status();
let mut i = 0;
while !control.should_terminate.probe() {
let num_threads = thread_pool.state.running_threads.load(Ordering::Relaxed);
if num_threads == 0 {
break;
}
if i >= num_threads {
i = 0;
continue;
}
thread_pool.threads[i]
.heartbeat
.store(true, Ordering::Relaxed);
i += 1;
let interval = interval / num_threads as u32;
let mut is_running = control.is_running.lock();
control.awakened.wait_for(&mut is_running, interval);
}
control.post_termination_status();
}