use crate::co_pool::creator::CoroutineCreator;
use crate::co_pool::task::Task;
use crate::common::beans::BeanFactory;
use crate::common::constants::PoolState;
use crate::common::ordered_work_steal::{OrderedLocalQueue, OrderedWorkStealQueue};
use crate::common::{get_timeout_time, now, CondvarBlocker};
use crate::coroutine::suspender::Suspender;
use crate::scheduler::{SchedulableCoroutine, Scheduler};
use crate::{error, impl_current_for, impl_display_by_debug, impl_for_named, trace, warn};
use dashmap::{DashMap, DashSet};
use once_cell::sync::Lazy;
use std::cell::Cell;
use std::ffi::c_longlong;
use std::io::{Error, ErrorKind};
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::time::Duration;
pub mod task;
mod state;
mod creator;
static RUNNING_TASKS: Lazy<DashMap<&str, &str>> = Lazy::new(DashMap::new);
static CANCEL_TASKS: Lazy<DashSet<&str>> = Lazy::new(DashSet::new);
#[repr(C)]
#[derive(Debug)]
pub struct CoroutinePool<'p> {
state: Cell<PoolState>,
#[doc = include_str!("../../docs/en/ordered-work-steal.md")]
task_queue: OrderedLocalQueue<'p, Task<'p>>,
workers: Scheduler<'p>,
running: AtomicUsize,
pop_fail_times: AtomicUsize,
min_size: AtomicUsize,
max_size: AtomicUsize,
keep_alive_time: AtomicU64,
blocker: Arc<CondvarBlocker>,
waits: DashMap<&'p str, Arc<(Mutex<bool>, Condvar)>>,
results: DashMap<String, Result<Option<usize>, &'p str>>,
no_waits: DashSet<&'p str>,
}
impl Drop for CoroutinePool<'_> {
fn drop(&mut self) {
if std::thread::panicking() {
return;
}
self.stop(Duration::from_secs(30)).unwrap_or_else(|e| {
panic!("Failed to stop coroutine pool {} due to {e} !", self.name())
});
assert_eq!(
PoolState::Stopped,
self.state(),
"The coroutine pool is not stopped !"
);
assert_eq!(
0,
self.get_running_size(),
"There are still tasks in progress !"
);
if !self.task_queue.is_empty() {
error!("Forget some tasks when closing the pool");
}
}
}
impl Default for CoroutinePool<'_> {
fn default() -> Self {
Self::new(
format!("open-coroutine-pool-{:?}", std::thread::current().id()),
crate::common::constants::DEFAULT_STACK_SIZE,
0,
65536,
0,
)
}
}
impl<'p> Deref for CoroutinePool<'p> {
type Target = Scheduler<'p>;
fn deref(&self) -> &Self::Target {
&self.workers
}
}
impl DerefMut for CoroutinePool<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.workers
}
}
impl_for_named!(CoroutinePool<'p>);
impl_current_for!(COROUTINE_POOL, CoroutinePool<'p>);
impl_display_by_debug!(CoroutinePool<'p>);
impl<'p> CoroutinePool<'p> {
#[must_use]
pub fn new(
name: String,
stack_size: usize,
min_size: usize,
max_size: usize,
keep_alive_time: u64,
) -> Self {
let mut workers = Scheduler::new(name, stack_size);
workers.add_listener(CoroutineCreator::default());
CoroutinePool {
state: Cell::new(PoolState::Running),
workers,
running: AtomicUsize::new(0),
pop_fail_times: AtomicUsize::new(0),
min_size: AtomicUsize::new(min_size),
max_size: AtomicUsize::new(max_size),
task_queue: BeanFactory::get_or_default::<OrderedWorkStealQueue<Task<'p>>>(
crate::common::constants::TASK_GLOBAL_QUEUE_BEAN,
)
.local_queue(),
keep_alive_time: AtomicU64::new(keep_alive_time),
blocker: Arc::default(),
results: DashMap::new(),
waits: DashMap::default(),
no_waits: DashSet::default(),
}
}
pub fn set_min_size(&self, min_size: usize) {
self.min_size.store(min_size, Ordering::Release);
}
pub fn get_min_size(&self) -> usize {
self.min_size.load(Ordering::Acquire)
}
pub fn get_running_size(&self) -> usize {
self.running.load(Ordering::Acquire)
}
pub fn set_max_size(&self, max_size: usize) {
self.max_size.store(max_size, Ordering::Release);
}
pub fn get_max_size(&self) -> usize {
self.max_size.load(Ordering::Acquire)
}
pub fn set_keep_alive_time(&self, keep_alive_time: u64) {
self.keep_alive_time
.store(keep_alive_time, Ordering::Release);
}
pub fn get_keep_alive_time(&self) -> u64 {
self.keep_alive_time.load(Ordering::Acquire)
}
pub fn is_empty(&self) -> bool {
self.size() == 0
}
pub fn size(&self) -> usize {
self.task_queue.len()
}
pub fn stop(&mut self, dur: Duration) -> std::io::Result<()> {
match self.state() {
PoolState::Running => {
assert_eq!(PoolState::Running, self.stopping()?);
self.do_stop(dur)?;
}
PoolState::Stopping => self.do_stop(dur)?,
PoolState::Stopped => self.do_clean(),
}
Ok(())
}
fn do_stop(&mut self, dur: Duration) -> std::io::Result<()> {
_ = self.try_timed_schedule_task(dur)?;
assert_eq!(PoolState::Stopping, self.stopped()?);
self.do_clean();
Ok(())
}
fn do_clean(&mut self) {
for r in &self.waits {
let task_name = *r.key();
_ = self
.results
.insert(task_name.to_string(), Err("The coroutine pool has stopped"));
self.notify(task_name);
}
}
pub fn submit_task(
&self,
name: Option<String>,
func: impl FnOnce(Option<usize>) -> Option<usize> + 'p,
param: Option<usize>,
priority: Option<c_longlong>,
) -> std::io::Result<String> {
match self.state() {
PoolState::Running => {}
PoolState::Stopping | PoolState::Stopped => {
return Err(Error::other("The coroutine pool is stopping or stopped !"))
}
}
let name = name.unwrap_or(format!("{}@{}", self.name(), uuid::Uuid::new_v4()));
self.submit_raw_task(Task::new(name.clone(), func, param, priority));
Ok(name)
}
pub(crate) fn submit_raw_task(&self, task: Task<'p>) {
self.task_queue.push(task);
self.blocker.notify();
}
pub fn try_take_task_result(&self, task_name: &str) -> Option<Result<Option<usize>, &'p str>> {
self.results.remove(task_name).map(|(_, r)| r)
}
pub fn clean_task_result(&self, task_name: &str) {
if self.try_take_task_result(task_name).is_some() {
return;
}
_ = self.no_waits.insert(Box::leak(Box::from(task_name)));
_ = CANCEL_TASKS.remove(task_name);
}
pub fn wait_task_result(
&self,
task_name: &str,
wait_time: Duration,
) -> std::io::Result<Result<Option<usize>, &str>> {
let key = Box::leak(Box::from(task_name));
if let Some(r) = self.try_take_task_result(key) {
self.notify(key);
return Ok(r);
}
if SchedulableCoroutine::current().is_some() {
let timeout_time = get_timeout_time(wait_time);
loop {
_ = self.try_run();
if let Some(r) = self.try_take_task_result(key) {
return Ok(r);
}
if timeout_time.saturating_sub(now()) == 0 {
return Err(Error::new(ErrorKind::TimedOut, "wait timeout"));
}
}
}
let arc = if let Some(arc) = self.waits.get(key) {
arc.clone()
} else {
let arc = Arc::new((Mutex::new(true), Condvar::new()));
assert!(self.waits.insert(key, arc.clone()).is_none());
arc
};
let (lock, cvar) = &*arc;
drop(
cvar.wait_timeout_while(
lock.lock().map_err(|e| Error::other(format!("{e}")))?,
wait_time,
|&mut pending| pending,
)
.map_err(|e| Error::other(format!("{e}")))?,
);
if let Some(r) = self.try_take_task_result(key) {
self.notify(key);
return Ok(r);
}
Err(Error::new(ErrorKind::TimedOut, "wait timeout"))
}
fn can_recycle(&self) -> bool {
match self.state() {
PoolState::Running => false,
PoolState::Stopping | PoolState::Stopped => true,
}
}
fn try_grow(&self) -> std::io::Result<()> {
if self.task_queue.is_empty() {
trace!("The coroutine pool:{} has no task !", self.name());
return Ok(());
}
let create_time = now();
self.submit_co(
move |suspender, ()| {
loop {
let pool = Self::current().expect("current pool not found");
if pool.try_run().is_some() {
pool.reset_pop_fail_times();
continue;
}
let running = pool.get_running_size();
if now().saturating_sub(create_time) >= pool.get_keep_alive_time()
&& running > pool.get_min_size()
|| pool.can_recycle()
{
return None;
}
_ = pool.pop_fail_times.fetch_add(1, Ordering::Release);
match pool.pop_fail_times.load(Ordering::Acquire).cmp(&running) {
std::cmp::Ordering::Less => suspender.suspend(),
std::cmp::Ordering::Equal | std::cmp::Ordering::Greater => {
pool.blocker.clone().block(Duration::from_millis(1));
pool.reset_pop_fail_times();
}
}
}
},
None,
None,
)
}
pub fn submit_co(
&self,
f: impl FnOnce(&Suspender<(), ()>, ()) -> Option<usize> + 'static,
stack_size: Option<usize>,
priority: Option<c_longlong>,
) -> std::io::Result<()> {
if self.get_running_size() >= self.get_max_size() {
trace!(
"The coroutine pool:{} has reached its maximum size !",
self.name()
);
return Err(Error::other(
"The coroutine pool has reached its maximum size !",
));
}
self.deref().submit_co(f, stack_size, priority).map(|_| {
_ = self.running.fetch_add(1, Ordering::Release);
})
}
fn reset_pop_fail_times(&self) {
self.pop_fail_times.store(0, Ordering::Release);
}
fn try_run(&self) -> Option<()> {
self.task_queue.pop().map(|task| {
let tname = task.get_name().to_string().leak();
if CANCEL_TASKS.contains(tname) {
_ = CANCEL_TASKS.remove(tname);
warn!("Cancel task:{} successfully !", tname);
return;
}
if let Some(co) = SchedulableCoroutine::current() {
_ = RUNNING_TASKS.insert(tname, co.name());
}
let (task_name, result) = task.run();
_ = RUNNING_TASKS.remove(tname);
let n = task_name.clone().leak();
if self.no_waits.contains(n) {
_ = self.no_waits.remove(n);
return;
}
assert!(
self.results.insert(task_name.clone(), result).is_none(),
"The previous result was not retrieved in a timely manner"
);
self.notify(&task_name);
})
}
fn notify(&self, task_name: &str) {
if let Some((_, arc)) = self.waits.remove(task_name) {
let (lock, cvar) = &*arc;
let mut pending = lock.lock().expect("notify task failed");
*pending = false;
cvar.notify_one();
}
}
pub fn try_cancel_task(task_name: &str) {
if let Some(info) = RUNNING_TASKS.get(task_name) {
let co_name = *info;
#[allow(unused_variables)]
if let Some(pthread) = Scheduler::get_scheduling_thread(co_name) {
#[cfg(unix)]
if nix::sys::pthread::pthread_kill(pthread, nix::sys::signal::Signal::SIGVTALRM)
.is_ok()
{
warn!(
"Attempt to cancel task:{} running on coroutine:{} by thread:{}, cancelling...",
task_name, co_name, pthread
);
} else {
error!(
"Attempt to cancel task:{} running on coroutine:{} by thread:{} failed !",
task_name, co_name, pthread
);
}
} else {
Scheduler::try_cancel_coroutine(co_name);
warn!(
"Attempt to cancel task:{} running on coroutine:{}, cancelling...",
task_name, co_name
);
}
} else {
_ = CANCEL_TASKS.insert(Box::leak(Box::from(task_name)));
warn!("Attempt to cancel task:{}, cancelling...", task_name);
}
}
pub fn try_schedule_task(&mut self) -> std::io::Result<()> {
self.try_timeout_schedule_task(u64::MAX).map(|_| ())
}
pub fn try_timed_schedule_task(&mut self, dur: Duration) -> std::io::Result<u64> {
self.try_timeout_schedule_task(get_timeout_time(dur))
}
pub fn try_timeout_schedule_task(&mut self, timeout_time: u64) -> std::io::Result<u64> {
match self.state() {
PoolState::Running | PoolState::Stopping => {
drop(self.try_grow());
}
PoolState::Stopped => return Err(Error::other("The coroutine pool is stopped !")),
}
Self::init_current(self);
let r = self.try_timeout_schedule(timeout_time);
Self::clean_current();
r.map(|(left_time, _)| left_time)
}
}