pgrx_tests/framework/
shutdown.rs1use std::panic::{self, AssertUnwindSafe, Location};
11use std::sync::{Mutex, PoisonError};
12use std::{any, io, mem, process};
13
14#[track_caller]
19pub fn add_shutdown_hook<F: FnOnce()>(func: F)
20where
21 F: Send + 'static,
22{
23 SHUTDOWN_HOOKS
24 .lock()
25 .unwrap_or_else(PoisonError::into_inner)
26 .push(ShutdownHook { source: Location::caller(), callback: Box::new(func) });
27}
28
29pub(super) fn register_shutdown_hook() {
30 unsafe {
31 libc::atexit(run_shutdown_hooks);
32 }
33}
34
35extern "C" fn run_shutdown_hooks() {
52 let guard = PanicGuard;
53 let mut any_panicked = false;
54 let mut hooks = SHUTDOWN_HOOKS.lock().unwrap_or_else(PoisonError::into_inner);
55 for hook in mem::take(&mut *hooks).into_iter().rev() {
57 any_panicked |= hook.run().is_err();
58 }
59 if any_panicked {
60 write_stderr("error: one or more shutdown hooks panicked (see `stderr` for details).\n");
61 process::abort()
62 }
63 mem::forget(guard);
64}
65
66struct PanicGuard;
75impl Drop for PanicGuard {
76 fn drop(&mut self) {
77 write_stderr("Failed to catch panic in the `atexit` callback, aborting!\n");
78 process::abort();
79 }
80}
81
82static SHUTDOWN_HOOKS: Mutex<Vec<ShutdownHook>> = Mutex::new(Vec::new());
83
84struct ShutdownHook {
85 source: &'static Location<'static>,
86 callback: Box<dyn FnOnce() + Send>,
87}
88
89impl ShutdownHook {
90 fn run(self) -> Result<(), ()> {
91 let Self { source, callback } = self;
92 let result = panic::catch_unwind(AssertUnwindSafe(callback));
93 if let Err(e) = result {
94 let msg = failure_message(&e);
95 write_stderr(&format!(
96 "error: shutdown hook (registered at {source}) panicked: {msg}\n"
97 ));
98 Err(())
99 } else {
100 Ok(())
101 }
102 }
103}
104
105fn failure_message(e: &(dyn any::Any + Send)) -> &str {
106 if let Some(&msg) = e.downcast_ref::<&'static str>() {
107 msg
108 } else if let Some(msg) = e.downcast_ref::<String>() {
109 msg.as_str()
110 } else {
111 "<panic payload of unknown type>"
112 }
113}
114
115fn write_stderr(s: &str) {
117 loop {
118 let res = unsafe { libc::write(libc::STDERR_FILENO, s.as_ptr().cast(), s.len()) };
119 if res >= 0 || io::Error::last_os_error().kind() != io::ErrorKind::Interrupted {
122 break;
123 }
124 }
125}