1#![deny(unsafe_op_in_unsafe_fn)]
11#![allow(non_snake_case)]
12
13use core::ffi::CStr;
14use std::any::Any;
15use std::cell::Cell;
16use std::fmt::{Display, Formatter};
17use std::hint::unreachable_unchecked;
18use std::panic::{
19 catch_unwind, panic_any, resume_unwind, AssertUnwindSafe, Location, PanicHookInfo, UnwindSafe,
20};
21
22use crate::elog::PgLogLevel;
23use crate::errcodes::PgSqlErrorCode;
24use crate::{pfree, AsPgCStr, MemoryContextSwitchTo};
25
26pub trait ErrorReportable {
28 type Inner;
29
30 fn unwrap_or_report(self) -> Self::Inner;
32}
33
34impl<T, E> ErrorReportable for Result<T, E>
35where
36 E: Any + Display,
37{
38 type Inner = T;
39
40 fn unwrap_or_report(self) -> Self::Inner {
46 self.unwrap_or_else(|e| {
47 let any: Box<&dyn Any> = Box::new(&e);
48 if any.downcast_ref::<ErrorReport>().is_some() {
49 let any: Box<dyn Any> = Box::new(e);
50 any.downcast::<ErrorReport>().unwrap().report(PgLogLevel::ERROR);
51 unreachable!();
52 } else {
53 ereport!(ERROR, PgSqlErrorCode::ERRCODE_DATA_EXCEPTION, &format!("{e}"));
54 }
55 })
56 }
57}
58
59#[derive(Debug)]
60pub struct ErrorReportLocation {
61 pub(crate) file: String,
62 pub(crate) funcname: Option<String>,
63 pub(crate) line: u32,
64 pub(crate) col: u32,
65 pub(crate) backtrace: Option<std::backtrace::Backtrace>,
66}
67
68impl Default for ErrorReportLocation {
69 fn default() -> Self {
70 Self {
71 file: std::string::String::from("<unknown>"),
72 funcname: None,
73 line: 0,
74 col: 0,
75 backtrace: None,
76 }
77 }
78}
79
80impl Display for ErrorReportLocation {
81 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82 match &self.funcname {
83 Some(funcname) => {
84 write!(f, "{}, {}:{}:{}", funcname, self.file, self.line, self.col)?;
86 }
87
88 None => {
89 write!(f, "{}:{}:{}", self.file, self.line, self.col)?;
90 }
91 }
92
93 if let Some(backtrace) = &self.backtrace {
94 if backtrace.status() == std::backtrace::BacktraceStatus::Captured {
95 write!(f, "\n{backtrace}")?;
96 }
97 }
98
99 Ok(())
100 }
101}
102
103impl From<&Location<'_>> for ErrorReportLocation {
104 fn from(location: &Location<'_>) -> Self {
105 Self {
106 file: location.file().to_string(),
107 funcname: None,
108 line: location.line(),
109 col: location.column(),
110 backtrace: None,
111 }
112 }
113}
114
115impl From<&PanicHookInfo<'_>> for ErrorReportLocation {
116 fn from(pi: &PanicHookInfo<'_>) -> Self {
117 pi.location().map(|l| l.into()).unwrap_or_default()
118 }
119}
120
121#[derive(Debug)]
124pub struct ErrorReport {
125 pub(crate) sqlerrcode: PgSqlErrorCode,
126 pub(crate) message: String,
127 pub(crate) hint: Option<String>,
128 pub(crate) detail: Option<String>,
129 pub(crate) location: ErrorReportLocation,
130}
131
132impl Display for ErrorReport {
133 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
134 write!(f, "{}: {}", self.sqlerrcode, self.message)?;
135 if let Some(hint) = &self.hint {
136 write!(f, "\nHINT: {hint}")?;
137 }
138 if let Some(detail) = &self.detail {
139 write!(f, "\nDETAIL: {detail}")?;
140 }
141 write!(f, "\nLOCATION: {}", self.location)
142 }
143}
144
145#[derive(Debug)]
146pub struct ErrorReportWithLevel {
147 pub(crate) level: PgLogLevel,
148 pub(crate) inner: ErrorReport,
149}
150
151impl ErrorReportWithLevel {
152 fn report(self) {
153 match self.level {
154 PgLogLevel::ERROR => panic_any(self),
156
157 PgLogLevel::FATAL | PgLogLevel::PANIC => {
159 do_ereport(self);
160 unreachable!()
161 }
162
163 _ => do_ereport(self),
165 }
166 }
167
168 pub fn level(&self) -> PgLogLevel {
170 self.level
171 }
172
173 pub fn sql_error_code(&self) -> PgSqlErrorCode {
175 self.inner.sqlerrcode
176 }
177
178 pub fn message(&self) -> &str {
180 self.inner.message()
181 }
182
183 pub fn detail(&self) -> Option<&str> {
185 self.inner.detail()
186 }
187
188 pub fn detail_with_backtrace(&self) -> Option<String> {
190 match (self.detail(), self.backtrace()) {
191 (Some(detail), Some(bt))
192 if bt.status() == std::backtrace::BacktraceStatus::Captured =>
193 {
194 Some(format!("{detail}\n{bt}"))
195 }
196 (Some(d), _) => Some(d.to_string()),
197 (None, Some(bt)) if bt.status() == std::backtrace::BacktraceStatus::Captured => {
198 Some(format!("\n{bt}"))
199 }
200 (None, _) => None,
201 }
202 }
203
204 pub fn hint(&self) -> Option<&str> {
206 self.inner.hint()
207 }
208
209 pub fn file(&self) -> &str {
211 &self.inner.location.file
212 }
213
214 pub fn line_number(&self) -> u32 {
216 self.inner.location.line
217 }
218
219 pub fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
221 self.inner.location.backtrace.as_ref()
222 }
223
224 pub fn function_name(&self) -> Option<&str> {
226 self.inner.location.funcname.as_deref()
227 }
228
229 fn context_message(&self) -> Option<String> {
231 None
233 }
234}
235
236impl ErrorReport {
237 #[track_caller]
242 pub fn new<S: Into<String>>(
243 sqlerrcode: PgSqlErrorCode,
244 message: S,
245 funcname: &'static str,
246 ) -> Self {
247 let mut location: ErrorReportLocation = Location::caller().into();
248 location.funcname = Some(funcname.to_string());
249
250 Self { sqlerrcode, message: message.into(), hint: None, detail: None, location }
251 }
252
253 fn with_location<S: Into<String>>(
258 sqlerrcode: PgSqlErrorCode,
259 message: S,
260 location: ErrorReportLocation,
261 ) -> Self {
262 Self { sqlerrcode, message: message.into(), hint: None, detail: None, location }
263 }
264
265 pub fn set_detail<S: Into<String>>(mut self, detail: S) -> Self {
267 self.detail = Some(detail.into());
268 self
269 }
270
271 pub fn set_hint<S: Into<String>>(mut self, hint: S) -> Self {
273 self.hint = Some(hint.into());
274 self
275 }
276
277 pub fn message(&self) -> &str {
279 &self.message
280 }
281
282 pub fn detail(&self) -> Option<&str> {
284 self.detail.as_deref()
285 }
286
287 pub fn hint(&self) -> Option<&str> {
289 self.hint.as_deref()
290 }
291
292 pub fn report(self, level: PgLogLevel) {
296 ErrorReportWithLevel { level, inner: self }.report()
297 }
298}
299
300thread_local! { static PANIC_LOCATION: Cell<Option<ErrorReportLocation>> = const { Cell::new(None) }}
301
302fn take_panic_location() -> ErrorReportLocation {
303 PANIC_LOCATION.with(|p| p.take().unwrap_or_default())
304}
305
306pub fn register_pg_guard_panic_hook() {
307 use super::thread_check::is_os_main_thread;
308
309 let default_hook = std::panic::take_hook();
310 std::panic::set_hook(Box::new(move |info: _| {
311 if is_os_main_thread() == Some(true) {
312 PANIC_LOCATION.with(|thread_local| {
314 thread_local.replace({
315 let mut info: ErrorReportLocation = info.into();
316 info.backtrace = Some(std::backtrace::Backtrace::capture());
317 Some(info)
318 })
319 });
320 } else {
321 default_hook(info)
323 }
324 }))
325}
326
327#[derive(Debug)]
329pub enum CaughtError {
330 PostgresError(ErrorReportWithLevel),
332
333 ErrorReport(ErrorReportWithLevel),
335
336 RustPanic { ereport: ErrorReportWithLevel, payload: Box<dyn Any + Send> },
338}
339
340impl CaughtError {
341 pub fn rethrow(self) -> ! {
345 resume_unwind(Box::new(self))
348 }
349}
350
351#[derive(Debug)]
352enum GuardAction<R> {
353 Return(R),
354 ReThrow,
355 Report(ErrorReportWithLevel),
356}
357
358#[doc(hidden)]
388pub unsafe fn pgrx_extern_c_guard<Func, R>(f: Func) -> R
391where
392 Func: FnOnce() -> R,
393{
394 match unsafe { run_guarded(AssertUnwindSafe(f)) } {
395 GuardAction::Return(r) => r,
396 GuardAction::ReThrow => {
397 #[cfg_attr(target_os = "windows", link(name = "postgres"))]
398 unsafe extern "C-unwind" {
399 fn pg_re_throw() -> !;
400 }
401 unsafe {
402 crate::CurrentMemoryContext = crate::ErrorContext;
403 pg_re_throw()
404 }
405 }
406 GuardAction::Report(ereport) => {
407 do_ereport(ereport);
408 unreachable!("pgrx reported a CaughtError that wasn't raised at ERROR or above");
409 }
410 }
411}
412
413#[inline(never)]
415unsafe fn run_guarded<F, R>(f: F) -> GuardAction<R>
416where
417 F: FnOnce() -> R + UnwindSafe,
418{
419 match catch_unwind(f) {
420 Ok(v) => GuardAction::Return(v),
421 Err(e) => match downcast_panic_payload(e) {
422 CaughtError::PostgresError(_) => {
423 GuardAction::ReThrow
426 }
427 CaughtError::ErrorReport(ereport) | CaughtError::RustPanic { ereport, .. } => {
428 GuardAction::Report(ereport)
429 }
430 },
431 }
432}
433
434pub(crate) fn downcast_panic_payload(e: Box<dyn Any + Send>) -> CaughtError {
436 if e.downcast_ref::<CaughtError>().is_some() {
437 *e.downcast::<CaughtError>().unwrap()
439 } else if e.downcast_ref::<ErrorReportWithLevel>().is_some() {
440 CaughtError::ErrorReport(*e.downcast().unwrap())
442 } else if e.downcast_ref::<ErrorReport>().is_some() {
443 CaughtError::ErrorReport(ErrorReportWithLevel {
445 level: PgLogLevel::ERROR,
446 inner: *e.downcast().unwrap(),
447 })
448 } else if let Some(message) = e.downcast_ref::<&str>() {
449 CaughtError::RustPanic {
451 ereport: ErrorReportWithLevel {
452 level: PgLogLevel::ERROR,
453 inner: ErrorReport::with_location(
454 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
455 *message,
456 take_panic_location(),
457 ),
458 },
459 payload: e,
460 }
461 } else if let Some(message) = e.downcast_ref::<String>() {
462 CaughtError::RustPanic {
464 ereport: ErrorReportWithLevel {
465 level: PgLogLevel::ERROR,
466 inner: ErrorReport::with_location(
467 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
468 message,
469 take_panic_location(),
470 ),
471 },
472 payload: e,
473 }
474 } else {
475 CaughtError::RustPanic {
477 ereport: ErrorReportWithLevel {
478 level: PgLogLevel::ERROR,
479 inner: ErrorReport::with_location(
480 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
481 "Box<Any>",
482 take_panic_location(),
483 ),
484 },
485 payload: e,
486 }
487 }
488}
489
490fn do_ereport(ereport: ErrorReportWithLevel) {
499 const PERCENT_S: &CStr = c"%s";
500 const DOMAIN: *const ::std::os::raw::c_char = std::ptr::null_mut();
501
502 crate::thread_check::check_active_thread();
506
507 #[cfg_attr(target_os = "windows", link(name = "postgres"))]
513 unsafe extern "C-unwind" {
514 fn errcode(sqlerrcode: ::std::os::raw::c_int) -> ::std::os::raw::c_int;
515 fn errmsg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
516 fn errdetail(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
517 fn errhint(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
518 fn errcontext_msg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
519 }
520
521 #[cfg_attr(target_os = "windows", link(name = "postgres"))]
523 unsafe extern "C-unwind" {
524 fn errstart(elevel: ::std::os::raw::c_int, domain: *const ::std::os::raw::c_char) -> bool;
525 fn errfinish(
526 filename: *const ::std::os::raw::c_char,
527 lineno: ::std::os::raw::c_int,
528 funcname: *const ::std::os::raw::c_char,
529 );
530 }
531
532 let level = ereport.level();
533 unsafe {
534 if errstart(level as _, DOMAIN) {
535 let sqlerrcode = ereport.sql_error_code();
536 let message = ereport.message().as_pg_cstr();
537 let detail = ereport.detail_with_backtrace().as_pg_cstr();
538 let hint = ereport.hint().as_pg_cstr();
539 let context = ereport.context_message().as_pg_cstr();
540 let lineno = ereport.line_number();
541
542 let prev_cxt = MemoryContextSwitchTo(crate::ErrorContext);
547 let file = ereport.file().as_pg_cstr();
548 let funcname = ereport.function_name().as_pg_cstr();
549 MemoryContextSwitchTo(prev_cxt);
550
551 drop(ereport);
553
554 errcode(sqlerrcode as _);
561 if !message.is_null() {
562 errmsg(PERCENT_S.as_ptr(), message);
563 pfree(message.cast());
564 }
565 if !detail.is_null() {
566 errdetail(PERCENT_S.as_ptr(), detail);
567 pfree(detail.cast());
568 }
569 if !hint.is_null() {
570 errhint(PERCENT_S.as_ptr(), hint);
571 pfree(hint.cast());
572 }
573 if !context.is_null() {
574 errcontext_msg(PERCENT_S.as_ptr(), context);
575 pfree(context.cast());
576 }
577
578 errfinish(file, lineno as _, funcname);
579
580 if level >= PgLogLevel::ERROR {
581 unreachable_unchecked()
585 } else {
586 if !file.is_null() {
588 pfree(file.cast());
589 }
590 if !funcname.is_null() {
591 pfree(funcname.cast());
592 }
593 }
594 }
595 }
596}