use std::{
collections::HashMap,
sync::{
Mutex,
MutexGuard,
atomic::{
AtomicU8,
AtomicUsize,
Ordering,
},
},
};
use qubit_atomic::AtomicCount;
use qubit_executor::service::ExecutorServiceLifecycle;
use tokio::sync::Notify;
use crate::pending_cancel::PendingCancel;
pub(crate) struct RayonExecutorServiceState {
lifecycle: AtomicU8,
active_tasks: AtomicCount,
queued_tasks: AtomicCount,
submission_lock: Mutex<()>,
pending_tasks: Mutex<HashMap<usize, PendingCancel>>,
next_task_id: AtomicUsize,
terminated_notify: Notify,
}
impl RayonExecutorServiceState {
pub(crate) fn new() -> Self {
Self {
lifecycle: AtomicU8::new(ExecutorServiceLifecycle::Running as u8),
active_tasks: AtomicCount::new(0),
queued_tasks: AtomicCount::new(0),
submission_lock: Mutex::new(()),
pending_tasks: Mutex::new(HashMap::new()),
next_task_id: AtomicUsize::new(0),
terminated_notify: Notify::new(),
}
}
pub(crate) fn lock_submission(&self) -> MutexGuard<'_, ()> {
self.submission_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn lock_pending_tasks(&self) -> MutexGuard<'_, HashMap<usize, PendingCancel>> {
self.pending_tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn stored_lifecycle(&self) -> ExecutorServiceLifecycle {
lifecycle_from_u8(self.lifecycle.load(Ordering::Acquire))
}
pub(crate) fn lifecycle(&self) -> ExecutorServiceLifecycle {
let lifecycle = self.stored_lifecycle();
if lifecycle != ExecutorServiceLifecycle::Running && self.has_no_active_tasks() {
ExecutorServiceLifecycle::Terminated
} else {
lifecycle
}
}
pub(crate) fn is_not_running(&self) -> bool {
self.stored_lifecycle() != ExecutorServiceLifecycle::Running
}
pub(crate) fn shutdown(&self) {
let _ = self
.lifecycle
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
(lifecycle_from_u8(current) == ExecutorServiceLifecycle::Running)
.then_some(ExecutorServiceLifecycle::ShuttingDown as u8)
});
}
pub(crate) fn stop(&self) {
self.lifecycle
.store(ExecutorServiceLifecycle::Stopping as u8, Ordering::Release);
}
pub(crate) fn has_no_active_tasks(&self) -> bool {
self.active_tasks.is_zero()
}
pub(crate) fn next_task_id(&self) -> usize {
self.next_task_id.fetch_add(1, Ordering::Relaxed)
}
pub(crate) fn on_task_accepted(&self) {
self.active_tasks.inc();
self.queued_tasks.inc();
}
pub(crate) fn register_pending_task(&self, task_id: usize, cancel: PendingCancel) {
self.lock_pending_tasks().insert(task_id, cancel);
}
pub(crate) fn start_pending_task<F>(&self, task_id: usize, start: F) -> bool
where
F: FnOnce() -> bool,
{
let mut pending_tasks = self.lock_pending_tasks();
if !pending_tasks.contains_key(&task_id) {
return false;
}
if !start() {
return false;
}
pending_tasks.remove(&task_id);
self.queued_tasks.dec();
true
}
pub(crate) fn cancel_pending_task(&self, task_id: usize, cancel: &PendingCancel) -> bool {
let should_notify = {
let mut pending_tasks = self.lock_pending_tasks();
if !pending_tasks.contains_key(&task_id) {
return false;
}
if !cancel() {
return false;
}
pending_tasks.remove(&task_id);
self.queued_tasks.dec();
self.active_tasks.dec() == 0
};
if should_notify {
self.notify_if_terminated();
}
true
}
pub(crate) fn drain_pending_tasks_for_shutdown(&self) -> (usize, usize, Vec<PendingCancel>) {
let mut pending_tasks = self.lock_pending_tasks();
let queued = self.queued_tasks.get();
let running = self.active_tasks.get().saturating_sub(queued);
let pending = pending_tasks
.drain()
.map(|(_, cancel)| cancel)
.collect::<Vec<_>>();
debug_assert_eq!(pending.len(), queued);
(queued, running, pending)
}
pub(crate) fn cancel_drained_pending_tasks(&self, pending: Vec<PendingCancel>) -> usize {
let mut cancelled = 0usize;
for cancel in pending {
let was_cancelled = cancel();
debug_assert!(
was_cancelled,
"drained pending rayon task should cancel before start",
);
if was_cancelled {
self.queued_tasks.dec();
self.active_tasks.dec();
cancelled += 1;
}
}
self.notify_if_terminated();
cancelled
}
pub(crate) fn on_task_completed(&self) {
if self.active_tasks.dec() == 0 {
self.notify_if_terminated();
}
}
pub(crate) fn notify_if_terminated(&self) {
if self.is_not_running() && self.has_no_active_tasks() {
self.terminated_notify.notify_waiters();
}
}
}
fn lifecycle_from_u8(value: u8) -> ExecutorServiceLifecycle {
match value {
0 => ExecutorServiceLifecycle::Running,
1 => ExecutorServiceLifecycle::ShuttingDown,
2 => ExecutorServiceLifecycle::Stopping,
_ => ExecutorServiceLifecycle::Terminated,
}
}