dig_service/tasks.rs
1//! Tracked-task registry.
2//!
3//! [`TaskRegistry`] is a cheap clone-and-spawn utility: every task spawned
4//! through it is tracked, given a stable name, and joined at shutdown with
5//! a configurable deadline. Stragglers past the deadline are aborted and
6//! counted so operators see a graceful-but-incomplete shutdown.
7//!
8//! # Why not `tokio::task::JoinSet`?
9//!
10//! `JoinSet` forgets tasks as soon as they complete; we want to **track**
11//! live tasks for the service's lifetime (so `ServiceHandle::tasks()` can
12//! enumerate them). `JoinSet` also doesn't carry task names. `TaskRegistry`
13//! layers both on top.
14//!
15//! # Guarantees
16//!
17//! - Every spawned task receives a clone of the [`ShutdownToken`](crate::ShutdownToken).
18//! - `join_all(deadline)` waits up to the deadline for each live task;
19//! remaining tasks are aborted and counted in
20//! [`ServiceError::ShutdownDeadlineExceeded`](crate::ServiceError::ShutdownDeadlineExceeded).
21
22use std::future::Future;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use parking_lot::Mutex;
27use tokio::task::{AbortHandle, JoinHandle};
28
29use crate::error::{Result, ServiceError};
30use crate::shutdown::ShutdownToken;
31
32/// Kind of task, for observability.
33#[non_exhaustive]
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TaskKind {
36 /// A long-running background loop (sync loop, heartbeat).
37 BackgroundLoop,
38 /// A per-peer connection handler.
39 PeerConnection,
40 /// A per-request RPC handler.
41 RpcHandler,
42 /// A periodic maintenance job (compaction, log rotation).
43 Maintenance,
44}
45
46/// A clone-safe registry of spawned tasks.
47///
48/// `TaskRegistry: Clone + Send + Sync`. Cloning is cheap (`Arc` bump) and
49/// all clones share the same underlying task list.
50#[derive(Clone)]
51pub struct TaskRegistry {
52 inner: Arc<Inner>,
53}
54
55struct Inner {
56 shutdown: ShutdownToken,
57 entries: Mutex<Vec<Entry>>,
58}
59
60struct Entry {
61 name: &'static str,
62 kind: TaskKind,
63 spawned_at: Instant,
64 handle: JoinHandle<()>,
65}
66
67/// Per-task snapshot surfaced by [`TaskRegistry::snapshot`].
68#[derive(Debug, Clone)]
69pub struct TaskSummary {
70 /// Static name passed to `spawn`.
71 pub name: &'static str,
72 /// When the task was spawned.
73 pub spawned_at: Instant,
74 /// Kind classification.
75 pub kind: TaskKind,
76}
77
78impl TaskRegistry {
79 /// Construct an empty registry linked to the given shutdown token.
80 pub fn new(shutdown: ShutdownToken) -> Self {
81 Self {
82 inner: Arc::new(Inner {
83 shutdown,
84 entries: Mutex::new(Vec::new()),
85 }),
86 }
87 }
88
89 /// Spawn a tracked task.
90 ///
91 /// The closure receives no arguments — if it needs the shutdown token
92 /// or a sub-registry, capture them via move. Use
93 /// [`shutdown`](Self::shutdown) to obtain a clone of the token.
94 ///
95 /// Returns an [`AbortHandle`] so the caller can force-abort the task
96 /// if needed (e.g., a per-peer handler when the peer disconnects). The
97 /// real `JoinHandle` is held by the registry and is awaited in
98 /// [`join_all`](Self::join_all) at shutdown. Callers that want to wait
99 /// for a single task to complete must await by their own means (a
100 /// `tokio::sync::oneshot` or similar inside `fut`) — the registry is
101 /// optimised for the shutdown-join case, not per-task completion.
102 pub fn spawn<F>(&self, name: &'static str, kind: TaskKind, fut: F) -> AbortHandle
103 where
104 F: Future<Output = anyhow::Result<()>> + Send + 'static,
105 {
106 let handle: JoinHandle<()> = tokio::spawn(async move {
107 if let Err(e) = fut.await {
108 tracing::error!(task = name, error = %e, "task exited with error");
109 }
110 });
111 let abort = handle.abort_handle();
112 self.inner.entries.lock().push(Entry {
113 name,
114 kind,
115 spawned_at: Instant::now(),
116 handle,
117 });
118 abort
119 }
120
121 /// Borrow the shutdown token this registry is linked to.
122 pub fn shutdown(&self) -> &ShutdownToken {
123 &self.inner.shutdown
124 }
125
126 /// Await all tracked tasks up to `deadline`; abort laggards.
127 ///
128 /// Returns `Ok(())` if every task exited inside the deadline.
129 /// Returns [`ServiceError::ShutdownDeadlineExceeded`] otherwise, with
130 /// the count of aborted tasks.
131 pub async fn join_all(&self, deadline: Duration) -> Result<()> {
132 // Drain the list; we don't hold the lock across awaits.
133 let entries = {
134 let mut g = self.inner.entries.lock();
135 std::mem::take(&mut *g)
136 };
137
138 let start = Instant::now();
139 let mut pending = 0usize;
140
141 for entry in entries {
142 let remaining = deadline.saturating_sub(start.elapsed());
143 if remaining.is_zero() {
144 entry.handle.abort();
145 pending += 1;
146 continue;
147 }
148 match tokio::time::timeout(remaining, entry.handle).await {
149 Ok(Ok(())) => {}
150 Ok(Err(join_err)) => {
151 // Task panicked or was cancelled.
152 tracing::warn!(
153 task = entry.name,
154 elapsed = ?entry.spawned_at.elapsed(),
155 panic = join_err.is_panic(),
156 "tracked task did not exit cleanly",
157 );
158 }
159 Err(_elapsed) => {
160 // The `.handle` was consumed by `timeout`; it will abort
161 // automatically when dropped.
162 pending += 1;
163 tracing::warn!(
164 task = entry.name,
165 "task exceeded shutdown deadline; aborting",
166 );
167 }
168 }
169 }
170
171 if pending == 0 {
172 Ok(())
173 } else {
174 Err(ServiceError::ShutdownDeadlineExceeded { deadline, pending })
175 }
176 }
177
178 /// Snapshot of currently-tracked tasks (live only; completed tasks
179 /// are not garbage-collected until `join_all`).
180 pub fn snapshot(&self) -> Vec<TaskSummary> {
181 self.inner
182 .entries
183 .lock()
184 .iter()
185 .map(|e| TaskSummary {
186 name: e.name,
187 spawned_at: e.spawned_at,
188 kind: e.kind,
189 })
190 .collect()
191 }
192
193 /// Current live-task count.
194 pub fn len(&self) -> usize {
195 self.inner.entries.lock().len()
196 }
197
198 /// Whether no tasks are currently registered.
199 pub fn is_empty(&self) -> bool {
200 self.inner.entries.lock().is_empty()
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 /// **Proves:** a freshly-constructed registry is empty.
209 ///
210 /// **Why it matters:** Basic invariant — nothing should be "tracked"
211 /// before `spawn` is called.
212 ///
213 /// **Catches:** a default-state regression.
214 #[test]
215 fn empty_on_construction() {
216 let r = TaskRegistry::new(ShutdownToken::new());
217 assert!(r.is_empty());
218 assert_eq!(r.len(), 0);
219 }
220
221 /// **Proves:** `spawn` increments `len` and adds an entry to the
222 /// snapshot.
223 ///
224 /// **Why it matters:** `ServiceHandle::tasks()` forwards to `snapshot`
225 /// so operators can see what's running. An empty snapshot for a
226 /// running service is a red flag.
227 ///
228 /// **Catches:** a regression where `spawn` forgets to push to `entries`.
229 #[tokio::test]
230 async fn spawn_registers() {
231 let r = TaskRegistry::new(ShutdownToken::new());
232 let _h = r.spawn("t1", TaskKind::BackgroundLoop, async { Ok(()) });
233 assert_eq!(r.len(), 1);
234 assert_eq!(r.snapshot()[0].name, "t1");
235 assert_eq!(r.snapshot()[0].kind, TaskKind::BackgroundLoop);
236 }
237
238 /// **Proves:** `join_all` waits for tasks that finish inside the deadline,
239 /// reporting success.
240 ///
241 /// **Why it matters:** The happy path for graceful shutdown.
242 ///
243 /// **Catches:** a regression where `join_all` returns before tasks
244 /// actually complete.
245 #[tokio::test]
246 async fn join_all_awaits_fast_tasks() {
247 let r = TaskRegistry::new(ShutdownToken::new());
248 r.spawn("fast", TaskKind::BackgroundLoop, async {
249 tokio::time::sleep(Duration::from_millis(5)).await;
250 Ok(())
251 });
252 let res = r.join_all(Duration::from_secs(5)).await;
253 assert!(res.is_ok());
254 }
255
256 /// **Proves:** `join_all` aborts tasks that outlast the deadline and
257 /// returns `ShutdownDeadlineExceeded` with the correct count.
258 ///
259 /// **Why it matters:** If a task has a bug (infinite loop, stuck await),
260 /// shutdown must not hang forever. Aborting past the deadline is the
261 /// safety valve, and the error count lets operators see how many tasks
262 /// went wrong.
263 ///
264 /// **Catches:** a regression where `join_all` blocks without abort,
265 /// or where the `pending` count gets off by one.
266 #[tokio::test]
267 async fn join_all_aborts_slow_tasks() {
268 let r = TaskRegistry::new(ShutdownToken::new());
269 r.spawn("slow", TaskKind::BackgroundLoop, async {
270 tokio::time::sleep(Duration::from_secs(60)).await;
271 Ok(())
272 });
273 let res = r.join_all(Duration::from_millis(20)).await;
274 assert!(matches!(
275 res,
276 Err(ServiceError::ShutdownDeadlineExceeded { pending: 1, .. })
277 ));
278 }
279}