use once_cell::sync::Lazy;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
enum ActiveTaskOp {
Add { id: u64, handle: JoinHandle<()> },
Remove { id: u64 },
}
struct RemoveOnDrop {
id: u64,
task_tx_weak: mpsc::WeakUnboundedSender<ActiveTaskOp>,
}
impl Drop for RemoveOnDrop {
fn drop(&mut self) {
if let Some(tx) = self.task_tx_weak.upgrade() {
let _ = tx.send(ActiveTaskOp::Remove { id: self.id });
}
}
}
struct TaskKillswitch {
task_tx: parking_lot::RwLock<Option<mpsc::UnboundedSender<ActiveTaskOp>>>,
task_counter: AtomicU64,
all_killed: watch::Receiver<()>,
}
impl TaskKillswitch {
fn new() -> Self {
let (task_tx, task_rx) = mpsc::unbounded_channel();
let (signal_killed, all_killed) = watch::channel(());
let active_tasks = ActiveTasks {
task_rx,
tasks: Default::default(),
signal_killed,
};
tokio::spawn(active_tasks.collect());
Self {
task_tx: parking_lot::RwLock::new(Some(task_tx)),
task_counter: Default::default(),
all_killed,
}
}
fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
let Some(task_tx) = self.task_tx.read().as_ref().cloned() else {
return;
};
let id = self.task_counter.fetch_add(1, Ordering::SeqCst);
let task_tx_weak = task_tx.downgrade();
let handle = tokio::spawn(async move {
let _guard = RemoveOnDrop { task_tx_weak, id };
fut.await;
});
let _ = task_tx.send(ActiveTaskOp::Add { id, handle });
}
fn activate(&self) {
assert!(
self.task_tx.write().take().is_some(),
"killswitch can't be used twice"
);
}
fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
let mut signal = self.all_killed.clone();
async move {
let _ = signal.changed().await;
}
}
}
struct ActiveTasks {
task_rx: mpsc::UnboundedReceiver<ActiveTaskOp>,
tasks: HashMap<u64, JoinHandle<()>>,
signal_killed: watch::Sender<()>,
}
impl ActiveTasks {
async fn collect(mut self) {
while let Some(op) = self.task_rx.recv().await {
self.handle_task_op(op);
}
for task in self.tasks.into_values() {
task.abort();
}
drop(self.signal_killed);
}
fn handle_task_op(&mut self, op: ActiveTaskOp) {
match op {
ActiveTaskOp::Add { id, handle } => {
self.tasks.insert(id, handle);
},
ActiveTaskOp::Remove { id } => {
self.tasks.remove(&id);
},
}
}
}
static TASK_KILLSWITCH: Lazy<TaskKillswitch> = Lazy::new(TaskKillswitch::new);
#[inline]
pub fn spawn_with_killswitch(fut: impl Future<Output = ()> + Send + 'static) {
TASK_KILLSWITCH.spawn_task(fut);
}
#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
pub async fn activate() {
TASK_KILLSWITCH.activate()
}
#[inline]
pub fn activate_now() {
TASK_KILLSWITCH.activate();
}
#[inline]
pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
TASK_KILLSWITCH.killed()
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::future;
use std::time::Duration;
use tokio::sync::oneshot;
struct TaskAbortSignal(Option<oneshot::Sender<()>>);
impl TaskAbortSignal {
fn new() -> (Self, oneshot::Receiver<()>) {
let (tx, rx) = oneshot::channel();
(Self(Some(tx)), rx)
}
}
impl Drop for TaskAbortSignal {
fn drop(&mut self) {
let _ = self.0.take().unwrap().send(());
}
}
fn start_test_tasks(
killswitch: &TaskKillswitch,
) -> Vec<oneshot::Receiver<()>> {
(0..1000)
.map(|_| {
let (tx, rx) = TaskAbortSignal::new();
killswitch.spawn_task(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(3600))
.await;
drop(tx);
});
rx
})
.collect()
}
#[tokio::test]
async fn activate_killswitch_early() {
let killswitch = TaskKillswitch::new();
let abort_signals = start_test_tasks(&killswitch);
killswitch.activate();
tokio::time::timeout(
Duration::from_secs(1),
future::join_all(abort_signals),
)
.await
.expect("tasks should be killed within given timeframe");
}
#[tokio::test]
async fn activate_killswitch_with_delay() {
let killswitch = TaskKillswitch::new();
let abort_signals = start_test_tasks(&killswitch);
let signal_handle = tokio::spawn(killswitch.killed());
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert!(!signal_handle.is_finished());
killswitch.activate();
tokio::time::timeout(
Duration::from_secs(1),
future::join_all(abort_signals),
)
.await
.expect("tasks should be killed within given timeframe");
tokio::time::timeout(Duration::from_secs(1), signal_handle)
.await
.expect("killed() signal should have resolved")
.expect("signal task should join successfully");
}
}