#![warn(missing_docs)]
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering as AtomicOrdering};
use std::sync::{Arc, Condvar, Mutex};
use std::{process, thread};
type BoxedRunnable = Box<Runnable + Send + 'static>;
pub const NORMAL_PRIORITY: isize = 0;
pub trait Runnable {
fn run(self: Box<Self>);
}
impl<F: FnOnce()> Runnable for F {
#[inline(always)]
fn run(self: Box<Self>) {
self()
}
}
struct Job {
runnable: BoxedRunnable,
priority: isize,
}
impl PartialOrd for Job {
fn partial_cmp(&self, other: &Job) -> Option<Ordering> {
Some(self.priority.cmp(&other.priority))
}
}
impl Ord for Job {
fn cmp(&self, other: &Job) -> Ordering {
self.priority.cmp(&other.priority)
}
}
impl PartialEq for Job {
fn eq(&self, other: &Job) -> bool {
self.priority == other.priority
}
}
impl Eq for Job {}
struct WorkerState {
workers: HashMap<usize, Worker>,
removed_handles: Vec<Option<thread::JoinHandle<()>>>,
busy_workers: usize,
id_counter: usize,
}
struct Worker {
id: usize,
handle: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(
id: usize,
worker_state: Arc<Mutex<WorkerState>>,
job_queue: Arc<Mutex<BinaryHeap<Job>>>,
condvar: Arc<Condvar>,
min_size: Arc<usize>,
max_size: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
) -> Self {
let builder = thread::Builder::new().name(format!("worker-{}", id));
let handle = builder.spawn(move || loop {
let job = {
let mut guard = job_queue.lock().unwrap();
while guard.is_empty() && !shutdown.load(AtomicOrdering::SeqCst) {
guard = condvar.wait(guard).unwrap();
}
guard.pop()
};
if job.is_none() {
break;
}
let job = job.unwrap();
{
let mut guard = worker_state.lock().unwrap();
if guard.busy_workers < guard.workers.len() {
guard.busy_workers += 1;
}
}
let auto_grow = max_size.load(AtomicOrdering::SeqCst) > *min_size;
if auto_grow {
JobPool::try_grow(
worker_state.clone(),
job_queue.clone(),
condvar.clone(),
min_size.clone(),
max_size.clone(),
shutdown.clone(),
);
}
job.runnable.run();
let mut guard = worker_state.lock().unwrap();
if guard.busy_workers > 0 {
guard.busy_workers -= 1;
}
if auto_grow {
if guard.workers.len() > *min_size && guard.busy_workers < *min_size {
let worker = guard.workers.remove(&id);
if let Some(worker) = worker {
if let Some(handle) = worker.handle {
guard.removed_handles.push(Some(handle));
}
}
break;
}
}
});
let handle = match handle {
Ok(h) => Some(h),
Err(e) => {
eprintln!("Error: {}", e);
process::exit(1);
}
};
Self { id, handle }
}
}
pub struct JobPool {
size: Arc<usize>,
max_size: Arc<AtomicUsize>,
worker_state: Arc<Mutex<WorkerState>>,
job_queue: Arc<Mutex<BinaryHeap<Job>>>,
condvar: Arc<Condvar>,
shutdown: Arc<AtomicBool>,
}
impl JobPool {
pub fn new(size: usize) -> Self {
if size == 0 {
panic!("size cannot be 0")
}
let job_queue = Arc::new(Mutex::new(BinaryHeap::new()));
let condvar = Arc::new(Condvar::new());
let max_size = Arc::new(AtomicUsize::new(size));
let shutdown = Arc::new(AtomicBool::new(false));
let size = Arc::new(size);
let worker_state = Arc::new(Mutex::new(WorkerState {
workers: HashMap::new(),
removed_handles: Vec::new(),
busy_workers: 0,
id_counter: 0,
}));
{
let mut guard = worker_state.lock().unwrap();
for id in 0..*size {
guard.workers.insert(
id,
Worker::new(
id,
worker_state.clone(),
job_queue.clone(),
condvar.clone(),
size.clone(),
max_size.clone(),
shutdown.clone(),
),
);
}
}
Self {
size,
max_size,
worker_state,
job_queue,
condvar,
shutdown,
}
}
pub fn queue<J>(&mut self, job: J)
where
J: Runnable + Send + 'static,
{
self.queue_job(job, NORMAL_PRIORITY);
}
pub fn queue_with_priority<J>(&mut self, job: J, priority: isize)
where
J: Runnable + Send + 'static,
{
self.queue_job(job, priority);
}
fn queue_job<J>(&mut self, job: J, priority: isize)
where
J: Runnable + Send + 'static,
{
if self.shutdown.load(AtomicOrdering::SeqCst) {
panic!("Error: this threadpool has been shutdown!");
} else {
self.push_new_job(job, priority);
self.condvar.notify_one();
if self.max_size.load(AtomicOrdering::SeqCst) > *self.size {
Self::try_grow(
self.worker_state.clone(),
self.job_queue.clone(),
self.condvar.clone(),
self.size.clone(),
self.max_size.clone(),
self.shutdown.clone()
);
}
}
}
fn try_grow(
worker_state: Arc<Mutex<WorkerState>>,
job_queue: Arc<Mutex<BinaryHeap<Job>>>,
condvar: Arc<Condvar>,
min_size: Arc<usize>,
max_size: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
) {
let remaining_job_count = {
let guard = job_queue.lock().unwrap();
guard.len()
};
let mut guard = worker_state.lock().unwrap();
let busy_workers = guard.busy_workers;
let total_workers = guard.workers.len();
let max_size_val = max_size.load(AtomicOrdering::SeqCst);
assert!(total_workers <= max_size_val);
assert!(busy_workers <= total_workers);
if busy_workers < total_workers || total_workers == max_size_val {
return;
}
let available_workers = total_workers - busy_workers;
if remaining_job_count <= available_workers {
return;
}
if shutdown.load(AtomicOrdering::SeqCst) {
return;
}
let new_id = {
guard.id_counter += 1;
guard.id_counter
};
guard.workers.insert(
new_id,
Worker::new(
new_id,
worker_state.clone(),
job_queue.clone(),
condvar.clone(),
min_size.clone(),
max_size.clone(),
shutdown.clone(),
),
);
}
fn push_new_job<J>(&mut self, job: J, priority: isize)
where
J: Runnable + Send + 'static,
{
let mut guard = self.job_queue.lock().unwrap();
guard.push(Job {
runnable: Box::new(job),
priority: priority,
});
}
pub fn auto_grow(&mut self, max_size: usize) {
if max_size <= *self.size {
panic!("max_size must be greater than initial JobPool size");
}
self.max_size.store(max_size, AtomicOrdering::SeqCst);
}
pub fn active_workers_count(&self) -> usize {
let guard = self.worker_state.lock().unwrap();
guard.busy_workers
}
pub fn shutdown(&mut self) {
if self.has_shutdown() {
return;
}
self.notify_shutdown();
let handles = {
let mut handles = Vec::new();
let mut guard = self.worker_state.lock().unwrap();
handles.reserve(guard.workers.len() + guard.removed_handles.len());
for (_, worker) in &mut guard.workers {
if let Some(handle) = worker.handle.take() {
handles.push((Some(worker.id), handle));
}
}
for handle in &mut guard.removed_handles {
handles.push((None, handle.take().unwrap()));
}
handles
};
for (id, handle) in handles {
match handle.join() {
Ok(_) => (),
Err(e) => match id {
Some(id) => eprintln!("Error joining worker-{} thread: {:?}", id, e),
None => eprintln!("Error joining thread: {:?}", e),
},
}
}
}
pub fn shutdown_no_wait(&mut self) -> Option<Vec<thread::JoinHandle<()>>> {
if self.has_shutdown() {
return None;
}
self.notify_shutdown();
let mut handles = Vec::new();
let mut guard = self.worker_state.lock().unwrap();
handles.reserve(guard.workers.len() + guard.removed_handles.len());
for (_, worker) in &mut guard.workers {
if let Some(handle) = worker.handle.take() {
handles.push(handle);
}
}
for handle in &mut guard.removed_handles {
handles.push(handle.take().unwrap());
}
Some(handles)
}
pub fn has_shutdown(&self) -> bool {
self.shutdown.load(AtomicOrdering::SeqCst)
}
fn notify_shutdown(&mut self) {
self.shutdown.store(true, AtomicOrdering::SeqCst);
self.condvar.notify_all();
}
}
impl Drop for JobPool {
fn drop(&mut self) {
self.shutdown();
}
}