use std::{
collections::HashMap,
sync::{
Mutex,
MutexGuard,
atomic::{
AtomicU8,
AtomicUsize,
Ordering,
},
},
};
use qubit_atomic::AtomicCount;
use qubit_executor::service::{
ExecutorServiceLifecycle,
StopReport,
};
use qubit_lock::Monitor;
use crate::pending_cancel::PendingCancel;
pub(crate) struct RayonExecutorServiceState {
lifecycle: AtomicU8,
active_tasks: AtomicCount,
submission_lock: Mutex<()>,
pending_tasks: Mutex<HashMap<usize, PendingCancel>>,
next_task_id: AtomicUsize,
terminated: Monitor<bool>,
}
impl RayonExecutorServiceState {
pub(crate) fn new() -> Self {
Self {
lifecycle: AtomicU8::new(ExecutorServiceLifecycle::Running as u8),
active_tasks: AtomicCount::new(0),
submission_lock: Mutex::new(()),
pending_tasks: Mutex::new(HashMap::new()),
next_task_id: AtomicUsize::new(0),
terminated: Monitor::new(false),
}
}
pub(crate) fn lock_submission(&self) -> MutexGuard<'_, ()> {
self.submission_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn lock_pending_tasks(&self) -> MutexGuard<'_, HashMap<usize, PendingCancel>> {
self.pending_tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn stored_lifecycle(&self) -> ExecutorServiceLifecycle {
lifecycle_from_u8(self.lifecycle.load(Ordering::Acquire))
}
pub(crate) fn lifecycle(&self) -> ExecutorServiceLifecycle {
let lifecycle = self.stored_lifecycle();
if lifecycle != ExecutorServiceLifecycle::Running && self.has_no_active_tasks() {
ExecutorServiceLifecycle::Terminated
} else {
lifecycle
}
}
pub(crate) fn is_not_running(&self) -> bool {
self.stored_lifecycle() != ExecutorServiceLifecycle::Running
}
pub(crate) fn shutdown(&self) {
let _ = self
.lifecycle
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
(lifecycle_from_u8(current) == ExecutorServiceLifecycle::Running)
.then_some(ExecutorServiceLifecycle::ShuttingDown as u8)
});
}
pub(crate) fn stop(&self) {
self.lifecycle
.store(ExecutorServiceLifecycle::Stopping as u8, Ordering::Release);
}
pub(crate) fn has_no_active_tasks(&self) -> bool {
self.active_tasks.is_zero()
}
pub(crate) fn next_task_id(&self) -> usize {
self.next_task_id.fetch_add(1, Ordering::Relaxed)
}
pub(crate) fn on_task_accepted(&self) {
self.active_tasks.inc();
}
pub(crate) fn register_pending_task(&self, task_id: usize, cancel: PendingCancel) {
self.lock_pending_tasks().insert(task_id, cancel);
}
pub(crate) fn start_pending_task<F>(&self, task_id: usize, start: F) -> bool
where
F: FnOnce() -> bool,
{
let mut pending_tasks = self.lock_pending_tasks();
if !pending_tasks.contains_key(&task_id) {
return false;
}
if !start() {
return false;
}
pending_tasks.remove(&task_id);
true
}
pub(crate) fn cancel_pending_task(&self, task_id: usize, cancel: &PendingCancel) -> bool {
let should_notify = {
let mut pending_tasks = self.lock_pending_tasks();
if !pending_tasks.contains_key(&task_id) {
return false;
}
if !cancel() {
return false;
}
pending_tasks.remove(&task_id);
self.active_tasks.dec() == 0
};
if should_notify {
self.notify_if_terminated();
}
true
}
pub(crate) fn cancel_pending_tasks_for_stop(&self) -> StopReport {
let (report, should_notify) = {
let mut pending_tasks = self.lock_pending_tasks();
let queued = pending_tasks.len();
let running = self.active_tasks.get().saturating_sub(queued);
let mut cancelled = 0usize;
for (_, cancel) in pending_tasks.drain() {
let was_cancelled = cancel();
debug_assert!(
was_cancelled,
"drained pending rayon task should cancel before start",
);
if was_cancelled {
self.active_tasks.dec();
cancelled += 1;
}
}
(
StopReport::new(queued, running, cancelled),
self.has_no_active_tasks(),
)
};
if should_notify {
self.notify_if_terminated();
}
report
}
pub(crate) fn on_task_completed(&self) {
if self.active_tasks.dec() == 0 {
self.notify_if_terminated();
}
}
pub(crate) fn wait_for_termination(&self) {
self.terminated.wait_until(|terminated| *terminated, |_| ());
}
pub(crate) fn notify_if_terminated(&self) {
if self.is_not_running() && self.has_no_active_tasks() {
self.terminated.write(|terminated| *terminated = true);
self.terminated.notify_all();
}
}
}
fn lifecycle_from_u8(value: u8) -> ExecutorServiceLifecycle {
match value {
0 => ExecutorServiceLifecycle::Running,
1 => ExecutorServiceLifecycle::ShuttingDown,
2 => ExecutorServiceLifecycle::Stopping,
_ => ExecutorServiceLifecycle::Terminated,
}
}