#![deny(unsafe_op_in_unsafe_fn)]
#![allow(non_snake_case)]
use core::ffi::CStr;
use std::any::Any;
use std::borrow::Cow;
use std::cell::Cell;
use std::fmt::{Display, Formatter};
use std::hint::unreachable_unchecked;
use std::panic::{
AssertUnwindSafe, Location, PanicHookInfo, UnwindSafe, catch_unwind, panic_any, resume_unwind,
};
use crate::elog::PgLogLevel;
use crate::errcodes::PgSqlErrorCode;
use crate::{AsPgCStr, MemoryContextSwitchTo, pfree};
pub trait ErrorReportable {
type Inner;
fn unwrap_or_report(self) -> Self::Inner;
}
impl<T, E> ErrorReportable for Result<T, E>
where
E: Any + Display,
{
type Inner = T;
fn unwrap_or_report(self) -> Self::Inner {
self.unwrap_or_else(|e| {
let any: Box<&dyn Any> = Box::new(&e);
if any.downcast_ref::<ErrorReport>().is_some() {
let any: Box<dyn Any> = Box::new(e);
any.downcast::<ErrorReport>().unwrap().report(PgLogLevel::ERROR);
unreachable!();
} else {
ereport!(ERROR, PgSqlErrorCode::ERRCODE_DATA_EXCEPTION, format!("{e}"));
}
})
}
}
#[derive(Debug)]
pub struct ErrorReportLocation {
pub(crate) file: String,
pub(crate) funcname: Option<String>,
pub(crate) line: u32,
pub(crate) col: u32,
pub(crate) backtrace: Option<std::backtrace::Backtrace>,
}
impl Default for ErrorReportLocation {
fn default() -> Self {
Self {
file: std::string::String::from("<unknown>"),
funcname: None,
line: 0,
col: 0,
backtrace: None,
}
}
}
impl Display for ErrorReportLocation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match &self.funcname {
Some(funcname) => {
write!(f, "{}, {}:{}:{}", funcname, self.file, self.line, self.col)?;
}
None => {
write!(f, "{}:{}:{}", self.file, self.line, self.col)?;
}
}
if let Some(backtrace) = &self.backtrace
&& backtrace.status() == std::backtrace::BacktraceStatus::Captured
{
write!(f, "\n{backtrace}")?;
}
Ok(())
}
}
impl From<&Location<'_>> for ErrorReportLocation {
fn from(location: &Location<'_>) -> Self {
Self {
file: location.file().to_string(),
funcname: None,
line: location.line(),
col: location.column(),
backtrace: None,
}
}
}
impl From<&PanicHookInfo<'_>> for ErrorReportLocation {
fn from(pi: &PanicHookInfo<'_>) -> Self {
pi.location().map(|l| l.into()).unwrap_or_default()
}
}
#[derive(Debug)]
pub struct ErrorReport {
pub(crate) sqlerrcode: PgSqlErrorCode,
pub(crate) message: Cow<'static, str>,
pub(crate) hint: Option<String>,
pub(crate) detail: Option<String>,
pub(crate) domain: Option<String>,
pub(crate) location: ErrorReportLocation,
}
impl Display for ErrorReport {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.sqlerrcode, self.message)?;
if let Some(hint) = &self.hint {
write!(f, "\nHINT: {hint}")?;
}
if let Some(detail) = &self.detail {
write!(f, "\nDETAIL: {detail}")?;
}
write!(f, "\nLOCATION: {}", self.location)
}
}
#[derive(Debug)]
pub struct ErrorReportWithLevel {
pub(crate) level: PgLogLevel,
pub(crate) inner: ErrorReport,
}
impl ErrorReportWithLevel {
fn report(self) {
match self.level {
PgLogLevel::ERROR => panic_any(self),
PgLogLevel::FATAL | PgLogLevel::PANIC => {
do_ereport(self);
unreachable!()
}
_ => do_ereport(self),
}
}
pub fn level(&self) -> PgLogLevel {
self.level
}
pub fn sql_error_code(&self) -> PgSqlErrorCode {
self.inner.sqlerrcode
}
pub fn message(&self) -> &str {
self.inner.message()
}
pub fn detail(&self) -> Option<&str> {
self.inner.detail()
}
pub fn detail_with_backtrace(&self) -> Option<String> {
match (self.detail(), self.backtrace()) {
(Some(detail), Some(bt))
if bt.status() == std::backtrace::BacktraceStatus::Captured =>
{
Some(format!("{detail}\n{bt}"))
}
(Some(d), _) => Some(d.to_string()),
(None, Some(bt)) if bt.status() == std::backtrace::BacktraceStatus::Captured => {
Some(format!("\n{bt}"))
}
(None, _) => None,
}
}
pub fn hint(&self) -> Option<&str> {
self.inner.hint()
}
pub fn domain(&self) -> Option<&str> {
self.inner.domain()
}
pub fn file(&self) -> &str {
&self.inner.location.file
}
pub fn line_number(&self) -> u32 {
self.inner.location.line
}
pub fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
self.inner.location.backtrace.as_ref()
}
pub fn function_name(&self) -> Option<&str> {
self.inner.location.funcname.as_deref()
}
fn context_message(&self) -> Option<String> {
None
}
}
impl ErrorReport {
#[track_caller]
pub fn new<S: Into<Cow<'static, str>>>(
sqlerrcode: PgSqlErrorCode,
message: S,
funcname: &'static str,
) -> Self {
let mut location: ErrorReportLocation = Location::caller().into();
location.funcname = Some(funcname.to_string());
Self {
sqlerrcode,
message: message.into(),
hint: None,
detail: None,
domain: None,
location,
}
}
fn with_location<S: Into<Cow<'static, str>>>(
sqlerrcode: PgSqlErrorCode,
message: S,
location: ErrorReportLocation,
) -> Self {
Self {
sqlerrcode,
message: message.into(),
hint: None,
detail: None,
domain: None,
location,
}
}
pub fn set_detail<S: Into<String>>(mut self, detail: S) -> Self {
self.detail = Some(detail.into());
self
}
pub fn set_hint<S: Into<String>>(mut self, hint: S) -> Self {
self.hint = Some(hint.into());
self
}
pub fn set_domain<S: Into<String>>(mut self, domain: S) -> Self {
self.domain = Some(domain.into());
self
}
pub fn message(&self) -> &str {
&self.message
}
pub fn detail(&self) -> Option<&str> {
self.detail.as_deref()
}
pub fn hint(&self) -> Option<&str> {
self.hint.as_deref()
}
pub fn domain(&self) -> Option<&str> {
self.domain.as_deref()
}
pub fn report(self, level: PgLogLevel) {
ErrorReportWithLevel { level, inner: self }.report()
}
}
thread_local! { static PANIC_LOCATION: Cell<Option<ErrorReportLocation>> = const { Cell::new(None) }}
fn take_panic_location() -> ErrorReportLocation {
PANIC_LOCATION.with(|p| p.take().unwrap_or_default())
}
pub fn register_pg_guard_panic_hook() {
use super::thread_check::is_os_main_thread;
let default_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info: _| {
PANIC_LOCATION.with(|thread_local| {
thread_local.replace({
let mut info: ErrorReportLocation = info.into();
info.backtrace = Some(std::backtrace::Backtrace::capture());
Some(info)
})
});
if is_os_main_thread() == Some(false) {
default_hook(info)
}
}))
}
#[derive(Debug)]
pub enum CaughtError {
PostgresError(ErrorReportWithLevel),
ErrorReport(ErrorReportWithLevel),
RustPanic { ereport: ErrorReportWithLevel, payload: Box<dyn Any + Send> },
}
impl CaughtError {
pub fn rethrow(self) -> ! {
resume_unwind(Box::new(self))
}
}
#[derive(Debug)]
enum GuardAction<R> {
Return(R),
Report(ErrorReportWithLevel),
}
#[doc(hidden)]
pub unsafe fn pgrx_extern_c_guard<Func, R>(f: Func) -> R
where
Func: FnOnce() -> R,
{
match unsafe { run_guarded(AssertUnwindSafe(f)) } {
GuardAction::Return(r) => r,
GuardAction::Report(ereport) => {
do_ereport(ereport);
unreachable!("pgrx reported a CaughtError that wasn't raised at ERROR or above");
}
}
}
#[inline(never)]
unsafe fn run_guarded<F, R>(f: F) -> GuardAction<R>
where
F: FnOnce() -> R + UnwindSafe,
{
match catch_unwind(f) {
Ok(v) => GuardAction::Return(v),
Err(e) => match downcast_panic_payload(e) {
CaughtError::PostgresError(ereport) => {
GuardAction::Report(ereport)
}
CaughtError::ErrorReport(ereport) | CaughtError::RustPanic { ereport, .. } => {
GuardAction::Report(ereport)
}
},
}
}
pub(crate) fn downcast_panic_payload(e: Box<dyn Any + Send>) -> CaughtError {
if e.downcast_ref::<CaughtError>().is_some() {
let mut caught = *e.downcast::<CaughtError>().unwrap();
if let CaughtError::PostgresError(ref mut ereport) = caught {
if ereport.inner.location.backtrace.is_none() {
let panic_location = take_panic_location();
ereport.inner.location.backtrace = panic_location.backtrace;
}
}
caught
} else if e.downcast_ref::<ErrorReportWithLevel>().is_some() {
CaughtError::ErrorReport(*e.downcast().unwrap())
} else if e.downcast_ref::<ErrorReport>().is_some() {
CaughtError::ErrorReport(ErrorReportWithLevel {
level: PgLogLevel::ERROR,
inner: *e.downcast().unwrap(),
})
} else if let Some(message) = e.downcast_ref::<&str>() {
CaughtError::RustPanic {
ereport: ErrorReportWithLevel {
level: PgLogLevel::ERROR,
inner: ErrorReport::with_location(
PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
message.to_string(),
take_panic_location(),
),
},
payload: e,
}
} else if let Some(message) = e.downcast_ref::<String>() {
CaughtError::RustPanic {
ereport: ErrorReportWithLevel {
level: PgLogLevel::ERROR,
inner: ErrorReport::with_location(
PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
message.clone(),
take_panic_location(),
),
},
payload: e,
}
} else {
CaughtError::RustPanic {
ereport: ErrorReportWithLevel {
level: PgLogLevel::ERROR,
inner: ErrorReport::with_location(
PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
"Box<Any>",
take_panic_location(),
),
},
payload: e,
}
}
}
fn do_ereport(ereport: ErrorReportWithLevel) {
const PERCENT_S: &CStr = c"%s";
const DEFAULT_DOMAIN: *const ::std::os::raw::c_char = std::ptr::null_mut();
crate::thread_check::check_active_thread();
#[cfg_attr(target_os = "windows", link(name = "postgres"))]
unsafe extern "C-unwind" {
fn errcode(sqlerrcode: ::std::os::raw::c_int) -> ::std::os::raw::c_int;
fn errmsg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
fn errdetail(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
fn errhint(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
fn errcontext_msg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
}
#[cfg_attr(target_os = "windows", link(name = "postgres"))]
unsafe extern "C-unwind" {
fn errstart(elevel: ::std::os::raw::c_int, domain: *const ::std::os::raw::c_char) -> bool;
fn errfinish(
filename: *const ::std::os::raw::c_char,
lineno: ::std::os::raw::c_int,
funcname: *const ::std::os::raw::c_char,
);
}
let level = ereport.level();
let domain_cstring = ereport.inner.domain.as_ref().map(|d| {
std::ffi::CString::new(d.as_str()).expect("domain must not contain interior NUL bytes")
});
let domain_ptr = domain_cstring.as_ref().map_or(DEFAULT_DOMAIN, |c| c.as_ptr());
unsafe {
if errstart(level as _, domain_ptr) {
let sqlerrcode = ereport.sql_error_code();
let message = ereport.message().as_pg_cstr();
let detail = ereport.detail_with_backtrace().as_pg_cstr();
let hint = ereport.hint().as_pg_cstr();
let context = ereport.context_message().as_pg_cstr();
let lineno = ereport.line_number();
let prev_cxt = MemoryContextSwitchTo(crate::ErrorContext);
let file = ereport.file().as_pg_cstr();
let funcname = ereport.function_name().as_pg_cstr();
MemoryContextSwitchTo(prev_cxt);
drop(ereport);
errcode(sqlerrcode as _);
if !message.is_null() {
errmsg(PERCENT_S.as_ptr(), message);
pfree(message.cast());
}
if !detail.is_null() {
errdetail(PERCENT_S.as_ptr(), detail);
pfree(detail.cast());
}
if !hint.is_null() {
errhint(PERCENT_S.as_ptr(), hint);
pfree(hint.cast());
}
if !context.is_null() {
errcontext_msg(PERCENT_S.as_ptr(), context);
pfree(context.cast());
}
if level >= PgLogLevel::ERROR {
crate::submodules::thread_check::active_thread::clear();
}
errfinish(file, lineno as _, funcname);
if level >= PgLogLevel::ERROR {
unreachable_unchecked()
} else {
if !file.is_null() {
pfree(file.cast());
}
if !funcname.is_null() {
pfree(funcname.cast());
}
}
}
}
}