pgrx_tests/framework/
shutdown.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10use std::panic::{self, AssertUnwindSafe, Location};
11use std::sync::{Mutex, PoisonError};
12use std::{any, io, mem, process};
13
14/// Register a shutdown hook to be called when the process exits.
15///
16/// Note that shutdown hooks are only run on the client, so must be added from
17/// your `setup` function, not the `#[pg_test]` itself.
18#[track_caller]
19pub fn add_shutdown_hook<F: FnOnce()>(func: F)
20where
21    F: Send + 'static,
22{
23    SHUTDOWN_HOOKS
24        .lock()
25        .unwrap_or_else(PoisonError::into_inner)
26        .push(ShutdownHook { source: Location::caller(), callback: Box::new(func) });
27}
28
29pub(super) fn register_shutdown_hook() {
30    unsafe {
31        libc::atexit(run_shutdown_hooks);
32    }
33}
34
35/// The `atexit` callback.
36///
37/// If we panic from `atexit`, we end up causing `exit` to unwind. Unwinding
38/// from a nounwind + noreturn function can cause some destructors to run twice,
39/// causing (for example) libtest to SIGSEGV.
40///
41/// This ends up looking like a memory bug in either `pgrx` or the user code, and
42/// is very hard to track down, so we go to some lengths to prevent it.
43/// Essentially:
44///
45/// - Panics in each user hook are caught and reported.
46/// - As a stop-gap an abort-on-drop panic guard is used to ensure there isn't a
47///   place we missed.
48///
49/// We also write to stderr directly instead, since otherwise our output will
50/// sometimes be redirected.
51extern "C" fn run_shutdown_hooks() {
52    let guard = PanicGuard;
53    let mut any_panicked = false;
54    let mut hooks = SHUTDOWN_HOOKS.lock().unwrap_or_else(PoisonError::into_inner);
55    // Note: run hooks in the opposite order they were registered.
56    for hook in mem::take(&mut *hooks).into_iter().rev() {
57        any_panicked |= hook.run().is_err();
58    }
59    if any_panicked {
60        write_stderr("error: one or more shutdown hooks panicked (see `stderr` for details).\n");
61        process::abort()
62    }
63    mem::forget(guard);
64}
65
66/// Prevent panics in a block of code.
67///
68/// Prints a message and aborts in its drop. Intended usage is like:
69/// ```ignore
70/// let guard = PanicGuard;
71/// // ...code that absolutely must never unwind goes here...
72/// core::mem::forget(guard);
73/// ```
74struct PanicGuard;
75impl Drop for PanicGuard {
76    fn drop(&mut self) {
77        write_stderr("Failed to catch panic in the `atexit` callback, aborting!\n");
78        process::abort();
79    }
80}
81
82static SHUTDOWN_HOOKS: Mutex<Vec<ShutdownHook>> = Mutex::new(Vec::new());
83
84struct ShutdownHook {
85    source: &'static Location<'static>,
86    callback: Box<dyn FnOnce() + Send>,
87}
88
89impl ShutdownHook {
90    fn run(self) -> Result<(), ()> {
91        let Self { source, callback } = self;
92        let result = panic::catch_unwind(AssertUnwindSafe(callback));
93        if let Err(e) = result {
94            let msg = failure_message(&e);
95            write_stderr(&format!(
96                "error: shutdown hook (registered at {source}) panicked: {msg}\n"
97            ));
98            Err(())
99        } else {
100            Ok(())
101        }
102    }
103}
104
105fn failure_message(e: &(dyn any::Any + Send)) -> &str {
106    if let Some(&msg) = e.downcast_ref::<&'static str>() {
107        msg
108    } else if let Some(msg) = e.downcast_ref::<String>() {
109        msg.as_str()
110    } else {
111        "<panic payload of unknown type>"
112    }
113}
114
115/// Write to stderr, bypassing libtest's output redirection. Doesn't append `\n`.
116fn write_stderr(s: &str) {
117    loop {
118        let res = unsafe { libc::write(libc::STDERR_FILENO, s.as_ptr().cast(), s.len()) };
119        // Handle EINTR to ensure we don't drop messages.
120        // `Error::last_os_error()` just reads from errno, so it's fine to use here.
121        if res >= 0 || io::Error::last_os_error().kind() != io::ErrorKind::Interrupted {
122            break;
123        }
124    }
125}