Skip to main content

nexus_async_rt/
shutdown.rs

1//! Graceful shutdown support.
2//!
3//! [`ShutdownSignal`] is a future that resolves when a shutdown is
4//! requested — either by a Unix signal (SIGTERM, SIGINT) or by
5//! explicitly calling [`ShutdownHandle::trigger`].
6//!
7//! The Runtime checks the shutdown flag each poll cycle. When set,
8//! the root future can observe it via the `ShutdownSignal` future
9//! and begin connection draining.
10//!
11//! # Usage
12//!
13//! ```ignore
14//! let mut rt = Runtime::new(&mut world);
15//!
16//! // Install signal handlers (call once at startup).
17//! rt.install_signal_handlers();
18//!
19//! rt.block_on(async move {
20//!     spawn_boxed(connection_tasks...);
21//!
22//!     // Wait for SIGTERM/SIGINT.
23//!     nexus_async_rt::ShutdownSignal::current().await;
24//!
25//!     // Drain connections, flush buffers, etc.
26//! });
27//! ```
28
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicBool, Ordering};
33use std::task::{Context, Poll, Waker};
34
35/// Shared shutdown flag.
36#[derive(Clone)]
37pub struct ShutdownHandle {
38    flag: Arc<AtomicBool>,
39    /// Mio waker to break epoll_wait when shutdown is triggered.
40    mio_waker: Option<Arc<mio::Waker>>,
41    /// Task waker slot — the ShutdownSignal future registers here.
42    /// Protected by Mutex because the signal handler thread may
43    /// call wake(). Only contested at shutdown time (once per process).
44    pub(crate) task_waker: Arc<std::sync::Mutex<Option<Waker>>>,
45}
46
47impl ShutdownHandle {
48    pub(crate) fn new() -> Self {
49        Self {
50            flag: Arc::new(AtomicBool::new(false)),
51            mio_waker: None,
52            task_waker: Arc::new(std::sync::Mutex::new(None)),
53        }
54    }
55
56    /// Set the mio waker. Called by Runtime during construction.
57    pub(crate) fn set_mio_waker(&mut self, waker: Arc<mio::Waker>) {
58        self.mio_waker = Some(waker);
59    }
60
61    /// Trigger shutdown programmatically.
62    ///
63    /// Sets the flag, wakes the registered task waker (if any), and
64    /// breaks epoll_wait so the runtime loop re-polls the root future.
65    pub fn trigger(&self) {
66        self.flag.store(true, Ordering::Release);
67        // Wake the task waker first — signal the future directly.
68        if let Ok(mut guard) = self.task_waker.lock() {
69            if let Some(w) = guard.take() {
70                w.wake();
71            }
72        }
73        if let Some(w) = &self.mio_waker {
74            let _ = w.wake();
75        }
76    }
77
78    /// Check if shutdown has been requested.
79    pub fn is_shutdown(&self) -> bool {
80        self.flag.load(Ordering::Acquire)
81    }
82
83    /// Get the underlying flag Arc for signal handler registration.
84    pub(crate) fn flag_ptr(&self) -> Arc<AtomicBool> {
85        Arc::clone(&self.flag)
86    }
87
88    /// Returns a future that completes when shutdown is triggered.
89    pub fn signal(&self) -> ShutdownSignal {
90        ShutdownSignal {
91            flag: Arc::as_ptr(&self.flag),
92            task_waker: self.task_waker.clone(),
93        }
94    }
95}
96
97/// Future that resolves when shutdown is triggered.
98///
99/// Registers (and updates) a waker on every poll so that
100/// `ShutdownHandle::trigger()` (or a signal handler) can wake the
101/// awaiting task directly. The waker is overwritten on each poll to
102/// handle the case where the future is re-polled from a different
103/// task context.
104///
105/// **Single waiter only.** Only one task may await `ShutdownSignal` at a
106/// time. If a second task polls while a waker is already registered, the
107/// waker is replaced (not duplicated). For multi-waiter shutdown, use
108/// [`CancellationToken`](crate::CancellationToken) instead.
109///
110/// Holds a raw pointer to the AtomicBool flag, valid for the lifetime
111/// of the Runtime (which outlives `block_on` which outlives all tasks).
112pub struct ShutdownSignal {
113    pub(crate) flag: *const AtomicBool,
114    pub(crate) task_waker: Arc<std::sync::Mutex<Option<Waker>>>,
115}
116
117impl ShutdownSignal {
118    /// Returns a [`ShutdownSignal`] future for the currently running runtime.
119    ///
120    /// The returned future resolves when shutdown is triggered — either by
121    /// a Unix signal handler installed via
122    /// [`Runtime::install_signal_handlers`](crate::Runtime::install_signal_handlers)
123    /// (SIGTERM / SIGINT) or by an explicit
124    /// [`ShutdownHandle::trigger`] call. Mirrors
125    /// `tokio::runtime::Handle::current()`. Read as
126    /// `ShutdownSignal::current().await` — "await the current shutdown
127    /// signal".
128    ///
129    /// **Single waiter only** — see the type-level docs. For multi-waiter
130    /// patterns, use [`CancellationToken`](crate::CancellationToken).
131    ///
132    /// # Panics
133    ///
134    /// Panics if called outside a [`Runtime::block_on`](crate::Runtime::block_on)
135    /// context.
136    #[must_use]
137    pub fn current() -> ShutdownSignal {
138        let (flag, waker_ptr) = crate::context::current_shutdown_ptrs();
139        assert!(
140            !flag.is_null(),
141            "ShutdownSignal::current() called outside Runtime::block_on"
142        );
143        // Defense-in-depth: flag and waker_ptr are written together by
144        // install(), so this should be unreachable — but a future refactor
145        // that splits the install path would make a null waker_ptr deref UB.
146        // Catch it at the call site instead.
147        assert!(
148            !waker_ptr.is_null(),
149            "ShutdownSignal::current(): waker_ptr null while flag non-null (runtime install bug)"
150        );
151        // SAFETY: install() writes flag and waker pointers together; both
152        // verified non-null above; pointers are valid for Runtime lifetime.
153        let task_waker = unsafe { (*waker_ptr).clone() };
154        ShutdownSignal { flag, task_waker }
155    }
156}
157
158impl Future for ShutdownSignal {
159    type Output = ();
160
161    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
162        // SAFETY: flag points to the AtomicBool inside the Runtime's
163        // ShutdownHandle (Arc-allocated, stable address). Valid for
164        // Runtime lifetime.
165        if unsafe { &*self.flag }.load(Ordering::Acquire) {
166            return Poll::Ready(());
167        }
168
169        // Register (or update) the waker so trigger() can wake us.
170        // Always update — the waker may have changed if the future was
171        // re-polled from a different task context.
172        if let Ok(mut guard) = self.task_waker.lock() {
173            *guard = Some(cx.waker().clone());
174        }
175
176        // Double-check after registration (lost wakeup prevention).
177        if unsafe { &*self.flag }.load(Ordering::Acquire) {
178            Poll::Ready(())
179        } else {
180            Poll::Pending
181        }
182    }
183}
184
185/// Install signal handlers for SIGTERM and SIGINT that trigger shutdown.
186///
187/// Uses `signal-hook` for safe, portable signal registration. The
188/// handler atomically sets the flag. The mio waker breaks epoll_wait
189/// so the runtime notices the flag promptly.
190pub fn install_signal_handlers(flag: &Arc<AtomicBool>, mio_waker: &Arc<mio::Waker>) {
191    let waker_ref = Arc::clone(mio_waker);
192
193    // signal-hook provides safe registration with proper cleanup.
194    // The closure runs in signal context — only async-signal-safe
195    // operations (atomic store + eventfd write).
196    signal_hook::flag::register(signal_hook::consts::SIGTERM, Arc::clone(flag))
197        .expect("failed to register SIGTERM handler");
198    signal_hook::flag::register(signal_hook::consts::SIGINT, Arc::clone(flag))
199        .expect("failed to register SIGINT handler");
200
201    // signal-hook::flag::register sets the AtomicBool on signal, but
202    // we also need to break epoll_wait. Register a second handler that
203    // fires the mio waker.
204    unsafe {
205        signal_hook::low_level::register(signal_hook::consts::SIGTERM, move || {
206            let _ = waker_ref.wake();
207        })
208        .expect("failed to register SIGTERM waker");
209    }
210    let waker_ref2 = Arc::clone(mio_waker);
211    unsafe {
212        signal_hook::low_level::register(signal_hook::consts::SIGINT, move || {
213            let _ = waker_ref2.wake();
214        })
215        .expect("failed to register SIGINT waker");
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn shutdown_handle_trigger() {
225        let handle = ShutdownHandle::new();
226        assert!(!handle.is_shutdown());
227        handle.trigger();
228        assert!(handle.is_shutdown());
229    }
230
231    #[test]
232    fn shutdown_signal_resolves_after_trigger() {
233        use crate::{Runtime, spawn_boxed};
234        use nexus_rt::WorldBuilder;
235        use std::cell::Cell;
236        use std::rc::Rc;
237
238        let wb = WorldBuilder::new();
239        let mut world = wb.build();
240        let mut rt = Runtime::new(&mut world);
241        let shutdown = rt.shutdown_handle();
242
243        let done = Rc::new(Cell::new(false));
244        let flag = done.clone();
245
246        // Trigger shutdown from a spawned task after a short delay.
247        let sh = shutdown.clone();
248        rt.block_on(async move {
249            spawn_boxed(async move {
250                crate::context::sleep(std::time::Duration::from_millis(50)).await;
251                sh.trigger();
252            });
253
254            // Root future waits for shutdown.
255            shutdown.signal().await;
256            flag.set(true);
257        });
258
259        assert!(done.get());
260    }
261
262    #[test]
263    #[should_panic(expected = "called outside Runtime::block_on")]
264    fn shutdown_signal_current_panics_outside_runtime() {
265        // Pins the documented panic contract for
266        // `ShutdownSignal::current()`. Symmetric to
267        // `IoHandle::current_panics_outside_runtime` and
268        // `WorldCtx::current_panics_outside_runtime`.
269        let _ = ShutdownSignal::current();
270    }
271
272    #[test]
273    fn shutdown_signal_current_resolves_after_trigger() {
274        // Sister test to `shutdown_signal_resolves_after_trigger`, but
275        // exercises the TLS-fetcher path (`ShutdownSignal::current()`)
276        // instead of `handle.signal()`. Catches regressions in the
277        // CTX_SHUTDOWN / CTX_SHUTDOWN_WAKER install/uninstall wiring.
278        use crate::{Runtime, spawn_boxed};
279        use nexus_rt::WorldBuilder;
280        use std::cell::Cell;
281        use std::rc::Rc;
282
283        let wb = WorldBuilder::new();
284        let mut world = wb.build();
285        let mut rt = Runtime::new(&mut world);
286        let shutdown = rt.shutdown_handle();
287
288        let done = Rc::new(Cell::new(false));
289        let flag = done.clone();
290
291        let sh = shutdown.clone();
292        rt.block_on(async move {
293            spawn_boxed(async move {
294                crate::context::sleep(std::time::Duration::from_millis(50)).await;
295                sh.trigger();
296            });
297
298            // Fetch the signal via the TLS-based current() rather than
299            // handle.signal() — this is the path users will hit.
300            ShutdownSignal::current().await;
301            flag.set(true);
302        });
303
304        assert!(done.get());
305    }
306
307    #[test]
308    fn shutdown_signal_waker_updates_on_repoll() {
309        // Verify the waker is updated on each poll (not stale from first poll).
310        use std::task::{RawWaker, RawWakerVTable, Waker};
311
312        let handle = ShutdownHandle::new();
313        let mut signal = Box::pin(handle.signal());
314
315        // First poll with noop waker — registers it.
316        let noop = unsafe {
317            static V: RawWakerVTable =
318                RawWakerVTable::new(|p| RawWaker::new(p, &V), |_| {}, |_| {}, |_| {});
319            Waker::from_raw(RawWaker::new(std::ptr::null(), &V))
320        };
321        let mut cx = Context::from_waker(&noop);
322        assert_eq!(signal.as_mut().poll(&mut cx), Poll::Pending);
323
324        // Second poll with a tracking waker — should overwrite.
325        let woke = std::cell::Cell::new(false);
326        let flag_ptr = &woke as *const std::cell::Cell<bool> as *const ();
327        let tracking = unsafe {
328            static V2: RawWakerVTable = RawWakerVTable::new(
329                |p| RawWaker::new(p, &V2),
330                |p| unsafe { (*(p as *const std::cell::Cell<bool>)).set(true) },
331                |p| unsafe { (*(p as *const std::cell::Cell<bool>)).set(true) },
332                |_| {},
333            );
334            Waker::from_raw(RawWaker::new(flag_ptr, &V2))
335        };
336        let mut cx2 = Context::from_waker(&tracking);
337        assert_eq!(signal.as_mut().poll(&mut cx2), Poll::Pending);
338
339        // Trigger shutdown — must wake the tracking waker, not the noop.
340        handle.trigger();
341        assert!(woke.get(), "latest waker must fire on trigger");
342    }
343
344    #[test]
345    fn shutdown_signal_already_triggered() {
346        // Trigger before first poll — immediate Ready, no waker registration.
347        use std::task::{RawWaker, RawWakerVTable, Waker};
348
349        let handle = ShutdownHandle::new();
350        handle.trigger();
351
352        let mut signal = Box::pin(handle.signal());
353        let waker = unsafe {
354            static V: RawWakerVTable =
355                RawWakerVTable::new(|p| RawWaker::new(p, &V), |_| {}, |_| {}, |_| {});
356            Waker::from_raw(RawWaker::new(std::ptr::null(), &V))
357        };
358        let mut cx = Context::from_waker(&waker);
359        assert_eq!(signal.as_mut().poll(&mut cx), Poll::Ready(()));
360    }
361}