use std::sync::{
Arc,
Mutex,
MutexGuard,
atomic::AtomicU8,
};
use qubit_executor::service::ExecutorServiceLifecycle;
use qubit_lock::Monitor;
use tokio::{
sync::Notify,
task::AbortHandle,
};
use crate::executor_service_lifecycle_bits;
struct TrackedAbortHandle {
marker: Arc<()>,
handle: AbortHandle,
cancel: Box<dyn FnOnce() -> bool + Send + 'static>,
}
#[derive(Default)]
struct TokioExecutorTaskCounts {
queued: usize,
running: usize,
}
impl TokioExecutorTaskCounts {
fn accept_task(&mut self) {
self.queued += 1;
}
fn mark_started(&mut self) {
debug_assert!(self.queued > 0);
self.queued = self.queued.saturating_sub(1);
self.running += 1;
}
fn finish_task(&mut self, started: bool) {
if started {
debug_assert!(self.running > 0);
self.running = self.running.saturating_sub(1);
} else {
debug_assert!(self.queued > 0);
self.queued = self.queued.saturating_sub(1);
}
}
fn is_empty(&self) -> bool {
self.queued == 0 && self.running == 0
}
}
#[derive(Default)]
pub(crate) struct TokioExecutorServiceState {
lifecycle: AtomicU8,
task_counts: Monitor<TokioExecutorTaskCounts>,
submission_lock: Mutex<()>,
abort_handles: Mutex<Vec<TrackedAbortHandle>>,
pub(crate) terminated_notify: Notify,
}
impl TokioExecutorServiceState {
pub(crate) fn lock_submission(&self) -> MutexGuard<'_, ()> {
self.submission_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
pub(crate) fn accept_task(&self) {
self.task_counts.write(TokioExecutorTaskCounts::accept_task);
}
pub(crate) fn mark_task_started(&self) {
self.task_counts
.write(TokioExecutorTaskCounts::mark_started);
}
pub(crate) fn finish_task(&self, started: bool) {
let terminated = self.task_counts.write(|counts| {
counts.finish_task(started);
self.is_not_running() && counts.is_empty()
});
if terminated {
self.notify_termination_waiters();
}
}
pub(crate) fn task_count_snapshot(&self) -> (usize, usize) {
self.task_counts
.read(|counts| (counts.queued, counts.running))
}
pub(crate) fn register_abort_handle<F>(&self, marker: Arc<()>, handle: AbortHandle, cancel: F)
where
F: FnOnce() -> bool + Send + 'static,
{
let mut handles = self.lock_abort_handles();
if !handle.is_finished() {
handles.push(TrackedAbortHandle {
marker,
handle,
cancel: Box::new(cancel),
});
}
}
pub(crate) fn remove_abort_handle(&self, marker: &Arc<()>) {
self.lock_abort_handles()
.retain(|tracked| !Arc::ptr_eq(&tracked.marker, marker));
}
pub(crate) fn abort_tracked_tasks(&self) -> usize {
let mut cancellation_count = 0usize;
let mut handles = self.lock_abort_handles();
for tracked in handles.drain(..) {
if !tracked.handle.is_finished() {
tracked.handle.abort();
if (tracked.cancel)() {
cancellation_count += 1;
}
}
}
cancellation_count
}
pub(crate) fn notify_if_terminated(&self) {
let terminated = self
.task_counts
.read(|counts| self.is_not_running() && counts.is_empty());
if terminated {
self.notify_termination_waiters();
}
}
pub(crate) fn wait_termination(&self) {
self.task_counts.wait_until(
|counts| self.is_not_running() && counts.is_empty(),
|_counts| {},
);
}
fn notify_termination_waiters(&self) {
self.task_counts.notify_all();
self.terminated_notify.notify_waiters();
}
fn lock_abort_handles(&self) -> MutexGuard<'_, Vec<TrackedAbortHandle>> {
self.abort_handles
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
pub(crate) fn lifecycle(&self) -> ExecutorServiceLifecycle {
let lifecycle = executor_service_lifecycle_bits::load(&self.lifecycle);
let has_no_tasks = self.task_counts.read(TokioExecutorTaskCounts::is_empty);
if lifecycle != ExecutorServiceLifecycle::Running && has_no_tasks {
ExecutorServiceLifecycle::Terminated
} else {
lifecycle
}
}
pub(crate) fn is_not_running(&self) -> bool {
executor_service_lifecycle_bits::load(&self.lifecycle) != ExecutorServiceLifecycle::Running
}
pub(crate) fn shutdown(&self) {
executor_service_lifecycle_bits::shutdown(&self.lifecycle);
}
pub(crate) fn stop(&self) {
executor_service_lifecycle_bits::stop(&self.lifecycle);
}
}