#[cfg(feature = "async")]
use futures::{
future::BoxFuture,
task::{waker_ref, ArcWake},
};
use futures_channel::oneshot;
use futures_executor::block_on;
use std::future::Future;
use std::option::Option;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Condvar, Mutex,
};
#[cfg(feature = "async")]
use std::task::Context;
use std::thread;
use std::time::Duration;
const BITS: usize = std::mem::size_of::<usize>() * 8;
pub const MAX_SIZE: usize = (1 << (BITS / 2)) - 1;
type Job = Box<dyn FnOnce() + Send + 'static>;
pub trait Task<R: Send>: Send {
fn run(self) -> R;
fn into_fn(self) -> Option<Box<dyn FnOnce() -> R + Send + 'static>>;
fn is_fn(&self) -> bool;
}
impl<R, F> Task<R> for F
where
R: Send,
F: FnOnce() -> R + Send + 'static,
{
fn run(self) -> R {
self()
}
fn into_fn(self) -> Option<Box<dyn FnOnce() -> R + Send + 'static>> {
Some(Box::new(self))
}
fn is_fn(&self) -> bool {
true
}
}
pub struct JoinHandle<T: Send> {
pub receiver: oneshot::Receiver<T>,
}
impl<T: Send> JoinHandle<T> {
pub fn try_await_complete(self) -> Result<T, oneshot::Canceled> {
block_on(self.receiver)
}
pub fn await_complete(self) -> T {
self.try_await_complete()
.expect("could not receive message because channel was cancelled")
}
}
#[cfg(feature = "async")]
struct AsyncTask {
future: Mutex<Option<BoxFuture<'static, ()>>>,
pool: ThreadPool,
}
#[cfg(feature = "async")]
impl ArcWake for AsyncTask {
fn wake_by_ref(arc_self: &Arc<Self>) {
let cloned_task = arc_self.clone();
arc_self
.pool
.try_execute(cloned_task)
.expect("failed to wake future because message could not be sent to pool");
}
}
#[cfg(feature = "async")]
impl Task<()> for Arc<AsyncTask> {
fn run(self) {
let mut future_slot = self.future.lock().expect("failed to acquire mutex");
if let Some(mut future) = future_slot.take() {
let waker = waker_ref(&self);
let context = &mut Context::from_waker(&*waker);
if future.as_mut().poll(context).is_pending() {
*future_slot = Some(future);
}
}
}
fn into_fn(self) -> Option<Box<dyn FnOnce() + Send + 'static>> {
None
}
fn is_fn(&self) -> bool {
false
}
}
trait ThreadSafe: Send {}
impl<R: Send> ThreadSafe for dyn Task<R> {}
impl<R: Send> ThreadSafe for JoinHandle<R> {}
impl ThreadSafe for ThreadPool {}
#[derive(Clone)]
pub struct ThreadPool {
core_size: usize,
max_size: usize,
keep_alive: Duration,
channel_data: Arc<ChannelData>,
worker_data: Arc<WorkerData>,
}
impl ThreadPool {
pub fn new(core_size: usize, max_size: usize, keep_alive: Duration) -> Self {
static POOL_COUNTER: AtomicUsize = AtomicUsize::new(1);
let name = format!(
"rusty_pool_{}",
POOL_COUNTER.fetch_add(1, Ordering::Relaxed)
);
ThreadPool::new_named(name, core_size, max_size, keep_alive)
}
pub fn new_named(
name: String,
core_size: usize,
max_size: usize,
keep_alive: Duration,
) -> Self {
let (sender, receiver) = crossbeam_channel::unbounded();
if max_size == 0 || max_size < core_size {
panic!("max_size must be greater than 0 and greater or equal to the core pool size");
} else if max_size > MAX_SIZE {
panic!(
"max_size may not exceed {}, the maximum value that can be stored within half the bits of usize ({} -> {} bits in this case)",
MAX_SIZE,
BITS,
BITS / 2
);
}
let worker_data = WorkerData {
pool_name: name,
worker_count_data: WorkerCountData::default(),
worker_number: AtomicUsize::new(1),
join_notify_condvar: Condvar::new(),
join_notify_mutex: Mutex::new(()),
join_generation: AtomicUsize::new(0),
};
let channel_data = ChannelData { sender, receiver };
Self {
core_size,
max_size,
keep_alive,
channel_data: Arc::new(channel_data),
worker_data: Arc::new(worker_data),
}
}
pub fn get_current_worker_count(&self) -> usize {
self.worker_data.worker_count_data.get_total_worker_count()
}
pub fn get_idle_worker_count(&self) -> usize {
self.worker_data.worker_count_data.get_idle_worker_count()
}
pub fn execute<T: Task<()> + 'static>(&self, task: T) {
if self.try_execute(task).is_err() {
panic!("the channel of the thread pool has been closed");
}
}
pub fn try_execute<T: Task<()> + 'static>(
&self,
task: T,
) -> Result<(), crossbeam_channel::SendError<Job>> {
if task.is_fn() {
self.try_execute_task(
task.into_fn()
.expect("Task::into_fn returned None despite is_fn returning true"),
)
} else {
self.try_execute_task(Box::new(move || {
task.run();
}))
}
}
pub fn evaluate<R: Send + 'static, T: Task<R> + 'static>(&self, task: T) -> JoinHandle<R> {
match self.try_evaluate(task) {
Ok(handle) => handle,
Err(e) => panic!("the channel of the thread pool has been closed: {:?}", e),
}
}
pub fn try_evaluate<R: Send + 'static, T: Task<R> + 'static>(
&self,
task: T,
) -> Result<JoinHandle<R>, crossbeam_channel::SendError<Job>> {
let (sender, receiver) = oneshot::channel::<R>();
let join_handle = JoinHandle { receiver };
let job = || {
let result = task.run();
let _ignored_result = sender.send(result);
};
let execute_attempt = self.try_execute_task(Box::new(job));
execute_attempt.map(|_| join_handle)
}
pub fn complete<R: Send + 'static>(
&self,
future: impl Future<Output = R> + 'static + Send,
) -> JoinHandle<R> {
self.evaluate(|| block_on(future))
}
pub fn try_complete<R: Send + 'static>(
&self,
future: impl Future<Output = R> + 'static + Send,
) -> Result<JoinHandle<R>, crossbeam_channel::SendError<Job>> {
self.try_evaluate(|| block_on(future))
}
#[cfg(feature = "async")]
pub fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) {
let future_task = Arc::new(AsyncTask {
future: Mutex::new(Some(Box::pin(future))),
pool: self.clone(),
});
self.execute(future_task)
}
#[cfg(feature = "async")]
pub fn try_spawn(
&self,
future: impl Future<Output = ()> + 'static + Send,
) -> Result<(), crossbeam_channel::SendError<Job>> {
let future_task = Arc::new(AsyncTask {
future: Mutex::new(Some(Box::pin(future))),
pool: self.clone(),
});
self.try_execute(future_task)
}
#[cfg(feature = "async")]
pub fn spawn_await<R: Send + 'static>(
&self,
future: impl Future<Output = R> + 'static + Send,
) -> JoinHandle<R> {
match self.try_spawn_await(future) {
Ok(handle) => handle,
Err(e) => panic!("the channel of the thread pool has been closed: {:?}", e),
}
}
#[cfg(feature = "async")]
pub fn try_spawn_await<R: Send + 'static>(
&self,
future: impl Future<Output = R> + 'static + Send,
) -> Result<JoinHandle<R>, crossbeam_channel::SendError<Job>> {
let (sender, receiver) = oneshot::channel::<R>();
let join_handle = JoinHandle { receiver };
self.try_spawn(async {
let result = future.await;
let _ignored_result = sender.send(result);
})
.map(|_| join_handle)
}
#[inline]
fn try_execute_task(&self, task: Job) -> Result<(), crossbeam_channel::SendError<Job>> {
let worker_count_data = &self.worker_data.worker_count_data;
let mut worker_count_val = worker_count_data.worker_count.load(Ordering::Relaxed);
let (mut curr_worker_count, idle_worker_count) = WorkerCountData::split(worker_count_val);
let mut curr_idle_count = idle_worker_count;
if curr_worker_count < self.core_size {
let witnessed =
worker_count_data.try_increment_worker_total(worker_count_val, self.core_size);
if witnessed == worker_count_val
|| WorkerCountData::get_total_count(witnessed) < self.core_size
{
let worker = Worker::new(
self.channel_data.receiver.clone(),
Arc::clone(&self.worker_data),
None,
);
worker.start(Some(task));
return Ok(());
}
curr_worker_count = WorkerCountData::get_total_count(witnessed);
curr_idle_count = WorkerCountData::get_idle_count(witnessed);
worker_count_val = witnessed;
}
if curr_worker_count < self.max_size && (idle_worker_count == 0 || curr_idle_count == 0) {
let witnessed =
worker_count_data.try_increment_worker_total(worker_count_val, self.max_size);
if witnessed == worker_count_val
|| WorkerCountData::get_total_count(witnessed) < self.max_size
{
let worker = Worker::new(
self.channel_data.receiver.clone(),
Arc::clone(&self.worker_data),
Some(self.keep_alive),
);
worker.start(Some(task));
return Ok(());
}
}
self.send_task_to_channel(task)
}
pub fn join(&self) {
self.inner_join(None);
}
pub fn join_timeout(&self, time_out: Duration) {
self.inner_join(Some(time_out));
}
pub fn shutdown(self) {
drop(self);
}
pub fn shutdown_join(self) {
self.inner_shutdown_join(None);
}
pub fn shutdown_join_timeout(self, timeout: Duration) {
self.inner_shutdown_join(Some(timeout));
}
pub fn get_name(&self) -> &str {
&self.worker_data.pool_name
}
pub fn start_core_threads(&self) {
let worker_count_data = &self.worker_data.worker_count_data;
let core_size = self.core_size;
let mut curr_worker_count = worker_count_data.worker_count.load(Ordering::Relaxed);
if WorkerCountData::get_total_count(curr_worker_count) >= core_size {
return;
}
loop {
let witnessed = worker_count_data.try_increment_worker_count(
curr_worker_count,
INCREMENT_TOTAL | INCREMENT_IDLE,
core_size,
);
if WorkerCountData::get_total_count(witnessed) >= core_size {
return;
}
let worker = Worker::new(
self.channel_data.receiver.clone(),
Arc::clone(&self.worker_data),
None,
);
worker.start(None);
curr_worker_count = witnessed;
}
}
#[inline]
fn send_task_to_channel(&self, task: Job) -> Result<(), crossbeam_channel::SendError<Job>> {
self.channel_data.sender.send(task)?;
Ok(())
}
#[inline]
fn inner_join(&self, time_out: Option<Duration>) {
ThreadPool::_do_join(&self.worker_data, &self.channel_data.receiver, time_out);
}
#[inline]
fn inner_shutdown_join(self, timeout: Option<Duration>) {
let current_worker_data = self.worker_data.clone();
let receiver = self.channel_data.receiver.clone();
drop(self);
ThreadPool::_do_join(¤t_worker_data, &receiver, timeout);
}
#[inline]
fn _do_join(
current_worker_data: &Arc<WorkerData>,
receiver: &crossbeam_channel::Receiver<Job>,
time_out: Option<Duration>,
) {
if ThreadPool::is_idle(current_worker_data, receiver) {
return;
}
let join_generation = current_worker_data.join_generation.load(Ordering::SeqCst);
let guard = current_worker_data
.join_notify_mutex
.lock()
.expect("could not get join notify mutex lock");
match time_out {
Some(time_out) => {
let _ret_guard = current_worker_data
.join_notify_condvar
.wait_timeout_while(guard, time_out, |_| {
join_generation
== current_worker_data.join_generation.load(Ordering::Relaxed)
&& !ThreadPool::is_idle(current_worker_data, receiver)
})
.expect("could not wait for join condvar");
}
None => {
let _ret_guard = current_worker_data
.join_notify_condvar
.wait_while(guard, |_| {
join_generation
== current_worker_data.join_generation.load(Ordering::Relaxed)
&& !ThreadPool::is_idle(current_worker_data, receiver)
})
.expect("could not wait for join condvar");
}
};
let _ = current_worker_data.join_generation.compare_exchange(
join_generation,
join_generation.wrapping_add(1),
Ordering::SeqCst,
Ordering::SeqCst,
);
}
#[inline]
fn is_idle(
current_worker_data: &Arc<WorkerData>,
receiver: &crossbeam_channel::Receiver<Job>,
) -> bool {
let (current_worker_count, current_idle_count) =
current_worker_data.worker_count_data.get_both();
current_idle_count == current_worker_count && receiver.is_empty()
}
}
impl Default for ThreadPool {
fn default() -> Self {
let num_cpus = num_cpus::get();
ThreadPool::new(
num_cpus,
std::cmp::max(num_cpus, num_cpus * 2),
Duration::from_secs(60),
)
}
}
#[derive(Default)]
pub struct Builder {
name: Option<String>,
core_size: Option<usize>,
max_size: Option<usize>,
keep_alive: Option<Duration>,
}
impl Builder {
pub fn new() -> Builder {
Builder::default()
}
pub fn name(mut self, name: String) -> Builder {
self.name = Some(name);
self
}
pub fn core_size(mut self, size: usize) -> Builder {
self.core_size = Some(size);
self
}
pub fn max_size(mut self, size: usize) -> Builder {
self.max_size = Some(size);
self
}
pub fn keep_alive(mut self, keep_alive: Duration) -> Builder {
self.keep_alive = Some(keep_alive);
self
}
pub fn build(self) -> ThreadPool {
use std::cmp::{max, min};
let core_size = self.core_size.unwrap_or_else(|| {
let num_cpus = num_cpus::get();
if let Some(max_size) = self.max_size {
min(MAX_SIZE, min(num_cpus, max_size))
} else {
min(MAX_SIZE, num_cpus)
}
});
let max_size = self
.max_size
.unwrap_or_else(|| min(MAX_SIZE, max(core_size, core_size * 2)));
let keep_alive = self.keep_alive.unwrap_or_else(|| Duration::from_secs(60));
if let Some(name) = self.name {
ThreadPool::new_named(name, core_size, max_size, keep_alive)
} else {
ThreadPool::new(core_size, max_size, keep_alive)
}
}
}
#[derive(Clone)]
struct Worker {
receiver: crossbeam_channel::Receiver<Job>,
worker_data: Arc<WorkerData>,
keep_alive: Option<Duration>,
}
impl Worker {
fn new(
receiver: crossbeam_channel::Receiver<Job>,
worker_data: Arc<WorkerData>,
keep_alive: Option<Duration>,
) -> Self {
Worker {
receiver,
worker_data,
keep_alive,
}
}
fn start(self, task: Option<Job>) {
let worker_name = format!(
"{}_thread_{}",
self.worker_data.pool_name,
self.worker_data
.worker_number
.fetch_add(1, Ordering::Relaxed)
);
thread::Builder::new()
.name(worker_name)
.spawn(move || {
let mut sentinel = Sentinel::new(&self);
if let Some(task) = task {
self.exec_task_and_notify(&mut sentinel, task);
}
loop {
let received_task: Result<Job, _> = match self.keep_alive {
Some(keep_alive) => self.receiver.recv_timeout(keep_alive).map_err(|_| ()),
None => self.receiver.recv().map_err(|_| ()),
};
match received_task {
Ok(task) => {
self.worker_data.worker_count_data.decrement_worker_idle();
self.exec_task_and_notify(&mut sentinel, task);
}
Err(_) => {
break;
}
}
}
self.worker_data.worker_count_data.decrement_both();
})
.expect("could not spawn thread");
}
#[inline]
fn exec_task_and_notify(&self, sentinel: &mut Sentinel, task: Job) {
sentinel.is_working = true;
task();
sentinel.is_working = false;
self.mark_idle_and_notify_joiners_if_no_work();
}
#[inline]
fn mark_idle_and_notify_joiners_if_no_work(&self) {
let (old_total_count, old_idle_count) = self
.worker_data
.worker_count_data
.increment_worker_idle_ret_both();
if old_total_count == old_idle_count + 1 && self.receiver.is_empty() {
let _lock = self
.worker_data
.join_notify_mutex
.lock()
.expect("could not get join notify mutex lock");
self.worker_data.join_notify_condvar.notify_all();
}
}
}
struct Sentinel<'s> {
is_working: bool,
worker_ref: &'s Worker,
}
impl Sentinel<'_> {
fn new(worker_ref: &Worker) -> Sentinel<'_> {
Sentinel {
is_working: false,
worker_ref,
}
}
}
impl Drop for Sentinel<'_> {
fn drop(&mut self) {
if thread::panicking() {
if self.is_working {
self.worker_ref.mark_idle_and_notify_joiners_if_no_work();
}
let worker = self.worker_ref.clone();
worker.start(None);
}
}
}
const WORKER_IDLE_MASK: usize = MAX_SIZE;
const INCREMENT_TOTAL: usize = 1 << (BITS / 2);
const INCREMENT_IDLE: usize = 1;
#[derive(Default)]
struct WorkerCountData {
worker_count: AtomicUsize,
}
impl WorkerCountData {
fn get_total_worker_count(&self) -> usize {
let curr_val = self.worker_count.load(Ordering::Relaxed);
WorkerCountData::get_total_count(curr_val)
}
fn get_idle_worker_count(&self) -> usize {
let curr_val = self.worker_count.load(Ordering::Relaxed);
WorkerCountData::get_idle_count(curr_val)
}
fn get_both(&self) -> (usize, usize) {
let curr_val = self.worker_count.load(Ordering::Relaxed);
WorkerCountData::split(curr_val)
}
#[allow(dead_code)]
fn increment_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_add(INCREMENT_TOTAL | INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
fn decrement_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_TOTAL | INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
fn try_increment_worker_total(&self, expected: usize, max_total: usize) -> usize {
self.try_increment_worker_count(expected, INCREMENT_TOTAL, max_total)
}
fn try_increment_worker_count(
&self,
mut expected: usize,
increment: usize,
max_total: usize,
) -> usize {
loop {
match self.worker_count.compare_exchange_weak(
expected,
expected + increment,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(witnessed) => return witnessed,
Err(witnessed) if WorkerCountData::get_total_count(witnessed) >= max_total => {
return witnessed
}
Err(witnessed) => expected = witnessed,
}
}
}
#[allow(dead_code)]
fn increment_worker_total(&self) -> usize {
let old_val = self
.worker_count
.fetch_add(INCREMENT_TOTAL, Ordering::Relaxed);
WorkerCountData::get_total_count(old_val)
}
#[allow(dead_code)]
fn increment_worker_total_ret_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_add(INCREMENT_TOTAL, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
#[allow(dead_code)]
fn decrement_worker_total(&self) -> usize {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_TOTAL, Ordering::Relaxed);
WorkerCountData::get_total_count(old_val)
}
#[allow(dead_code)]
fn decrement_worker_total_ret_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_TOTAL, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
#[allow(dead_code)]
fn increment_worker_idle(&self) -> usize {
let old_val = self
.worker_count
.fetch_add(INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::get_idle_count(old_val)
}
fn increment_worker_idle_ret_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_add(INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
fn decrement_worker_idle(&self) -> usize {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::get_idle_count(old_val)
}
#[allow(dead_code)]
fn decrement_worker_idle_ret_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
#[inline]
fn split(val: usize) -> (usize, usize) {
let total_count = val >> (BITS / 2);
let idle_count = val & WORKER_IDLE_MASK;
(total_count, idle_count)
}
#[inline]
fn get_total_count(val: usize) -> usize {
val >> (BITS / 2)
}
#[inline]
fn get_idle_count(val: usize) -> usize {
val & WORKER_IDLE_MASK
}
}
struct WorkerData {
pool_name: String,
worker_count_data: WorkerCountData,
worker_number: AtomicUsize,
join_notify_condvar: Condvar,
join_notify_mutex: Mutex<()>,
join_generation: AtomicUsize,
}
struct ChannelData {
sender: crossbeam_channel::Sender<Job>,
receiver: crossbeam_channel::Receiver<Job>,
}
#[cfg(test)]
mod tests {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::thread;
use std::time::Duration;
use super::Builder;
use super::ThreadPool;
use super::WorkerCountData;
#[test]
fn it_works() {
let pool = ThreadPool::new(2, 10, Duration::from_secs(5));
let count = Arc::new(AtomicUsize::new(0));
let count1 = count.clone();
pool.execute(move || {
count1.fetch_add(1, Ordering::Relaxed);
thread::sleep(std::time::Duration::from_secs(4));
});
let count2 = count.clone();
pool.execute(move || {
count2.fetch_add(1, Ordering::Relaxed);
thread::sleep(std::time::Duration::from_secs(4));
});
let count3 = count.clone();
pool.execute(move || {
count3.fetch_add(1, Ordering::Relaxed);
thread::sleep(std::time::Duration::from_secs(4));
});
let count4 = count.clone();
pool.execute(move || {
count4.fetch_add(1, Ordering::Relaxed);
thread::sleep(std::time::Duration::from_secs(4));
});
thread::sleep(std::time::Duration::from_secs(20));
let count5 = count.clone();
pool.execute(move || {
count5.fetch_add(1, Ordering::Relaxed);
thread::sleep(std::time::Duration::from_secs(4));
});
let count6 = count.clone();
pool.execute(move || {
count6.fetch_add(1, Ordering::Relaxed);
thread::sleep(std::time::Duration::from_secs(4));
});
let count7 = count.clone();
pool.execute(move || {
count7.fetch_add(1, Ordering::Relaxed);
thread::sleep(std::time::Duration::from_secs(4));
});
let count8 = count.clone();
pool.execute(move || {
count8.fetch_add(1, Ordering::Relaxed);
thread::sleep(std::time::Duration::from_secs(4));
});
thread::sleep(std::time::Duration::from_secs(20));
let count = count.load(Ordering::Relaxed);
let worker_count = pool.get_current_worker_count();
assert_eq!(count, 8);
assert_eq!(worker_count, 2);
assert_eq!(pool.get_idle_worker_count(), 2);
}
#[test]
#[ignore]
fn stress_test() {
let pool = Arc::new(ThreadPool::new(3, 50, Duration::from_secs(30)));
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..5 {
let pool_1 = pool.clone();
let clone = counter.clone();
pool.execute(move || {
for _ in 0..160 {
let clone = clone.clone();
pool_1.execute(move || {
clone.fetch_add(1, Ordering::Relaxed);
thread::sleep(Duration::from_secs(10));
});
}
thread::sleep(Duration::from_secs(20));
for _ in 0..160 {
let clone = clone.clone();
pool_1.execute(move || {
clone.fetch_add(1, Ordering::Relaxed);
thread::sleep(Duration::from_secs(10));
});
}
});
}
thread::sleep(Duration::from_secs(10));
assert_eq!(pool.get_current_worker_count(), 50);
pool.join();
assert_eq!(counter.load(Ordering::Relaxed), 1600);
thread::sleep(Duration::from_secs(31));
assert_eq!(pool.get_current_worker_count(), 3);
}
#[test]
fn test_join() {
let pool = ThreadPool::new(0, 1, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
let clone_1 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_1.fetch_add(1, Ordering::Relaxed);
});
let clone_2 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_2.fetch_add(1, Ordering::Relaxed);
});
pool.join();
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
#[test]
fn test_join_timeout() {
let pool = ThreadPool::new(0, 1, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
let clone = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(10));
clone.fetch_add(1, Ordering::Relaxed);
});
pool.join_timeout(Duration::from_secs(5));
assert_eq!(counter.load(Ordering::Relaxed), 0);
pool.join();
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[test]
fn test_shutdown() {
let pool = ThreadPool::new(1, 3, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
let clone_1 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_1.fetch_add(1, Ordering::Relaxed);
});
let clone_2 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_2.fetch_add(1, Ordering::Relaxed);
});
let clone_3 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_3.fetch_add(1, Ordering::Relaxed);
});
let clone_4 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_4.fetch_add(1, Ordering::Relaxed);
});
pool.join_timeout(Duration::from_secs(2));
pool.shutdown();
thread::sleep(Duration::from_secs(5));
assert_eq!(counter.load(Ordering::Relaxed), 3);
}
#[should_panic(
expected = "max_size must be greater than 0 and greater or equal to the core pool size"
)]
#[test]
fn test_panic_on_0_max_pool_size() {
ThreadPool::new(0, 0, Duration::from_secs(2));
}
#[should_panic(
expected = "max_size must be greater than 0 and greater or equal to the core pool size"
)]
#[test]
fn test_panic_on_smaller_max_than_core_pool_size() {
ThreadPool::new(10, 4, Duration::from_secs(2));
}
#[should_panic(expected = "max_size may not exceed")]
#[test]
fn test_panic_on_max_size_exceeds_half_usize() {
ThreadPool::new(
10,
1 << ((std::mem::size_of::<usize>() * 8) / 2),
Duration::from_secs(2),
);
}
#[test]
fn test_empty_join() {
let pool = ThreadPool::new(3, 10, Duration::from_secs(10));
pool.join();
}
#[test]
fn test_join_when_complete() {
let pool = ThreadPool::new(3, 10, Duration::from_secs(5));
pool.execute(|| {
thread::sleep(Duration::from_millis(5000));
});
thread::sleep(Duration::from_millis(5000));
pool.join();
}
#[test]
fn test_full_usage() {
let pool = ThreadPool::new(5, 50, Duration::from_secs(10));
for _ in 0..100 {
pool.execute(|| {
thread::sleep(Duration::from_secs(30));
});
}
thread::sleep(Duration::from_secs(10));
assert_eq!(pool.get_current_worker_count(), 50);
pool.join();
thread::sleep(Duration::from_secs(15));
assert_eq!(pool.get_current_worker_count(), 5);
}
#[test]
fn test_shutdown_join() {
let pool = ThreadPool::new(1, 1, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
let clone = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(10));
clone.fetch_add(1, Ordering::Relaxed);
});
pool.shutdown_join();
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[test]
fn test_shutdown_join_timeout() {
let pool = ThreadPool::new(1, 1, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
let clone = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(10));
clone.fetch_add(1, Ordering::Relaxed);
});
pool.shutdown_join_timeout(Duration::from_secs(5));
assert_eq!(counter.load(Ordering::Relaxed), 0);
}
#[test]
fn test_empty_shutdown_join() {
let pool = ThreadPool::new(1, 5, Duration::from_secs(5));
pool.shutdown_join();
}
#[test]
fn test_shutdown_core_pool() {
let pool = ThreadPool::new(5, 5, Duration::from_secs(1));
let counter = Arc::new(AtomicUsize::new(0));
let worker_data = pool.worker_data.clone();
for _ in 0..7 {
let clone = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(2));
clone.fetch_add(1, Ordering::Relaxed);
});
}
assert_eq!(pool.get_current_worker_count(), 5);
assert_eq!(pool.get_idle_worker_count(), 0);
pool.shutdown_join();
assert_eq!(counter.load(Ordering::Relaxed), 7);
thread::sleep(Duration::from_millis(50));
assert_eq!(worker_data.worker_count_data.get_total_worker_count(), 0);
assert_eq!(worker_data.worker_count_data.get_idle_worker_count(), 0);
}
#[test]
fn test_shutdown_idle_core_pool() {
let pool = ThreadPool::new(5, 5, Duration::from_secs(1));
let counter = Arc::new(AtomicUsize::new(0));
let worker_data = pool.worker_data.clone();
for _ in 0..5 {
let clone = counter.clone();
pool.execute(move || {
clone.fetch_add(1, Ordering::Relaxed);
});
}
pool.shutdown_join();
assert_eq!(counter.load(Ordering::Relaxed), 5);
thread::sleep(Duration::from_millis(50));
assert_eq!(worker_data.worker_count_data.get_total_worker_count(), 0);
assert_eq!(worker_data.worker_count_data.get_idle_worker_count(), 0);
}
#[test]
fn test_shutdown_on_complete() {
let pool = ThreadPool::new(3, 10, Duration::from_secs(5));
pool.execute(|| {
thread::sleep(Duration::from_millis(5000));
});
thread::sleep(Duration::from_millis(5000));
pool.shutdown_join();
}
#[test]
fn test_shutdown_after_complete() {
let pool = ThreadPool::new(3, 10, Duration::from_secs(5));
pool.execute(|| {
thread::sleep(Duration::from_millis(5000));
});
thread::sleep(Duration::from_millis(7000));
pool.shutdown_join();
}
#[test]
fn worker_count_test() {
let worker_count_data = WorkerCountData::default();
assert_eq!(worker_count_data.get_total_worker_count(), 0);
assert_eq!(worker_count_data.get_idle_worker_count(), 0);
worker_count_data.increment_both();
assert_eq!(worker_count_data.get_total_worker_count(), 1);
assert_eq!(worker_count_data.get_idle_worker_count(), 1);
for _ in 0..10 {
worker_count_data.increment_both();
}
assert_eq!(worker_count_data.get_total_worker_count(), 11);
assert_eq!(worker_count_data.get_idle_worker_count(), 11);
for _ in 0..15 {
worker_count_data.increment_worker_total();
}
for _ in 0..7 {
worker_count_data.increment_worker_idle();
}
assert_eq!(worker_count_data.get_total_worker_count(), 26);
assert_eq!(worker_count_data.get_idle_worker_count(), 18);
assert_eq!(worker_count_data.get_both(), (26, 18));
for _ in 0..5 {
worker_count_data.decrement_both();
}
assert_eq!(worker_count_data.get_total_worker_count(), 21);
assert_eq!(worker_count_data.get_idle_worker_count(), 13);
for _ in 0..13 {
worker_count_data.decrement_worker_total();
}
for _ in 0..4 {
worker_count_data.decrement_worker_idle();
}
assert_eq!(worker_count_data.get_total_worker_count(), 8);
assert_eq!(worker_count_data.get_idle_worker_count(), 9);
for _ in 0..456789 {
worker_count_data.increment_worker_total();
}
assert_eq!(worker_count_data.get_total_worker_count(), 456797);
assert_eq!(worker_count_data.get_idle_worker_count(), 9);
assert_eq!(worker_count_data.get_both(), (456797, 9));
for _ in 0..23456 {
worker_count_data.increment_worker_idle();
}
assert_eq!(worker_count_data.get_total_worker_count(), 456797);
assert_eq!(worker_count_data.get_idle_worker_count(), 23465);
for _ in 0..150000 {
worker_count_data.decrement_worker_total();
}
assert_eq!(worker_count_data.get_total_worker_count(), 306797);
assert_eq!(worker_count_data.get_idle_worker_count(), 23465);
for _ in 0..10000 {
worker_count_data.decrement_worker_idle();
}
assert_eq!(worker_count_data.get_total_worker_count(), 306797);
assert_eq!(worker_count_data.get_idle_worker_count(), 13465);
}
#[test]
fn test_try_increment_worker_total() {
let worker_count_data = WorkerCountData::default();
let witness = worker_count_data.try_increment_worker_total(0, 5);
assert_eq!(witness, 0);
assert_eq!(worker_count_data.get_total_worker_count(), 1);
assert_eq!(worker_count_data.get_idle_worker_count(), 0);
let witness = worker_count_data.try_increment_worker_total(0, 5);
assert_eq!(witness, 0x0000_0001_0000_0000);
assert_eq!(worker_count_data.get_total_worker_count(), 2);
assert_eq!(worker_count_data.get_idle_worker_count(), 0);
worker_count_data.try_increment_worker_total(2, 5);
worker_count_data.try_increment_worker_total(2, 5);
worker_count_data.try_increment_worker_total(4, 5);
worker_count_data.try_increment_worker_total(4, 5);
let witness = worker_count_data.try_increment_worker_total(2, 5);
assert_eq!(WorkerCountData::get_total_count(witness), 5);
assert_eq!(WorkerCountData::get_idle_count(witness), 0);
assert_eq!(worker_count_data.get_total_worker_count(), 5);
assert_eq!(worker_count_data.get_idle_worker_count(), 0);
let worker_count_data = Arc::new(worker_count_data);
let mut join_handles = Vec::with_capacity(5);
for _ in 0..5 {
let worker_count_data = worker_count_data.clone();
let join_handle = thread::spawn(move || {
for i in 0..5 {
worker_count_data.try_increment_worker_total(5 + i, 15);
}
});
join_handles.push(join_handle);
}
for join_handle in join_handles {
join_handle.join().unwrap();
}
assert_eq!(worker_count_data.get_total_worker_count(), 15);
assert_eq!(worker_count_data.get_idle_worker_count(), 0);
}
#[test]
fn test_join_enqueued_task() {
let pool = ThreadPool::new(3, 50, Duration::from_secs(20));
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..160 {
let clone = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(10));
clone.fetch_add(1, Ordering::Relaxed);
});
}
thread::sleep(Duration::from_secs(5));
assert_eq!(pool.get_current_worker_count(), 50);
pool.join();
assert_eq!(counter.load(Ordering::Relaxed), 160);
thread::sleep(Duration::from_secs(21));
assert_eq!(pool.get_current_worker_count(), 3);
}
#[test]
fn test_panic_all() {
let pool = ThreadPool::new(3, 10, Duration::from_secs(2));
for _ in 0..10 {
pool.execute(|| {
panic!("test");
})
}
pool.join();
thread::sleep(Duration::from_secs(5));
assert_eq!(pool.get_current_worker_count(), 3);
assert_eq!(pool.get_idle_worker_count(), 3);
}
#[test]
fn test_panic_some() {
let pool = ThreadPool::new(3, 10, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
for i in 0..10 {
let clone = counter.clone();
pool.execute(move || {
if i < 3 || i % 2 == 0 {
thread::sleep(Duration::from_secs(5));
clone.fetch_add(1, Ordering::Relaxed);
} else {
thread::sleep(Duration::from_secs(5));
panic!("test");
}
})
}
pool.join();
assert_eq!(counter.load(Ordering::Relaxed), 6);
assert_eq!(pool.get_current_worker_count(), 10);
assert_eq!(pool.get_idle_worker_count(), 10);
thread::sleep(Duration::from_secs(10));
assert_eq!(pool.get_current_worker_count(), 3);
assert_eq!(pool.get_idle_worker_count(), 3);
}
#[test]
fn test_panic_all_core_threads() {
let pool = ThreadPool::new(3, 3, Duration::from_secs(1));
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..3 {
pool.execute(|| {
panic!("test");
})
}
pool.join();
for i in 0..10 {
let clone = counter.clone();
pool.execute(move || {
if i < 3 || i % 2 == 0 {
clone.fetch_add(1, Ordering::Relaxed);
} else {
thread::sleep(Duration::from_secs(5));
panic!("test");
}
})
}
pool.join();
assert_eq!(counter.load(Ordering::Relaxed), 6);
assert_eq!(pool.get_current_worker_count(), 3);
assert_eq!(pool.get_idle_worker_count(), 3);
}
#[test]
fn test_drop_all_receivers() {
let pool = ThreadPool::new(0, 3, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..3 {
let clone = counter.clone();
pool.execute(move || {
clone.fetch_add(1, Ordering::Relaxed);
})
}
pool.join();
assert_eq!(counter.load(Ordering::Relaxed), 3);
thread::sleep(Duration::from_secs(10));
assert_eq!(pool.get_current_worker_count(), 0);
for _ in 0..3 {
let clone = counter.clone();
pool.execute(move || {
clone.fetch_add(1, Ordering::Relaxed);
})
}
pool.join();
assert_eq!(counter.load(Ordering::Relaxed), 6);
}
#[test]
fn test_evaluate() {
let pool = ThreadPool::new(0, 3, Duration::from_secs(5));
let count = AtomicUsize::new(0);
let handle = pool.evaluate(move || {
count.fetch_add(1, Ordering::Relaxed);
thread::sleep(Duration::from_secs(5));
count.fetch_add(1, Ordering::Relaxed)
});
let result = handle.await_complete();
assert_eq!(result, 1);
}
#[test]
fn test_multiple_evaluate() {
let pool = ThreadPool::new(0, 3, Duration::from_secs(5));
let count = AtomicUsize::new(0);
let handle_1 = pool.evaluate(move || {
for _ in 0..10000 {
count.fetch_add(1, Ordering::Relaxed);
}
thread::sleep(Duration::from_secs(5));
for _ in 0..10000 {
count.fetch_add(1, Ordering::Relaxed);
}
count.load(Ordering::Relaxed)
});
let handle_2 = pool.evaluate(move || {
let result = handle_1.await_complete();
let mut count = result;
count += 15000;
thread::sleep(Duration::from_secs(5));
count += 20000;
count
});
let result = handle_2.await_complete();
assert_eq!(result, 55000);
}
#[should_panic(expected = "could not receive message because channel was cancelled")]
#[test]
fn test_evaluate_panic() {
let pool = Builder::new().core_size(5).max_size(50).build();
let handle = pool.evaluate(|| {
let x = 3;
if x == 3 {
panic!("expected panic")
}
return x;
});
handle.await_complete();
}
#[test]
fn test_complete_fut() {
let pool = ThreadPool::new(0, 3, Duration::from_secs(5));
async fn async_fn() -> i8 {
8
}
let fut = async_fn();
let handle = pool.complete(fut);
assert_eq!(handle.await_complete(), 8);
}
#[cfg(feature = "async")]
#[test]
fn test_spawn() {
let pool = ThreadPool::default();
async fn add(x: i32, y: i32) -> i32 {
x + y
}
async fn multiply(x: i32, y: i32) -> i32 {
x * y
}
let count = Arc::new(AtomicUsize::new(0));
let clone = count.clone();
pool.spawn(async move {
let a = add(2, 3).await; let b = add(2, a).await; let c = multiply(2, b).await; let d = multiply(a, add(2, 1).await).await; let e = add(c, d).await;
clone.fetch_add(e as usize, Ordering::Relaxed);
});
pool.join();
assert_eq!(count.load(Ordering::Relaxed), 29);
}
#[cfg(feature = "async")]
#[test]
fn test_spawn_await() {
let pool = ThreadPool::default();
async fn sub(x: i32, y: i32) -> i32 {
x - y
}
async fn div(x: i32, y: i32) -> i32 {
x / y
}
let handle = pool.spawn_await(async {
let a = sub(120, 10).await; let b = div(sub(a, 10).await, 4).await; div(sub(b, div(10, 2).await).await, 5).await });
assert_eq!(handle.await_complete(), 4)
}
#[test]
fn test_drop_oneshot_receiver() {
let pool = Builder::new().core_size(1).max_size(1).build();
let handle = pool.evaluate(|| {
thread::sleep(Duration::from_secs(5));
5
});
drop(handle);
thread::sleep(Duration::from_secs(10));
let current_thread_index = pool.worker_data.worker_number.load(Ordering::Relaxed);
assert_eq!(current_thread_index, 2);
}
#[test]
fn test_builder_max_size() {
Builder::new().max_size(1).build();
}
#[test]
fn test_multi_thread_join() {
let pool = ThreadPool::default();
let count = Arc::new(AtomicUsize::new(0));
let clone1 = count.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(10));
clone1.fetch_add(1, Ordering::Relaxed);
});
let clone2 = count.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(10));
clone2.fetch_add(1, Ordering::Relaxed);
});
let clone3 = count.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(10));
clone3.fetch_add(1, Ordering::Relaxed);
});
let pool2 = pool.clone();
let clone4 = count.clone();
thread::spawn(move || {
thread::sleep(Duration::from_secs(5));
pool2.execute(move || {
thread::sleep(Duration::from_secs(15));
clone4.fetch_add(2, Ordering::Relaxed);
});
});
let pool3 = pool.clone();
let pool4 = pool.clone();
let pool5 = pool.clone();
let h1 = thread::spawn(move || {
pool3.join();
});
let h2 = thread::spawn(move || {
pool4.join();
});
let h3 = thread::spawn(move || {
pool5.join();
});
h1.join().unwrap();
h2.join().unwrap();
h3.join().unwrap();
assert_eq!(count.load(Ordering::Relaxed), 5);
}
#[test]
fn test_start_core_threads() {
let pool = Builder::new().core_size(5).build();
pool.start_core_threads();
assert_eq!(pool.get_current_worker_count(), 5);
assert_eq!(pool.get_idle_worker_count(), 5);
}
#[test]
fn test_start_and_use_core_threads() {
let pool = Builder::new()
.core_size(5)
.max_size(10)
.keep_alive(Duration::from_secs(u64::MAX))
.build();
pool.start_core_threads();
let result = pool.evaluate(|| 5 + 5).await_complete();
assert_eq!(result, 10);
assert_eq!(pool.get_current_worker_count(), 5);
}
}