use dashmap::DashMap;
use parking_lot::Mutex;
use tokio::sync::watch;
use tokio::task;
use tokio::task::AbortHandle;
use std::future::Future;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::LazyLock;
struct RemoveOnDrop {
id: task::Id,
storage: &'static ActiveTasks,
}
impl Drop for RemoveOnDrop {
fn drop(&mut self) {
self.storage.remove_task(self.id);
}
}
struct TaskKillswitch {
activated: AtomicBool,
storage: &'static ActiveTasks,
all_killed: watch::Receiver<()>,
signal_killed: Mutex<Option<watch::Sender<()>>>,
}
impl TaskKillswitch {
fn new(storage: &'static ActiveTasks) -> Self {
let (signal_killed, all_killed) = watch::channel(());
let signal_killed = Mutex::new(Some(signal_killed));
Self {
activated: AtomicBool::new(false),
storage,
signal_killed,
all_killed,
}
}
fn with_leaked_storage() -> Self {
let storage = Box::leak(Box::new(ActiveTasks::default()));
Self::new(storage)
}
fn was_activated(&self) -> bool {
self.activated.load(Ordering::Relaxed)
}
fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
if self.was_activated() {
return;
}
let storage = self.storage;
let handle = tokio::spawn(async move {
let id = task::id();
let _guard = RemoveOnDrop { id, storage };
fut.await;
})
.abort_handle();
let res = self.storage.add_task_if(handle, || !self.was_activated());
if let Err(handle) = res {
handle.abort();
}
}
fn activate(&self) {
assert!(
!self.activated.swap(true, Ordering::Relaxed),
"killswitch can't be used twice"
);
let tasks = self.storage;
let signal_killed = self.signal_killed.lock().take();
std::thread::spawn(move || {
tasks.kill_all();
drop(signal_killed);
});
}
fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
let mut signal = self.all_killed.clone();
async move {
let _ = signal.changed().await;
}
}
}
enum TaskEntry {
Handle(AbortHandle),
Tombstone,
}
#[derive(Default)]
struct ActiveTasks {
tasks: DashMap<task::Id, TaskEntry>,
}
impl ActiveTasks {
fn kill_all(&self) {
self.tasks.retain(|_, entry| {
if let TaskEntry::Handle(task) = entry {
task.abort();
}
false });
}
fn add_task_if(
&self, handle: AbortHandle, cond: impl FnOnce() -> bool,
) -> Result<(), AbortHandle> {
use dashmap::Entry::*;
let id = handle.id();
match self.tasks.entry(id) {
Vacant(e) => {
if !cond() {
return Err(handle);
}
e.insert(TaskEntry::Handle(handle));
},
Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {
e.remove();
},
Occupied(_) => panic!("tokio task ID already in use: {id}"),
}
Ok(())
}
fn remove_task(&self, id: task::Id) {
use dashmap::Entry::*;
match self.tasks.entry(id) {
Vacant(e) => {
e.insert(TaskEntry::Tombstone);
},
Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {},
Occupied(e) => {
e.remove();
},
}
}
}
static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
LazyLock::new(TaskKillswitch::with_leaked_storage);
#[inline]
pub fn spawn_with_killswitch(fut: impl Future<Output = ()> + Send + 'static) {
TASK_KILLSWITCH.spawn_task(fut);
}
#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
pub async fn activate() {
TASK_KILLSWITCH.activate()
}
#[inline]
pub fn activate_now() {
TASK_KILLSWITCH.activate();
}
#[inline]
pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
TASK_KILLSWITCH.killed()
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::future;
use std::time::Duration;
use tokio::sync::oneshot;
struct TaskAbortSignal(Option<oneshot::Sender<()>>);
impl TaskAbortSignal {
fn new() -> (Self, oneshot::Receiver<()>) {
let (tx, rx) = oneshot::channel();
(Self(Some(tx)), rx)
}
}
impl Drop for TaskAbortSignal {
fn drop(&mut self) {
let _ = self.0.take().unwrap().send(());
}
}
fn start_test_tasks(
killswitch: &TaskKillswitch,
) -> Vec<oneshot::Receiver<()>> {
(0..1000)
.map(|_| {
let (tx, rx) = TaskAbortSignal::new();
killswitch.spawn_task(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(3600))
.await;
drop(tx);
});
rx
})
.collect()
}
#[tokio::test]
async fn activate_killswitch_early() {
let killswitch = TaskKillswitch::with_leaked_storage();
let abort_signals = start_test_tasks(&killswitch);
killswitch.activate();
tokio::time::timeout(
Duration::from_secs(1),
future::join_all(abort_signals),
)
.await
.expect("tasks should be killed within given timeframe");
}
#[tokio::test]
async fn activate_killswitch_with_delay() {
let killswitch = TaskKillswitch::with_leaked_storage();
let abort_signals = start_test_tasks(&killswitch);
let signal_handle = tokio::spawn(killswitch.killed());
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert!(!signal_handle.is_finished());
killswitch.activate();
tokio::time::timeout(
Duration::from_secs(1),
future::join_all(abort_signals),
)
.await
.expect("tasks should be killed within given timeframe");
tokio::time::timeout(Duration::from_secs(1), signal_handle)
.await
.expect("killed() signal should have resolved")
.expect("signal task should join successfully");
}
}