1use dashmap::DashMap;
28use parking_lot::Mutex;
29use tokio::sync::watch;
30use tokio::task;
31use tokio::task::AbortHandle;
32use tokio::task::Id;
33
34use std::future::Future;
35use std::sync::atomic::AtomicBool;
36use std::sync::atomic::Ordering;
37use std::sync::LazyLock;
38
39struct RemoveOnDrop {
42 id: task::Id,
43 storage: &'static ActiveTasks,
44}
45impl Drop for RemoveOnDrop {
46 fn drop(&mut self) {
47 self.storage.remove_task(self.id);
48 }
49}
50
51struct TaskKillswitch {
57 activated: AtomicBool,
59 storage: &'static ActiveTasks,
60
61 all_killed: watch::Receiver<()>,
65 signal_killed: Mutex<Option<watch::Sender<()>>>,
70}
71
72impl TaskKillswitch {
73 fn new(storage: &'static ActiveTasks) -> Self {
74 let (signal_killed, all_killed) = watch::channel(());
75 let signal_killed = Mutex::new(Some(signal_killed));
76
77 Self {
78 activated: AtomicBool::new(false),
79 storage,
80 signal_killed,
81 all_killed,
82 }
83 }
84
85 fn with_leaked_storage() -> Self {
90 let storage = Box::leak(Box::new(ActiveTasks::default()));
91 Self::new(storage)
92 }
93
94 fn was_activated(&self) -> bool {
95 self.activated.load(Ordering::Relaxed)
98 }
99
100 #[track_caller]
101 fn spawn_task(
102 &self, fut: impl Future<Output = ()> + Send + 'static,
103 ) -> Option<Id> {
104 if self.was_activated() {
105 return None;
106 }
107
108 let storage = self.storage;
109 let handle = tokio::spawn(async move {
110 let id = task::id();
111 let _guard = RemoveOnDrop { id, storage };
112 fut.await;
113 })
114 .abort_handle();
115
116 let id = handle.id();
117
118 let res = self.storage.add_task_if(handle, || !self.was_activated());
119 if let Err(handle) = res {
120 handle.abort();
122 return None;
123 }
124 Some(id)
125 }
126
127 fn activate(&self) {
128 assert!(
133 !self.activated.swap(true, Ordering::Relaxed),
134 "killswitch can't be used twice"
135 );
136
137 let tasks = self.storage;
138 let signal_killed = self.signal_killed.lock().take();
139 std::thread::spawn(move || {
140 tasks.kill_all();
141 drop(signal_killed);
142 });
143 }
144
145 fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
146 let mut signal = self.all_killed.clone();
147 async move {
148 let _ = signal.changed().await;
149 }
150 }
151}
152
153enum TaskEntry {
154 Handle(AbortHandle),
156 Tombstone,
159}
160
161#[derive(Default)]
162struct ActiveTasks {
163 tasks: DashMap<task::Id, TaskEntry>,
164}
165
166impl ActiveTasks {
167 fn kill_all(&self) {
168 self.tasks.retain(|_, entry| {
169 if let TaskEntry::Handle(task) = entry {
170 task.abort();
171 }
172 false });
174 }
175
176 fn add_task_if(
177 &self, handle: AbortHandle, cond: impl FnOnce() -> bool,
178 ) -> Result<(), AbortHandle> {
179 use dashmap::Entry::*;
180 let id = handle.id();
181
182 match self.tasks.entry(id) {
183 Vacant(e) => {
184 if !cond() {
185 return Err(handle);
186 }
187 e.insert(TaskEntry::Handle(handle));
188 },
189 Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {
190 e.remove();
193 },
194 Occupied(_) => panic!("tokio task ID already in use: {id}"),
195 }
196
197 Ok(())
198 }
199
200 fn remove_task(&self, id: task::Id) {
201 use dashmap::Entry::*;
202 match self.tasks.entry(id) {
203 Vacant(e) => {
204 e.insert(TaskEntry::Tombstone);
206 },
207 Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {},
208 Occupied(e) => {
209 e.remove();
210 },
211 }
212 }
213}
214
215static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
217 LazyLock::new(TaskKillswitch::with_leaked_storage);
218
219#[inline]
224#[track_caller]
225pub fn spawn_with_killswitch(
226 fut: impl Future<Output = ()> + Send + 'static,
227) -> Option<Id> {
228 TASK_KILLSWITCH.spawn_task(fut)
229}
230
231#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
232pub async fn activate() {
233 TASK_KILLSWITCH.activate()
234}
235
236#[inline]
242pub fn activate_now() {
243 TASK_KILLSWITCH.activate();
244}
245
246#[inline]
253pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
254 TASK_KILLSWITCH.killed()
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use futures_util::future;
261 use std::time::Duration;
262 use tokio::sync::oneshot;
263
264 struct TaskAbortSignal(Option<oneshot::Sender<()>>);
265
266 impl TaskAbortSignal {
267 fn new() -> (Self, oneshot::Receiver<()>) {
268 let (tx, rx) = oneshot::channel();
269
270 (Self(Some(tx)), rx)
271 }
272 }
273
274 impl Drop for TaskAbortSignal {
275 fn drop(&mut self) {
276 let _ = self.0.take().unwrap().send(());
277 }
278 }
279
280 fn start_test_tasks(
281 killswitch: &TaskKillswitch,
282 ) -> Vec<oneshot::Receiver<()>> {
283 (0..1000)
284 .map(|_| {
285 let (tx, rx) = TaskAbortSignal::new();
286
287 killswitch.spawn_task(async move {
288 tokio::time::sleep(tokio::time::Duration::from_secs(3600))
289 .await;
290 drop(tx);
291 });
292
293 rx
294 })
295 .collect()
296 }
297
298 #[tokio::test]
299 async fn activate_killswitch_early() {
300 let killswitch = TaskKillswitch::with_leaked_storage();
301 let abort_signals = start_test_tasks(&killswitch);
302
303 killswitch.activate();
304
305 tokio::time::timeout(
306 Duration::from_secs(1),
307 future::join_all(abort_signals),
308 )
309 .await
310 .expect("tasks should be killed within given timeframe");
311 }
312
313 #[tokio::test]
314 async fn activate_killswitch_with_delay() {
315 let killswitch = TaskKillswitch::with_leaked_storage();
316 let abort_signals = start_test_tasks(&killswitch);
317 let signal_handle = tokio::spawn(killswitch.killed());
318
319 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
321
322 assert!(!signal_handle.is_finished());
323 killswitch.activate();
324
325 tokio::time::timeout(
326 Duration::from_secs(1),
327 future::join_all(abort_signals),
328 )
329 .await
330 .expect("tasks should be killed within given timeframe");
331
332 tokio::time::timeout(Duration::from_secs(1), signal_handle)
333 .await
334 .expect("killed() signal should have resolved")
335 .expect("signal task should join successfully");
336 }
337}