use std::{
fmt,
future::Future,
ops::{Range, RangeInclusive, RangeTo, RangeToInclusive},
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
Condvar,
Mutex,
},
thread,
time::{Duration, Instant},
};
use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
use once_cell::sync::Lazy;
use crate::{
error::PoolFullError,
task::{Coroutine, Task},
worker::{Listener, Worker},
};
#[cfg(threadfin_has_atomic64)]
type AtomicCounter = std::sync::atomic::AtomicU64;
#[cfg(not(threadfin_has_atomic64))]
type AtomicCounter = std::sync::atomic::AtomicU32;
pub trait SizeConstraint {
fn min(&self) -> usize;
fn max(&self) -> usize;
}
impl SizeConstraint for usize {
fn min(&self) -> usize {
*self
}
fn max(&self) -> usize {
*self
}
}
impl SizeConstraint for Range<usize> {
fn min(&self) -> usize {
self.start
}
fn max(&self) -> usize {
self.end
}
}
impl SizeConstraint for RangeInclusive<usize> {
fn min(&self) -> usize {
*self.start()
}
fn max(&self) -> usize {
*self.end()
}
}
impl SizeConstraint for RangeTo<usize> {
fn min(&self) -> usize {
0
}
fn max(&self) -> usize {
self.end
}
}
impl SizeConstraint for RangeToInclusive<usize> {
fn min(&self) -> usize {
0
}
fn max(&self) -> usize {
self.end
}
}
pub struct PerCore<T>(pub T);
static CORE_COUNT: Lazy<usize> = Lazy::new(|| num_cpus::get().max(1));
impl<T> From<T> for PerCore<T> {
fn from(size: T) -> Self {
Self(size)
}
}
impl<T: SizeConstraint> SizeConstraint for PerCore<T> {
fn min(&self) -> usize {
*CORE_COUNT * self.0.min()
}
fn max(&self) -> usize {
*CORE_COUNT * self.0.max()
}
}
#[derive(Debug)]
pub struct Builder {
name: Option<String>,
size: Option<(usize, usize)>,
stack_size: Option<usize>,
queue_limit: Option<usize>,
worker_concurrency_limit: usize,
keep_alive: Duration,
}
impl Default for Builder {
fn default() -> Self {
Self {
name: None,
size: None,
stack_size: None,
queue_limit: None,
worker_concurrency_limit: 16,
keep_alive: Duration::from_secs(60),
}
}
}
impl Builder {
pub fn name<T: Into<String>>(mut self, name: T) -> Self {
let name = name.into();
if name.as_bytes().contains(&0) {
panic!("thread pool name must not contain null bytes");
}
self.name = Some(name);
self
}
pub fn size<S: SizeConstraint>(mut self, size: S) -> Self {
let (min, max) = (size.min(), size.max());
if min > max {
panic!("thread pool minimum size cannot be larger than maximum size");
}
if max == 0 {
panic!("thread pool maximum size must be non-zero");
}
self.size = Some((min, max));
self
}
pub fn stack_size(mut self, size: usize) -> Self {
self.stack_size = Some(size);
self
}
pub fn queue_limit(mut self, limit: usize) -> Self {
self.queue_limit = Some(limit);
self
}
pub fn keep_alive(mut self, duration: Duration) -> Self {
self.keep_alive = duration;
self
}
pub fn worker_concurrency_limit(mut self, limit: usize) -> Self {
self.worker_concurrency_limit = limit;
self
}
pub fn build(self) -> ThreadPool {
let size = self.size.unwrap_or_else(|| {
let size = PerCore(1..2);
(size.min(), size.max())
});
let shared = Shared {
min_threads: size.0,
max_threads: size.1,
thread_count: Default::default(),
running_tasks_count: Default::default(),
completed_tasks_count: Default::default(),
panicked_tasks_count: Default::default(),
keep_alive: self.keep_alive,
shutdown_cvar: Condvar::new(),
};
let pool = ThreadPool {
thread_name: self.name,
stack_size: self.stack_size,
concurrency_limit: self.worker_concurrency_limit,
queue: self.queue_limit.map(bounded).unwrap_or_else(unbounded),
immediate_queue: bounded(0),
shared: Arc::new(shared),
};
for _ in 0..size.0 {
let result = pool.spawn_thread(None);
assert!(result.is_ok());
}
pool
}
}
pub struct ThreadPool {
thread_name: Option<String>,
stack_size: Option<usize>,
concurrency_limit: usize,
queue: (Sender<Coroutine>, Receiver<Coroutine>),
immediate_queue: (Sender<Coroutine>, Receiver<Coroutine>),
shared: Arc<Shared>,
}
impl Default for ThreadPool {
fn default() -> Self {
Self::new()
}
}
impl ThreadPool {
#[inline]
pub fn new() -> Self {
Self::builder().build()
}
#[inline]
pub fn builder() -> Builder {
Builder::default()
}
pub fn threads(&self) -> usize {
*self.shared.thread_count.lock().unwrap()
}
#[inline]
pub fn queued_tasks(&self) -> usize {
self.queue.0.len()
}
#[inline]
pub fn running_tasks(&self) -> usize {
self.shared.running_tasks_count.load(Ordering::Relaxed)
}
#[inline]
#[allow(clippy::useless_conversion)]
pub fn completed_tasks(&self) -> u64 {
self.shared.completed_tasks_count.load(Ordering::Relaxed).into()
}
#[inline]
#[allow(clippy::useless_conversion)]
pub fn panicked_tasks(&self) -> u64 {
self.shared.panicked_tasks_count.load(Ordering::SeqCst).into()
}
pub fn execute<T, F>(&self, closure: F) -> Task<T>
where
T: Send + 'static,
F: FnOnce() -> T + Send + 'static,
{
let (task, coroutine) = Task::from_closure(closure);
self.execute_coroutine(coroutine);
task
}
pub fn execute_future<T, F>(&self, future: F) -> Task<T>
where
T: Send + 'static,
F: Future<Output = T> + Send + 'static,
{
let (task, coroutine) = Task::from_future(future);
self.execute_coroutine(coroutine);
task
}
pub fn try_execute<T, F>(&self, closure: F) -> Result<Task<T>, PoolFullError<F>>
where
T: Send + 'static,
F: FnOnce() -> T + Send + 'static,
{
let (task, coroutine) = Task::from_closure(closure);
self.try_execute_coroutine(coroutine)
.map(|_| task)
.map_err(|coroutine| PoolFullError(coroutine.into_inner_closure()))
}
pub fn try_execute_future<T, F>(&self, future: F) -> Result<Task<T>, PoolFullError<F>>
where
T: Send + 'static,
F: Future<Output = T> + Send + 'static,
{
let (task, coroutine) = Task::from_future(future);
self.try_execute_coroutine(coroutine)
.map(|_| task)
.map_err(|coroutine| PoolFullError(coroutine.into_inner_future()))
}
fn execute_coroutine(&self, coroutine: Coroutine) {
if let Err(coroutine) = self.try_execute_coroutine(coroutine) {
self.queue.0.send(coroutine).unwrap();
}
}
fn try_execute_coroutine(&self, coroutine: Coroutine) -> Result<(), Coroutine> {
if let Err(e) = self.immediate_queue.0.try_send(coroutine) {
debug_assert!(!e.is_disconnected());
if let Err(e) = self.spawn_thread(Some(e.into_inner())) {
if let Err(e) = self.queue.0.try_send(e.unwrap()) {
return Err(e.into_inner());
}
}
}
Ok(())
}
pub fn join(self) {
self.join_internal(None);
}
pub fn join_timeout(self, timeout: Duration) -> bool {
self.join_deadline(Instant::now() + timeout)
}
pub fn join_deadline(self, deadline: Instant) -> bool {
self.join_internal(Some(deadline))
}
fn join_internal(self, deadline: Option<Instant>) -> bool {
drop(self.queue.0);
let mut thread_count = self.shared.thread_count.lock().unwrap();
while *thread_count > 0 {
if let Some(deadline) = deadline {
if let Some(timeout) = deadline.checked_duration_since(Instant::now()) {
thread_count = self
.shared
.shutdown_cvar
.wait_timeout(thread_count, timeout)
.unwrap()
.0;
} else {
return false;
}
}
else {
thread_count = self.shared.shutdown_cvar.wait(thread_count).unwrap();
}
}
true
}
fn spawn_thread(&self, initial_task: Option<Coroutine>) -> Result<(), Option<Coroutine>> {
struct WorkerListener {
shared: Arc<Shared>,
}
impl Listener for WorkerListener {
fn on_task_started(&mut self) {
self.shared
.running_tasks_count
.fetch_add(1, Ordering::Relaxed);
}
fn on_task_completed(&mut self, panicked: bool) {
self.shared
.running_tasks_count
.fetch_sub(1, Ordering::Relaxed);
self.shared
.completed_tasks_count
.fetch_add(1, Ordering::Relaxed);
if panicked {
self.shared
.panicked_tasks_count
.fetch_add(1, Ordering::SeqCst);
}
}
fn on_idle(&mut self) -> bool {
*self.shared.thread_count.lock().unwrap() > self.shared.min_threads
}
}
impl Drop for WorkerListener {
fn drop(&mut self) {
if let Ok(mut count) = self.shared.thread_count.lock() {
*count = count.saturating_sub(1);
self.shared.shutdown_cvar.notify_all();
}
}
}
let mut thread_count = self.shared.thread_count.lock().unwrap();
if *thread_count >= self.shared.max_threads {
return Err(initial_task);
}
let mut builder = thread::Builder::new();
if let Some(name) = self.thread_name.as_ref() {
builder = builder.name(name.clone());
}
if let Some(size) = self.stack_size {
builder = builder.stack_size(size);
}
*thread_count += 1;
let worker = Worker::new(
initial_task,
self.queue.1.clone(),
self.immediate_queue.1.clone(),
self.concurrency_limit,
self.shared.keep_alive,
WorkerListener {
shared: self.shared.clone(),
},
);
drop(thread_count);
builder.spawn(move || worker.run()).unwrap();
Ok(())
}
}
impl fmt::Debug for ThreadPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ThreadPool")
.field("queued_tasks", &self.queued_tasks())
.field("running_tasks", &self.running_tasks())
.field("completed_tasks", &self.completed_tasks())
.finish()
}
}
struct Shared {
min_threads: usize,
max_threads: usize,
thread_count: Mutex<usize>,
running_tasks_count: AtomicUsize,
completed_tasks_count: AtomicCounter,
panicked_tasks_count: AtomicCounter,
keep_alive: Duration,
shutdown_cvar: Condvar,
}