1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use parking_lot::Mutex;
9
10use crate::error::{Error, Result};
11use crate::hook::ShutdownHook;
12use crate::reason::ShutdownReason;
13use crate::signal::SignalSet;
14use crate::state::Inner;
15use crate::token::{ShutdownToken, ShutdownTrigger};
16
17#[cfg(any(feature = "tokio", all(feature = "async-std", not(feature = "tokio")),))]
18use crate::signal::Signal;
19
20const DEFAULT_GRACEFUL_MS: u64 = 5_000;
22
23const DEFAULT_FORCE_MS: u64 = 10_000;
25
26pub struct Coordinator {
40 inner: Arc<Inner>,
41 signals: SignalSet,
42 graceful_timeout: Duration,
43 force_timeout: Duration,
44 hooks: Mutex<Vec<Box<dyn ShutdownHook>>>,
45 installed: AtomicBool,
46 hooks_completed: AtomicUsize,
47}
48
49impl core::fmt::Debug for Coordinator {
50 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51 f.debug_struct("Coordinator")
52 .field("signals", &self.signals)
53 .field("graceful_timeout", &self.graceful_timeout)
54 .field("force_timeout", &self.force_timeout)
55 .field(
56 "hooks",
57 &format_args!("[{} hook(s)]", self.hooks.lock().len()),
58 )
59 .field("installed", &self.installed.load(Ordering::Relaxed))
60 .finish()
61 }
62}
63
64impl Coordinator {
65 #[must_use]
67 pub fn builder() -> CoordinatorBuilder {
68 CoordinatorBuilder::new()
69 }
70
71 #[must_use]
73 pub fn token(&self) -> ShutdownToken {
74 ShutdownToken::new(Arc::clone(&self.inner))
75 }
76
77 #[must_use]
79 pub fn trigger(&self) -> ShutdownTrigger {
80 ShutdownTrigger::new(Arc::clone(&self.inner))
81 }
82
83 #[must_use]
85 pub fn signals(&self) -> SignalSet {
86 self.signals
87 }
88
89 #[must_use]
91 pub fn graceful_timeout(&self) -> Duration {
92 self.graceful_timeout
93 }
94
95 #[must_use]
97 pub fn force_timeout(&self) -> Duration {
98 self.force_timeout
99 }
100
101 #[must_use]
103 pub fn is_installed(&self) -> bool {
104 self.installed.load(Ordering::Relaxed)
105 }
106
107 #[must_use]
109 pub fn statistics(&self) -> Statistics {
110 let hooks_registered = self.hooks.lock().len();
111 let hooks_completed = self.hooks_completed.load(Ordering::Relaxed);
112 Statistics {
113 initiated: self.inner.is_initiated(),
114 reason: self.inner.reason(),
115 hooks_registered,
116 hooks_completed,
117 elapsed: self.inner.elapsed(),
118 }
119 }
120
121 pub fn run_hooks(&self, reason: ShutdownReason) -> usize {
129 let mut hooks = self.hooks.lock();
130 hooks.sort_by_key(|h| core::cmp::Reverse(h.priority()));
131 let start = Instant::now();
132 let mut count = 0usize;
133 for hook in hooks.iter() {
134 if start.elapsed() > self.graceful_timeout {
135 break;
136 }
137 hook.run(reason);
138 count += 1;
139 self.hooks_completed.fetch_add(1, Ordering::Relaxed);
140 }
141 count
142 }
143
144 pub fn install(&self) -> Result<()> {
167 if self
168 .installed
169 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
170 .is_err()
171 {
172 return Err(Error::AlreadyInstalled);
173 }
174
175 let result = self.install_impl();
176
177 if result.is_err() {
178 self.installed.store(false, Ordering::Release);
179 }
180 result
181 }
182
183 #[cfg(feature = "tokio")]
184 fn install_impl(&self) -> Result<()> {
185 install_tokio(self)
186 }
187
188 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
189 fn install_impl(&self) -> Result<()> {
190 install_async_std(self)
191 }
192
193 #[cfg(all(
194 feature = "ctrlc-fallback",
195 not(feature = "tokio"),
196 not(feature = "async-std")
197 ))]
198 fn install_impl(&self) -> Result<()> {
199 install_ctrlc(self)
200 }
201
202 #[cfg(not(any(feature = "tokio", feature = "async-std", feature = "ctrlc-fallback")))]
203 #[allow(clippy::unused_self)]
204 fn install_impl(&self) -> Result<()> {
205 Err(Error::NoRuntime)
206 }
207}
208
209pub struct CoordinatorBuilder {
211 signals: SignalSet,
212 graceful_timeout: Duration,
213 force_timeout: Duration,
214 hooks: Vec<Box<dyn ShutdownHook>>,
215}
216
217impl core::fmt::Debug for CoordinatorBuilder {
218 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
219 f.debug_struct("CoordinatorBuilder")
220 .field("signals", &self.signals)
221 .field("graceful_timeout", &self.graceful_timeout)
222 .field("force_timeout", &self.force_timeout)
223 .field("hooks", &format_args!("[{} hook(s)]", self.hooks.len()))
224 .finish()
225 }
226}
227
228impl CoordinatorBuilder {
229 #[must_use]
232 pub fn new() -> Self {
233 Self {
234 signals: SignalSet::graceful(),
235 graceful_timeout: Duration::from_millis(DEFAULT_GRACEFUL_MS),
236 force_timeout: Duration::from_millis(DEFAULT_FORCE_MS),
237 hooks: Vec::new(),
238 }
239 }
240
241 #[must_use]
243 pub fn signals(mut self, set: SignalSet) -> Self {
244 self.signals = set;
245 self
246 }
247
248 #[must_use]
251 pub fn graceful_timeout(mut self, d: Duration) -> Self {
252 self.graceful_timeout = d;
253 self
254 }
255
256 #[must_use]
260 pub fn force_timeout(mut self, d: Duration) -> Self {
261 self.force_timeout = d;
262 self
263 }
264
265 #[must_use]
267 pub fn hook<H: ShutdownHook>(mut self, h: H) -> Self {
268 self.hooks.push(Box::new(h));
269 self
270 }
271
272 #[must_use]
274 pub fn build(self) -> Coordinator {
275 Coordinator {
276 inner: Inner::new(),
277 signals: self.signals,
278 graceful_timeout: self.graceful_timeout,
279 force_timeout: self.force_timeout,
280 hooks: Mutex::new(self.hooks),
281 installed: AtomicBool::new(false),
282 hooks_completed: AtomicUsize::new(0),
283 }
284 }
285}
286
287impl Default for CoordinatorBuilder {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293#[derive(Debug, Clone)]
296pub struct Statistics {
297 pub initiated: bool,
299 pub reason: Option<ShutdownReason>,
301 pub hooks_registered: usize,
303 pub hooks_completed: usize,
305 pub elapsed: Option<Duration>,
307}
308
309#[cfg(feature = "tokio")]
314fn install_tokio(coord: &Coordinator) -> Result<()> {
315 let trigger = coord.trigger();
316 let set = coord.signals;
317
318 #[cfg(unix)]
319 {
320 use tokio::signal::unix::{signal, SignalKind};
321
322 macro_rules! reg {
323 ($sig:expr, $kind:expr) => {{
324 if set.contains($sig) {
325 let mut stream = signal($kind).map_err(|e| Error::SignalRegistration {
326 signal: $sig,
327 source: e,
328 })?;
329 let t = trigger.clone();
330 tokio::spawn(async move {
331 while stream.recv().await.is_some() {
332 t.trigger(ShutdownReason::Signal($sig));
333 }
334 });
335 }
336 }};
337 }
338
339 reg!(Signal::Terminate, SignalKind::terminate());
340 reg!(Signal::Interrupt, SignalKind::interrupt());
341 reg!(Signal::Quit, SignalKind::quit());
342 reg!(Signal::Hangup, SignalKind::hangup());
343 reg!(Signal::Pipe, SignalKind::pipe());
344 reg!(Signal::User1, SignalKind::user_defined1());
345 reg!(Signal::User2, SignalKind::user_defined2());
346 }
347
348 #[cfg(windows)]
349 {
350 use tokio::signal::windows::{ctrl_break, ctrl_c, ctrl_close, ctrl_shutdown};
351
352 if set.contains(Signal::Interrupt) {
353 let mut s = ctrl_c().map_err(|e| Error::SignalRegistration {
354 signal: Signal::Interrupt,
355 source: e,
356 })?;
357 let t = trigger.clone();
358 tokio::spawn(async move {
359 while s.recv().await.is_some() {
360 t.trigger(ShutdownReason::Signal(Signal::Interrupt));
361 }
362 });
363 }
364 if set.contains(Signal::Quit) {
365 let mut s = ctrl_break().map_err(|e| Error::SignalRegistration {
366 signal: Signal::Quit,
367 source: e,
368 })?;
369 let t = trigger.clone();
370 tokio::spawn(async move {
371 while s.recv().await.is_some() {
372 t.trigger(ShutdownReason::Signal(Signal::Quit));
373 }
374 });
375 }
376 if set.contains(Signal::Terminate) {
377 let mut s = ctrl_close().map_err(|e| Error::SignalRegistration {
378 signal: Signal::Terminate,
379 source: e,
380 })?;
381 let t = trigger.clone();
382 tokio::spawn(async move {
383 while s.recv().await.is_some() {
384 t.trigger(ShutdownReason::Signal(Signal::Terminate));
385 }
386 });
387 }
388 if set.contains(Signal::Hangup) {
389 let mut s = ctrl_shutdown().map_err(|e| Error::SignalRegistration {
390 signal: Signal::Hangup,
391 source: e,
392 })?;
393 let t = trigger.clone();
394 tokio::spawn(async move {
395 while s.recv().await.is_some() {
396 t.trigger(ShutdownReason::Signal(Signal::Hangup));
397 }
398 });
399 }
400 let _ = &trigger;
402 }
403
404 Ok(())
405}
406
407#[cfg(all(feature = "async-std", not(feature = "tokio")))]
412fn install_async_std(coord: &Coordinator) -> Result<()> {
413 let trigger = coord.trigger();
414 let set = coord.signals;
415
416 #[cfg(unix)]
417 {
418 use futures::stream::StreamExt;
419 use signal_hook_async_std::Signals as SHSignals;
420
421 let mut signum_to_variant: Vec<(i32, Signal)> = Vec::new();
422 for sig in set.iter() {
423 if let Some(n) = sig.unix_number() {
424 signum_to_variant.push((n, sig));
425 }
426 }
427
428 let nums: Vec<i32> = signum_to_variant.iter().map(|(n, _)| *n).collect();
429 if !nums.is_empty() {
430 let signals = SHSignals::new(&nums).map_err(|e| {
431 let first = signum_to_variant
432 .first()
433 .map(|(_, s)| *s)
434 .unwrap_or(Signal::Terminate);
435 Error::SignalRegistration {
436 signal: first,
437 source: e,
438 }
439 })?;
440 let t = trigger.clone();
441 async_std::task::spawn(async move {
442 let mut signals = signals;
443 while let Some(num) = signals.next().await {
444 if let Some(sig) = signum_to_variant
445 .iter()
446 .find(|(n, _)| *n == num)
447 .map(|(_, s)| *s)
448 {
449 t.trigger(ShutdownReason::Signal(sig));
450 }
451 }
452 });
453 }
454 }
455
456 #[cfg(windows)]
457 {
458 if set.contains(Signal::Interrupt) {
459 let t = trigger.clone();
460 ctrlc::try_set_handler(move || {
461 let _ = t.trigger(ShutdownReason::Signal(Signal::Interrupt));
462 })
463 .map_err(|e| Error::SignalRegistration {
464 signal: Signal::Interrupt,
465 source: std::io::Error::other(e),
466 })?;
467 }
468 let _ = &trigger;
469 }
470
471 Ok(())
472}
473
474#[cfg(all(
479 feature = "ctrlc-fallback",
480 not(feature = "tokio"),
481 not(feature = "async-std")
482))]
483fn install_ctrlc(coord: &Coordinator) -> Result<()> {
484 use crate::signal::Signal;
485 let trigger = coord.trigger();
486 if coord.signals.contains(Signal::Interrupt) {
487 ctrlc::try_set_handler(move || {
488 let _ = trigger.trigger(ShutdownReason::Signal(Signal::Interrupt));
489 })
490 .map_err(|e| Error::SignalRegistration {
491 signal: Signal::Interrupt,
492 source: std::io::Error::other(e),
493 })?;
494 }
495 Ok(())
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use std::sync::atomic::{AtomicUsize, Ordering};
502 use std::sync::Arc;
503
504 use crate::hook::hook_from_fn;
505
506 #[test]
507 fn builder_defaults() {
508 let c = Coordinator::builder().build();
509 assert_eq!(c.signals(), SignalSet::graceful());
510 assert_eq!(c.graceful_timeout(), Duration::from_millis(5_000));
511 assert_eq!(c.force_timeout(), Duration::from_millis(10_000));
512 assert!(!c.is_installed());
513 let stats = c.statistics();
514 assert!(!stats.initiated);
515 assert_eq!(stats.hooks_registered, 0);
516 assert_eq!(stats.hooks_completed, 0);
517 }
518
519 #[test]
520 fn token_observes_trigger() {
521 let c = Coordinator::builder().build();
522 let token = c.token();
523 let trigger = c.trigger();
524 assert!(!token.is_initiated());
525 assert!(trigger.trigger(ShutdownReason::Requested));
526 assert!(token.is_initiated());
527 assert_eq!(token.reason(), Some(ShutdownReason::Requested));
528 assert!(!trigger.trigger(ShutdownReason::Forced));
529 assert_eq!(token.reason(), Some(ShutdownReason::Requested));
530 }
531
532 #[test]
533 fn hooks_run_in_priority_order() {
534 let order = Arc::new(parking_lot::Mutex::new(Vec::<i32>::new()));
535
536 let push = |p: i32, order: &Arc<parking_lot::Mutex<Vec<i32>>>| {
537 let o = Arc::clone(order);
538 hook_from_fn(format!("p{p}"), p, move |_| {
539 o.lock().push(p);
540 })
541 };
542
543 let c = Coordinator::builder()
544 .hook(push(0, &order))
545 .hook(push(100, &order))
546 .hook(push(50, &order))
547 .build();
548
549 let count = c.run_hooks(ShutdownReason::Requested);
550 assert_eq!(count, 3);
551 assert_eq!(*order.lock(), vec![100, 50, 0]);
552 assert_eq!(c.statistics().hooks_completed, 3);
553 }
554
555 #[test]
556 fn hooks_respect_graceful_budget() {
557 let counter = Arc::new(AtomicUsize::new(0));
558 let c1 = Arc::clone(&counter);
559 let c2 = Arc::clone(&counter);
560
561 let slow = hook_from_fn("slow", 100, move |_| {
562 c1.fetch_add(1, Ordering::Relaxed);
563 std::thread::sleep(Duration::from_millis(30));
564 });
565 let later = hook_from_fn("later", 0, move |_| {
566 c2.fetch_add(1, Ordering::Relaxed);
567 });
568
569 let c = Coordinator::builder()
570 .graceful_timeout(Duration::from_millis(5))
571 .hook(slow)
572 .hook(later)
573 .build();
574
575 let count = c.run_hooks(ShutdownReason::Requested);
576 assert_eq!(count, 1);
577 assert_eq!(counter.load(Ordering::Relaxed), 1);
578 }
579
580 #[test]
581 fn elapsed_increases_after_trigger() {
582 let c = Coordinator::builder().build();
583 let token = c.token();
584 assert!(token.elapsed().is_none());
585 let _ = c.trigger().trigger(ShutdownReason::Requested);
586 let first = token.elapsed().unwrap();
587 std::thread::sleep(Duration::from_millis(5));
588 let second = token.elapsed().unwrap();
589 assert!(second >= first);
590 }
591
592 #[test]
593 fn wait_blocking_timeout_returns_false_on_expiry() {
594 let c = Coordinator::builder().build();
595 let token = c.token();
596 assert!(!token.wait_blocking_timeout(Duration::from_millis(5)));
597 }
598
599 #[test]
600 fn wait_blocking_timeout_returns_true_on_trigger() {
601 let c = Coordinator::builder().build();
602 let token = c.token();
603 let trigger = c.trigger();
604
605 let handle = std::thread::spawn(move || {
606 std::thread::sleep(Duration::from_millis(10));
607 trigger.trigger(ShutdownReason::Requested);
608 });
609
610 assert!(token.wait_blocking_timeout(Duration::from_secs(1)));
611 handle.join().unwrap();
612 }
613
614 #[cfg(not(any(feature = "tokio", feature = "async-std", feature = "ctrlc-fallback")))]
615 #[test]
616 fn install_errors_with_no_runtime() {
617 let c = Coordinator::builder().build();
618 assert!(matches!(c.install(), Err(Error::NoRuntime)));
619 }
620
621 #[cfg(feature = "tokio")]
622 #[tokio::test]
623 async fn token_wait_resolves_on_trigger() {
624 let c = Coordinator::builder().build();
625 let token = c.token();
626 let trigger = c.trigger();
627
628 let waiter = tokio::spawn(async move { token.wait().await });
629 tokio::time::sleep(Duration::from_millis(10)).await;
630 assert!(trigger.trigger(ShutdownReason::Requested));
631 let _ = tokio::time::timeout(Duration::from_secs(1), waiter)
632 .await
633 .expect("wait did not resolve within 1s");
634 }
635}