use std::thread::{self, JoinHandle};
use crossbeam::channel::{self as mpmc, Receiver, Sender};
use once_cell::sync::Lazy;
use crate::{
error::{Error, InvalidArgumentError},
sink::{OverflowPolicy, Task},
sync::*,
Result,
};
pub struct ThreadPool {
threads: Vec<Option<JoinHandle<()>>>,
sender: Option<Sender<Task>>,
}
pub struct ThreadPoolBuilder {
capacity: usize,
threads: usize,
}
struct Worker {
receiver: Receiver<Task>,
}
impl ThreadPool {
#[must_use]
pub fn builder() -> ThreadPoolBuilder {
ThreadPoolBuilder {
capacity: 8192,
threads: 1,
}
}
pub fn new() -> Result<Self> {
Self::builder().build()
}
pub(super) fn assign_task(&self, task: Task, overflow_policy: OverflowPolicy) -> Result<()> {
let sender = self.sender.as_ref().unwrap();
match overflow_policy {
OverflowPolicy::Block => sender.send(task).map_err(Error::from_crossbeam_send),
OverflowPolicy::DropIncoming => sender
.try_send(task)
.map_err(Error::from_crossbeam_try_send),
}
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
self.sender.take();
for thread in &mut self.threads {
thread
.take()
.unwrap()
.join()
.expect("failed to join a thread from pool");
}
}
}
impl ThreadPoolBuilder {
pub fn capacity(&mut self, capacity: usize) -> &mut Self {
self.capacity = capacity;
self
}
#[allow(dead_code)]
fn threads(&mut self, threads: usize) -> &mut Self {
self.threads = threads;
self
}
pub fn build(&self) -> Result<ThreadPool> {
if self.capacity < 1 {
return Err(Error::InvalidArgument(
InvalidArgumentError::ThreadPoolCapacity("cannot be 0".to_string()),
));
}
if self.threads < 1 {
panic!("threads of ThreadPool cannot be 0");
}
let (sender, receiver) = mpmc::bounded(self.capacity);
let mut threads = Vec::new();
threads.resize_with(self.threads, || {
let receiver = receiver.clone();
Some(thread::spawn(move || Worker { receiver }.run()))
});
Ok(ThreadPool {
threads,
sender: Some(sender),
})
}
}
impl Worker {
fn run(&self) {
while let Ok(task) = self.receiver.recv() {
task.exec();
}
}
}
#[must_use]
pub(crate) fn default_thread_pool() -> Arc<ThreadPool> {
static POOL_WEAK: Lazy<Mutex<Weak<ThreadPool>>> = Lazy::new(|| Mutex::new(Weak::new()));
let mut pool_weak = POOL_WEAK.lock_expect();
match pool_weak.upgrade() {
Some(pool) => pool,
None => {
let pool = Arc::new(ThreadPool::builder().build().unwrap());
*pool_weak = Arc::downgrade(&pool);
pool
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn panic_capacity_0() {
assert!(matches!(
ThreadPool::builder().capacity(0).build(),
Err(Error::InvalidArgument(
InvalidArgumentError::ThreadPoolCapacity(_)
))
));
}
#[test]
#[should_panic]
fn panic_thread_0() {
let _ = ThreadPool::builder().threads(0).build();
}
}