#![warn(missing_docs)]
use std::sync::Arc;
#[doc(no_inline)]
pub use async_executor::Task;
pub struct ThreadPoolDescriptor {
pub num_threads: usize,
pub stack_size: usize,
pub thread_name: String,
pub start_handler: Option<Box<dyn Fn(usize) + Send + Sync>>,
pub exit_handler: Option<Box<dyn Fn(usize) + Send + Sync>>,
}
impl Default for ThreadPoolDescriptor {
fn default() -> Self {
Self {
num_threads: 2,
stack_size: 2 * 1024 * 1024,
thread_name: "Thread pool".to_owned(),
start_handler: None,
exit_handler: None,
}
}
}
#[derive(Debug)]
struct ThreadPoolInner {
threads: Vec<std::thread::JoinHandle<()>>,
shutdown_tx: async_channel::Sender<()>,
}
impl Drop for ThreadPoolInner {
fn drop(&mut self) {
self.shutdown_tx.close();
for join_handle in self.threads.drain(..) {
let res = join_handle.join();
if !std::thread::panicking() {
res.expect("the task thread panicked while executing");
}
}
}
}
#[derive(Debug, Clone)]
pub struct ThreadPool {
executor: Arc<async_executor::Executor<'static>>,
inner: Arc<ThreadPoolInner>,
}
impl ThreadPool {
pub fn new(descriptor: ThreadPoolDescriptor) -> Result<Self, std::io::Error> {
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
let executor = Arc::new(async_executor::Executor::new());
let mut threads = Vec::with_capacity(descriptor.num_threads);
let descriptor = Arc::new(descriptor);
for i in 0..descriptor.num_threads {
let thread_descriptor = descriptor.clone();
let thread_executor = Arc::clone(&executor);
let thread_name = format!("{} ({})", descriptor.thread_name, i);
let thread_shutdown_rx = shutdown_rx.clone();
let mut thread_builder = std::thread::Builder::new().name(thread_name);
thread_builder = thread_builder.stack_size(descriptor.stack_size);
let thread = thread_builder.spawn(move || {
if let Some(start_handler) = &thread_descriptor.start_handler {
start_handler(i)
}
let shutdown_future = thread_executor.run(thread_shutdown_rx.recv());
if let Some(exit_handler) = &thread_descriptor.exit_handler {
exit_handler(i)
}
futures_lite::future::block_on(shutdown_future).unwrap_err();
})?;
threads.push(thread)
}
Ok(Self {
executor,
inner: Arc::new(ThreadPoolInner {
threads,
shutdown_tx,
}),
})
}
pub fn scope<'scope, S, R>(&self, s: S) -> Vec<R>
where
S: FnOnce(&mut Scope<'scope, R>) + 'scope + Send,
R: Send + 'static,
{
let executor = &*self.executor;
let executor: &'scope async_executor::Executor = unsafe { std::mem::transmute(executor) };
let mut scope = Scope {
executor,
spawned_tasks: Vec::new(),
};
s(&mut scope);
if scope.spawned_tasks.is_empty() {
Vec::with_capacity(0)
} else if scope.spawned_tasks.len() == 1 {
vec![futures_lite::future::block_on(&mut scope.spawned_tasks[0])]
} else {
let mut futures = async move {
let mut future_results = Vec::with_capacity(scope.spawned_tasks.len());
for task in scope.spawned_tasks {
future_results.push(task.await);
}
future_results
};
let futures = unsafe { core::pin::Pin::new_unchecked(&mut futures) };
let futures: std::pin::Pin<&mut dyn futures_lite::Future<Output = Vec<R>>> = futures;
let mut futures: std::pin::Pin<
&'static mut (dyn futures_lite::Future<Output = Vec<R>> + 'static),
> = unsafe { std::mem::transmute(futures) };
loop {
if let Some(result) =
futures_lite::future::block_on(futures_lite::future::poll_once(&mut futures))
{
break result;
};
self.executor.try_tick();
}
}
}
pub fn spawn<T>(
&self,
future: impl futures_lite::Future<Output = T> + Send + 'static,
) -> async_executor::Task<T>
where
T: Send + 'static,
{
self.executor.spawn(future)
}
}
#[derive(Debug)]
pub struct Scope<'scope, R> {
executor: &'scope async_executor::Executor<'scope>,
spawned_tasks: Vec<async_executor::Task<R>>,
}
impl<'scope, T: Send + 'scope> Scope<'scope, T> {
pub fn spawn<Fut: futures_lite::Future<Output = T> + 'scope + Send>(&mut self, f: Fut) {
let task = self.executor.spawn(f);
self.spawned_tasks.push(task);
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicI32, Ordering};
use super::*;
#[test]
pub fn test_scoped_spawn() {
let pool = ThreadPool::new(ThreadPoolDescriptor::default()).unwrap();
let boxed = Box::new(100);
let boxed_ref = &*boxed;
let counter = Arc::new(AtomicI32::new(0));
let outputs = pool.scope(|scope| {
for _ in 0..100 {
let count_clone = counter.clone();
scope.spawn(async move {
if *boxed_ref != 100 {
panic!("expected 100")
} else {
count_clone.fetch_add(1, Ordering::Relaxed);
*boxed_ref
}
});
}
});
for output in &outputs {
assert_eq!(*output, 100);
}
assert_eq!(outputs.len(), 100);
assert_eq!(counter.load(Ordering::Relaxed), 100);
}
#[test]
pub fn test_custom_handler() {
let start_counter = Arc::new(AtomicI32::new(0));
let thread_start_counter = start_counter.clone();
let exit_counter = Arc::new(AtomicI32::new(0));
let thread_exit_counter = exit_counter.clone();
let _ = ThreadPool::new(ThreadPoolDescriptor {
num_threads: 5,
start_handler: Some(Box::new(move |_| {
thread_start_counter.fetch_add(1, Ordering::SeqCst);
})),
exit_handler: Some(Box::new(move |_| {
thread_exit_counter.fetch_add(1, Ordering::SeqCst);
})),
..Default::default()
})
.unwrap();
std::thread::sleep(std::time::Duration::from_millis(50));
assert_eq!(start_counter.load(Ordering::SeqCst), 5);
assert_eq!(exit_counter.load(Ordering::SeqCst), 5);
}
#[test]
pub fn test_task_spawn() {
let pool = ThreadPool::new(ThreadPoolDescriptor::default()).unwrap();
let task = pool.spawn(async { 42 });
assert_eq!(futures_lite::future::block_on(task), 42);
}
}