Skip to main content

nexus_async_rt/
shutdown.rs

1//! Graceful shutdown support.
2//!
3//! [`ShutdownSignal`] is a future that resolves when a shutdown is
4//! requested — either by a Unix signal (SIGTERM, SIGINT) or by
5//! explicitly calling [`ShutdownHandle::trigger`].
6//!
7//! The Runtime checks the shutdown flag each poll cycle. When set,
8//! the root future can observe it via the `ShutdownSignal` future
9//! and begin connection draining.
10//!
11//! # Usage
12//!
13//! ```ignore
14//! let mut rt = Runtime::new(&mut world);
15//!
16//! // Install signal handlers (call once at startup).
17//! rt.install_signal_handlers();
18//!
19//! rt.block_on(async move {
20//!     spawn_boxed(connection_tasks...);
21//!
22//!     // Wait for SIGTERM/SIGINT.
23//!     nexus_async_rt::shutdown_signal().await;
24//!
25//!     // Drain connections, flush buffers, etc.
26//! });
27//! ```
28
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicBool, Ordering};
33use std::task::{Context, Poll, Waker};
34
35/// Shared shutdown flag.
36#[derive(Clone)]
37pub struct ShutdownHandle {
38    flag: Arc<AtomicBool>,
39    /// Mio waker to break epoll_wait when shutdown is triggered.
40    mio_waker: Option<Arc<mio::Waker>>,
41    /// Task waker slot — the ShutdownSignal future registers here.
42    /// Protected by Mutex because the signal handler thread may
43    /// call wake(). Only contested at shutdown time (once per process).
44    pub(crate) task_waker: Arc<std::sync::Mutex<Option<Waker>>>,
45}
46
47impl ShutdownHandle {
48    pub(crate) fn new() -> Self {
49        Self {
50            flag: Arc::new(AtomicBool::new(false)),
51            mio_waker: None,
52            task_waker: Arc::new(std::sync::Mutex::new(None)),
53        }
54    }
55
56    /// Set the mio waker. Called by Runtime during construction.
57    pub(crate) fn set_mio_waker(&mut self, waker: Arc<mio::Waker>) {
58        self.mio_waker = Some(waker);
59    }
60
61    /// Trigger shutdown programmatically.
62    ///
63    /// Sets the flag, wakes the registered task waker (if any), and
64    /// breaks epoll_wait so the runtime loop re-polls the root future.
65    pub fn trigger(&self) {
66        self.flag.store(true, Ordering::Release);
67        // Wake the task waker first — signal the future directly.
68        if let Ok(mut guard) = self.task_waker.lock() {
69            if let Some(w) = guard.take() {
70                w.wake();
71            }
72        }
73        if let Some(w) = &self.mio_waker {
74            let _ = w.wake();
75        }
76    }
77
78    /// Check if shutdown has been requested.
79    pub fn is_shutdown(&self) -> bool {
80        self.flag.load(Ordering::Acquire)
81    }
82
83    /// Get the underlying flag Arc for signal handler registration.
84    pub(crate) fn flag_ptr(&self) -> Arc<AtomicBool> {
85        Arc::clone(&self.flag)
86    }
87
88    /// Returns a future that completes when shutdown is triggered.
89    pub fn signal(&self) -> ShutdownSignal {
90        ShutdownSignal {
91            flag: Arc::as_ptr(&self.flag),
92            task_waker: self.task_waker.clone(),
93        }
94    }
95}
96
97/// Future that resolves when shutdown is triggered.
98///
99/// Registers (and updates) a waker on every poll so that
100/// `ShutdownHandle::trigger()` (or a signal handler) can wake the
101/// awaiting task directly. The waker is overwritten on each poll to
102/// handle the case where the future is re-polled from a different
103/// task context.
104///
105/// **Single waiter only.** Only one task may await `ShutdownSignal` at a
106/// time. If a second task polls while a waker is already registered, the
107/// waker is replaced (not duplicated). For multi-waiter shutdown, use
108/// [`CancellationToken`](crate::CancellationToken) instead.
109///
110/// Holds a raw pointer to the AtomicBool flag, valid for the lifetime
111/// of the Runtime (which outlives `block_on` which outlives all tasks).
112pub struct ShutdownSignal {
113    pub(crate) flag: *const AtomicBool,
114    pub(crate) task_waker: Arc<std::sync::Mutex<Option<Waker>>>,
115}
116
117impl Future for ShutdownSignal {
118    type Output = ();
119
120    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
121        // SAFETY: flag points to the AtomicBool inside the Runtime's
122        // ShutdownHandle (Arc-allocated, stable address). Valid for
123        // Runtime lifetime.
124        if unsafe { &*self.flag }.load(Ordering::Acquire) {
125            return Poll::Ready(());
126        }
127
128        // Register (or update) the waker so trigger() can wake us.
129        // Always update — the waker may have changed if the future was
130        // re-polled from a different task context.
131        if let Ok(mut guard) = self.task_waker.lock() {
132            *guard = Some(cx.waker().clone());
133        }
134
135        // Double-check after registration (lost wakeup prevention).
136        if unsafe { &*self.flag }.load(Ordering::Acquire) {
137            Poll::Ready(())
138        } else {
139            Poll::Pending
140        }
141    }
142}
143
144/// Install signal handlers for SIGTERM and SIGINT that trigger shutdown.
145///
146/// Uses `signal-hook` for safe, portable signal registration. The
147/// handler atomically sets the flag. The mio waker breaks epoll_wait
148/// so the runtime notices the flag promptly.
149pub fn install_signal_handlers(flag: &Arc<AtomicBool>, mio_waker: &Arc<mio::Waker>) {
150    let waker_ref = Arc::clone(mio_waker);
151
152    // signal-hook provides safe registration with proper cleanup.
153    // The closure runs in signal context — only async-signal-safe
154    // operations (atomic store + eventfd write).
155    signal_hook::flag::register(signal_hook::consts::SIGTERM, Arc::clone(flag))
156        .expect("failed to register SIGTERM handler");
157    signal_hook::flag::register(signal_hook::consts::SIGINT, Arc::clone(flag))
158        .expect("failed to register SIGINT handler");
159
160    // signal-hook::flag::register sets the AtomicBool on signal, but
161    // we also need to break epoll_wait. Register a second handler that
162    // fires the mio waker.
163    unsafe {
164        signal_hook::low_level::register(signal_hook::consts::SIGTERM, move || {
165            let _ = waker_ref.wake();
166        })
167        .expect("failed to register SIGTERM waker");
168    }
169    let waker_ref2 = Arc::clone(mio_waker);
170    unsafe {
171        signal_hook::low_level::register(signal_hook::consts::SIGINT, move || {
172            let _ = waker_ref2.wake();
173        })
174        .expect("failed to register SIGINT waker");
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn shutdown_handle_trigger() {
184        let handle = ShutdownHandle::new();
185        assert!(!handle.is_shutdown());
186        handle.trigger();
187        assert!(handle.is_shutdown());
188    }
189
190    #[test]
191    fn shutdown_signal_resolves_after_trigger() {
192        use crate::{Runtime, spawn_boxed};
193        use nexus_rt::WorldBuilder;
194        use std::cell::Cell;
195        use std::rc::Rc;
196
197        let wb = WorldBuilder::new();
198        let mut world = wb.build();
199        let mut rt = Runtime::new(&mut world);
200        let shutdown = rt.shutdown_handle();
201
202        let done = Rc::new(Cell::new(false));
203        let flag = done.clone();
204
205        // Trigger shutdown from a spawned task after a short delay.
206        let sh = shutdown.clone();
207        rt.block_on(async move {
208            spawn_boxed(async move {
209                crate::context::sleep(std::time::Duration::from_millis(50)).await;
210                sh.trigger();
211            });
212
213            // Root future waits for shutdown.
214            shutdown.signal().await;
215            flag.set(true);
216        });
217
218        assert!(done.get());
219    }
220
221    #[test]
222    fn shutdown_signal_waker_updates_on_repoll() {
223        // Verify the waker is updated on each poll (not stale from first poll).
224        use std::task::{RawWaker, RawWakerVTable, Waker};
225
226        let handle = ShutdownHandle::new();
227        let mut signal = Box::pin(handle.signal());
228
229        // First poll with noop waker — registers it.
230        let noop = unsafe {
231            static V: RawWakerVTable =
232                RawWakerVTable::new(|p| RawWaker::new(p, &V), |_| {}, |_| {}, |_| {});
233            Waker::from_raw(RawWaker::new(std::ptr::null(), &V))
234        };
235        let mut cx = Context::from_waker(&noop);
236        assert_eq!(signal.as_mut().poll(&mut cx), Poll::Pending);
237
238        // Second poll with a tracking waker — should overwrite.
239        let woke = std::cell::Cell::new(false);
240        let flag_ptr = &woke as *const std::cell::Cell<bool> as *const ();
241        let tracking = unsafe {
242            static V2: RawWakerVTable = RawWakerVTable::new(
243                |p| RawWaker::new(p, &V2),
244                |p| unsafe { (*(p as *const std::cell::Cell<bool>)).set(true) },
245                |p| unsafe { (*(p as *const std::cell::Cell<bool>)).set(true) },
246                |_| {},
247            );
248            Waker::from_raw(RawWaker::new(flag_ptr, &V2))
249        };
250        let mut cx2 = Context::from_waker(&tracking);
251        assert_eq!(signal.as_mut().poll(&mut cx2), Poll::Pending);
252
253        // Trigger shutdown — must wake the tracking waker, not the noop.
254        handle.trigger();
255        assert!(woke.get(), "latest waker must fire on trigger");
256    }
257
258    #[test]
259    fn shutdown_signal_already_triggered() {
260        // Trigger before first poll — immediate Ready, no waker registration.
261        use std::task::{RawWaker, RawWakerVTable, Waker};
262
263        let handle = ShutdownHandle::new();
264        handle.trigger();
265
266        let mut signal = Box::pin(handle.signal());
267        let waker = unsafe {
268            static V: RawWakerVTable =
269                RawWakerVTable::new(|p| RawWaker::new(p, &V), |_| {}, |_| {}, |_| {});
270            Waker::from_raw(RawWaker::new(std::ptr::null(), &V))
271        };
272        let mut cx = Context::from_waker(&waker);
273        assert_eq!(signal.as_mut().poll(&mut cx), Poll::Ready(()));
274    }
275}