#![allow(dead_code)]
use futures::{ channel::oneshot };
use futures_lite::{ future::block_on };
use crossbeam_channel;
use crate::prelude::全局_CPU数量;
use futures::{
future::BoxFuture,
task::{waker_ref, ArcWake},
};
use std::future::Future;
use std::option::Option;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Condvar, Mutex,
};
use std::task::Context;
use std::thread;
use std::time::Duration;
const BITS: usize = std::mem::size_of::<usize>() * 8;
pub const MAX_SIZE: usize = (1 << (BITS / 2)) - 1;
type Job = Box<dyn FnOnce() + Send + 'static>;
pub trait Task<R: Send>: Send {
fn run(self) -> R;
fn into_fn(self) -> Option<Box<dyn FnOnce() -> R + Send + 'static>>;
fn is_fn(&self) -> bool;
}
impl<R, F> Task<R> for F
where
R: Send,
F: FnOnce() -> R + Send + 'static,
{
fn run(self) -> R {
self()
}
fn into_fn(self) -> Option<Box<dyn FnOnce() -> R + Send + 'static>> {
Some(Box::new(self))
}
fn is_fn(&self) -> bool {
true
}
}
pub struct JoinHandle<T: Send> {
pub receiver: oneshot::Receiver<T>,
}
impl<T: Send> JoinHandle<T> {
pub fn try_await_complete(self) -> Result<T, oneshot::Canceled> {
block_on(self.receiver)
}
pub fn await_complete(self) -> T {
self.try_await_complete()
.expect("could not receive message because channel was cancelled")
}
}
struct AsyncTask {
future: Mutex<Option<BoxFuture<'static, ()>>>,
pool: ThreadPool,
}
impl ArcWake for AsyncTask {
fn wake_by_ref(arc_self: &Arc<Self>) {
let cloned_task = arc_self.clone();
arc_self
.pool
.try_execute(cloned_task)
.expect("failed to wake future because message could not be sent to pool");
}
}
impl Task<()> for Arc<AsyncTask> {
fn run(self) {
let mut future_slot = self.future.lock().expect("failed to acquire mutex");
if let Some(mut future) = future_slot.take() {
let waker = waker_ref(&self);
let context = &mut Context::from_waker(&*waker);
if future.as_mut().poll(context).is_pending() {
*future_slot = Some(future);
}
}
}
fn into_fn(self) -> Option<Box<dyn FnOnce() + Send + 'static>> {
None
}
fn is_fn(&self) -> bool {
false
}
}
trait ThreadSafe: Send {}
impl<R: Send> ThreadSafe for dyn Task<R> {}
impl<R: Send> ThreadSafe for JoinHandle<R> {}
impl ThreadSafe for ThreadPool {}
#[derive(Clone)]
pub struct ThreadPool {
core_size: usize,
max_size: usize,
keep_alive: Duration,
channel_data: Arc<ChannelData>,
worker_data: Arc<WorkerData>,
}
impl ThreadPool {
pub fn new(core_size: usize, max_size: usize, keep_alive: Duration) -> Self {
static POOL_COUNTER: AtomicUsize = AtomicUsize::new(1);
let name = format!(
"rusty_pool_{}",
POOL_COUNTER.fetch_add(1, Ordering::Relaxed)
);
ThreadPool::new_named(name, core_size, max_size, keep_alive)
}
pub fn new_named(
name: String,
core_size: usize,
max_size: usize,
keep_alive: Duration,
) -> Self {
let (sender, receiver) = crossbeam_channel::unbounded();
if max_size == 0 || max_size < core_size {
panic!("max_size must be greater than 0 and greater or equal to the core pool size");
} else if max_size > MAX_SIZE {
panic!(
"max_size may not exceed {}, the maximum value that can be stored within half the bits of usize ({} -> {} bits in this case)",
MAX_SIZE,
BITS,
BITS / 2
);
}
let worker_data = WorkerData {
pool_name: name,
worker_count_data: WorkerCountData::default(),
worker_number: AtomicUsize::new(1),
join_notify_condvar: Condvar::new(),
join_notify_mutex: Mutex::new(()),
join_generation: AtomicUsize::new(0),
};
let channel_data = ChannelData { sender, receiver };
Self {
core_size,
max_size,
keep_alive,
channel_data: Arc::new(channel_data),
worker_data: Arc::new(worker_data),
}
}
pub fn get_current_worker_count(&self) -> usize {
self.worker_data.worker_count_data.get_total_worker_count()
}
pub fn get_idle_worker_count(&self) -> usize {
self.worker_data.worker_count_data.get_idle_worker_count()
}
pub fn execute<T: Task<()> + 'static>(&self, task: T) {
if self.try_execute(task).is_err() {
panic!("the channel of the thread pool has been closed");
}
}
pub fn try_execute<T: Task<()> + 'static>(
&self,
task: T,
) -> Result<(), crossbeam_channel::SendError<Job>> {
if task.is_fn() {
self.try_execute_task(
task.into_fn()
.expect("Task::into_fn returned None despite is_fn returning true"),
)
} else {
self.try_execute_task(Box::new(move || {
task.run();
}))
}
}
pub fn evaluate<R: Send + 'static, T: Task<R> + 'static>(&self, task: T) -> JoinHandle<R> {
match self.try_evaluate(task) {
Ok(handle) => handle,
Err(e) => panic!("the channel of the thread pool has been closed: {:?}", e),
}
}
pub fn try_evaluate<R: Send + 'static, T: Task<R> + 'static>(
&self,
task: T,
) -> Result<JoinHandle<R>, crossbeam_channel::SendError<Job>> {
let (sender, receiver) = oneshot::channel::<R>();
let join_handle = JoinHandle { receiver };
let job = || {
let result = task.run();
let _ignored_result = sender.send(result);
};
let execute_attempt = self.try_execute_task(Box::new(job));
execute_attempt.map(|_| join_handle)
}
pub fn complete<R: Send + 'static>(
&self,
future: impl Future<Output = R> + 'static + Send,
) -> JoinHandle<R> {
self.evaluate(|| block_on(future))
}
pub fn try_complete<R: Send + 'static>(
&self,
future: impl Future<Output = R> + 'static + Send,
) -> Result<JoinHandle<R>, crossbeam_channel::SendError<Job>> {
self.try_evaluate(|| block_on(future))
}
pub fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) {
let future_task = Arc::new(AsyncTask {
future: Mutex::new(Some(Box::pin(future))),
pool: self.clone(),
});
self.execute(future_task)
}
pub fn try_spawn(
&self,
future: impl Future<Output = ()> + 'static + Send,
) -> Result<(), crossbeam_channel::SendError<Job>> {
let future_task = Arc::new(AsyncTask {
future: Mutex::new(Some(Box::pin(future))),
pool: self.clone(),
});
self.try_execute(future_task)
}
pub fn spawn_await<R: Send + 'static>(
&self,
future: impl Future<Output = R> + 'static + Send,
) -> JoinHandle<R> {
match self.try_spawn_await(future) {
Ok(handle) => handle,
Err(e) => panic!("the channel of the thread pool has been closed: {:?}", e),
}
}
pub fn try_spawn_await<R: Send + 'static>(
&self,
future: impl Future<Output = R> + 'static + Send,
) -> Result<JoinHandle<R>, crossbeam_channel::SendError<Job>> {
let (sender, receiver) = oneshot::channel::<R>();
let join_handle = JoinHandle { receiver };
self.try_spawn(async {
let result = future.await;
let _ignored_result = sender.send(result);
})
.map(|_| join_handle)
}
#[inline]
fn try_execute_task(&self, task: Job) -> Result<(), crossbeam_channel::SendError<Job>> {
let worker_count_data = &self.worker_data.worker_count_data;
let mut worker_count_val = worker_count_data.worker_count.load(Ordering::Relaxed);
let (mut curr_worker_count, idle_worker_count) = WorkerCountData::split(worker_count_val);
let mut curr_idle_count = idle_worker_count;
if curr_worker_count < self.core_size {
let witnessed =
worker_count_data.try_increment_worker_total(worker_count_val, self.core_size);
if witnessed == worker_count_val
|| WorkerCountData::get_total_count(witnessed) < self.core_size
{
let worker = Worker::new(
self.channel_data.receiver.clone(),
Arc::clone(&self.worker_data),
None,
);
worker.start(Some(task));
return Ok(());
}
curr_worker_count = WorkerCountData::get_total_count(witnessed);
curr_idle_count = WorkerCountData::get_idle_count(witnessed);
worker_count_val = witnessed;
}
if curr_worker_count < self.max_size && (idle_worker_count == 0 || curr_idle_count == 0) {
let witnessed =
worker_count_data.try_increment_worker_total(worker_count_val, self.max_size);
if witnessed == worker_count_val
|| WorkerCountData::get_total_count(witnessed) < self.max_size
{
let worker = Worker::new(
self.channel_data.receiver.clone(),
Arc::clone(&self.worker_data),
Some(self.keep_alive),
);
worker.start(Some(task));
return Ok(());
}
}
self.send_task_to_channel(task)
}
pub fn join(&self) {
self.inner_join(None);
}
pub fn join_timeout(&self, time_out: Duration) {
self.inner_join(Some(time_out));
}
pub fn shutdown(self) {
drop(self);
}
pub fn shutdown_join(self) {
self.inner_shutdown_join(None);
}
pub fn shutdown_join_timeout(self, timeout: Duration) {
self.inner_shutdown_join(Some(timeout));
}
pub fn get_name(&self) -> &str {
&self.worker_data.pool_name
}
pub fn start_core_threads(&self) {
let worker_count_data = &self.worker_data.worker_count_data;
let core_size = self.core_size;
let mut curr_worker_count = worker_count_data.worker_count.load(Ordering::Relaxed);
if WorkerCountData::get_total_count(curr_worker_count) >= core_size {
return;
}
loop {
let witnessed = worker_count_data.try_increment_worker_count(
curr_worker_count,
INCREMENT_TOTAL | INCREMENT_IDLE,
core_size,
);
if WorkerCountData::get_total_count(witnessed) >= core_size {
return;
}
let worker = Worker::new(
self.channel_data.receiver.clone(),
Arc::clone(&self.worker_data),
None,
);
worker.start(None);
curr_worker_count = witnessed;
}
}
#[inline]
fn send_task_to_channel(&self, task: Job) -> Result<(), crossbeam_channel::SendError<Job>> {
self.channel_data.sender.send(task)?;
Ok(())
}
#[inline]
fn inner_join(&self, time_out: Option<Duration>) {
ThreadPool::_do_join(&self.worker_data, &self.channel_data.receiver, time_out);
}
#[inline]
fn inner_shutdown_join(self, timeout: Option<Duration>) {
let current_worker_data = self.worker_data.clone();
let receiver = self.channel_data.receiver.clone();
drop(self);
ThreadPool::_do_join(¤t_worker_data, &receiver, timeout);
}
#[inline]
fn _do_join(
current_worker_data: &Arc<WorkerData>,
receiver: &crossbeam_channel::Receiver<Job>,
time_out: Option<Duration>,
) {
if ThreadPool::is_idle(current_worker_data, receiver) {
return;
}
let join_generation = current_worker_data.join_generation.load(Ordering::SeqCst);
let guard = current_worker_data
.join_notify_mutex
.lock()
.expect("could not get join notify mutex lock");
match time_out {
Some(time_out) => {
let _ret_guard = current_worker_data
.join_notify_condvar
.wait_timeout_while(guard, time_out, |_| {
join_generation
== current_worker_data.join_generation.load(Ordering::Relaxed)
&& !ThreadPool::is_idle(current_worker_data, receiver)
})
.expect("could not wait for join condvar");
}
None => {
let _ret_guard = current_worker_data
.join_notify_condvar
.wait_while(guard, |_| {
join_generation
== current_worker_data.join_generation.load(Ordering::Relaxed)
&& !ThreadPool::is_idle(current_worker_data, receiver)
})
.expect("could not wait for join condvar");
}
};
let _ = current_worker_data.join_generation.compare_exchange(
join_generation,
join_generation.wrapping_add(1),
Ordering::SeqCst,
Ordering::SeqCst,
);
}
#[inline]
fn is_idle(
current_worker_data: &Arc<WorkerData>,
receiver: &crossbeam_channel::Receiver<Job>,
) -> bool {
let (current_worker_count, current_idle_count) =
current_worker_data.worker_count_data.get_both();
current_idle_count == current_worker_count && receiver.is_empty()
}
}
impl Default for ThreadPool {
fn default() -> Self {
let num_cpus = *全局_CPU数量;
ThreadPool::new(
num_cpus,
std::cmp::max(num_cpus, num_cpus * 2),
Duration::from_secs(60),
)
}
}
#[derive(Default)]
pub struct Builder {
name: Option<String>,
core_size: Option<usize>,
max_size: Option<usize>,
keep_alive: Option<Duration>,
}
impl Builder {
pub fn new() -> Builder {
Builder::default()
}
pub fn name(mut self, name: String) -> Builder {
self.name = Some(name);
self
}
pub fn core_size(mut self, size: usize) -> Builder {
self.core_size = Some(size);
self
}
pub fn max_size(mut self, size: usize) -> Builder {
self.max_size = Some(size);
self
}
pub fn keep_alive(mut self, keep_alive: Duration) -> Builder {
self.keep_alive = Some(keep_alive);
self
}
pub fn build(self) -> ThreadPool {
use std::cmp::{max, min};
let core_size = self.core_size.unwrap_or_else(|| {
let num_cpus = num_cpus::get();
if let Some(max_size) = self.max_size {
min(MAX_SIZE, min(num_cpus, max_size))
} else {
min(MAX_SIZE, num_cpus)
}
});
let max_size = self
.max_size
.unwrap_or_else(|| min(MAX_SIZE, max(core_size, core_size * 2)));
let keep_alive = self.keep_alive.unwrap_or_else(|| Duration::from_secs(60));
if let Some(name) = self.name {
ThreadPool::new_named(name, core_size, max_size, keep_alive)
} else {
ThreadPool::new(core_size, max_size, keep_alive)
}
}
}
#[derive(Clone)]
struct Worker {
receiver: crossbeam_channel::Receiver<Job>,
worker_data: Arc<WorkerData>,
keep_alive: Option<Duration>,
}
impl Worker {
fn new(
receiver: crossbeam_channel::Receiver<Job>,
worker_data: Arc<WorkerData>,
keep_alive: Option<Duration>,
) -> Self {
Worker {
receiver,
worker_data,
keep_alive,
}
}
fn start(self, task: Option<Job>) {
let worker_name = format!(
"{}_thread_{}",
self.worker_data.pool_name,
self.worker_data
.worker_number
.fetch_add(1, Ordering::Relaxed)
);
thread::Builder::new()
.name(worker_name)
.spawn(move || {
let mut sentinel = Sentinel::new(&self);
if let Some(task) = task {
self.exec_task_and_notify(&mut sentinel, task);
}
loop {
let received_task: Result<Job, _> = match self.keep_alive {
Some(keep_alive) => self.receiver.recv_timeout(keep_alive).map_err(|_| ()),
None => self.receiver.recv().map_err(|_| ()),
};
match received_task {
Ok(task) => {
self.worker_data.worker_count_data.decrement_worker_idle();
self.exec_task_and_notify(&mut sentinel, task);
}
Err(_) => {
break;
}
}
}
self.worker_data.worker_count_data.decrement_both();
})
.expect("could not spawn thread");
}
#[inline]
fn exec_task_and_notify(&self, sentinel: &mut Sentinel, task: Job) {
sentinel.is_working = true;
task();
sentinel.is_working = false;
self.mark_idle_and_notify_joiners_if_no_work();
}
#[inline]
fn mark_idle_and_notify_joiners_if_no_work(&self) {
let (old_total_count, old_idle_count) = self
.worker_data
.worker_count_data
.increment_worker_idle_ret_both();
if old_total_count == old_idle_count + 1 && self.receiver.is_empty() {
let _lock = self
.worker_data
.join_notify_mutex
.lock()
.expect("could not get join notify mutex lock");
self.worker_data.join_notify_condvar.notify_all();
}
}
}
struct Sentinel<'s> {
is_working: bool,
worker_ref: &'s Worker,
}
impl Sentinel<'_> {
fn new(worker_ref: &Worker) -> Sentinel<'_> {
Sentinel {
is_working: false,
worker_ref,
}
}
}
impl Drop for Sentinel<'_> {
fn drop(&mut self) {
if thread::panicking() {
if self.is_working {
self.worker_ref.mark_idle_and_notify_joiners_if_no_work();
}
let worker = self.worker_ref.clone();
worker.start(None);
}
}
}
const WORKER_IDLE_MASK: usize = MAX_SIZE;
const INCREMENT_TOTAL: usize = 1 << (BITS / 2);
const INCREMENT_IDLE: usize = 1;
#[derive(Default)]
struct WorkerCountData {
worker_count: AtomicUsize,
}
impl WorkerCountData {
fn get_total_worker_count(&self) -> usize {
let curr_val = self.worker_count.load(Ordering::Relaxed);
WorkerCountData::get_total_count(curr_val)
}
fn get_idle_worker_count(&self) -> usize {
let curr_val = self.worker_count.load(Ordering::Relaxed);
WorkerCountData::get_idle_count(curr_val)
}
fn get_both(&self) -> (usize, usize) {
let curr_val = self.worker_count.load(Ordering::Relaxed);
WorkerCountData::split(curr_val)
}
#[allow(dead_code)]
fn increment_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_add(INCREMENT_TOTAL | INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
fn decrement_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_TOTAL | INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
fn try_increment_worker_total(&self, expected: usize, max_total: usize) -> usize {
self.try_increment_worker_count(expected, INCREMENT_TOTAL, max_total)
}
fn try_increment_worker_count(
&self,
mut expected: usize,
increment: usize,
max_total: usize,
) -> usize {
loop {
match self.worker_count.compare_exchange_weak(
expected,
expected + increment,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(witnessed) => return witnessed,
Err(witnessed) if WorkerCountData::get_total_count(witnessed) >= max_total => {
return witnessed
}
Err(witnessed) => expected = witnessed,
}
}
}
#[allow(dead_code)]
fn increment_worker_total(&self) -> usize {
let old_val = self
.worker_count
.fetch_add(INCREMENT_TOTAL, Ordering::Relaxed);
WorkerCountData::get_total_count(old_val)
}
#[allow(dead_code)]
fn increment_worker_total_ret_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_add(INCREMENT_TOTAL, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
#[allow(dead_code)]
fn decrement_worker_total(&self) -> usize {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_TOTAL, Ordering::Relaxed);
WorkerCountData::get_total_count(old_val)
}
#[allow(dead_code)]
fn decrement_worker_total_ret_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_TOTAL, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
#[allow(dead_code)]
fn increment_worker_idle(&self) -> usize {
let old_val = self
.worker_count
.fetch_add(INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::get_idle_count(old_val)
}
fn increment_worker_idle_ret_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_add(INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
fn decrement_worker_idle(&self) -> usize {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::get_idle_count(old_val)
}
#[allow(dead_code)]
fn decrement_worker_idle_ret_both(&self) -> (usize, usize) {
let old_val = self
.worker_count
.fetch_sub(INCREMENT_IDLE, Ordering::Relaxed);
WorkerCountData::split(old_val)
}
#[inline]
fn split(val: usize) -> (usize, usize) {
let total_count = val >> (BITS / 2);
let idle_count = val & WORKER_IDLE_MASK;
(total_count, idle_count)
}
#[inline]
fn get_total_count(val: usize) -> usize {
val >> (BITS / 2)
}
#[inline]
fn get_idle_count(val: usize) -> usize {
val & WORKER_IDLE_MASK
}
}
struct WorkerData {
pool_name: String,
worker_count_data: WorkerCountData,
worker_number: AtomicUsize,
join_notify_condvar: Condvar,
join_notify_mutex: Mutex<()>,
join_generation: AtomicUsize,
}
struct ChannelData {
sender: crossbeam_channel::Sender<Job>,
receiver: crossbeam_channel::Receiver<Job>,
}