#![allow(static_mut_refs)]
use crate as pgrx; use crate::pg_sys;
use crate::prelude::*;
use enum_map::{Enum, EnumMap};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Enum)]
pub enum PgXactCallbackEvent {
Abort,
Commit,
PreCommit,
ParallelAbort,
ParallelCommit,
ParallelPreCommit,
Prepare,
PrePrepare,
}
impl PgXactCallbackEvent {
fn translate_pg_event(pg_event: pg_sys::XactEvent::Type) -> Self {
use pg_sys::XactEvent::*;
match pg_event {
XACT_EVENT_ABORT => PgXactCallbackEvent::Abort,
XACT_EVENT_COMMIT => PgXactCallbackEvent::Commit,
XACT_EVENT_PARALLEL_ABORT => PgXactCallbackEvent::ParallelAbort,
XACT_EVENT_PARALLEL_COMMIT => PgXactCallbackEvent::ParallelCommit,
XACT_EVENT_PARALLEL_PRE_COMMIT => PgXactCallbackEvent::ParallelPreCommit,
XACT_EVENT_PREPARE => PgXactCallbackEvent::Prepare,
XACT_EVENT_PRE_COMMIT => PgXactCallbackEvent::PreCommit,
XACT_EVENT_PRE_PREPARE => PgXactCallbackEvent::PrePrepare,
unknown => panic!("Unrecognized XactEvent: {unknown}"),
}
}
}
pub struct XactCallbackReceipt(Rc<RefCell<Option<XactCallbackWrapper>>>);
impl XactCallbackReceipt {
pub fn unregister_callback(self) {
self.0.replace(None);
}
}
struct XactCallbackWrapper(
Box<dyn FnOnce() + std::panic::UnwindSafe + std::panic::RefUnwindSafe + 'static>,
);
type CallbackMap =
EnumMap<PgXactCallbackEvent, Option<Vec<Rc<RefCell<Option<XactCallbackWrapper>>>>>>;
pub fn register_xact_callback<F>(which_event: PgXactCallbackEvent, f: F) -> XactCallbackReceipt
where
F: FnOnce() + std::panic::UnwindSafe + std::panic::RefUnwindSafe + 'static,
{
static mut XACT_HOOKS: Option<CallbackMap> = None;
#[pg_guard]
unsafe extern "C-unwind" fn callback(
event: pg_sys::XactEvent::Type,
_arg: *mut ::std::os::raw::c_void,
) {
let which_event = PgXactCallbackEvent::translate_pg_event(event);
let hooks = match which_event {
PgXactCallbackEvent::Commit
| PgXactCallbackEvent::Abort
| PgXactCallbackEvent::ParallelCommit
| PgXactCallbackEvent::ParallelAbort => XACT_HOOKS
.replace(CallbackMap::default())
.expect("XACT_HOOKS was None during Commit/Abort")[which_event]
.take(),
_ => XACT_HOOKS.as_mut().expect("XACT_HOOKS was None")[which_event].take(),
};
if let Some(hooks) = hooks {
for hook in hooks.into_iter() {
if let Some(hook) = hook.replace(None) {
hook.0();
}
}
}
}
fn maybe_initialize<'a>() -> &'a mut CallbackMap {
unsafe {
if XACT_HOOKS.is_none() {
XACT_HOOKS.replace(Default::default());
pg_sys::RegisterXactCallback(Some(callback), std::ptr::null_mut());
}
XACT_HOOKS.as_mut().expect("XACT_HOOKS was None during maybe_initialize")
}
}
let hooks = maybe_initialize();
let wrapped_func = Rc::new(RefCell::new(Some(XactCallbackWrapper(Box::new(f)))));
let entry = hooks[which_event].get_or_insert_with(Default::default);
entry.push(Rc::clone(&wrapped_func));
XactCallbackReceipt(wrapped_func)
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub enum PgSubXactCallbackEvent {
AbortSub,
CommitSub,
PreCommitSub,
StartSub,
}
impl PgSubXactCallbackEvent {
fn translate_pg_event(event: pg_sys::SubXactEvent::Type) -> Self {
use pg_sys::SubXactEvent::*;
match event {
SUBXACT_EVENT_ABORT_SUB => PgSubXactCallbackEvent::AbortSub,
SUBXACT_EVENT_COMMIT_SUB => PgSubXactCallbackEvent::CommitSub,
SUBXACT_EVENT_PRE_COMMIT_SUB => PgSubXactCallbackEvent::PreCommitSub,
SUBXACT_EVENT_START_SUB => PgSubXactCallbackEvent::StartSub,
_ => panic!("Unrecognized SubXactEvent: {event}"),
}
}
}
pub struct SubXactCallbackReceipt(Rc<RefCell<Option<SubXactCallbackWrapper>>>);
impl SubXactCallbackReceipt {
pub fn unregister_callback(self) {
self.0.replace(None);
}
}
struct SubXactCallbackWrapper(
Box<
dyn Fn(pg_sys::SubTransactionId, pg_sys::SubTransactionId)
+ std::panic::UnwindSafe
+ std::panic::RefUnwindSafe
+ 'static,
>,
);
type SubCallbackMap =
HashMap<PgSubXactCallbackEvent, Vec<Rc<RefCell<Option<SubXactCallbackWrapper>>>>>;
pub fn register_subxact_callback<F>(
which_event: PgSubXactCallbackEvent,
f: F,
) -> SubXactCallbackReceipt
where
F: Fn(pg_sys::SubTransactionId, pg_sys::SubTransactionId)
+ std::panic::UnwindSafe
+ std::panic::RefUnwindSafe
+ 'static,
{
static mut SUB_HOOKS: Option<SubCallbackMap> = None;
#[pg_guard]
unsafe extern "C-unwind" fn callback(
event: pg_sys::SubXactEvent::Type,
my_subid: pg_sys::SubTransactionId,
parent_subid: pg_sys::SubTransactionId,
_arg: *mut ::std::os::raw::c_void,
) {
let which_event = PgSubXactCallbackEvent::translate_pg_event(event);
let hooks = SUB_HOOKS.as_mut();
if let Some(hooks) = hooks {
let hooks = hooks.get(&which_event);
if let Some(hooks) = hooks {
for hook in hooks.iter() {
let hook = hook.borrow();
if let Some(hook) = hook.as_ref() {
(hook.0)(my_subid, parent_subid)
}
}
}
}
}
fn maybe_initialize<'a>() -> &'a mut SubCallbackMap {
unsafe {
if SUB_HOOKS.is_none() {
SUB_HOOKS.replace(HashMap::new());
pg_sys::UnregisterSubXactCallback(Some(callback), std::ptr::null_mut());
pg_sys::RegisterSubXactCallback(Some(callback), std::ptr::null_mut());
register_xact_callback(PgXactCallbackEvent::Commit, || {
SUB_HOOKS.take();
});
register_xact_callback(PgXactCallbackEvent::Abort, || {
SUB_HOOKS.take();
});
}
SUB_HOOKS.as_mut().expect("SUB_HOOKS was None during maybe_initialize()")
}
}
let hooks = maybe_initialize();
let entry = hooks.entry(which_event).or_default();
let wrapped_func = Rc::new(RefCell::new(Some(SubXactCallbackWrapper(Box::new(f)))));
entry.push(wrapped_func.clone());
SubXactCallbackReceipt(wrapped_func)
}