#![deny(future_incompatible)]
#![deny(nonstandard_style)]
#![deny(rust_2018_idioms)]
use crate::{
task::{Job, Task, ThreadLocalJob, ThreadLocalTask},
threads::ThreadAllocationOutput,
util::ThreadLocalPointer,
};
use futures_intrusive::{
channel::shared::{oneshot_channel, ChannelReceiveFuture, OneshotReceiver},
sync::ManualResetEvent,
};
use futures_task::{Context, Poll};
use parking_lot::{Condvar, Mutex, RawMutex};
use priority_queue::PriorityQueue;
use slotmap::{DefaultKey, DenseSlotMap};
use std::{
any::Any,
future::Future,
panic::{catch_unwind, AssertUnwindSafe, UnwindSafe},
pin::Pin,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
};
pub mod affinity;
mod error;
mod task;
pub mod threads;
mod util;
mod worker;
pub use error::*;
pub type Priority = u32;
pub type PoolCount = u8;
pub struct JoinHandle<T: 'static> {
_receiver: OneshotReceiver<Result<T, Box<dyn Any + Send + 'static>>>,
receiver_future: ChannelReceiveFuture<RawMutex, Result<T, Box<dyn Any + Send + 'static>>>,
}
impl<T: 'static> Future for JoinHandle<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let fut = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().receiver_future) };
let poll_res = fut.poll(ctx);
match poll_res {
Poll::Ready(None) => {
Poll::Pending
}
Poll::Ready(Some(value)) => Poll::Ready(value.unwrap_or_else(|_| panic!("Job panicked!"))),
Poll::Pending => Poll::Pending,
}
}
}
struct CatchUnwind<Fut>(Fut);
impl<Fut> CatchUnwind<Fut>
where
Fut: Future + UnwindSafe,
{
fn new(future: Fut) -> CatchUnwind<Fut> {
CatchUnwind(future)
}
}
impl<Fut> Future for CatchUnwind<Fut>
where
Fut: Future + UnwindSafe,
{
type Output = Result<Fut::Output, Box<dyn Any + Send>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let f = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
catch_unwind(AssertUnwindSafe(|| f.poll(cx)))?.map(Ok)
}
}
struct ThreadLocalQueue<TD> {
waiting: Mutex<DenseSlotMap<DefaultKey, Arc<ThreadLocalTask<TD>>>>,
inner: Mutex<PriorityQueue<ThreadLocalJob<TD>, u32>>,
}
struct FlaggedCondvar {
running: AtomicBool,
inner: Condvar,
}
struct Queue<TD> {
waiting: Mutex<DenseSlotMap<DefaultKey, Arc<Task<TD>>>>,
inner: Mutex<PriorityQueue<Job<TD>, u32>>,
condvars: Vec<FlaggedCondvar>,
}
impl<TD> Queue<TD> {
fn notify_one(&self) {
for var in &self.condvars {
if !var.running.load(Ordering::Relaxed) {
var.inner.notify_one();
return;
}
}
}
fn notify_all(&self) {
for var in &self.condvars {
var.inner.notify_all();
}
}
}
struct Shared<TD> {
active_threads: AtomicUsize,
idle_wait: ManualResetEvent,
job_count: AtomicUsize,
death_signal: AtomicBool,
queue: Queue<TD>,
}
pub struct Switchyard<TD: 'static> {
shared: Arc<Shared<TD>>,
threads: Vec<std::thread::JoinHandle<()>>,
thread_local_data: Vec<*mut Arc<TD>>,
}
impl<TD: 'static> Switchyard<TD> {
pub fn new<TDFunc>(
thread_allocations: impl IntoIterator<Item = ThreadAllocationOutput>,
thread_local_data_creation: TDFunc,
) -> Result<Self, SwitchyardCreationError>
where
TDFunc: Fn() -> TD + Send + Sync + 'static,
{
let (thread_local_sender, thread_local_receiver) = std::sync::mpsc::channel();
let thread_local_data_creation_arc = Arc::new(thread_local_data_creation);
let allocation_vec: Vec<_> = thread_allocations.into_iter().collect();
let num_logical_cpus = num_cpus::get();
for allocation in allocation_vec.iter() {
if let Some(affin) = allocation.affinity {
if affin >= num_logical_cpus {
return Err(SwitchyardCreationError::InvalidAffinity {
affinity: affin,
total_threads: num_logical_cpus,
});
}
}
}
let mut shared = Arc::new(Shared {
queue: Queue {
waiting: Mutex::new(DenseSlotMap::new()),
inner: Mutex::new(PriorityQueue::new()),
condvars: Vec::new(),
},
active_threads: AtomicUsize::new(allocation_vec.len()),
idle_wait: ManualResetEvent::new(false),
job_count: AtomicUsize::new(0),
death_signal: AtomicBool::new(false),
});
let shared_guard = Arc::get_mut(&mut shared).unwrap();
let queue_local_indices: Vec<_> = allocation_vec
.iter()
.map(|_| {
let condvar_array = &mut shared_guard.queue.condvars;
let queue_local_index = condvar_array.len();
condvar_array.push(FlaggedCondvar {
inner: Condvar::new(),
running: AtomicBool::new(true),
});
queue_local_index
})
.collect();
let mut threads = Vec::with_capacity(allocation_vec.len());
for (mut thread_info, queue_local_index) in allocation_vec.into_iter().zip(queue_local_indices) {
let builder = std::thread::Builder::new();
let builder = if let Some(name) = thread_info.name.take() {
builder.name(name)
} else {
builder
};
let builder = if let Some(stack_size) = thread_info.stack_size.take() {
builder.stack_size(stack_size)
} else {
builder
};
threads.push(
builder
.spawn(worker::body::<TD, TDFunc>(
Arc::clone(&shared),
thread_info,
queue_local_index,
thread_local_sender.clone(),
thread_local_data_creation_arc.clone(),
))
.unwrap_or_else(|_| panic!("Could not spawn thread")),
);
}
drop(thread_local_sender);
let mut thread_local_data = Vec::with_capacity(threads.len());
while let Ok(ThreadLocalPointer(ptr)) = thread_local_receiver.recv() {
thread_local_data.push(ptr);
}
Ok(Self {
threads,
shared,
thread_local_data,
})
}
fn spawn_header(&self) {
assert!(
!self.shared.death_signal.load(Ordering::Acquire),
"finish() has been called on this Switchyard. No more jobs may be added."
);
self.shared.job_count.fetch_add(1, Ordering::AcqRel);
self.shared.idle_wait.reset();
}
pub fn spawn<Fut, T>(&self, priority: Priority, fut: Fut) -> JoinHandle<T>
where
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.spawn_header();
let (sender, receiver) = oneshot_channel();
let job = Job::Future(Task::new(
Arc::clone(&self.shared),
async move {
let _ = sender.send(CatchUnwind::new(std::panic::AssertUnwindSafe(fut)).await);
},
priority,
));
let queue: &Queue<TD> = &self.shared.queue;
let mut queue_guard = queue.inner.lock();
queue_guard.push(job, priority);
queue.notify_one();
drop(queue_guard);
JoinHandle {
receiver_future: receiver.receive(),
_receiver: receiver,
}
}
pub fn spawn_local<Func, Fut, T>(&self, priority: Priority, async_fn: Func) -> JoinHandle<T>
where
Func: FnOnce(Arc<TD>) -> Fut + Send + 'static,
Fut: Future<Output = T>,
T: Send + 'static,
{
self.spawn_header();
let (sender, receiver) = oneshot_channel();
let job = Job::Local(Box::new(move |td| {
Box::pin(async move {
let unwind_async_fn = AssertUnwindSafe(async_fn);
let unwind_td = AssertUnwindSafe(td);
let future = catch_unwind(move || AssertUnwindSafe(unwind_async_fn.0(unwind_td.0)));
let ret = match future {
Ok(fut) => CatchUnwind::new(AssertUnwindSafe(fut)).await,
Err(panic) => Err(panic),
};
let _ = sender.send(ret);
})
}));
let queue: &Queue<TD> = &self.shared.queue;
let mut queue_guard = queue.inner.lock();
queue_guard.push(job, priority);
queue.notify_one();
drop(queue_guard);
JoinHandle {
receiver_future: receiver.receive(),
_receiver: receiver,
}
}
pub async fn wait_for_idle(&self) {
self.shared.idle_wait.wait().await;
}
pub fn jobs(&self) -> usize {
self.shared.job_count.load(Ordering::Relaxed)
}
pub fn active_threads(&self) -> usize {
self.shared.active_threads.load(Ordering::Relaxed)
}
pub fn finish(&mut self) {
self.shared.death_signal.store(true, Ordering::Release);
let lock = self.shared.queue.inner.lock();
self.shared.queue.notify_all();
drop(lock);
self.thread_local_data.clear();
for thread in self.threads.drain(..) {
thread.join().unwrap();
}
}
}
impl<TD: 'static> Drop for Switchyard<TD> {
fn drop(&mut self) {
self.finish()
}
}
unsafe impl<TD> Send for Switchyard<TD> {}
unsafe impl<TD> Sync for Switchyard<TD> {}