Skip to main content

dig_service/
tasks.rs

1//! Tracked-task registry.
2//!
3//! [`TaskRegistry`] is a cheap clone-and-spawn utility: every task spawned
4//! through it is tracked, given a stable name, and joined at shutdown with
5//! a configurable deadline. Stragglers past the deadline are aborted and
6//! counted so operators see a graceful-but-incomplete shutdown.
7//!
8//! # Why not `tokio::task::JoinSet`?
9//!
10//! `JoinSet` forgets tasks as soon as they complete; we want to **track**
11//! live tasks for the service's lifetime (so `ServiceHandle::tasks()` can
12//! enumerate them). `JoinSet` also doesn't carry task names. `TaskRegistry`
13//! layers both on top.
14//!
15//! # Guarantees
16//!
17//! - Every spawned task receives a clone of the [`ShutdownToken`](crate::ShutdownToken).
18//! - `join_all(deadline)` waits up to the deadline for each live task;
19//!   remaining tasks are aborted and counted in
20//!   [`ServiceError::ShutdownDeadlineExceeded`](crate::ServiceError::ShutdownDeadlineExceeded).
21
22use std::future::Future;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use parking_lot::Mutex;
27use tokio::task::{AbortHandle, JoinHandle};
28
29use crate::error::{Result, ServiceError};
30use crate::shutdown::ShutdownToken;
31
32/// Kind of task, for observability.
33#[non_exhaustive]
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TaskKind {
36    /// A long-running background loop (sync loop, heartbeat).
37    BackgroundLoop,
38    /// A per-peer connection handler.
39    PeerConnection,
40    /// A per-request RPC handler.
41    RpcHandler,
42    /// A periodic maintenance job (compaction, log rotation).
43    Maintenance,
44}
45
46/// A clone-safe registry of spawned tasks.
47///
48/// `TaskRegistry: Clone + Send + Sync`. Cloning is cheap (`Arc` bump) and
49/// all clones share the same underlying task list.
50#[derive(Clone)]
51pub struct TaskRegistry {
52    inner: Arc<Inner>,
53}
54
55struct Inner {
56    shutdown: ShutdownToken,
57    entries: Mutex<Vec<Entry>>,
58}
59
60struct Entry {
61    name: &'static str,
62    kind: TaskKind,
63    spawned_at: Instant,
64    handle: JoinHandle<()>,
65}
66
67/// Per-task snapshot surfaced by [`TaskRegistry::snapshot`].
68#[derive(Debug, Clone)]
69pub struct TaskSummary {
70    /// Static name passed to `spawn`.
71    pub name: &'static str,
72    /// When the task was spawned.
73    pub spawned_at: Instant,
74    /// Kind classification.
75    pub kind: TaskKind,
76}
77
78impl TaskRegistry {
79    /// Construct an empty registry linked to the given shutdown token.
80    pub fn new(shutdown: ShutdownToken) -> Self {
81        Self {
82            inner: Arc::new(Inner {
83                shutdown,
84                entries: Mutex::new(Vec::new()),
85            }),
86        }
87    }
88
89    /// Spawn a tracked task.
90    ///
91    /// The closure receives no arguments — if it needs the shutdown token
92    /// or a sub-registry, capture them via move. Use
93    /// [`shutdown`](Self::shutdown) to obtain a clone of the token.
94    ///
95    /// Returns an [`AbortHandle`] so the caller can force-abort the task
96    /// if needed (e.g., a per-peer handler when the peer disconnects). The
97    /// real `JoinHandle` is held by the registry and is awaited in
98    /// [`join_all`](Self::join_all) at shutdown. Callers that want to wait
99    /// for a single task to complete must await by their own means (a
100    /// `tokio::sync::oneshot` or similar inside `fut`) — the registry is
101    /// optimised for the shutdown-join case, not per-task completion.
102    pub fn spawn<F>(&self, name: &'static str, kind: TaskKind, fut: F) -> AbortHandle
103    where
104        F: Future<Output = anyhow::Result<()>> + Send + 'static,
105    {
106        let handle: JoinHandle<()> = tokio::spawn(async move {
107            if let Err(e) = fut.await {
108                tracing::error!(task = name, error = %e, "task exited with error");
109            }
110        });
111        let abort = handle.abort_handle();
112        self.inner.entries.lock().push(Entry {
113            name,
114            kind,
115            spawned_at: Instant::now(),
116            handle,
117        });
118        abort
119    }
120
121    /// Borrow the shutdown token this registry is linked to.
122    pub fn shutdown(&self) -> &ShutdownToken {
123        &self.inner.shutdown
124    }
125
126    /// Await all tracked tasks up to `deadline`; abort laggards.
127    ///
128    /// Returns `Ok(())` if every task exited inside the deadline.
129    /// Returns [`ServiceError::ShutdownDeadlineExceeded`] otherwise, with
130    /// the count of aborted tasks.
131    pub async fn join_all(&self, deadline: Duration) -> Result<()> {
132        // Drain the list; we don't hold the lock across awaits.
133        let entries = {
134            let mut g = self.inner.entries.lock();
135            std::mem::take(&mut *g)
136        };
137
138        let start = Instant::now();
139        let mut pending = 0usize;
140
141        for entry in entries {
142            let remaining = deadline.saturating_sub(start.elapsed());
143            if remaining.is_zero() {
144                entry.handle.abort();
145                pending += 1;
146                continue;
147            }
148            match tokio::time::timeout(remaining, entry.handle).await {
149                Ok(Ok(())) => {}
150                Ok(Err(join_err)) => {
151                    // Task panicked or was cancelled.
152                    tracing::warn!(
153                        task = entry.name,
154                        elapsed = ?entry.spawned_at.elapsed(),
155                        panic = join_err.is_panic(),
156                        "tracked task did not exit cleanly",
157                    );
158                }
159                Err(_elapsed) => {
160                    // The `.handle` was consumed by `timeout`; it will abort
161                    // automatically when dropped.
162                    pending += 1;
163                    tracing::warn!(
164                        task = entry.name,
165                        "task exceeded shutdown deadline; aborting",
166                    );
167                }
168            }
169        }
170
171        if pending == 0 {
172            Ok(())
173        } else {
174            Err(ServiceError::ShutdownDeadlineExceeded { deadline, pending })
175        }
176    }
177
178    /// Snapshot of currently-tracked tasks (live only; completed tasks
179    /// are not garbage-collected until `join_all`).
180    pub fn snapshot(&self) -> Vec<TaskSummary> {
181        self.inner
182            .entries
183            .lock()
184            .iter()
185            .map(|e| TaskSummary {
186                name: e.name,
187                spawned_at: e.spawned_at,
188                kind: e.kind,
189            })
190            .collect()
191    }
192
193    /// Current live-task count.
194    pub fn len(&self) -> usize {
195        self.inner.entries.lock().len()
196    }
197
198    /// Whether no tasks are currently registered.
199    pub fn is_empty(&self) -> bool {
200        self.inner.entries.lock().is_empty()
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    /// **Proves:** a freshly-constructed registry is empty.
209    ///
210    /// **Why it matters:** Basic invariant — nothing should be "tracked"
211    /// before `spawn` is called.
212    ///
213    /// **Catches:** a default-state regression.
214    #[test]
215    fn empty_on_construction() {
216        let r = TaskRegistry::new(ShutdownToken::new());
217        assert!(r.is_empty());
218        assert_eq!(r.len(), 0);
219    }
220
221    /// **Proves:** `spawn` increments `len` and adds an entry to the
222    /// snapshot.
223    ///
224    /// **Why it matters:** `ServiceHandle::tasks()` forwards to `snapshot`
225    /// so operators can see what's running. An empty snapshot for a
226    /// running service is a red flag.
227    ///
228    /// **Catches:** a regression where `spawn` forgets to push to `entries`.
229    #[tokio::test]
230    async fn spawn_registers() {
231        let r = TaskRegistry::new(ShutdownToken::new());
232        let _h = r.spawn("t1", TaskKind::BackgroundLoop, async { Ok(()) });
233        assert_eq!(r.len(), 1);
234        assert_eq!(r.snapshot()[0].name, "t1");
235        assert_eq!(r.snapshot()[0].kind, TaskKind::BackgroundLoop);
236    }
237
238    /// **Proves:** `join_all` waits for tasks that finish inside the deadline,
239    /// reporting success.
240    ///
241    /// **Why it matters:** The happy path for graceful shutdown.
242    ///
243    /// **Catches:** a regression where `join_all` returns before tasks
244    /// actually complete.
245    #[tokio::test]
246    async fn join_all_awaits_fast_tasks() {
247        let r = TaskRegistry::new(ShutdownToken::new());
248        r.spawn("fast", TaskKind::BackgroundLoop, async {
249            tokio::time::sleep(Duration::from_millis(5)).await;
250            Ok(())
251        });
252        let res = r.join_all(Duration::from_secs(5)).await;
253        assert!(res.is_ok());
254    }
255
256    /// **Proves:** `join_all` aborts tasks that outlast the deadline and
257    /// returns `ShutdownDeadlineExceeded` with the correct count.
258    ///
259    /// **Why it matters:** If a task has a bug (infinite loop, stuck await),
260    /// shutdown must not hang forever. Aborting past the deadline is the
261    /// safety valve, and the error count lets operators see how many tasks
262    /// went wrong.
263    ///
264    /// **Catches:** a regression where `join_all` blocks without abort,
265    /// or where the `pending` count gets off by one.
266    #[tokio::test]
267    async fn join_all_aborts_slow_tasks() {
268        let r = TaskRegistry::new(ShutdownToken::new());
269        r.spawn("slow", TaskKind::BackgroundLoop, async {
270            tokio::time::sleep(Duration::from_secs(60)).await;
271            Ok(())
272        });
273        let res = r.join_all(Duration::from_millis(20)).await;
274        assert!(matches!(
275            res,
276            Err(ServiceError::ShutdownDeadlineExceeded { pending: 1, .. })
277        ));
278    }
279}