use alloc::collections::VecDeque;
use alloc::format;
use alloc::string::ToString;
use alloc::vec::Vec;
use core::cell::Cell;
use core::cmp;
use core::future::Future;
use core::num::NonZero;
use core::pin::Pin;
use core::pin::pin;
use core::ptr;
use core::ptr::NonNull;
use core::task::Context;
use core::task::Poll;
use core::time::Duration;
use async_task::Runnable;
use async_task::Task;
use tracing::debug;
use tracing::trace;
use tracing::trace_span;
use crate::blocker::Blocker;
use crate::job::HeapJob;
use crate::job::JobQueue;
use crate::job::JobRef;
use crate::job::StackJob;
use crate::platform::*;
use crate::scope::Scope;
use crate::signal::Signal;
use crate::unwind;
pub struct Lease {
thread_pool: &'static ThreadPool,
index: usize,
heartbeat: Arc<AtomicBool>,
}
#[cfg(not(feature = "shuttle"))]
pub const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100);
pub struct ThreadPool {
state: Mutex<ThreadPoolState>,
job_is_ready: Condvar,
new_participant: Condvar,
}
struct ThreadPoolState {
shared_jobs: VecDeque<JobRef>,
tenants: Vec<Option<Tenant>>,
managed_threads: ManagedThreads,
}
impl ThreadPoolState {
fn claim_shared_job(&mut self) -> Option<JobRef> {
self.shared_jobs.pop_front()
}
fn claim_lease(&mut self, thread_pool: &'static ThreadPool) -> Lease {
let heartbeat = Arc::new(AtomicBool::new(false));
let tenant = Tenant {
heartbeat: Arc::downgrade(&heartbeat),
};
for (index, occupant) in self.tenants.iter_mut().enumerate() {
if occupant.is_none() {
*occupant = Some(tenant);
return Lease {
thread_pool,
index,
heartbeat,
};
}
}
self.tenants.push(Some(tenant));
Lease {
thread_pool,
index: self.tenants.len(),
heartbeat,
}
}
fn claim_leases(&mut self, thread_pool: &'static ThreadPool, num: usize) -> Vec<Lease> {
let mut leases = Vec::with_capacity(num);
for (index, occupant) in self.tenants.iter_mut().enumerate() {
if leases.len() == num {
return leases;
}
if occupant.is_none() {
let heartbeat = Arc::new(AtomicBool::new(false));
let tenant = Tenant {
heartbeat: Arc::downgrade(&heartbeat),
};
*occupant = Some(tenant);
leases.push(Lease {
thread_pool,
index,
heartbeat,
});
}
}
while leases.len() != num {
let heartbeat = Arc::new(AtomicBool::new(false));
let tenant = Tenant {
heartbeat: Arc::downgrade(&heartbeat),
};
self.tenants.push(Some(tenant));
leases.push(Lease {
thread_pool,
index: self.tenants.len(),
heartbeat,
});
}
leases
}
}
struct Tenant {
heartbeat: Weak<AtomicBool>,
}
struct ManagedThreads {
workers: Vec<ManagedWorker>,
heartbeat: Option<ThreadControl>,
}
struct ManagedWorker {
index: usize,
control: ThreadControl,
}
struct ThreadControl {
halt: Arc<AtomicBool>,
handle: JoinHandle<()>,
}
#[allow(clippy::new_without_default)]
impl ThreadPool {
pub const fn new() -> ThreadPool {
ThreadPool {
state: Mutex::new(ThreadPoolState {
shared_jobs: VecDeque::new(),
tenants: Vec::new(),
managed_threads: ManagedThreads {
workers: Vec::new(),
heartbeat: None,
},
}),
job_is_ready: Condvar::new(),
new_participant: Condvar::new(),
}
}
pub fn claim_lease(&'static self) -> Lease {
self.new_participant.notify_one();
let mut state = self.state.lock().unwrap();
state.claim_lease(self)
}
pub fn resize_to_available(&'static self) -> usize {
let available = available_parallelism().map(NonZero::get).unwrap_or(1);
let available = available.saturating_sub(2);
self.resize_to(available)
}
pub fn resize_to(&'static self, new_size: usize) -> usize {
self.resize(|_| new_size)
}
pub fn grow(&'static self, added_threads: usize) -> usize {
self.resize(|current_size| current_size + added_threads)
}
pub fn shrink(&'static self, terminated_threads: usize) -> usize {
self.resize(|current_size| current_size - terminated_threads)
}
pub fn populate(&'static self) -> usize {
self.resize(
|current_size| {
if current_size == 0 { 1 } else { current_size }
},
)
}
pub fn depopulate(&'static self) -> usize {
self.resize_to(0)
}
#[cold]
pub fn resize<F>(&'static self, get_size: F) -> usize
where
F: Fn(usize) -> usize,
{
debug!("starting threadpool resize");
trace!("locking state");
let mut state = self.state.lock().unwrap();
let current_size = state.managed_threads.workers.len();
let mut new_size = get_size(current_size);
trace!(
"attempting to resize thread pool from {} to {} thread(s)",
current_size, new_size
);
match new_size.cmp(¤t_size) {
cmp::Ordering::Equal => {
debug!("completed threadpool resize, size unchanged");
return current_size;
}
cmp::Ordering::Greater => {
trace!("locking worker leases");
let new_leases = state.claim_leases(self, new_size - current_size);
new_size = current_size + new_leases.len(); trace!("acquired leases for {} new threads", new_size);
#[cfg(not(feature = "shuttle"))]
if new_size > 0 && current_size == 0 {
debug!("spawning heartbeat runner");
let halt = Arc::new(AtomicBool::new(false));
let heartbeat_halt = halt.clone();
let handle = ThreadBuilder::new()
.name("heartbeat".to_string())
.spawn(move || {
heartbeat_loop(self, heartbeat_halt);
})
.unwrap();
let control = ThreadControl { halt, handle };
state.managed_threads.heartbeat = Some(control);
}
let barrier = Arc::new(Barrier::new(new_leases.len() + 1));
for lease in new_leases {
let index = lease.index;
debug!("spawning managed worker with index {}", index);
let halt = Arc::new(AtomicBool::new(false));
let worker_halt = halt.clone();
let worker_barrier = barrier.clone();
let handle = ThreadBuilder::new()
.name(format!("worker {index}"))
.spawn(move || {
managed_worker(lease, worker_halt, worker_barrier);
})
.unwrap();
let control = ThreadControl { halt, handle };
state
.managed_threads
.workers
.push(ManagedWorker { index, control });
}
drop(state);
barrier.wait();
}
cmp::Ordering::Less => {
if let Some(control) = state.managed_threads.heartbeat.take() {
control.halt.store(true, Ordering::Relaxed);
let _ = control.handle.join();
}
let terminating_workers = state.managed_threads.workers.split_off(new_size);
drop(state);
for worker in &terminating_workers {
worker.control.halt.store(true, Ordering::Relaxed);
}
self.job_is_ready.notify_all();
let own_lease = Worker::map_current(|worker| worker.lease.index);
for worker in terminating_workers {
if Some(worker.index) != own_lease {
let _ = worker.control.handle.join();
}
}
}
}
debug!("completed thread pool resize");
new_size
}
#[inline(always)]
pub fn id(&self) -> usize {
ptr::from_ref(self) as usize
}
#[inline(always)]
pub fn with_worker<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&Worker) -> R,
{
Worker::with_current(|worker| match worker {
Some(worker) if worker.lease.thread_pool.id() == self.id() => f(worker),
_ => self.with_worker_cold(f),
})
}
#[cold]
fn with_worker_cold<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&Worker) -> R,
{
let lease = self.state.lock().unwrap().claim_lease(self);
Worker::occupy(lease, f)
}
}
impl ThreadPool {
#[inline]
pub fn spawn<F>(&'static self, f: F)
where
F: FnOnce(&Worker) + Send + 'static,
{
self.with_worker(|worker| worker.spawn(f));
}
#[inline]
pub fn spawn_future<F, T>(&'static self, future: F) -> Task<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let schedule = move |runnable: Runnable| {
let job_pointer = runnable.into_raw();
#[inline]
fn execute_runnable(this: NonNull<()>, _worker: &Worker) {
let runnable = unsafe { Runnable::<()>::from_raw(this) };
runnable.run();
}
let job_ref = unsafe { JobRef::new_raw(job_pointer, execute_runnable) };
self.with_worker(|worker| {
worker.queue.push_back(job_ref);
});
};
let (runnable, task) = async_task::spawn(future, schedule);
runnable.schedule();
task
}
#[inline]
pub fn spawn_async<Fn, Fut, T>(&'static self, f: Fn) -> Task<T>
where
Fn: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let future = async move { f().await };
self.spawn_future(future)
}
#[inline]
pub fn block_on<F, T>(&'static self, future: F) -> T
where
F: Future<Output = T> + Send,
T: Send,
{
self.with_worker(|worker| worker.block_on(future))
}
#[inline]
pub fn join<A, B, RA, RB>(&'static self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&Worker) -> RA + Send,
B: FnOnce(&Worker) -> RB + Send,
RA: Send,
RB: Send,
{
self.with_worker(|worker| worker.join(a, b))
}
#[inline]
pub fn scope<'scope, F, T>(&'static self, f: F) -> T
where
F: FnOnce(&Scope<'scope>) -> T,
{
self.with_worker(|worker| worker.scope(f))
}
}
thread_local! {
static WORKER_PTR: Cell<*const Worker> = const { Cell::new(ptr::null()) };
}
pub struct Worker {
pub(crate) migrated: Cell<bool>,
pub(crate) lease: Lease,
pub(crate) queue: JobQueue,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Yield {
Executed,
Idle,
}
impl Worker {
#[inline]
pub fn map_current<F, R>(f: F) -> Option<R>
where
F: FnOnce(&Worker) -> R,
{
let worker_ptr = WORKER_PTR.with(Cell::get);
if !worker_ptr.is_null() {
Some(f(unsafe { &*worker_ptr }))
} else {
None
}
}
#[inline]
pub fn with_current<F, R>(f: F) -> R
where
F: FnOnce(Option<&Worker>) -> R,
{
let worker_ptr = WORKER_PTR.with(Cell::get);
if !worker_ptr.is_null() {
f(Some(unsafe { &*worker_ptr }))
} else {
f(None)
}
}
#[inline]
pub fn occupy<F, R>(lease: Lease, f: F) -> R
where
F: FnOnce(&Worker) -> R,
{
trace!("occupying lease");
let span = trace_span!("occupy", lease = lease.index);
let _enter = span.enter();
let worker = Worker {
migrated: Cell::new(false),
lease,
queue: JobQueue::new(),
};
let outer_ptr = WORKER_PTR.with(|ptr| ptr.replace(&worker));
let result = f(&worker);
while let Some(job_ref) = worker.queue.pop_front() {
worker.execute(job_ref, false);
}
WORKER_PTR.with(|ptr| ptr.set(outer_ptr));
trace!("vacating lease");
result
}
#[inline]
pub fn index(&self) -> usize {
self.lease.index
}
#[cold]
fn promote(&self) {
let mut state = self.lease.thread_pool.state.lock().unwrap();
if let Some(job) = self.queue.pop_front() {
state.shared_jobs.push_back(job);
self.lease.thread_pool.job_is_ready.notify_one();
}
}
#[inline]
pub fn wait_for_signal<T>(&self, signal: &Signal<T>) -> T
where
T: Send,
{
loop {
if let Some(value) = unsafe { signal.try_recv() } {
return value;
}
if self.yield_now() == Yield::Idle {
return unsafe { signal.recv() };
}
}
}
#[inline]
pub fn find_work(&self) -> Option<(JobRef, bool)> {
self.queue
.pop_back()
.map(|job| (job, false))
.or_else(|| self.claim_shared_job().map(|job| (job, true)))
}
#[cold]
pub fn claim_shared_job(&self) -> Option<JobRef> {
self.lease
.thread_pool
.state
.lock()
.unwrap()
.claim_shared_job()
}
#[inline]
pub fn yield_local(&self) -> Yield {
match self.queue.pop_back() {
Some(job_ref) => {
self.execute(job_ref, false);
Yield::Executed
}
None => Yield::Idle,
}
}
#[inline]
pub fn yield_now(&self) -> Yield {
match self.find_work() {
Some((job_ref, migrated)) => {
self.execute(job_ref, migrated);
Yield::Executed
}
None => Yield::Idle,
}
}
#[inline]
pub fn migrated(&self) -> bool {
self.migrated.get()
}
#[inline]
fn execute(&self, job_ref: JobRef, migrated: bool) {
let migrated = self.migrated.replace(migrated);
job_ref.execute(self);
self.migrated.set(migrated);
}
}
impl Worker {
#[inline]
pub fn spawn<F>(&self, f: F)
where
F: FnOnce(&Worker) + Send + 'static,
{
let job = HeapJob::new(f);
let job_ref = unsafe { job.into_job_ref() };
self.queue.push_back(job_ref);
}
#[inline]
pub fn spawn_future<F, T>(&self, future: F) -> Task<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.lease.thread_pool.spawn_future(future)
}
#[inline]
pub fn spawn_async<Fn, Fut, T>(&self, f: Fn) -> Task<T>
where
Fn: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.lease.thread_pool.spawn_async(f)
}
#[inline]
pub fn block_on<F, T>(&self, future: F) -> T
where
F: Future<Output = T> + Send,
T: Send,
{
let blocker = Blocker::new();
let waker = unsafe { blocker.as_waker() };
let mut ctx = Context::from_waker(&waker);
let mut future = pin!(future);
loop {
match future.as_mut().poll(&mut ctx) {
Poll::Pending => {
while blocker.would_block() {
if self.yield_now() == Yield::Idle {
blocker.block();
break;
}
}
}
Poll::Ready(res) => return res,
}
}
}
#[inline]
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&Worker) -> RA + Send,
B: FnOnce(&Worker) -> RB + Send,
RA: Send,
RB: Send,
{
let stack_job = StackJob::new(a);
let job_ref = unsafe { stack_job.as_job_ref() };
let job_ref_id = job_ref.id();
self.queue.push_back(job_ref);
if self.lease.heartbeat.load(Ordering::Relaxed) {
self.promote();
self.lease.heartbeat.store(false, Ordering::Relaxed);
}
let result_b = b(self);
if let Some(job) = self.queue.pop_back() {
if job.id() == job_ref_id {
let a = unsafe { stack_job.unwrap() };
let result_a = a(self);
return (result_a, result_b);
}
self.execute(job, false);
}
let result_a = self.wait_for_signal(stack_job.signal());
match result_a {
Ok(result_a) => (result_a, result_b),
Err(error) => unwind::resume_unwinding(error),
}
}
#[inline]
pub fn scope<'scope, F, T>(&self, f: F) -> T
where
F: FnOnce(&Scope<'scope>) -> T,
{
let scope = unsafe { pin!(Scope::new()) };
let scope_ref = Pin::get_ref(scope.into_ref());
f(scope_ref)
}
}
pub fn spawn<F>(f: F)
where
F: FnOnce(&Worker) + Send + 'static,
{
Worker::with_current(|worker| {
worker
.expect("attempt to call `forte::spawn` from outside a thread pool")
.spawn(f);
});
}
pub fn spawn_future<F, T>(future: F) -> Task<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
Worker::with_current(|worker| {
worker
.expect("attempt to call `forte::spawn_future` from outside a thread pool")
.spawn_future(future)
})
}
pub fn spawn_async<Fn, Fut, T>(f: Fn) -> Task<T>
where
Fn: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
Worker::with_current(|worker| {
worker
.expect("attempt to call `forte::spawn_async` from outside a thread pool")
.spawn_async(f)
})
}
pub fn block_on<F, T>(future: F) -> T
where
F: Future<Output = T> + Send,
T: Send,
{
Worker::with_current(|worker| {
worker
.expect("attempt to call `forte::block_on` from outside a thread pool")
.block_on(future)
})
}
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
where
A: FnOnce(&Worker) -> RA + Send,
B: FnOnce(&Worker) -> RB + Send,
RA: Send,
RB: Send,
{
Worker::with_current(|worker| {
worker
.expect("attempt to call `forte::join` from outside a thread pool")
.join(a, b)
})
}
pub fn scope<'scope, F, T>(f: F) -> T
where
F: FnOnce(&Scope<'scope>) -> T,
{
Worker::with_current(|worker| {
worker
.expect("attempt to call `forte::scope` from outside a thread pool")
.scope(f)
})
}
fn managed_worker(lease: Lease, halt: Arc<AtomicBool>, barrier: Arc<Barrier>) {
trace!("starting managed worker");
barrier.wait();
Worker::occupy(lease, |worker| {
while !halt.load(Ordering::Relaxed) {
if let Some(job) = worker.queue.pop_back() {
worker.execute(job, false);
continue;
}
let mut state = worker.lease.thread_pool.state.lock().unwrap();
while !halt.load(Ordering::Relaxed) {
if let Some(job) = state.claim_shared_job() {
drop(state);
worker.execute(job, true);
break;
}
state = worker.lease.thread_pool.job_is_ready.wait(state).unwrap();
}
}
});
trace!("exiting managed worker");
}
#[cfg(not(feature = "shuttle"))]
fn heartbeat_loop(thread_pool: &'static ThreadPool, halt: Arc<AtomicBool>) {
use std::thread;
trace!("starting managed heartbeat thread");
let mut queued_to_heartbeat = 0;
let mut state = thread_pool.state.lock().unwrap();
while !halt.load(Ordering::Relaxed) {
let num_slots = state.tenants.len();
let mut num_occupied: u32 = 0;
let mut sent_heartbeat = false;
for i in 0..num_slots {
let tenant_index = (queued_to_heartbeat + i) % num_slots;
if let Some(tenant) = &mut state.tenants[tenant_index] {
let Some(heartbeat) = tenant.heartbeat.upgrade() else {
state.tenants[tenant_index] = None;
continue;
};
if !sent_heartbeat {
heartbeat.store(true, Ordering::Relaxed);
sent_heartbeat = true;
queued_to_heartbeat = (tenant_index + 1) % num_slots;
}
num_occupied += 1;
}
}
if num_occupied > 0 {
drop(state);
let sleep_interval = HEARTBEAT_INTERVAL / num_occupied;
thread::sleep(sleep_interval);
state = thread_pool.state.lock().unwrap();
} else {
state = thread_pool.new_participant.wait(state).unwrap();
}
}
}
#[cfg(all(test, not(feature = "shuttle")))]
mod tests {
use alloc::vec;
use core::sync::atomic::AtomicU8;
use super::*;
#[test]
fn join_basic() {
static THREAD_POOL: ThreadPool = ThreadPool::new();
THREAD_POOL.populate();
let mut a = 0;
let mut b = 0;
THREAD_POOL.join(|_| a += 1, |_| b += 1);
assert_eq!(a, 1);
assert_eq!(b, 1);
THREAD_POOL.depopulate();
}
#[test]
fn join_long() {
fn increment(worker: &Worker, slice: &mut [u32]) {
match slice.len() {
0 => (),
1 => slice[0] += 1,
_ => {
let (head, tail) = slice.split_at_mut(1);
worker.join(|_| head[0] += 1, |worker| increment(worker, tail));
}
}
}
static THREAD_POOL: ThreadPool = ThreadPool::new();
THREAD_POOL.populate();
let mut vals = [0; 1_024];
THREAD_POOL.with_worker(|worker| increment(worker, &mut vals));
assert_eq!(vals, [1; 1_024]);
THREAD_POOL.depopulate();
}
#[test]
fn join_very_long() {
fn increment(worker: &Worker, slice: &mut [u32]) {
match slice.len() {
0 => (),
1 => slice[0] += 1,
_ => {
let mid = slice.len() / 2;
let (left, right) = slice.split_at_mut(mid);
worker.join(
|worker| increment(worker, left),
|worker| increment(worker, right),
);
}
}
}
static THREAD_POOL: ThreadPool = ThreadPool::new();
THREAD_POOL.populate();
let mut vals = vec![0; 1_024 * 1_024];
THREAD_POOL.with_worker(|worker| increment(worker, &mut vals));
assert_eq!(vals, vec![1; 1_024 * 1_024]);
THREAD_POOL.depopulate();
}
#[test]
fn concurrent_scopes() {
const NUM_JOBS: u8 = 128;
static THREAD_POOL: ThreadPool = ThreadPool::new();
THREAD_POOL.resize_to(4);
let a = AtomicU8::new(0);
let b = AtomicU8::new(0);
THREAD_POOL.scope(|scope| {
for _ in 0..NUM_JOBS {
scope.spawn(|_| {
THREAD_POOL.join(
|_| a.fetch_add(1, Ordering::Relaxed),
|_| b.fetch_add(1, Ordering::Relaxed),
);
});
}
});
assert_eq!(a.load(Ordering::Relaxed), NUM_JOBS);
assert_eq!(b.load(Ordering::Relaxed), NUM_JOBS);
THREAD_POOL.depopulate();
}
}