use num_cpus;
use crossbeam_channel::{ unbounded, Receiver, Sender };
use std::fmt;
use std::sync::atomic::{ AtomicUsize, Ordering };
use std::sync::{ Arc, Condvar, Mutex };
use std::thread;
#[cfg(test)]
mod test;
pub fn single_queue_threadpool_auto_config() -> SingleQueueThreadpool {
single_queue_threadpool_builder().build()
}
pub const fn single_queue_threadpool_builder() -> SingleQueueThreadpoolBuilder {
SingleQueueThreadpoolBuilder {
num_workers: None,
worker_name: None,
thread_stack_size: None,
}
}
trait FnBox {
fn call_box(self: Box<Self>);
}
impl<F: FnOnce()> FnBox for F {
fn call_box(self: Box<F>) {
(*self)()
}
}
type Thunk<'a> = Box<dyn FnBox + Send + 'a>;
struct Sentinel<'a> {
shared_data: &'a Arc<SingleQueueThreadpoolSharedData>,
active: bool,
}
impl<'a> Sentinel<'a> {
fn new(shared_data: &'a Arc<SingleQueueThreadpoolSharedData>) -> Sentinel<'a> {
Sentinel {
shared_data: shared_data,
active: true,
}
}
fn cancel(mut self) {
self.active = false;
}
}
impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
if self.active {
self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
if thread::panicking() {
self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
}
self.shared_data.no_work_notify_all();
spawn_in_pool(self.shared_data.clone())
}
}
}
#[derive(Clone, Default)]
pub struct SingleQueueThreadpoolBuilder {
num_workers: Option<usize>,
worker_name: Option<String>,
thread_stack_size: Option<usize>,
}
impl SingleQueueThreadpoolBuilder {
pub fn num_workers(mut self, num_workers: usize) -> SingleQueueThreadpoolBuilder {
assert!(num_workers > 0);
self.num_workers = Some(num_workers);
self
}
pub fn worker_name<S: AsRef<str>>(mut self, name: S) -> SingleQueueThreadpoolBuilder {
self.worker_name = Some(name.as_ref().to_owned());
self
}
pub fn thread_stack_size(mut self, size: usize) -> SingleQueueThreadpoolBuilder {
self.thread_stack_size = Some(size);
self
}
pub fn build(self) -> SingleQueueThreadpool {
let (tx, rx) = unbounded::<Thunk<'static>>();
let num_workers = self.num_workers.unwrap_or_else(num_cpus::get);
let shared_data = Arc::new(SingleQueueThreadpoolSharedData {
name: self.worker_name,
job_receiver: Mutex::new(rx),
empty_condvar: Condvar::new(),
empty_trigger: Mutex::new(()),
join_generation: AtomicUsize::new(0),
queued_count: AtomicUsize::new(0),
active_count: AtomicUsize::new(0),
max_thread_count: AtomicUsize::new(num_workers),
panic_count: AtomicUsize::new(0),
stack_size: self.thread_stack_size,
});
for _ in 0..num_workers {
spawn_in_pool(shared_data.clone());
}
SingleQueueThreadpool {
jobs: tx,
shared_data: shared_data,
}
}
}
struct SingleQueueThreadpoolSharedData {
name: Option<String>,
job_receiver: Mutex<Receiver<Thunk<'static>>>,
empty_trigger: Mutex<()>,
empty_condvar: Condvar,
join_generation: AtomicUsize,
queued_count: AtomicUsize,
active_count: AtomicUsize,
max_thread_count: AtomicUsize,
panic_count: AtomicUsize,
stack_size: Option<usize>,
}
impl SingleQueueThreadpoolSharedData {
fn has_work(&self) -> bool {
self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
}
fn no_work_notify_all(&self) {
if !self.has_work() {
*self
.empty_trigger
.lock()
.expect("Unable to notify all joining threads");
self.empty_condvar.notify_all();
}
}
}
pub struct SingleQueueThreadpool {
jobs: Sender<Thunk<'static>>,
shared_data: Arc<SingleQueueThreadpoolSharedData>,
}
impl SingleQueueThreadpool {
pub fn execute<F>(&self, job: F)
where
F: FnOnce() + Send + 'static,
{
self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
self.jobs
.send(Box::new(job))
.expect("SingleQueueThreadpool::execute unable to send job into queue.");
}
pub fn queued_count(&self) -> usize {
self.shared_data.queued_count.load(Ordering::Relaxed)
}
pub fn active_count(&self) -> usize {
self.shared_data.active_count.load(Ordering::SeqCst)
}
pub fn max_count(&self) -> usize {
self.shared_data.max_thread_count.load(Ordering::Relaxed)
}
pub fn panic_count(&self) -> usize {
self.shared_data.panic_count.load(Ordering::Relaxed)
}
pub fn set_num_workers(&self, num_workers: usize) {
assert!(num_workers >= 1);
let prev_num_workers = self
.shared_data
.max_thread_count
.swap(num_workers, Ordering::Release);
if let Some(num_spawn) = num_workers.checked_sub(prev_num_workers) {
for _ in 0..num_spawn {
spawn_in_pool(self.shared_data.clone());
}
}
}
pub fn join(&self) {
if self.shared_data.has_work() == false {
return ();
}
let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
let mut lock = self.shared_data.empty_trigger.lock().unwrap();
while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
&& self.shared_data.has_work()
{
lock = self.shared_data.empty_condvar.wait(lock).unwrap();
}
let _ = self.shared_data.join_generation.compare_exchange(
generation,
generation.wrapping_add(1),
Ordering::SeqCst,
Ordering::SeqCst,
);
}
}
impl Clone for SingleQueueThreadpool {
fn clone(&self) -> SingleQueueThreadpool {
SingleQueueThreadpool {
jobs: self.jobs.clone(),
shared_data: self.shared_data.clone(),
}
}
}
impl fmt::Debug for SingleQueueThreadpool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SingleQueueThreadpool")
.field("name", &self.shared_data.name)
.field("queued_count", &self.queued_count())
.field("active_count", &self.active_count())
.field("max_count", &self.max_count())
.finish()
}
}
impl PartialEq for SingleQueueThreadpool {
fn eq(&self, other: &SingleQueueThreadpool) -> bool {
Arc::ptr_eq(&self.shared_data, &other.shared_data)
}
}
impl Eq for SingleQueueThreadpool {}
fn spawn_in_pool(shared_data: Arc<SingleQueueThreadpoolSharedData>) {
let mut builder = thread::Builder::new();
if let Some(ref name) = shared_data.name {
builder = builder.name(name.clone());
}
if let Some(ref stack_size) = shared_data.stack_size {
builder = builder.stack_size(stack_size.to_owned());
}
builder
.spawn(move || {
let sentinel = Sentinel::new(&shared_data);
loop {
let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
if thread_counter_val >= max_thread_count_val {
break;
}
let message = {
let lock = shared_data
.job_receiver
.lock()
.expect("Worker thread unable to lock job_receiver");
lock.recv()
};
let job = match message {
Ok(job) => job,
Err(..) => break,
};
shared_data.active_count.fetch_add(1, Ordering::SeqCst);
shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
job.call_box();
shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
shared_data.no_work_notify_all();
}
sentinel.cancel();
})
.unwrap();
}