use crossbeam::{channel, crossbeam_channel};
use std::option::Option;
use std::sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
mpsc, Arc, Condvar, Mutex,
};
use std::thread;
use std::time::Duration;
type Task = Box<dyn FnOnce() + Send + 'static>;
pub enum ExecuteError<T> {
ChannelClosedError(channel::SendError<T>),
ThreadPoolClosedError,
}
pub struct ThreadPool {
is_shutdown: AtomicBool,
core_size: usize,
max_size: usize,
keep_alive: Duration,
worker_count: Arc<AtomicUsize>,
idle_worker_count: Arc<AtomicUsize>,
sender: channel::Sender<Task>,
receiver: Arc<channel::Receiver<Task>>,
join_notify_condvar: Arc<Condvar>,
join_notify_mutex: Arc<Mutex<()>>,
}
impl ThreadPool {
pub fn new(core_size: usize, max_size: usize, keep_alive: Duration) -> Self {
let (sender, receiver) = crossbeam_channel::unbounded();
if max_size == 0 || max_size < core_size {
panic!("max_size must be greater than 0 and greater or equal to the core pool size");
}
Self {
is_shutdown: AtomicBool::new(false),
core_size,
max_size,
keep_alive,
worker_count: Arc::new(AtomicUsize::new(0)),
idle_worker_count: Arc::new(AtomicUsize::new(0)),
sender,
receiver: Arc::new(receiver),
join_notify_condvar: Arc::new(Condvar::new()),
join_notify_mutex: Arc::new(Mutex::new(())),
}
}
pub fn get_current_worker_count(&self) -> usize {
self.worker_count.load(Ordering::SeqCst)
}
pub fn get_idle_worker_count(&self) -> usize {
self.idle_worker_count.load(Ordering::SeqCst)
}
pub fn execute<T: FnOnce() + Send + 'static>(
&self,
task: T,
) -> Result<(), ExecuteError<Box<dyn FnOnce() + Send + 'static>>> {
if self.is_shutdown.load(Ordering::SeqCst) {
return Err(ExecuteError::ThreadPoolClosedError);
}
if let Err(send_err) = self.sender.send(Box::new(task)) {
return Err(ExecuteError::ChannelClosedError(send_err));
}
let curr_worker_count = self.worker_count.load(Ordering::SeqCst);
if curr_worker_count < self.core_size {
self.create_worker(true, |old_val| {
old_val < self.core_size && !self.is_shutdown.load(Ordering::SeqCst)
});
} else if curr_worker_count < self.max_size
&& self.idle_worker_count.load(Ordering::SeqCst) == 0
{
self.create_worker(false, |old_val| {
old_val < self.max_size
&& self.idle_worker_count.load(Ordering::SeqCst) == 0
&& !self.is_shutdown.load(Ordering::SeqCst)
})
}
Ok(())
}
pub fn join(&self) {
self.join_internal(None);
}
pub fn join_timeout(&self, time_out: Duration) {
self.join_internal(Some(time_out));
}
pub fn shutdown(self) {
self.is_shutdown.store(true, Ordering::SeqCst);
drop(self);
}
fn create_worker<T: FnOnce(usize) -> bool>(&self, is_core: bool, recheck_condition: T) {
let (green_light_sender, green_light_receiver) = mpsc::channel();
Worker::new(
Arc::clone(&self.receiver),
green_light_receiver,
Arc::clone(&self.worker_count),
Arc::clone(&self.idle_worker_count),
!is_core,
if is_core { None } else { Some(self.keep_alive) },
Arc::clone(&self.join_notify_condvar),
Arc::clone(&self.join_notify_mutex),
);
let old_val = self.worker_count.fetch_add(1, Ordering::SeqCst);
if recheck_condition(old_val) {
green_light_sender
.send(true)
.expect("failed to send green light signal to worker");
} else {
green_light_sender
.send(false)
.expect("failed to send denied green light signal to worker");
self.worker_count.fetch_sub(1, Ordering::SeqCst);
}
}
fn join_internal(&self, time_out: Option<Duration>) {
let current_worker_count = self.worker_count.load(Ordering::SeqCst);
let current_idle_count = self.idle_worker_count.load(Ordering::SeqCst);
if current_idle_count == current_worker_count {
return;
}
let guard = self
.join_notify_mutex
.lock()
.expect("could not get join notify mutex lock");
match time_out {
Some(time_out) => {
let _ret_lock = self
.join_notify_condvar
.wait_timeout(guard, time_out)
.expect("could not wait for join condvar");
}
None => {
let _ret_lock = self
.join_notify_condvar
.wait(guard)
.expect("could not wait for join condvar");
}
};
}
}
struct Worker;
impl Worker {
fn new(
receiver: Arc<channel::Receiver<Task>>,
green_light_receiver: mpsc::Receiver<bool>,
worker_count: Arc<AtomicUsize>,
idle_worker_count: Arc<AtomicUsize>,
can_timeout: bool,
keep_alive: Option<Duration>,
join_notify_condvar: Arc<Condvar>,
join_notify_mutex: Arc<Mutex<()>>,
) -> Self {
thread::spawn(move || {
match green_light_receiver.recv() {
Ok(true) => {
}
_ => {
return;
}
}
idle_worker_count.fetch_add(1, Ordering::SeqCst);
loop {
let received_task: Result<Task, _> =
if can_timeout {
receiver
.recv_timeout(keep_alive.expect(
"keep_alive duration is NONE despite can_timeout being true",
))
.map_err(|_| ())
} else {
receiver.recv().map_err(|_| ())
};
match received_task {
Ok(task) => {
idle_worker_count.fetch_sub(1, Ordering::SeqCst);
task();
idle_worker_count.fetch_add(1, Ordering::SeqCst);
if receiver.is_empty() {
let _lock = join_notify_mutex
.lock()
.expect("could not get join notify mutex lock");
join_notify_condvar.notify_all();
}
}
Err(_) => {
idle_worker_count.fetch_sub(1, Ordering::SeqCst);
break;
}
}
}
worker_count.fetch_sub(1, Ordering::SeqCst);
});
Worker
}
}
#[allow(unused_must_use)]
#[cfg(test)]
mod tests {
use super::ThreadPool;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::thread;
use std::time::Duration;
#[test]
fn it_works() {
let pool = ThreadPool::new(2, 10, Duration::from_secs(5));
let count = Arc::new(AtomicUsize::new(0));
let count1 = count.clone();
pool.execute(move || {
count1.fetch_add(1, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_secs(4));
});
let count2 = count.clone();
pool.execute(move || {
count2.fetch_add(1, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_secs(4));
});
let count3 = count.clone();
pool.execute(move || {
count3.fetch_add(1, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_secs(4));
});
let count4 = count.clone();
pool.execute(move || {
count4.fetch_add(1, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_secs(4));
});
thread::sleep(std::time::Duration::from_secs(20));
let count5 = count.clone();
pool.execute(move || {
count5.fetch_add(1, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_secs(4));
});
let count6 = count.clone();
pool.execute(move || {
count6.fetch_add(1, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_secs(4));
});
let count7 = count.clone();
pool.execute(move || {
count7.fetch_add(1, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_secs(4));
});
let count8 = count.clone();
pool.execute(move || {
count8.fetch_add(1, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_secs(4));
});
thread::sleep(std::time::Duration::from_secs(20));
let count = count.load(Ordering::SeqCst);
let worker_count = pool.get_current_worker_count();
assert_eq!(count, 8);
assert_eq!(worker_count, 2);
assert_eq!(pool.get_idle_worker_count(), 2);
}
#[test]
#[ignore]
fn stress_test() {
let pool = ThreadPool::new(3, 50, Duration::from_secs(30));
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..160 {
let clone = counter.clone();
pool.execute(move || {
clone.fetch_add(1, Ordering::SeqCst);
thread::sleep(Duration::from_secs(10))
});
}
thread::sleep(Duration::from_secs(5));
assert!(pool.get_current_worker_count() <= 50);
thread::sleep(Duration::from_secs(20));
for _ in 0..160 {
let clone = counter.clone();
pool.execute(move || {
clone.fetch_add(1, Ordering::SeqCst);
thread::sleep(Duration::from_secs(10))
});
}
thread::sleep(Duration::from_secs(5));
assert!(pool.get_current_worker_count() <= 50);
thread::sleep(Duration::from_secs(200));
assert_eq!(counter.load(Ordering::SeqCst), 320);
assert_eq!(pool.get_current_worker_count(), 3);
assert_eq!(pool.get_idle_worker_count(), 3);
}
#[test]
fn test_join() {
let pool = ThreadPool::new(0, 1, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
let clone_1 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_1.fetch_add(1, Ordering::SeqCst);
});
let clone_2 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_2.fetch_add(1, Ordering::SeqCst);
});
pool.join();
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[test]
fn test_join_timeout() {
let pool = ThreadPool::new(0, 1, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
let clone = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(10));
clone.fetch_add(1, Ordering::SeqCst);
});
pool.join_timeout(Duration::from_secs(5));
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[test]
fn test_shutdown() {
let pool = ThreadPool::new(1, 3, Duration::from_secs(5));
let counter = Arc::new(AtomicUsize::new(0));
let clone_1 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_1.fetch_add(1, Ordering::SeqCst);
});
let clone_2 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_2.fetch_add(1, Ordering::SeqCst);
});
let clone_3 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_3.fetch_add(1, Ordering::SeqCst);
});
let clone_4 = counter.clone();
pool.execute(move || {
thread::sleep(Duration::from_secs(5));
clone_4.fetch_add(1, Ordering::SeqCst);
});
pool.join_timeout(Duration::from_secs(2));
pool.shutdown();
thread::sleep(Duration::from_secs(5));
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[should_panic(
expected = "max_size must be greater than 0 and greater or equal to the core pool size"
)]
#[test]
fn test_panic_on_0_max_pool_size() {
ThreadPool::new(0, 0, Duration::from_secs(2));
}
#[should_panic(
expected = "max_size must be greater than 0 and greater or equal to the core pool size"
)]
#[test]
fn test_panic_on_smaller_max_than_core_pool_size() {
ThreadPool::new(0, 0, Duration::from_secs(2));
}
#[test]
fn test_empty_join() {
let pool = ThreadPool::new(3, 10, Duration::from_secs(10));
pool.join();
}
#[test]
fn test_join_when_complete() {
let pool = ThreadPool::new(3, 10, Duration::from_secs(5));
pool.execute(|| {
thread::sleep(Duration::from_millis(5000));
});
thread::sleep(Duration::from_millis(5000));
pool.join();
}
}