use num_cpus;
use crossbeam_channel::{ unbounded, Receiver, Sender };
use log::{ trace, warn };
use std::fmt;
use std::sync::atomic::{ AtomicI8, AtomicUsize, Ordering };
use std::sync::{ Arc, Condvar, Mutex };
use std::{ thread, time };
#[cfg(test)]
mod test;
pub fn lft_auto_config() -> ThreadPool {
lft_builder().build()
}
pub const fn lft_builder() -> Builder {
Builder {
num_workers: None,
max_thread_count: None,
worker_name: None,
thread_stack_size: None,
}
}
trait FnBox {
fn call_box(self: Box<Self>);
}
impl<F: FnOnce()> FnBox for F {
fn call_box(self: Box<F>) {
(*self)()
}
}
type Thunk<'a> = Box<dyn FnBox + Send + 'a>;
struct Sentinel<'a> {
shared_data: &'a Arc<ThreadPoolSharedData>,
receiver: &'a Arc<Receiver<Thunk<'static>>>,
num_jobs: &'a Arc<AtomicUsize>,
thread_closing: &'a Arc<AtomicI8>,
active: bool,
}
impl<'a> Sentinel<'a> {
fn new(shared_data: &'a Arc<ThreadPoolSharedData>,
receiver: &'a Arc<Receiver<Thunk<'static>>>,
num_jobs: &'a Arc<AtomicUsize>,
thread_closing: &'a Arc<AtomicI8>) -> Sentinel<'a> {
Sentinel {
shared_data: shared_data,
receiver: receiver,
num_jobs: num_jobs,
thread_closing: thread_closing,
active: true,
}
}
fn cancel(mut self) {
self.active = false;
}
}
impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
if self.active {
self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
self.thread_closing.store(3, Ordering::SeqCst);
if thread::panicking() {
self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
}
if self.num_jobs.load(Ordering::Acquire) == 0 {
self.shared_data.no_work_notify_all();
}
self.thread_closing.store(1, Ordering::SeqCst);
spawn_in_pool(self.shared_data.clone(),
self.receiver.clone(),
self.num_jobs.clone(),
self.thread_closing.clone())
}
}
}
#[derive(Clone, Default)]
pub struct Builder {
num_workers: Option<usize>,
max_thread_count: Option<usize>,
worker_name: Option<String>,
thread_stack_size: Option<usize>,
}
impl Builder {
pub fn num_workers(mut self, num_workers: usize) -> Builder {
assert!(num_workers > 0);
self.num_workers = Some(num_workers);
self
}
pub fn max_thread_count(mut self, max_thread_count: usize) -> Builder {
assert!(max_thread_count > 0);
self.max_thread_count = Some(max_thread_count);
self
}
pub fn worker_name<S: AsRef<str>>(mut self, name: S) -> Builder {
self.worker_name = Some(name.as_ref().to_owned());
self
}
pub fn thread_stack_size(mut self, size: usize) -> Builder {
self.thread_stack_size = Some(size);
self
}
pub fn build(self) -> ThreadPool {
let mut num_workers = self.num_workers.unwrap_or_else(num_cpus::get);
let max_thread_count = self.max_thread_count.unwrap_or_else(|| {num_workers});
if max_thread_count < num_workers {
warn!("Number of works is larger than max thread number, shrinking
the thread pool to max thread number {}.", max_thread_count);
num_workers = max_thread_count;
}
let mut num_jobs_list: Vec<Arc<AtomicUsize>> = Vec::with_capacity(max_thread_count);
let mut thread_closing_list: Vec<Arc<AtomicI8>> = Vec::with_capacity(max_thread_count);
let mut sender_list: Vec<Sender<Thunk<'static>>> = Vec::with_capacity(max_thread_count);
let mut receiver_list: Vec<Arc<Receiver<Thunk<'static>>>> = Vec::with_capacity(max_thread_count);
for i in 0..max_thread_count {
let (tx, rx) = unbounded::<Thunk<'static>>();
num_jobs_list.push(Arc::new(AtomicUsize::new(0)));
sender_list.push(tx);
receiver_list.push(Arc::new(rx));
if i < num_workers {
thread_closing_list.push(Arc::new(AtomicI8::new(1)));
} else {
thread_closing_list.push(Arc::new(AtomicI8::new(3)));
}
}
let context = Arc::new(ThreadPoolContext {
queued_count: num_jobs_list.clone(),
thread_closing: thread_closing_list.clone(),
senders: sender_list,
receivers: receiver_list.clone(),
});
let shared_data = Arc::new(ThreadPoolSharedData {
name: self.worker_name,
empty_condvar: Condvar::new(),
empty_trigger: Mutex::new(()),
join_generation: AtomicUsize::new(0),
queued_count: AtomicUsize::new(0),
active_count: AtomicUsize::new(0),
num_workers: AtomicUsize::new(num_workers),
max_thread_count: AtomicUsize::new(max_thread_count),
panic_count: AtomicUsize::new(0),
stack_size: self.thread_stack_size,
});
let sleep_duration = time::Duration::from_millis(8);
for i in 0..max_thread_count {
spawn_in_pool(shared_data.clone(),
receiver_list[i].clone(),
num_jobs_list[i].clone(),
thread_closing_list[i].clone());
while thread_closing_list[i].load(Ordering::SeqCst) != 0 {
thread::sleep(sleep_duration);
}
}
ThreadPool {
shared_data: shared_data,
context: context,
}
}
}
struct ThreadPoolSharedData {
name: Option<String>,
empty_trigger: Mutex<()>,
empty_condvar: Condvar,
join_generation: AtomicUsize,
queued_count: AtomicUsize,
active_count: AtomicUsize,
num_workers: AtomicUsize,
max_thread_count: AtomicUsize,
panic_count: AtomicUsize,
stack_size: Option<usize>,
}
impl ThreadPoolSharedData {
fn has_work(&self) -> bool {
self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
}
fn no_work_notify_all(&self) {
if !self.has_work() {
*self
.empty_trigger
.lock()
.expect("Unable to notify all joining threads");
self.empty_condvar.notify_all();
}
}
}
struct ThreadPoolContext {
queued_count: Vec<Arc<AtomicUsize>>,
senders: Vec<Sender<Thunk<'static>>>,
receivers: Vec<Arc<Receiver<Thunk<'static>>>>,
thread_closing: Vec<Arc<AtomicI8>>,
}
pub struct ThreadPool {
shared_data: Arc<ThreadPoolSharedData>,
context: Arc<ThreadPoolContext>,
}
impl ThreadPool {
pub fn execute<F>(&self, job: F)
where
F: FnOnce() + Send + 'static,
{
let max_thread_count = self.shared_data.max_thread_count.load(Ordering::Relaxed);
loop {
let mut target_thread_id = max_thread_count + 1;
let mut min_jobs_counted: usize = 0;
for i in 0..max_thread_count {
if self.context.thread_closing[i].load(Ordering::Relaxed) > 0 {
continue;
}
if self.context.queued_count[i].load(Ordering::SeqCst) == 0 {
target_thread_id = i;
break;
}
if target_thread_id > max_thread_count
|| self.context.queued_count[i].load(Ordering::SeqCst) < min_jobs_counted {
target_thread_id = i;
min_jobs_counted = self.context.queued_count[i].load(Ordering::Relaxed);
}
}
if target_thread_id < max_thread_count &&
self.context.thread_closing[target_thread_id].load(Ordering::SeqCst) == 0 {
self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
self.context.queued_count[target_thread_id].fetch_add(1, Ordering::SeqCst);
self.context.senders[target_thread_id]
.send(Box::new(job))
.expect("ThreadPool::execute unable to send job into queue.");
break;
}
let ten_millis = time::Duration::from_millis(10);
thread::sleep(ten_millis);
}
}
pub fn queued_count(&self) -> usize {
self.shared_data.queued_count.load(Ordering::Relaxed)
}
pub fn active_count(&self) -> usize {
self.shared_data.active_count.load(Ordering::SeqCst)
}
pub fn max_count(&self) -> usize {
self.shared_data.max_thread_count.load(Ordering::Relaxed)
}
pub fn num_workers(&self) -> usize {
self.shared_data.num_workers.load(Ordering::Relaxed)
}
pub fn panic_count(&self) -> usize {
self.shared_data.panic_count.load(Ordering::Relaxed)
}
pub fn spawn_extra_one_worker(&self) {
if self.shared_data.num_workers.load(Ordering::Acquire)
>= self.shared_data.max_thread_count.load(Ordering::Relaxed) {
warn!("Max thread number exceeded.");
()
}
self.shared_data.num_workers.fetch_add(1, Ordering::SeqCst);
let mut spawn_completed = false;
while !spawn_completed {
let max_thread_count = self.shared_data.max_thread_count.load(Ordering::Relaxed);
for i in 0..max_thread_count {
if self.context.thread_closing[i].compare_exchange(3,
1,
Ordering::SeqCst,
Ordering::Relaxed) == Ok(3) {
spawn_in_pool(self.shared_data.clone(),
self.context.receivers[i].clone(),
self.context.queued_count[i].clone(),
self.context.thread_closing[i].clone());
spawn_completed = true;
break;
}
}
if self.shared_data.num_workers.load(Ordering::SeqCst)
>= self.shared_data.max_thread_count.load(Ordering::Relaxed) {
warn!("Max thread number exceeded.");
break;
}
}
}
pub fn shutdown_one_worker(&self) {
if self.shared_data.num_workers.load(Ordering::SeqCst) <= 0 {
warn!("No thread to shutdown");
()
}
self.shared_data.num_workers.fetch_sub(1, Ordering::SeqCst);
loop {
let max_thread_count = self.shared_data.max_thread_count.load(Ordering::Relaxed);
let mut target_thread_id = max_thread_count + 1;
let mut min_num_of_jobs = 0;
for i in 0..max_thread_count {
if self.context.thread_closing[i].load(Ordering::Relaxed) > 0 {
continue;
}
if target_thread_id > max_thread_count ||
min_num_of_jobs > self.context.queued_count[i].load(Ordering::Relaxed) {
target_thread_id = i;
min_num_of_jobs = self.context.queued_count[i].load(Ordering::Acquire);
}
}
if target_thread_id < max_thread_count &&
self.context.thread_closing[target_thread_id].compare_exchange(0,
2,
Ordering::SeqCst,
Ordering::Relaxed) == Ok(0) {
trace!("Closing thread id: {}.", target_thread_id);
break;
}
if self.shared_data.num_workers.load(Ordering::SeqCst) <= 0 {
warn!("No thread to shutdown");
break;
}
}
}
pub fn join(&self) {
if self.shared_data.has_work() == false {
return ();
}
let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
let mut lock = self.shared_data.empty_trigger.lock().unwrap();
while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
&& self.shared_data.has_work()
{
lock = self.shared_data.empty_condvar.wait(lock).unwrap();
}
let _ = self.shared_data.join_generation.compare_exchange(
generation,
generation.wrapping_add(1),
Ordering::SeqCst,
Ordering::SeqCst,
);
}
}
impl Clone for ThreadPool {
fn clone(&self) -> ThreadPool {
ThreadPool {
shared_data: self.shared_data.clone(),
context: self.context.clone(),
}
}
}
impl fmt::Debug for ThreadPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ThreadPool")
.field("name", &self.shared_data.name)
.field("queued_count", &self.queued_count())
.field("active_count", &self.active_count())
.field("max_count", &self.max_count())
.field("num_workers", &self.num_workers())
.finish()
}
}
impl PartialEq for ThreadPool {
fn eq(&self, other: &ThreadPool) -> bool {
Arc::ptr_eq(&self.shared_data, &other.shared_data)
}
}
impl Eq for ThreadPool {}
fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>,
receiver: Arc<Receiver<Thunk<'static>>>,
num_jobs: Arc<AtomicUsize>,
thread_closing: Arc<AtomicI8>) {
let mut builder = thread::Builder::new();
if let Some(ref name) = shared_data.name {
builder = builder.name(name.clone());
}
if let Some(ref stack_size) = shared_data.stack_size {
builder = builder.stack_size(stack_size.to_owned());
}
builder
.spawn(move || {
let sentinel = Sentinel::new(&shared_data, &receiver, &num_jobs, &thread_closing);
if thread_closing.compare_exchange(1, 0, Ordering::SeqCst, Ordering::Relaxed) == Ok(1) {
loop {
if thread_closing.load(Ordering::SeqCst) == 2
&& num_jobs.load(Ordering::SeqCst) == 0 {
break;
}
let message = {
receiver.recv()
};
let job = match message {
Ok(job) => job,
Err(..) => break,
};
shared_data.active_count.fetch_add(1, Ordering::SeqCst);
job.call_box();
num_jobs.fetch_sub(1, Ordering::SeqCst);
shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
if num_jobs.load(Ordering::SeqCst) == 0 {
shared_data.no_work_notify_all();
}
}
if thread_closing.compare_exchange(0,
3,
Ordering::SeqCst,
Ordering::Relaxed) == Ok(0) {
shared_data.num_workers.fetch_sub(1, Ordering::SeqCst);
} else {
let _ = thread_closing.compare_exchange(2,
3,
Ordering::SeqCst,
Ordering::Relaxed);
}
}
sentinel.cancel();
})
.unwrap();
}