1use dashmap::DashMap;
28use parking_lot::Mutex;
29use tokio::sync::watch;
30use tokio::task;
31use tokio::task::AbortHandle;
32use tokio::task::Id;
33
34use std::future::Future;
35use std::sync::atomic::AtomicBool;
36use std::sync::atomic::Ordering;
37use std::sync::LazyLock;
38
39struct RemoveOnDrop {
42 id: task::Id,
43 storage: &'static ActiveTasks,
44}
45impl Drop for RemoveOnDrop {
46 fn drop(&mut self) {
47 self.storage.remove_task(self.id);
48 }
49}
50
51struct TaskKillswitch {
57 activated: AtomicBool,
59 storage: &'static ActiveTasks,
60
61 all_killed: watch::Receiver<()>,
65 signal_killed: Mutex<Option<watch::Sender<()>>>,
70}
71
72impl TaskKillswitch {
73 fn new(storage: &'static ActiveTasks) -> Self {
74 let (signal_killed, all_killed) = watch::channel(());
75 let signal_killed = Mutex::new(Some(signal_killed));
76
77 Self {
78 activated: AtomicBool::new(false),
79 storage,
80 signal_killed,
81 all_killed,
82 }
83 }
84
85 fn with_leaked_storage() -> Self {
90 let storage = Box::leak(Box::new(ActiveTasks::default()));
91 Self::new(storage)
92 }
93
94 fn was_activated(&self) -> bool {
95 self.activated.load(Ordering::Relaxed)
98 }
99
100 fn spawn_task(
101 &self, fut: impl Future<Output = ()> + Send + 'static,
102 ) -> Option<Id> {
103 if self.was_activated() {
104 return None;
105 }
106
107 let storage = self.storage;
108 let handle = tokio::spawn(async move {
109 let id = task::id();
110 let _guard = RemoveOnDrop { id, storage };
111 fut.await;
112 })
113 .abort_handle();
114
115 let id = handle.id();
116
117 let res = self.storage.add_task_if(handle, || !self.was_activated());
118 if let Err(handle) = res {
119 handle.abort();
121 return None;
122 }
123 Some(id)
124 }
125
126 fn activate(&self) {
127 assert!(
132 !self.activated.swap(true, Ordering::Relaxed),
133 "killswitch can't be used twice"
134 );
135
136 let tasks = self.storage;
137 let signal_killed = self.signal_killed.lock().take();
138 std::thread::spawn(move || {
139 tasks.kill_all();
140 drop(signal_killed);
141 });
142 }
143
144 fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
145 let mut signal = self.all_killed.clone();
146 async move {
147 let _ = signal.changed().await;
148 }
149 }
150}
151
152enum TaskEntry {
153 Handle(AbortHandle),
155 Tombstone,
158}
159
160#[derive(Default)]
161struct ActiveTasks {
162 tasks: DashMap<task::Id, TaskEntry>,
163}
164
165impl ActiveTasks {
166 fn kill_all(&self) {
167 self.tasks.retain(|_, entry| {
168 if let TaskEntry::Handle(task) = entry {
169 task.abort();
170 }
171 false });
173 }
174
175 fn add_task_if(
176 &self, handle: AbortHandle, cond: impl FnOnce() -> bool,
177 ) -> Result<(), AbortHandle> {
178 use dashmap::Entry::*;
179 let id = handle.id();
180
181 match self.tasks.entry(id) {
182 Vacant(e) => {
183 if !cond() {
184 return Err(handle);
185 }
186 e.insert(TaskEntry::Handle(handle));
187 },
188 Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {
189 e.remove();
192 },
193 Occupied(_) => panic!("tokio task ID already in use: {id}"),
194 }
195
196 Ok(())
197 }
198
199 fn remove_task(&self, id: task::Id) {
200 use dashmap::Entry::*;
201 match self.tasks.entry(id) {
202 Vacant(e) => {
203 e.insert(TaskEntry::Tombstone);
205 },
206 Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {},
207 Occupied(e) => {
208 e.remove();
209 },
210 }
211 }
212}
213
214static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
216 LazyLock::new(TaskKillswitch::with_leaked_storage);
217
218#[inline]
223pub fn spawn_with_killswitch(
224 fut: impl Future<Output = ()> + Send + 'static,
225) -> Option<Id> {
226 TASK_KILLSWITCH.spawn_task(fut)
227}
228
229#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
230pub async fn activate() {
231 TASK_KILLSWITCH.activate()
232}
233
234#[inline]
240pub fn activate_now() {
241 TASK_KILLSWITCH.activate();
242}
243
244#[inline]
251pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
252 TASK_KILLSWITCH.killed()
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use futures_util::future;
259 use std::time::Duration;
260 use tokio::sync::oneshot;
261
262 struct TaskAbortSignal(Option<oneshot::Sender<()>>);
263
264 impl TaskAbortSignal {
265 fn new() -> (Self, oneshot::Receiver<()>) {
266 let (tx, rx) = oneshot::channel();
267
268 (Self(Some(tx)), rx)
269 }
270 }
271
272 impl Drop for TaskAbortSignal {
273 fn drop(&mut self) {
274 let _ = self.0.take().unwrap().send(());
275 }
276 }
277
278 fn start_test_tasks(
279 killswitch: &TaskKillswitch,
280 ) -> Vec<oneshot::Receiver<()>> {
281 (0..1000)
282 .map(|_| {
283 let (tx, rx) = TaskAbortSignal::new();
284
285 killswitch.spawn_task(async move {
286 tokio::time::sleep(tokio::time::Duration::from_secs(3600))
287 .await;
288 drop(tx);
289 });
290
291 rx
292 })
293 .collect()
294 }
295
296 #[tokio::test]
297 async fn activate_killswitch_early() {
298 let killswitch = TaskKillswitch::with_leaked_storage();
299 let abort_signals = start_test_tasks(&killswitch);
300
301 killswitch.activate();
302
303 tokio::time::timeout(
304 Duration::from_secs(1),
305 future::join_all(abort_signals),
306 )
307 .await
308 .expect("tasks should be killed within given timeframe");
309 }
310
311 #[tokio::test]
312 async fn activate_killswitch_with_delay() {
313 let killswitch = TaskKillswitch::with_leaked_storage();
314 let abort_signals = start_test_tasks(&killswitch);
315 let signal_handle = tokio::spawn(killswitch.killed());
316
317 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
319
320 assert!(!signal_handle.is_finished());
321 killswitch.activate();
322
323 tokio::time::timeout(
324 Duration::from_secs(1),
325 future::join_all(abort_signals),
326 )
327 .await
328 .expect("tasks should be killed within given timeframe");
329
330 tokio::time::timeout(Duration::from_secs(1), signal_handle)
331 .await
332 .expect("killed() signal should have resolved")
333 .expect("signal task should join successfully");
334 }
335}