nexus_async_rt/shutdown.rs
1//! Graceful shutdown support.
2//!
3//! [`ShutdownSignal`] is a future that resolves when a shutdown is
4//! requested — either by a Unix signal (SIGTERM, SIGINT) or by
5//! explicitly calling [`ShutdownHandle::trigger`].
6//!
7//! The Runtime checks the shutdown flag each poll cycle. When set,
8//! the root future can observe it via the `ShutdownSignal` future
9//! and begin connection draining.
10//!
11//! # Usage
12//!
13//! ```ignore
14//! let mut rt = Runtime::new(&mut world);
15//!
16//! // Install signal handlers (call once at startup).
17//! rt.install_signal_handlers();
18//!
19//! rt.block_on(async move {
20//! spawn_boxed(connection_tasks...);
21//!
22//! // Wait for SIGTERM/SIGINT.
23//! nexus_async_rt::ShutdownSignal::current().await;
24//!
25//! // Drain connections, flush buffers, etc.
26//! });
27//! ```
28
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicBool, Ordering};
33use std::task::{Context, Poll, Waker};
34
35/// Shared shutdown flag.
36#[derive(Clone)]
37pub struct ShutdownHandle {
38 flag: Arc<AtomicBool>,
39 /// Mio waker to break epoll_wait when shutdown is triggered.
40 mio_waker: Option<Arc<mio::Waker>>,
41 /// Task waker slot — the ShutdownSignal future registers here.
42 /// Protected by Mutex because the signal handler thread may
43 /// call wake(). Only contested at shutdown time (once per process).
44 pub(crate) task_waker: Arc<std::sync::Mutex<Option<Waker>>>,
45}
46
47impl ShutdownHandle {
48 pub(crate) fn new() -> Self {
49 Self {
50 flag: Arc::new(AtomicBool::new(false)),
51 mio_waker: None,
52 task_waker: Arc::new(std::sync::Mutex::new(None)),
53 }
54 }
55
56 /// Set the mio waker. Called by Runtime during construction.
57 pub(crate) fn set_mio_waker(&mut self, waker: Arc<mio::Waker>) {
58 self.mio_waker = Some(waker);
59 }
60
61 /// Trigger shutdown programmatically.
62 ///
63 /// Sets the flag, wakes the registered task waker (if any), and
64 /// breaks epoll_wait so the runtime loop re-polls the root future.
65 pub fn trigger(&self) {
66 self.flag.store(true, Ordering::Release);
67 // Wake the task waker first — signal the future directly.
68 if let Ok(mut guard) = self.task_waker.lock() {
69 if let Some(w) = guard.take() {
70 w.wake();
71 }
72 }
73 if let Some(w) = &self.mio_waker {
74 let _ = w.wake();
75 }
76 }
77
78 /// Check if shutdown has been requested.
79 pub fn is_shutdown(&self) -> bool {
80 self.flag.load(Ordering::Acquire)
81 }
82
83 /// Get the underlying flag Arc for signal handler registration.
84 pub(crate) fn flag_ptr(&self) -> Arc<AtomicBool> {
85 Arc::clone(&self.flag)
86 }
87
88 /// Returns a future that completes when shutdown is triggered.
89 pub fn signal(&self) -> ShutdownSignal {
90 ShutdownSignal {
91 flag: Arc::as_ptr(&self.flag),
92 task_waker: self.task_waker.clone(),
93 }
94 }
95}
96
97/// Future that resolves when shutdown is triggered.
98///
99/// Registers (and updates) a waker on every poll so that
100/// `ShutdownHandle::trigger()` (or a signal handler) can wake the
101/// awaiting task directly. The waker is overwritten on each poll to
102/// handle the case where the future is re-polled from a different
103/// task context.
104///
105/// **Single waiter only.** Only one task may await `ShutdownSignal` at a
106/// time. If a second task polls while a waker is already registered, the
107/// waker is replaced (not duplicated). For multi-waiter shutdown, use
108/// [`CancellationToken`](crate::CancellationToken) instead.
109///
110/// Holds a raw pointer to the AtomicBool flag, valid for the lifetime
111/// of the Runtime (which outlives `block_on` which outlives all tasks).
112pub struct ShutdownSignal {
113 pub(crate) flag: *const AtomicBool,
114 pub(crate) task_waker: Arc<std::sync::Mutex<Option<Waker>>>,
115}
116
117impl ShutdownSignal {
118 /// Returns a [`ShutdownSignal`] future for the currently running runtime.
119 ///
120 /// The returned future resolves when shutdown is triggered — either by
121 /// a Unix signal handler installed via
122 /// [`Runtime::install_signal_handlers`](crate::Runtime::install_signal_handlers)
123 /// (SIGTERM / SIGINT) or by an explicit
124 /// [`ShutdownHandle::trigger`] call. Mirrors
125 /// `tokio::runtime::Handle::current()`. Read as
126 /// `ShutdownSignal::current().await` — "await the current shutdown
127 /// signal".
128 ///
129 /// **Single waiter only** — see the type-level docs. For multi-waiter
130 /// patterns, use [`CancellationToken`](crate::CancellationToken).
131 ///
132 /// # Panics
133 ///
134 /// Panics if called outside a [`Runtime::block_on`](crate::Runtime::block_on)
135 /// context.
136 #[must_use]
137 pub fn current() -> ShutdownSignal {
138 let (flag, waker_ptr) = crate::context::current_shutdown_ptrs();
139 assert!(
140 !flag.is_null(),
141 "ShutdownSignal::current() called outside Runtime::block_on"
142 );
143 // Defense-in-depth: flag and waker_ptr are written together by
144 // install(), so this should be unreachable — but a future refactor
145 // that splits the install path would make a null waker_ptr deref UB.
146 // Catch it at the call site instead.
147 assert!(
148 !waker_ptr.is_null(),
149 "ShutdownSignal::current(): waker_ptr null while flag non-null (runtime install bug)"
150 );
151 // SAFETY: install() writes flag and waker pointers together; both
152 // verified non-null above; pointers are valid for Runtime lifetime.
153 let task_waker = unsafe { (*waker_ptr).clone() };
154 ShutdownSignal { flag, task_waker }
155 }
156}
157
158impl Future for ShutdownSignal {
159 type Output = ();
160
161 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
162 // SAFETY: flag points to the AtomicBool inside the Runtime's
163 // ShutdownHandle (Arc-allocated, stable address). Valid for
164 // Runtime lifetime.
165 if unsafe { &*self.flag }.load(Ordering::Acquire) {
166 return Poll::Ready(());
167 }
168
169 // Register (or update) the waker so trigger() can wake us.
170 // Always update — the waker may have changed if the future was
171 // re-polled from a different task context.
172 if let Ok(mut guard) = self.task_waker.lock() {
173 *guard = Some(cx.waker().clone());
174 }
175
176 // Double-check after registration (lost wakeup prevention).
177 if unsafe { &*self.flag }.load(Ordering::Acquire) {
178 Poll::Ready(())
179 } else {
180 Poll::Pending
181 }
182 }
183}
184
185/// Install signal handlers for SIGTERM and SIGINT that trigger shutdown.
186///
187/// Uses `signal-hook` for safe, portable signal registration. The
188/// handler atomically sets the flag. The mio waker breaks epoll_wait
189/// so the runtime notices the flag promptly.
190pub fn install_signal_handlers(flag: &Arc<AtomicBool>, mio_waker: &Arc<mio::Waker>) {
191 let waker_ref = Arc::clone(mio_waker);
192
193 // signal-hook provides safe registration with proper cleanup.
194 // The closure runs in signal context — only async-signal-safe
195 // operations (atomic store + eventfd write).
196 signal_hook::flag::register(signal_hook::consts::SIGTERM, Arc::clone(flag))
197 .expect("failed to register SIGTERM handler");
198 signal_hook::flag::register(signal_hook::consts::SIGINT, Arc::clone(flag))
199 .expect("failed to register SIGINT handler");
200
201 // signal-hook::flag::register sets the AtomicBool on signal, but
202 // we also need to break epoll_wait. Register a second handler that
203 // fires the mio waker.
204 unsafe {
205 signal_hook::low_level::register(signal_hook::consts::SIGTERM, move || {
206 let _ = waker_ref.wake();
207 })
208 .expect("failed to register SIGTERM waker");
209 }
210 let waker_ref2 = Arc::clone(mio_waker);
211 unsafe {
212 signal_hook::low_level::register(signal_hook::consts::SIGINT, move || {
213 let _ = waker_ref2.wake();
214 })
215 .expect("failed to register SIGINT waker");
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn shutdown_handle_trigger() {
225 let handle = ShutdownHandle::new();
226 assert!(!handle.is_shutdown());
227 handle.trigger();
228 assert!(handle.is_shutdown());
229 }
230
231 #[test]
232 fn shutdown_signal_resolves_after_trigger() {
233 use crate::{Runtime, spawn_boxed};
234 use nexus_rt::WorldBuilder;
235 use std::cell::Cell;
236 use std::rc::Rc;
237
238 let wb = WorldBuilder::new();
239 let mut world = wb.build();
240 let mut rt = Runtime::new(&mut world);
241 let shutdown = rt.shutdown_handle();
242
243 let done = Rc::new(Cell::new(false));
244 let flag = done.clone();
245
246 // Trigger shutdown from a spawned task after a short delay.
247 let sh = shutdown.clone();
248 rt.block_on(async move {
249 spawn_boxed(async move {
250 crate::context::sleep(std::time::Duration::from_millis(50)).await;
251 sh.trigger();
252 });
253
254 // Root future waits for shutdown.
255 shutdown.signal().await;
256 flag.set(true);
257 });
258
259 assert!(done.get());
260 }
261
262 #[test]
263 #[should_panic(expected = "called outside Runtime::block_on")]
264 fn shutdown_signal_current_panics_outside_runtime() {
265 // Pins the documented panic contract for
266 // `ShutdownSignal::current()`. Symmetric to
267 // `IoHandle::current_panics_outside_runtime` and
268 // `WorldCtx::current_panics_outside_runtime`.
269 let _ = ShutdownSignal::current();
270 }
271
272 #[test]
273 fn shutdown_signal_current_resolves_after_trigger() {
274 // Sister test to `shutdown_signal_resolves_after_trigger`, but
275 // exercises the TLS-fetcher path (`ShutdownSignal::current()`)
276 // instead of `handle.signal()`. Catches regressions in the
277 // CTX_SHUTDOWN / CTX_SHUTDOWN_WAKER install/uninstall wiring.
278 use crate::{Runtime, spawn_boxed};
279 use nexus_rt::WorldBuilder;
280 use std::cell::Cell;
281 use std::rc::Rc;
282
283 let wb = WorldBuilder::new();
284 let mut world = wb.build();
285 let mut rt = Runtime::new(&mut world);
286 let shutdown = rt.shutdown_handle();
287
288 let done = Rc::new(Cell::new(false));
289 let flag = done.clone();
290
291 let sh = shutdown.clone();
292 rt.block_on(async move {
293 spawn_boxed(async move {
294 crate::context::sleep(std::time::Duration::from_millis(50)).await;
295 sh.trigger();
296 });
297
298 // Fetch the signal via the TLS-based current() rather than
299 // handle.signal() — this is the path users will hit.
300 ShutdownSignal::current().await;
301 flag.set(true);
302 });
303
304 assert!(done.get());
305 }
306
307 #[test]
308 fn shutdown_signal_waker_updates_on_repoll() {
309 // Verify the waker is updated on each poll (not stale from first poll).
310 use std::task::{RawWaker, RawWakerVTable, Waker};
311
312 let handle = ShutdownHandle::new();
313 let mut signal = Box::pin(handle.signal());
314
315 // First poll with noop waker — registers it.
316 let noop = unsafe {
317 static V: RawWakerVTable =
318 RawWakerVTable::new(|p| RawWaker::new(p, &V), |_| {}, |_| {}, |_| {});
319 Waker::from_raw(RawWaker::new(std::ptr::null(), &V))
320 };
321 let mut cx = Context::from_waker(&noop);
322 assert_eq!(signal.as_mut().poll(&mut cx), Poll::Pending);
323
324 // Second poll with a tracking waker — should overwrite.
325 let woke = std::cell::Cell::new(false);
326 let flag_ptr = &woke as *const std::cell::Cell<bool> as *const ();
327 let tracking = unsafe {
328 static V2: RawWakerVTable = RawWakerVTable::new(
329 |p| RawWaker::new(p, &V2),
330 |p| unsafe { (*(p as *const std::cell::Cell<bool>)).set(true) },
331 |p| unsafe { (*(p as *const std::cell::Cell<bool>)).set(true) },
332 |_| {},
333 );
334 Waker::from_raw(RawWaker::new(flag_ptr, &V2))
335 };
336 let mut cx2 = Context::from_waker(&tracking);
337 assert_eq!(signal.as_mut().poll(&mut cx2), Poll::Pending);
338
339 // Trigger shutdown — must wake the tracking waker, not the noop.
340 handle.trigger();
341 assert!(woke.get(), "latest waker must fire on trigger");
342 }
343
344 #[test]
345 fn shutdown_signal_already_triggered() {
346 // Trigger before first poll — immediate Ready, no waker registration.
347 use std::task::{RawWaker, RawWakerVTable, Waker};
348
349 let handle = ShutdownHandle::new();
350 handle.trigger();
351
352 let mut signal = Box::pin(handle.signal());
353 let waker = unsafe {
354 static V: RawWakerVTable =
355 RawWakerVTable::new(|p| RawWaker::new(p, &V), |_| {}, |_| {}, |_| {});
356 Waker::from_raw(RawWaker::new(std::ptr::null(), &V))
357 };
358 let mut cx = Context::from_waker(&waker);
359 assert_eq!(signal.as_mut().poll(&mut cx), Poll::Ready(()));
360 }
361}