#![deny(unsafe_op_in_unsafe_fn)]
#![allow(non_snake_case)]
use core::ffi::CStr;
use std::any::Any;
use std::cell::Cell;
use std::fmt::{Display, Formatter};
use std::hint::unreachable_unchecked;
use std::panic::{
catch_unwind, panic_any, resume_unwind, Location, PanicInfo, RefUnwindSafe, UnwindSafe,
};
use crate::elog::PgLogLevel;
use crate::errcodes::PgSqlErrorCode;
use crate::{pfree, AsPgCStr, MemoryContextSwitchTo};
pub trait ErrorReportable {
type Inner;
fn report(self) -> Self::Inner;
}
impl<T, E> ErrorReportable for Result<T, E>
where
E: Any + Display,
{
type Inner = T;
fn report(self) -> Self::Inner {
match self {
Ok(value) => value,
Err(e) => {
let any: Box<&dyn Any> = Box::new(&e);
if any.downcast_ref::<ErrorReport>().is_some() {
let any: Box<dyn Any> = Box::new(e);
any.downcast::<ErrorReport>().unwrap().report(PgLogLevel::ERROR);
unreachable!();
} else {
ereport!(ERROR, PgSqlErrorCode::ERRCODE_DATA_EXCEPTION, &format!("{}", e));
}
}
}
}
}
#[derive(Debug)]
pub struct ErrorReportLocation {
pub(crate) file: String,
pub(crate) funcname: Option<String>,
pub(crate) line: u32,
pub(crate) col: u32,
pub(crate) backtrace: Option<std::backtrace::Backtrace>,
}
impl Default for ErrorReportLocation {
fn default() -> Self {
Self {
file: std::string::String::from("<unknown>"),
funcname: None,
line: 0,
col: 0,
backtrace: None,
}
}
}
impl Display for ErrorReportLocation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match &self.funcname {
Some(funcname) => {
write!(f, "{}, {}:{}:{}", funcname, self.file, self.line, self.col)?;
}
None => {
write!(f, "{}:{}:{}", self.file, self.line, self.col)?;
}
}
if let Some(backtrace) = &self.backtrace {
if backtrace.status() == std::backtrace::BacktraceStatus::Captured {
write!(f, "\n{}", backtrace)?;
}
}
Ok(())
}
}
impl From<&Location<'_>> for ErrorReportLocation {
fn from(location: &Location<'_>) -> Self {
Self {
file: location.file().to_string(),
funcname: None,
line: location.line(),
col: location.column(),
backtrace: None,
}
}
}
impl From<&PanicInfo<'_>> for ErrorReportLocation {
fn from(pi: &PanicInfo<'_>) -> Self {
pi.location().map(|l| l.into()).unwrap_or_default()
}
}
#[derive(Debug)]
pub struct ErrorReport {
pub(crate) sqlerrcode: PgSqlErrorCode,
pub(crate) message: String,
pub(crate) hint: Option<String>,
pub(crate) detail: Option<String>,
pub(crate) location: ErrorReportLocation,
}
impl Display for ErrorReport {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.sqlerrcode, self.message)?;
if let Some(hint) = &self.hint {
write!(f, "\nHINT: {}", hint)?;
}
if let Some(detail) = &self.detail {
write!(f, "\nDETAIL: {}", detail)?;
}
write!(f, "\nLOCATION: {}", self.location)
}
}
#[derive(Debug)]
pub struct ErrorReportWithLevel {
pub(crate) level: PgLogLevel,
pub(crate) inner: ErrorReport,
}
impl ErrorReportWithLevel {
fn report(self) {
if crate::ERROR <= self.level as _ {
panic_any(self)
} else {
do_ereport(self)
}
}
pub fn level(&self) -> PgLogLevel {
self.level
}
pub fn sql_error_code(&self) -> PgSqlErrorCode {
self.inner.sqlerrcode
}
pub fn message(&self) -> &str {
self.inner.message()
}
pub fn detail(&self) -> Option<&str> {
self.inner.detail()
}
pub fn detail_with_backtrace(&self) -> Option<String> {
match (self.detail(), self.backtrace()) {
(Some(d), Some(bt)) if bt.status() == std::backtrace::BacktraceStatus::Captured => {
Some(format!("{}\n{}", d, bt))
}
(Some(d), _) => Some(d.to_string()),
(None, Some(bt)) if bt.status() == std::backtrace::BacktraceStatus::Captured => {
Some(format!("\n{}", bt))
}
(None, _) => None,
}
}
pub fn hint(&self) -> Option<&str> {
self.inner.hint()
}
pub fn file(&self) -> &str {
&self.inner.location.file
}
pub fn line_number(&self) -> u32 {
self.inner.location.line
}
pub fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
self.inner.location.backtrace.as_ref()
}
pub fn function_name(&self) -> Option<&str> {
self.inner.location.funcname.as_ref().map(|s| s.as_str())
}
fn context_message(&self) -> Option<String> {
None
}
}
impl ErrorReport {
#[track_caller]
pub fn new<S: Into<String>>(
sqlerrcode: PgSqlErrorCode,
message: S,
funcname: &'static str,
) -> Self {
let mut location: ErrorReportLocation = Location::caller().into();
location.funcname = Some(funcname.to_string());
Self { sqlerrcode, message: message.into(), hint: None, detail: None, location }
}
fn with_location<S: Into<String>>(
sqlerrcode: PgSqlErrorCode,
message: S,
location: ErrorReportLocation,
) -> Self {
Self { sqlerrcode, message: message.into(), hint: None, detail: None, location }
}
pub fn set_detail<S: Into<String>>(mut self, detail: S) -> Self {
self.detail = Some(detail.into());
self
}
pub fn set_hint<S: Into<String>>(mut self, hint: S) -> Self {
self.hint = Some(hint.into());
self
}
pub fn message(&self) -> &str {
&self.message
}
pub fn detail(&self) -> Option<&str> {
self.detail.as_ref().map(|s| s.as_str())
}
pub fn hint(&self) -> Option<&str> {
self.hint.as_ref().map(|s| s.as_str())
}
pub fn report(self, level: PgLogLevel) {
ErrorReportWithLevel { level, inner: self }.report()
}
}
thread_local! { static PANIC_LOCATION: Cell<Option<ErrorReportLocation>> = const { Cell::new(None) }}
fn take_panic_location() -> ErrorReportLocation {
PANIC_LOCATION.with(|p| p.take().unwrap_or_default())
}
pub fn register_pg_guard_panic_hook() {
std::panic::set_hook(Box::new(|info| {
PANIC_LOCATION.with(|thread_local| {
thread_local.replace({
let mut info: ErrorReportLocation = info.into();
info.backtrace = Some(std::backtrace::Backtrace::capture());
Some(info)
})
});
}))
}
#[derive(Debug)]
pub enum CaughtError {
PostgresError(ErrorReportWithLevel),
ErrorReport(ErrorReportWithLevel),
RustPanic { ereport: ErrorReportWithLevel, payload: Box<dyn Any + Send> },
}
impl CaughtError {
pub fn rethrow(self) -> ! {
resume_unwind(Box::new(self))
}
}
#[derive(Debug)]
enum GuardAction<R> {
Return(R),
ReThrow,
Report(ErrorReportWithLevel),
}
#[doc(hidden)]
pub unsafe fn pgx_extern_c_guard<Func, R: Copy>(f: Func) -> R
where
Func: FnOnce() -> R + UnwindSafe + RefUnwindSafe,
{
match run_guarded(f) {
GuardAction::Return(r) => r,
GuardAction::ReThrow => {
extern "C" {
fn pg_re_throw() -> !;
}
unsafe {
crate::CurrentMemoryContext = crate::ErrorContext;
pg_re_throw()
}
}
GuardAction::Report(ereport) => {
do_ereport(ereport);
unreachable!("pgx reported a CaughtError that wasn't raised at ERROR or above");
}
}
}
#[inline(never)]
fn run_guarded<F, R: Copy>(f: F) -> GuardAction<R>
where
F: FnOnce() -> R + UnwindSafe + RefUnwindSafe,
{
match catch_unwind(f) {
Ok(v) => GuardAction::Return(v),
Err(e) => match downcast_panic_payload(e) {
CaughtError::PostgresError(_) => {
GuardAction::ReThrow
}
CaughtError::ErrorReport(ereport) | CaughtError::RustPanic { ereport, .. } => {
GuardAction::Report(ereport)
}
},
}
}
pub(crate) fn downcast_panic_payload(e: Box<dyn Any + Send>) -> CaughtError {
if e.downcast_ref::<CaughtError>().is_some() {
*e.downcast::<CaughtError>().unwrap()
} else if e.downcast_ref::<ErrorReportWithLevel>().is_some() {
CaughtError::ErrorReport(*e.downcast().unwrap())
} else if e.downcast_ref::<ErrorReport>().is_some() {
CaughtError::ErrorReport(ErrorReportWithLevel {
level: PgLogLevel::ERROR,
inner: *e.downcast().unwrap(),
})
} else if let Some(message) = e.downcast_ref::<&str>() {
CaughtError::RustPanic {
ereport: ErrorReportWithLevel {
level: PgLogLevel::ERROR,
inner: ErrorReport::with_location(
PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
*message,
take_panic_location(),
),
},
payload: e,
}
} else if let Some(message) = e.downcast_ref::<String>() {
CaughtError::RustPanic {
ereport: ErrorReportWithLevel {
level: PgLogLevel::ERROR,
inner: ErrorReport::with_location(
PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
message,
take_panic_location(),
),
},
payload: e,
}
} else {
CaughtError::RustPanic {
ereport: ErrorReportWithLevel {
level: PgLogLevel::ERROR,
inner: ErrorReport::with_location(
PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
"Box<Any>",
take_panic_location(),
),
},
payload: e,
}
}
}
fn do_ereport(ereport: ErrorReportWithLevel) {
const PERCENT_S: &CStr = unsafe { CStr::from_bytes_with_nul_unchecked(b"%s\0") };
const DOMAIN: *const ::std::os::raw::c_char = std::ptr::null_mut();
crate::thread_check::check_active_thread();
extern "C" {
fn errcode(sqlerrcode: ::std::os::raw::c_int) -> ::std::os::raw::c_int;
fn errmsg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
fn errdetail(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
fn errhint(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
fn errcontext_msg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
}
#[inline(always)]
#[rustfmt::skip] #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))]
fn do_ereport_impl(ereport: ErrorReportWithLevel) {
extern "C" {
fn errstart(elevel: ::std::os::raw::c_int, domain: *const ::std::os::raw::c_char) -> bool;
fn errfinish(filename: *const ::std::os::raw::c_char, lineno: ::std::os::raw::c_int, funcname: *const ::std::os::raw::c_char);
}
let level = ereport.level();
unsafe {
if errstart(level as _, DOMAIN) {
let sqlerrcode = ereport.sql_error_code();
let message = ereport.message().as_pg_cstr();
let detail = ereport.detail_with_backtrace().as_pg_cstr();
let hint = ereport.hint().as_pg_cstr();
let context = ereport.context_message().as_pg_cstr();
let lineno = ereport.line_number();
let prev_cxt = MemoryContextSwitchTo(crate::ErrorContext);
let file = ereport.file().as_pg_cstr();
let funcname = ereport.function_name().as_pg_cstr();
MemoryContextSwitchTo(prev_cxt);
drop(ereport);
errcode(sqlerrcode as _);
if !message.is_null() { errmsg(PERCENT_S.as_ptr(), message); pfree(message.cast()); }
if !detail.is_null() { errdetail(PERCENT_S.as_ptr(), detail); pfree(detail.cast()); }
if !hint.is_null() { errhint(PERCENT_S.as_ptr(), hint); pfree(hint.cast()); }
if !context.is_null() { errcontext_msg(PERCENT_S.as_ptr(), context); pfree(context.cast()); }
errfinish(file, lineno as _, funcname);
if level >= PgLogLevel::ERROR {
unreachable_unchecked()
} else {
if !file.is_null() { pfree(file.cast()); }
if !funcname.is_null() { pfree(funcname.cast()); }
}
}
}
}
#[inline(always)]
#[rustfmt::skip] #[cfg(any(feature = "pg11", feature = "pg12"))]
fn do_ereport_impl(ereport: ErrorReportWithLevel) {
extern "C" {
fn errstart(elevel: ::std::os::raw::c_int, filename: *const ::std::os::raw::c_char, lineno: ::std::os::raw::c_int, funcname: *const ::std::os::raw::c_char, domain: *const ::std::os::raw::c_char) -> bool;
fn errfinish(dummy: ::std::os::raw::c_int, ...);
}
unsafe {
let prev_cxt = MemoryContextSwitchTo(crate::ErrorContext);
let file = ereport.file().as_pg_cstr();
let lineno = ereport.line_number();
let funcname = ereport.function_name().as_pg_cstr();
MemoryContextSwitchTo(prev_cxt);
let level = ereport.level();
if errstart(level as _, file, lineno as _, funcname, DOMAIN) {
let sqlerrcode = ereport.sql_error_code();
let message = ereport.message().as_pg_cstr();
let detail = ereport.detail_with_backtrace().as_pg_cstr();
let hint = ereport.hint().as_pg_cstr();
let context = ereport.context_message().as_pg_cstr();
drop(ereport);
errcode(sqlerrcode as _);
if !message.is_null() { errmsg(PERCENT_S.as_ptr(), message); pfree(message.cast()); }
if !detail.is_null() { errdetail(PERCENT_S.as_ptr(), detail); pfree(detail.cast()); }
if !hint.is_null() { errhint(PERCENT_S.as_ptr(), hint); pfree(hint.cast()); }
if !context.is_null() { errcontext_msg(PERCENT_S.as_ptr(), context); pfree(context.cast()); }
errfinish(0);
}
if level >= PgLogLevel::ERROR {
unreachable_unchecked()
} else {
if !file.is_null() { pfree(file.cast()); }
if !funcname.is_null() { pfree(funcname.cast()); }
}
}
}
do_ereport_impl(ereport)
}