nexus_async_rt/
shutdown.rs1use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicBool, Ordering};
33use std::task::{Context, Poll, Waker};
34
35#[derive(Clone)]
37pub struct ShutdownHandle {
38 flag: Arc<AtomicBool>,
39 mio_waker: Option<Arc<mio::Waker>>,
41 pub(crate) task_waker: Arc<std::sync::Mutex<Option<Waker>>>,
45}
46
47impl ShutdownHandle {
48 pub(crate) fn new() -> Self {
49 Self {
50 flag: Arc::new(AtomicBool::new(false)),
51 mio_waker: None,
52 task_waker: Arc::new(std::sync::Mutex::new(None)),
53 }
54 }
55
56 pub(crate) fn set_mio_waker(&mut self, waker: Arc<mio::Waker>) {
58 self.mio_waker = Some(waker);
59 }
60
61 pub fn trigger(&self) {
66 self.flag.store(true, Ordering::Release);
67 if let Ok(mut guard) = self.task_waker.lock() {
69 if let Some(w) = guard.take() {
70 w.wake();
71 }
72 }
73 if let Some(w) = &self.mio_waker {
74 let _ = w.wake();
75 }
76 }
77
78 pub fn is_shutdown(&self) -> bool {
80 self.flag.load(Ordering::Acquire)
81 }
82
83 pub(crate) fn flag_ptr(&self) -> Arc<AtomicBool> {
85 Arc::clone(&self.flag)
86 }
87
88 pub fn signal(&self) -> ShutdownSignal {
90 ShutdownSignal {
91 flag: Arc::as_ptr(&self.flag),
92 task_waker: self.task_waker.clone(),
93 }
94 }
95}
96
97pub struct ShutdownSignal {
113 pub(crate) flag: *const AtomicBool,
114 pub(crate) task_waker: Arc<std::sync::Mutex<Option<Waker>>>,
115}
116
117impl Future for ShutdownSignal {
118 type Output = ();
119
120 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
121 if unsafe { &*self.flag }.load(Ordering::Acquire) {
125 return Poll::Ready(());
126 }
127
128 if let Ok(mut guard) = self.task_waker.lock() {
132 *guard = Some(cx.waker().clone());
133 }
134
135 if unsafe { &*self.flag }.load(Ordering::Acquire) {
137 Poll::Ready(())
138 } else {
139 Poll::Pending
140 }
141 }
142}
143
144pub fn install_signal_handlers(flag: &Arc<AtomicBool>, mio_waker: &Arc<mio::Waker>) {
150 let waker_ref = Arc::clone(mio_waker);
151
152 signal_hook::flag::register(signal_hook::consts::SIGTERM, Arc::clone(flag))
156 .expect("failed to register SIGTERM handler");
157 signal_hook::flag::register(signal_hook::consts::SIGINT, Arc::clone(flag))
158 .expect("failed to register SIGINT handler");
159
160 unsafe {
164 signal_hook::low_level::register(signal_hook::consts::SIGTERM, move || {
165 let _ = waker_ref.wake();
166 })
167 .expect("failed to register SIGTERM waker");
168 }
169 let waker_ref2 = Arc::clone(mio_waker);
170 unsafe {
171 signal_hook::low_level::register(signal_hook::consts::SIGINT, move || {
172 let _ = waker_ref2.wake();
173 })
174 .expect("failed to register SIGINT waker");
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn shutdown_handle_trigger() {
184 let handle = ShutdownHandle::new();
185 assert!(!handle.is_shutdown());
186 handle.trigger();
187 assert!(handle.is_shutdown());
188 }
189
190 #[test]
191 fn shutdown_signal_resolves_after_trigger() {
192 use crate::{Runtime, spawn_boxed};
193 use nexus_rt::WorldBuilder;
194 use std::cell::Cell;
195 use std::rc::Rc;
196
197 let wb = WorldBuilder::new();
198 let mut world = wb.build();
199 let mut rt = Runtime::new(&mut world);
200 let shutdown = rt.shutdown_handle();
201
202 let done = Rc::new(Cell::new(false));
203 let flag = done.clone();
204
205 let sh = shutdown.clone();
207 rt.block_on(async move {
208 spawn_boxed(async move {
209 crate::context::sleep(std::time::Duration::from_millis(50)).await;
210 sh.trigger();
211 });
212
213 shutdown.signal().await;
215 flag.set(true);
216 });
217
218 assert!(done.get());
219 }
220
221 #[test]
222 fn shutdown_signal_waker_updates_on_repoll() {
223 use std::task::{RawWaker, RawWakerVTable, Waker};
225
226 let handle = ShutdownHandle::new();
227 let mut signal = Box::pin(handle.signal());
228
229 let noop = unsafe {
231 static V: RawWakerVTable =
232 RawWakerVTable::new(|p| RawWaker::new(p, &V), |_| {}, |_| {}, |_| {});
233 Waker::from_raw(RawWaker::new(std::ptr::null(), &V))
234 };
235 let mut cx = Context::from_waker(&noop);
236 assert_eq!(signal.as_mut().poll(&mut cx), Poll::Pending);
237
238 let woke = std::cell::Cell::new(false);
240 let flag_ptr = &woke as *const std::cell::Cell<bool> as *const ();
241 let tracking = unsafe {
242 static V2: RawWakerVTable = RawWakerVTable::new(
243 |p| RawWaker::new(p, &V2),
244 |p| unsafe { (*(p as *const std::cell::Cell<bool>)).set(true) },
245 |p| unsafe { (*(p as *const std::cell::Cell<bool>)).set(true) },
246 |_| {},
247 );
248 Waker::from_raw(RawWaker::new(flag_ptr, &V2))
249 };
250 let mut cx2 = Context::from_waker(&tracking);
251 assert_eq!(signal.as_mut().poll(&mut cx2), Poll::Pending);
252
253 handle.trigger();
255 assert!(woke.get(), "latest waker must fire on trigger");
256 }
257
258 #[test]
259 fn shutdown_signal_already_triggered() {
260 use std::task::{RawWaker, RawWakerVTable, Waker};
262
263 let handle = ShutdownHandle::new();
264 handle.trigger();
265
266 let mut signal = Box::pin(handle.signal());
267 let waker = unsafe {
268 static V: RawWakerVTable =
269 RawWakerVTable::new(|p| RawWaker::new(p, &V), |_| {}, |_| {}, |_| {});
270 Waker::from_raw(RawWaker::new(std::ptr::null(), &V))
271 };
272 let mut cx = Context::from_waker(&waker);
273 assert_eq!(signal.as_mut().poll(&mut cx), Poll::Ready(()));
274 }
275}