use std::sync::{
Arc,
Mutex,
MutexGuard,
};
use qubit_atomic::{
Atomic,
AtomicCount,
};
use tokio::{
sync::Notify,
task::AbortHandle,
};
struct TrackedAbortHandle {
marker: Arc<()>,
handle: AbortHandle,
}
#[derive(Default)]
pub(crate) struct TokioIoExecutorServiceState {
pub(crate) shutdown: Atomic<bool>,
pub(crate) active_tasks: AtomicCount,
submission_lock: Mutex<()>,
abort_handles: Mutex<Vec<TrackedAbortHandle>>,
pub(crate) terminated_notify: Notify,
}
impl TokioIoExecutorServiceState {
pub(crate) fn lock_submission(&self) -> MutexGuard<'_, ()> {
self.submission_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
pub(crate) fn register_abort_handle(&self, marker: Arc<()>, handle: AbortHandle) {
let mut handles = self.lock_abort_handles();
if !handle.is_finished() {
handles.push(TrackedAbortHandle { marker, handle });
}
}
pub(crate) fn remove_abort_handle(&self, marker: &Arc<()>) {
self.lock_abort_handles()
.retain(|tracked| !Arc::ptr_eq(&tracked.marker, marker));
}
pub(crate) fn abort_tracked_tasks(&self) -> usize {
let mut cancellation_count = 0usize;
let mut handles = self.lock_abort_handles();
for tracked in handles.drain(..) {
if !tracked.handle.is_finished() {
tracked.handle.abort();
cancellation_count += 1;
}
}
cancellation_count
}
pub(crate) fn notify_if_terminated(&self) {
if self.shutdown.load() && self.active_tasks.is_zero() {
self.terminated_notify.notify_waiters();
}
}
fn lock_abort_handles(&self) -> MutexGuard<'_, Vec<TrackedAbortHandle>> {
self.abort_handles
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}