use futures_util::future::{AbortHandle, Abortable};
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::runtime::Builder;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::task::{spawn_local, JoinHandle, LocalSet};
#[derive(Clone)]
pub struct LocalPoolHandle {
pool: Arc<LocalPool>,
}
impl LocalPoolHandle {
#[track_caller]
pub fn new(pool_size: usize) -> LocalPoolHandle {
assert!(pool_size > 0);
let workers = (0..pool_size)
.map(|_| LocalWorkerHandle::new_worker())
.collect();
let pool = Arc::new(LocalPool { workers });
LocalPoolHandle { pool }
}
#[inline]
pub fn num_threads(&self) -> usize {
self.pool.workers.len()
}
pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
self.pool
.workers
.iter()
.map(|worker| worker.task_count.load(Ordering::SeqCst))
.collect::<Vec<_>>()
}
pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool
.spawn_pinned(create_task, WorkerChoice::LeastBurdened)
}
#[track_caller]
pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool
.spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
}
}
impl Debug for LocalPoolHandle {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("LocalPoolHandle")
}
}
enum WorkerChoice {
LeastBurdened,
ByIdx(usize),
}
struct LocalPool {
workers: Box<[LocalWorkerHandle]>,
}
impl LocalPool {
#[track_caller]
fn spawn_pinned<F, Fut>(
&self,
create_task: F,
worker_choice: WorkerChoice,
) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
let (sender, receiver) = oneshot::channel();
let (worker, job_guard) = match worker_choice {
WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
};
let worker_spawner = worker.spawner.clone();
worker.runtime_handle.spawn(async move {
let _job_guard = job_guard;
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let _abort_guard = AbortGuard(abort_handle);
let spawn_task = Box::new(move || {
let join_handle =
spawn_local(
async move { Abortable::new(create_task(), abort_registration).await },
);
if let Err(join_handle) = sender.send(join_handle) {
join_handle.abort()
}
});
if let Err(e) = worker_spawner.send(spawn_task) {
panic!("Failed to send job to worker: {e}");
}
let join_handle = match receiver.await {
Ok(handle) => handle,
Err(e) => {
panic!("Worker failed to send join handle: {e}");
}
};
let join_result = join_handle.await;
match join_result {
Ok(Ok(output)) => output,
Ok(Err(_)) => {
unreachable!(
"Reaching this branch means this task was previously \
aborted but it continued running anyways"
)
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else if e.is_cancelled() {
panic!("spawn_pinned task was canceled: {e}");
} else {
panic!("spawn_pinned task failed: {e}");
}
}
}
})
}
fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
loop {
let (worker, task_count) = self
.workers
.iter()
.map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
.min_by_key(|&(_, count)| count)
.expect("There must be more than one worker");
if worker
.task_count
.compare_exchange(
task_count,
task_count + 1,
Ordering::SeqCst,
Ordering::Relaxed,
)
.is_ok()
{
return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
}
}
}
#[track_caller]
fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
let worker = &self.workers[idx];
worker.task_count.fetch_add(1, Ordering::SeqCst);
(worker, JobCountGuard(Arc::clone(&worker.task_count)))
}
}
struct JobCountGuard(Arc<AtomicUsize>);
impl Drop for JobCountGuard {
fn drop(&mut self) {
let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
debug_assert!(previous_value >= 1);
}
}
struct AbortGuard(AbortHandle);
impl Drop for AbortGuard {
fn drop(&mut self) {
self.0.abort();
}
}
type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
struct LocalWorkerHandle {
runtime_handle: tokio::runtime::Handle,
spawner: UnboundedSender<PinnedFutureSpawner>,
task_count: Arc<AtomicUsize>,
}
impl LocalWorkerHandle {
fn new_worker() -> LocalWorkerHandle {
let (sender, receiver) = unbounded_channel();
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to start a pinned worker thread runtime");
let runtime_handle = runtime.handle().clone();
let task_count = Arc::new(AtomicUsize::new(0));
let task_count_clone = Arc::clone(&task_count);
std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
LocalWorkerHandle {
runtime_handle,
spawner: sender,
task_count,
}
}
fn run(
runtime: tokio::runtime::Runtime,
mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
task_count: Arc<AtomicUsize>,
) {
let local_set = LocalSet::new();
local_set.block_on(&runtime, async {
while let Some(spawn_task) = task_receiver.recv().await {
(spawn_task)();
}
});
let mut previous_task_count = task_count.load(Ordering::SeqCst);
loop {
runtime.block_on(tokio::task::yield_now());
let new_task_count = task_count.load(Ordering::SeqCst);
if new_task_count == previous_task_count {
break;
} else {
previous_task_count = new_task_count;
}
}
drop(local_set);
drop(runtime);
}
}