use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
static ACTIVE_THREAD: AtomicUsize = AtomicUsize::new(0);
#[track_caller]
pub fn check_active_thread() {
let current_thread = nonzero_thread_id();
match ACTIVE_THREAD.load(Ordering::Relaxed) {
0 => init_active_thread(current_thread),
thread_id => {
if current_thread.get() != thread_id {
thread_id_check_failed();
}
}
}
}
#[doc(hidden)]
pub mod active_thread {
#[doc(hidden)]
pub fn clear() {
super::ACTIVE_THREAD.store(0, std::sync::atomic::Ordering::Relaxed);
}
}
pub(super) fn is_os_main_thread() -> Option<bool> {
#[cfg(any(target_os = "macos", target_os = "openbsd", target_os = "freebsd"))]
return unsafe {
match libc::pthread_main_np() {
1 => Some(true),
0 => Some(false),
_ => None,
}
};
#[cfg(target_os = "linux")]
return unsafe {
let tid = libc::syscall(libc::SYS_gettid) as core::ffi::c_long;
let pid = libc::getpid() as core::ffi::c_long;
Some(tid == pid)
};
#[allow(unreachable_code)]
{
None
}
}
#[track_caller]
fn init_active_thread(tid: NonZeroUsize) {
match ACTIVE_THREAD.compare_exchange(0, tid.get(), Ordering::Relaxed, Ordering::Relaxed) {
#[cfg(all(target_family = "unix", not(target_os = "emscripten")))]
Ok(_) => unsafe {
static ATFORK_REGISTERED: AtomicBool = AtomicBool::new(false);
if !ATFORK_REGISTERED.swap(true, Ordering::Relaxed) {
extern "C" fn clear_in_child() {
ACTIVE_THREAD.store(0, Ordering::Relaxed);
ATFORK_REGISTERED.store(false, Ordering::Relaxed);
}
libc::pthread_atfork(None, None, Some(clear_in_child));
}
},
#[allow(unreachable_patterns)]
Ok(_) => (),
Err(_) => {
thread_id_check_failed();
}
}
}
#[cold]
#[inline(never)]
#[track_caller]
fn thread_id_check_failed() -> ! {
assert_ne!(is_os_main_thread(), Some(true), "`pgrx` active thread is not the main thread!?");
panic!(
"{}: postgres FFI may not be called from multiple threads.",
std::panic::Location::caller()
);
}
fn nonzero_thread_id() -> NonZeroUsize {
std::thread_local! {
static BYTE: u8 = const { 0 };
}
BYTE.with(|p: &u8| {
let addr = (p as *const u8).addr();
unsafe { NonZeroUsize::new_unchecked(addr) }
})
}