1#![deny(unsafe_op_in_unsafe_fn)]
10#![allow(non_snake_case)]
11
12use core::ffi::CStr;
13use std::any::Any;
14use std::cell::Cell;
15use std::fmt::{Display, Formatter};
16use std::hint::unreachable_unchecked;
17use std::panic::{
18 catch_unwind, panic_any, resume_unwind, Location, PanicInfo, RefUnwindSafe, UnwindSafe,
19};
20
21use crate::elog::PgLogLevel;
22use crate::errcodes::PgSqlErrorCode;
23use crate::{pfree, AsPgCStr, MemoryContextSwitchTo};
24
25pub trait ErrorReportable {
27 type Inner;
28
29 fn report(self) -> Self::Inner;
31}
32
33impl<T, E> ErrorReportable for Result<T, E>
34where
35 E: Any + Display,
36{
37 type Inner = T;
38
39 fn report(self) -> Self::Inner {
45 match self {
46 Ok(value) => value,
47 Err(e) => {
48 let any: Box<&dyn Any> = Box::new(&e);
49 if any.downcast_ref::<ErrorReport>().is_some() {
50 let any: Box<dyn Any> = Box::new(e);
51 any.downcast::<ErrorReport>().unwrap().report(PgLogLevel::ERROR);
52 unreachable!();
53 } else {
54 ereport!(ERROR, PgSqlErrorCode::ERRCODE_DATA_EXCEPTION, &format!("{}", e));
55 }
56 }
57 }
58 }
59}
60
61#[derive(Debug)]
62pub struct ErrorReportLocation {
63 pub(crate) file: String,
64 pub(crate) funcname: Option<String>,
65 pub(crate) line: u32,
66 pub(crate) col: u32,
67 pub(crate) backtrace: Option<std::backtrace::Backtrace>,
68}
69
70impl Default for ErrorReportLocation {
71 fn default() -> Self {
72 Self {
73 file: std::string::String::from("<unknown>"),
74 funcname: None,
75 line: 0,
76 col: 0,
77 backtrace: None,
78 }
79 }
80}
81
82impl Display for ErrorReportLocation {
83 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
84 match &self.funcname {
85 Some(funcname) => {
86 write!(f, "{}, {}:{}:{}", funcname, self.file, self.line, self.col)?;
88 }
89
90 None => {
91 write!(f, "{}:{}:{}", self.file, self.line, self.col)?;
92 }
93 }
94
95 if let Some(backtrace) = &self.backtrace {
96 if backtrace.status() == std::backtrace::BacktraceStatus::Captured {
97 write!(f, "\n{}", backtrace)?;
98 }
99 }
100
101 Ok(())
102 }
103}
104
105impl From<&Location<'_>> for ErrorReportLocation {
106 fn from(location: &Location<'_>) -> Self {
107 Self {
108 file: location.file().to_string(),
109 funcname: None,
110 line: location.line(),
111 col: location.column(),
112 backtrace: None,
113 }
114 }
115}
116
117impl From<&PanicInfo<'_>> for ErrorReportLocation {
118 fn from(pi: &PanicInfo<'_>) -> Self {
119 pi.location().map(|l| l.into()).unwrap_or_default()
120 }
121}
122
123#[derive(Debug)]
126pub struct ErrorReport {
127 pub(crate) sqlerrcode: PgSqlErrorCode,
128 pub(crate) message: String,
129 pub(crate) hint: Option<String>,
130 pub(crate) detail: Option<String>,
131 pub(crate) location: ErrorReportLocation,
132}
133
134impl Display for ErrorReport {
135 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
136 write!(f, "{}: {}", self.sqlerrcode, self.message)?;
137 if let Some(hint) = &self.hint {
138 write!(f, "\nHINT: {}", hint)?;
139 }
140 if let Some(detail) = &self.detail {
141 write!(f, "\nDETAIL: {}", detail)?;
142 }
143 write!(f, "\nLOCATION: {}", self.location)
144 }
145}
146
147#[derive(Debug)]
148pub struct ErrorReportWithLevel {
149 pub(crate) level: PgLogLevel,
150 pub(crate) inner: ErrorReport,
151}
152
153impl ErrorReportWithLevel {
154 fn report(self) {
155 if crate::ERROR <= self.level as _ {
160 panic_any(self)
161 } else {
162 do_ereport(self)
163 }
164 }
165
166 pub fn level(&self) -> PgLogLevel {
168 self.level
169 }
170
171 pub fn sql_error_code(&self) -> PgSqlErrorCode {
173 self.inner.sqlerrcode
174 }
175
176 pub fn message(&self) -> &str {
178 self.inner.message()
179 }
180
181 pub fn detail(&self) -> Option<&str> {
183 self.inner.detail()
184 }
185
186 pub fn detail_with_backtrace(&self) -> Option<String> {
188 match (self.detail(), self.backtrace()) {
189 (Some(d), Some(bt)) if bt.status() == std::backtrace::BacktraceStatus::Captured => {
190 Some(format!("{}\n{}", d, bt))
191 }
192 (Some(d), _) => Some(d.to_string()),
193 (None, Some(bt)) if bt.status() == std::backtrace::BacktraceStatus::Captured => {
194 Some(format!("\n{}", bt))
195 }
196 (None, _) => None,
197 }
198 }
199
200 pub fn hint(&self) -> Option<&str> {
202 self.inner.hint()
203 }
204
205 pub fn file(&self) -> &str {
207 &self.inner.location.file
208 }
209
210 pub fn line_number(&self) -> u32 {
212 self.inner.location.line
213 }
214
215 pub fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
217 self.inner.location.backtrace.as_ref()
218 }
219
220 pub fn function_name(&self) -> Option<&str> {
222 self.inner.location.funcname.as_ref().map(|s| s.as_str())
223 }
224
225 fn context_message(&self) -> Option<String> {
227 None
229 }
230}
231
232impl ErrorReport {
233 #[track_caller]
238 pub fn new<S: Into<String>>(
239 sqlerrcode: PgSqlErrorCode,
240 message: S,
241 funcname: &'static str,
242 ) -> Self {
243 let mut location: ErrorReportLocation = Location::caller().into();
244 location.funcname = Some(funcname.to_string());
245
246 Self { sqlerrcode, message: message.into(), hint: None, detail: None, location }
247 }
248
249 fn with_location<S: Into<String>>(
254 sqlerrcode: PgSqlErrorCode,
255 message: S,
256 location: ErrorReportLocation,
257 ) -> Self {
258 Self { sqlerrcode, message: message.into(), hint: None, detail: None, location }
259 }
260
261 pub fn set_detail<S: Into<String>>(mut self, detail: S) -> Self {
263 self.detail = Some(detail.into());
264 self
265 }
266
267 pub fn set_hint<S: Into<String>>(mut self, hint: S) -> Self {
269 self.hint = Some(hint.into());
270 self
271 }
272
273 pub fn message(&self) -> &str {
275 &self.message
276 }
277
278 pub fn detail(&self) -> Option<&str> {
280 self.detail.as_ref().map(|s| s.as_str())
281 }
282
283 pub fn hint(&self) -> Option<&str> {
285 self.hint.as_ref().map(|s| s.as_str())
286 }
287
288 pub fn report(self, level: PgLogLevel) {
292 ErrorReportWithLevel { level, inner: self }.report()
293 }
294}
295
296thread_local! { static PANIC_LOCATION: Cell<Option<ErrorReportLocation>> = const { Cell::new(None) }}
297
298fn take_panic_location() -> ErrorReportLocation {
299 PANIC_LOCATION.with(|p| p.take().unwrap_or_default())
300}
301
302pub fn register_pg_guard_panic_hook() {
303 std::panic::set_hook(Box::new(|info| {
304 PANIC_LOCATION.with(|thread_local| {
305 thread_local.replace({
306 let mut info: ErrorReportLocation = info.into();
307 info.backtrace = Some(std::backtrace::Backtrace::capture());
308 Some(info)
309 })
310 });
311 }))
312}
313
314#[derive(Debug)]
316pub enum CaughtError {
317 PostgresError(ErrorReportWithLevel),
319
320 ErrorReport(ErrorReportWithLevel),
322
323 RustPanic { ereport: ErrorReportWithLevel, payload: Box<dyn Any + Send> },
325}
326
327impl CaughtError {
328 pub fn rethrow(self) -> ! {
332 resume_unwind(Box::new(self))
335 }
336}
337
338#[derive(Debug)]
339enum GuardAction<R> {
340 Return(R),
341 ReThrow,
342 Report(ErrorReportWithLevel),
343}
344
345#[doc(hidden)]
376pub unsafe fn pgx_extern_c_guard<Func, R: Copy>(f: Func) -> R
377where
378 Func: FnOnce() -> R + UnwindSafe + RefUnwindSafe,
379{
380 match run_guarded(f) {
381 GuardAction::Return(r) => r,
382 GuardAction::ReThrow => {
383 extern "C" {
384 fn pg_re_throw() -> !;
385 }
386 unsafe {
387 crate::CurrentMemoryContext = crate::ErrorContext;
388 pg_re_throw()
389 }
390 }
391 GuardAction::Report(ereport) => {
392 do_ereport(ereport);
393 unreachable!("pgx reported a CaughtError that wasn't raised at ERROR or above");
394 }
395 }
396}
397
398#[inline(never)]
399fn run_guarded<F, R: Copy>(f: F) -> GuardAction<R>
400where
401 F: FnOnce() -> R + UnwindSafe + RefUnwindSafe,
402{
403 match catch_unwind(f) {
404 Ok(v) => GuardAction::Return(v),
405 Err(e) => match downcast_panic_payload(e) {
406 CaughtError::PostgresError(_) => {
407 GuardAction::ReThrow
410 }
411 CaughtError::ErrorReport(ereport) | CaughtError::RustPanic { ereport, .. } => {
412 GuardAction::Report(ereport)
413 }
414 },
415 }
416}
417
418pub(crate) fn downcast_panic_payload(e: Box<dyn Any + Send>) -> CaughtError {
420 if e.downcast_ref::<CaughtError>().is_some() {
421 *e.downcast::<CaughtError>().unwrap()
423 } else if e.downcast_ref::<ErrorReportWithLevel>().is_some() {
424 CaughtError::ErrorReport(*e.downcast().unwrap())
426 } else if e.downcast_ref::<ErrorReport>().is_some() {
427 CaughtError::ErrorReport(ErrorReportWithLevel {
429 level: PgLogLevel::ERROR,
430 inner: *e.downcast().unwrap(),
431 })
432 } else if let Some(message) = e.downcast_ref::<&str>() {
433 CaughtError::RustPanic {
435 ereport: ErrorReportWithLevel {
436 level: PgLogLevel::ERROR,
437 inner: ErrorReport::with_location(
438 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
439 *message,
440 take_panic_location(),
441 ),
442 },
443 payload: e,
444 }
445 } else if let Some(message) = e.downcast_ref::<String>() {
446 CaughtError::RustPanic {
448 ereport: ErrorReportWithLevel {
449 level: PgLogLevel::ERROR,
450 inner: ErrorReport::with_location(
451 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
452 message,
453 take_panic_location(),
454 ),
455 },
456 payload: e,
457 }
458 } else {
459 CaughtError::RustPanic {
461 ereport: ErrorReportWithLevel {
462 level: PgLogLevel::ERROR,
463 inner: ErrorReport::with_location(
464 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
465 "Box<Any>",
466 take_panic_location(),
467 ),
468 },
469 payload: e,
470 }
471 }
472}
473
474fn do_ereport(ereport: ErrorReportWithLevel) {
483 const PERCENT_S: &CStr = unsafe { CStr::from_bytes_with_nul_unchecked(b"%s\0") };
485 const DOMAIN: *const ::std::os::raw::c_char = std::ptr::null_mut();
486
487 crate::thread_check::check_active_thread();
491
492 extern "C" {
498 fn errcode(sqlerrcode: ::std::os::raw::c_int) -> ::std::os::raw::c_int;
499 fn errmsg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
500 fn errdetail(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
501 fn errhint(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
502 fn errcontext_msg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
503 }
504
505 #[inline(always)]
508 #[rustfmt::skip] #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))]
510 fn do_ereport_impl(ereport: ErrorReportWithLevel) {
511
512 extern "C" {
513 fn errstart(elevel: ::std::os::raw::c_int, domain: *const ::std::os::raw::c_char) -> bool;
514 fn errfinish(filename: *const ::std::os::raw::c_char, lineno: ::std::os::raw::c_int, funcname: *const ::std::os::raw::c_char);
515 }
516
517 let level = ereport.level();
518 unsafe {
519 if errstart(level as _, DOMAIN) {
520
521 let sqlerrcode = ereport.sql_error_code();
522 let message = ereport.message().as_pg_cstr();
523 let detail = ereport.detail_with_backtrace().as_pg_cstr();
524 let hint = ereport.hint().as_pg_cstr();
525 let context = ereport.context_message().as_pg_cstr();
526 let lineno = ereport.line_number();
527
528 let prev_cxt = MemoryContextSwitchTo(crate::ErrorContext);
533 let file = ereport.file().as_pg_cstr();
534 let funcname = ereport.function_name().as_pg_cstr();
535 MemoryContextSwitchTo(prev_cxt);
536
537 drop(ereport);
539
540 errcode(sqlerrcode as _);
547 if !message.is_null() { errmsg(PERCENT_S.as_ptr(), message); pfree(message.cast()); }
548 if !detail.is_null() { errdetail(PERCENT_S.as_ptr(), detail); pfree(detail.cast()); }
549 if !hint.is_null() { errhint(PERCENT_S.as_ptr(), hint); pfree(hint.cast()); }
550 if !context.is_null() { errcontext_msg(PERCENT_S.as_ptr(), context); pfree(context.cast()); }
551
552 errfinish(file, lineno as _, funcname);
553
554 if level >= PgLogLevel::ERROR {
555 unreachable_unchecked()
559 } else {
560 if !file.is_null() { pfree(file.cast()); }
562 if !funcname.is_null() { pfree(funcname.cast()); }
563 }
564 }
565 }
566 }
567
568 #[inline(always)]
572 #[rustfmt::skip] #[cfg(any(feature = "pg11", feature = "pg12"))]
574 fn do_ereport_impl(ereport: ErrorReportWithLevel) {
575
576 extern "C" {
577 fn errstart(elevel: ::std::os::raw::c_int, filename: *const ::std::os::raw::c_char, lineno: ::std::os::raw::c_int, funcname: *const ::std::os::raw::c_char, domain: *const ::std::os::raw::c_char) -> bool;
578 fn errfinish(dummy: ::std::os::raw::c_int, ...);
579 }
580
581 unsafe {
582 let prev_cxt = MemoryContextSwitchTo(crate::ErrorContext);
587 let file = ereport.file().as_pg_cstr();
588 let lineno = ereport.line_number();
589 let funcname = ereport.function_name().as_pg_cstr();
590 MemoryContextSwitchTo(prev_cxt);
591
592 let level = ereport.level();
593 if errstart(level as _, file, lineno as _, funcname, DOMAIN) {
594
595 let sqlerrcode = ereport.sql_error_code();
596 let message = ereport.message().as_pg_cstr();
597 let detail = ereport.detail_with_backtrace().as_pg_cstr();
598 let hint = ereport.hint().as_pg_cstr();
599 let context = ereport.context_message().as_pg_cstr();
600
601
602 drop(ereport);
604
605 errcode(sqlerrcode as _);
612 if !message.is_null() { errmsg(PERCENT_S.as_ptr(), message); pfree(message.cast()); }
613 if !detail.is_null() { errdetail(PERCENT_S.as_ptr(), detail); pfree(detail.cast()); }
614 if !hint.is_null() { errhint(PERCENT_S.as_ptr(), hint); pfree(hint.cast()); }
615 if !context.is_null() { errcontext_msg(PERCENT_S.as_ptr(), context); pfree(context.cast()); }
616
617 errfinish(0);
618 }
619
620 if level >= PgLogLevel::ERROR {
621 unreachable_unchecked()
625 } else {
626 if !file.is_null() { pfree(file.cast()); }
628 if !funcname.is_null() { pfree(funcname.cast()); }
629 }
630 }
631 }
632
633 do_ereport_impl(ereport)
634}