Skip to main content

task_killswitch/
lib.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27use dashmap::DashMap;
28use parking_lot::Mutex;
29use tokio::sync::watch;
30use tokio::task;
31use tokio::task::AbortHandle;
32use tokio::task::Id;
33
34use std::future::Future;
35use std::sync::atomic::AtomicBool;
36use std::sync::atomic::Ordering;
37use std::sync::LazyLock;
38
39/// Drop guard for task removal. If a task panics, this makes sure
40/// it is removed from [`ActiveTasks`] properly.
41struct RemoveOnDrop {
42    id: task::Id,
43    storage: &'static ActiveTasks,
44}
45impl Drop for RemoveOnDrop {
46    fn drop(&mut self) {
47        self.storage.remove_task(self.id);
48    }
49}
50
51/// A task killswitch that allows aborting all the tasks spawned with it at
52/// once. The implementation strives to minimize in-band locking. Spawning a
53/// future requires a single sharded lock from an internal [`DashMap`].
54/// Conflicts are expected to be very rare (dashmap defaults to `4 * nproc`
55/// shards, while each thread can only spawn one task at a time.)
56struct TaskKillswitch {
57    // Invariant: If `activated` is true, we don't add new tasks anymore.
58    activated: AtomicBool,
59    storage: &'static ActiveTasks,
60
61    /// Watcher that is triggered after all kill signals have been sent (by
62    /// dropping `signal_killed`.) Currently-running tasks are killed after
63    /// their next yield, which may be after this triggers.
64    all_killed: watch::Receiver<()>,
65    // NOTE: All we want here is to take ownership of `signal_killed` when
66    // activating the killswitch. That code path only runs once per instance, but
67    // requires interior mutability. Using `Mutex` is easier than bothering with
68    // an `UnsafeCell`. The mutex is guaranteed to be unlocked.
69    signal_killed: Mutex<Option<watch::Sender<()>>>,
70}
71
72impl TaskKillswitch {
73    fn new(storage: &'static ActiveTasks) -> Self {
74        let (signal_killed, all_killed) = watch::channel(());
75        let signal_killed = Mutex::new(Some(signal_killed));
76
77        Self {
78            activated: AtomicBool::new(false),
79            storage,
80            signal_killed,
81            all_killed,
82        }
83    }
84
85    /// Creates a killswitch by allocating and leaking the task storage.
86    ///
87    /// **NOTE:** This is intended for use in `static`s and tests. It should not
88    /// be exposed publicly!
89    fn with_leaked_storage() -> Self {
90        let storage = Box::leak(Box::new(ActiveTasks::default()));
91        Self::new(storage)
92    }
93
94    fn was_activated(&self) -> bool {
95        // All synchronization is done using locks,
96        // so we can use relaxed for our atomics.
97        self.activated.load(Ordering::Relaxed)
98    }
99
100    #[track_caller]
101    fn spawn_task(
102        &self, fut: impl Future<Output = ()> + Send + 'static,
103    ) -> Option<Id> {
104        if self.was_activated() {
105            return None;
106        }
107
108        let storage = self.storage;
109        let handle = tokio::spawn(async move {
110            let id = task::id();
111            let _guard = RemoveOnDrop { id, storage };
112            fut.await;
113        })
114        .abort_handle();
115
116        let id = handle.id();
117
118        let res = self.storage.add_task_if(handle, || !self.was_activated());
119        if let Err(handle) = res {
120            // Killswitch was activated by the time we got a lock on the map shard
121            handle.abort();
122            return None;
123        }
124        Some(id)
125    }
126
127    fn activate(&self) {
128        // We check `activated` after locking the map shard and before inserting
129        // an element. This ensures in-progress spawns either complete before
130        // `tasks.kill_all()` obtains the lock for that shard, or they abort
131        // afterwards.
132        assert!(
133            !self.activated.swap(true, Ordering::Relaxed),
134            "killswitch can't be used twice"
135        );
136
137        let tasks = self.storage;
138        let signal_killed = self.signal_killed.lock().take();
139        std::thread::spawn(move || {
140            tasks.kill_all();
141            drop(signal_killed);
142        });
143    }
144
145    fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
146        let mut signal = self.all_killed.clone();
147        async move {
148            let _ = signal.changed().await;
149        }
150    }
151}
152
153enum TaskEntry {
154    /// Task was added and not yet removed.
155    Handle(AbortHandle),
156    /// Task was removed before it was added. This can happen if a spawned
157    /// future completes before the spawning thread can add it to the map.
158    Tombstone,
159}
160
161#[derive(Default)]
162struct ActiveTasks {
163    tasks: DashMap<task::Id, TaskEntry>,
164}
165
166impl ActiveTasks {
167    fn kill_all(&self) {
168        self.tasks.retain(|_, entry| {
169            if let TaskEntry::Handle(task) = entry {
170                task.abort();
171            }
172            false // remove all elements
173        });
174    }
175
176    fn add_task_if(
177        &self, handle: AbortHandle, cond: impl FnOnce() -> bool,
178    ) -> Result<(), AbortHandle> {
179        use dashmap::Entry::*;
180        let id = handle.id();
181
182        match self.tasks.entry(id) {
183            Vacant(e) => {
184                if !cond() {
185                    return Err(handle);
186                }
187                e.insert(TaskEntry::Handle(handle));
188            },
189            Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {
190                // Task was removed before it was added. Clear the map entry and
191                // drop the handle.
192                e.remove();
193            },
194            Occupied(_) => panic!("tokio task ID already in use: {id}"),
195        }
196
197        Ok(())
198    }
199
200    fn remove_task(&self, id: task::Id) {
201        use dashmap::Entry::*;
202        match self.tasks.entry(id) {
203            Vacant(e) => {
204                // Task was not added yet, set a tombstone instead.
205                e.insert(TaskEntry::Tombstone);
206            },
207            Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {},
208            Occupied(e) => {
209                e.remove();
210            },
211        }
212    }
213}
214
215/// The global [`TaskKillswitch`] exposed publicly from the crate.
216static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
217    LazyLock::new(TaskKillswitch::with_leaked_storage);
218
219/// Spawns a new asynchronous task and registers it in the crate's global
220/// killswitch.
221///
222/// Under the hood, [`tokio::spawn`] schedules the actual execution.
223#[inline]
224#[track_caller]
225pub fn spawn_with_killswitch(
226    fut: impl Future<Output = ()> + Send + 'static,
227) -> Option<Id> {
228    TASK_KILLSWITCH.spawn_task(fut)
229}
230
231#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
232pub async fn activate() {
233    TASK_KILLSWITCH.activate()
234}
235
236/// Triggers the killswitch, thereby scheduling all registered tasks to be
237/// killed.
238///
239/// Note: tasks are not killed synchronously in this function. This means
240/// `activate_now()` will return before all tasks have been stopped.
241#[inline]
242pub fn activate_now() {
243    TASK_KILLSWITCH.activate();
244}
245
246/// Returns a future that resolves when all registered tasks have been killed,
247/// after [`activate_now`] has been called.
248///
249/// Note: tokio does not kill a task until the next time it yields to the
250/// runtime. This means some killed tasks may still be running by the time this
251/// Future resolves.
252#[inline]
253pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
254    TASK_KILLSWITCH.killed()
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use futures_util::future;
261    use std::time::Duration;
262    use tokio::sync::oneshot;
263
264    struct TaskAbortSignal(Option<oneshot::Sender<()>>);
265
266    impl TaskAbortSignal {
267        fn new() -> (Self, oneshot::Receiver<()>) {
268            let (tx, rx) = oneshot::channel();
269
270            (Self(Some(tx)), rx)
271        }
272    }
273
274    impl Drop for TaskAbortSignal {
275        fn drop(&mut self) {
276            let _ = self.0.take().unwrap().send(());
277        }
278    }
279
280    fn start_test_tasks(
281        killswitch: &TaskKillswitch,
282    ) -> Vec<oneshot::Receiver<()>> {
283        (0..1000)
284            .map(|_| {
285                let (tx, rx) = TaskAbortSignal::new();
286
287                killswitch.spawn_task(async move {
288                    tokio::time::sleep(tokio::time::Duration::from_secs(3600))
289                        .await;
290                    drop(tx);
291                });
292
293                rx
294            })
295            .collect()
296    }
297
298    #[tokio::test]
299    async fn activate_killswitch_early() {
300        let killswitch = TaskKillswitch::with_leaked_storage();
301        let abort_signals = start_test_tasks(&killswitch);
302
303        killswitch.activate();
304
305        tokio::time::timeout(
306            Duration::from_secs(1),
307            future::join_all(abort_signals),
308        )
309        .await
310        .expect("tasks should be killed within given timeframe");
311    }
312
313    #[tokio::test]
314    async fn activate_killswitch_with_delay() {
315        let killswitch = TaskKillswitch::with_leaked_storage();
316        let abort_signals = start_test_tasks(&killswitch);
317        let signal_handle = tokio::spawn(killswitch.killed());
318
319        // NOTE: give tasks time to start executing.
320        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
321
322        assert!(!signal_handle.is_finished());
323        killswitch.activate();
324
325        tokio::time::timeout(
326            Duration::from_secs(1),
327            future::join_all(abort_signals),
328        )
329        .await
330        .expect("tasks should be killed within given timeframe");
331
332        tokio::time::timeout(Duration::from_secs(1), signal_handle)
333            .await
334            .expect("killed() signal should have resolved")
335            .expect("signal task should join successfully");
336    }
337}