use std::{
cell::UnsafeCell,
sync::{Mutex, Once},
thread::ThreadId,
};
use crate::{
exceptions::{PyBaseException, PyTypeError},
ffi,
ffi_ptr_ext::FfiPtrExt,
types::{PyAnyMethods, PyTraceback, PyType},
Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
};
pub(crate) struct PyErrState {
normalized: Once,
normalizing_thread: Mutex<Option<ThreadId>>,
inner: UnsafeCell<Option<PyErrStateInner>>,
}
unsafe impl Send for PyErrState {}
unsafe impl Sync for PyErrState {}
#[cfg(feature = "nightly")]
unsafe impl crate::marker::Ungil for PyErrState {}
impl PyErrState {
pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
Self::from_inner(PyErrStateInner::Lazy(f))
}
pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
}
})))
}
pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
Self::from_inner(PyErrStateInner::Normalized(normalized))
}
pub(crate) fn restore(self, py: Python<'_>) {
self.inner
.into_inner()
.expect("PyErr state should never be invalid outside of normalization")
.restore(py)
}
fn from_inner(inner: PyErrStateInner) -> Self {
Self {
normalized: Once::new(),
normalizing_thread: Mutex::new(None),
inner: UnsafeCell::new(Some(inner)),
}
}
#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if self.normalized.is_completed() {
match unsafe {
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => return n,
_ => unreachable!(),
}
}
self.make_normalized(py)
}
#[cold]
fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
assert!(
!(*thread == std::thread::current().id()),
"Re-entrant normalization of PyErrState detected"
);
}
py.allow_threads(|| {
self.normalized.call_once(|| {
self.normalizing_thread
.lock()
.unwrap()
.replace(std::thread::current().id());
let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};
let normalized_state =
Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));
unsafe {
*self.inner.get() = Some(normalized_state);
}
})
});
match unsafe {
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
}
}
pub(crate) struct PyErrStateNormalized {
#[cfg(not(Py_3_12))]
ptype: Py<PyType>,
pub pvalue: Py<PyBaseException>,
#[cfg(not(Py_3_12))]
ptraceback: Option<Py<PyTraceback>>,
}
impl PyErrStateNormalized {
pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
Self {
#[cfg(not(Py_3_12))]
ptype: pvalue.get_type().into(),
#[cfg(not(Py_3_12))]
ptraceback: unsafe {
Py::from_owned_ptr_or_opt(
pvalue.py(),
ffi::PyException_GetTraceback(pvalue.as_ptr()),
)
},
pvalue: pvalue.into(),
}
}
#[cfg(not(Py_3_12))]
pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
self.ptype.bind(py).clone()
}
#[cfg(Py_3_12)]
pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
self.pvalue.bind(py).get_type()
}
#[cfg(not(Py_3_12))]
pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
self.ptraceback
.as_ref()
.map(|traceback| traceback.bind(py).clone())
}
#[cfg(Py_3_12)]
pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
unsafe {
ffi::PyException_GetTraceback(self.pvalue.as_ptr())
.assume_owned_or_opt(py)
.map(|b| b.downcast_into_unchecked())
}
}
pub(crate) fn take(py: Python<'_>) -> Option<PyErrStateNormalized> {
#[cfg(Py_3_12)]
{
unsafe { ffi::PyErr_GetRaisedException().assume_owned_or_opt(py) }.map(|pvalue| {
PyErrStateNormalized {
pvalue: unsafe { pvalue.downcast_into_unchecked() }.unbind(),
}
})
}
#[cfg(not(Py_3_12))]
{
let (ptype, pvalue, ptraceback) = unsafe {
let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
let mut ptraceback: *mut ffi::PyObject = std::ptr::null_mut();
ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
if !ptype.is_null() {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
}
(
ptype
.assume_owned_or_opt(py)
.map(|b| b.downcast_into_unchecked()),
pvalue
.assume_owned_or_opt(py)
.map(|b| b.downcast_into_unchecked()),
ptraceback
.assume_owned_or_opt(py)
.map(|b| b.downcast_into_unchecked()),
)
};
ptype.map(|ptype| PyErrStateNormalized {
ptype: ptype.unbind(),
pvalue: pvalue.expect("normalized exception value missing").unbind(),
ptraceback: ptraceback.map(Bound::unbind),
})
}
}
#[cfg(not(Py_3_12))]
unsafe fn from_normalized_ffi_tuple(
py: Python<'_>,
ptype: *mut ffi::PyObject,
pvalue: *mut ffi::PyObject,
ptraceback: *mut ffi::PyObject,
) -> Self {
PyErrStateNormalized {
ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
}
}
pub fn clone_ref(&self, py: Python<'_>) -> Self {
Self {
#[cfg(not(Py_3_12))]
ptype: self.ptype.clone_ref(py),
pvalue: self.pvalue.clone_ref(py),
#[cfg(not(Py_3_12))]
ptraceback: self
.ptraceback
.as_ref()
.map(|ptraceback| ptraceback.clone_ref(py)),
}
}
}
pub(crate) struct PyErrStateLazyFnOutput {
pub(crate) ptype: PyObject,
pub(crate) pvalue: PyObject,
}
pub(crate) type PyErrStateLazyFn =
dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;
enum PyErrStateInner {
Lazy(Box<PyErrStateLazyFn>),
Normalized(PyErrStateNormalized),
}
impl PyErrStateInner {
fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
match self {
#[cfg(not(Py_3_12))]
PyErrStateInner::Lazy(lazy) => {
let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
unsafe {
PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
}
}
#[cfg(Py_3_12)]
PyErrStateInner::Lazy(lazy) => {
raise_lazy(py, lazy);
PyErrStateNormalized::take(py)
.expect("exception missing after writing to the interpreter")
}
PyErrStateInner::Normalized(normalized) => normalized,
}
}
#[cfg(not(Py_3_12))]
fn restore(self, py: Python<'_>) {
let (ptype, pvalue, ptraceback) = match self {
PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
PyErrStateInner::Normalized(PyErrStateNormalized {
ptype,
pvalue,
ptraceback,
}) => (
ptype.into_ptr(),
pvalue.into_ptr(),
ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
),
};
unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
}
#[cfg(Py_3_12)]
fn restore(self, py: Python<'_>) {
match self {
PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
ffi::PyErr_SetRaisedException(pvalue.into_ptr())
},
}
}
}
#[cfg(not(Py_3_12))]
fn lazy_into_normalized_ffi_tuple(
py: Python<'_>,
lazy: Box<PyErrStateLazyFn>,
) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
raise_lazy(py, lazy);
let mut ptype = std::ptr::null_mut();
let mut pvalue = std::ptr::null_mut();
let mut ptraceback = std::ptr::null_mut();
unsafe {
ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
}
(ptype, pvalue, ptraceback)
}
fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
unsafe {
if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
ffi::PyErr_SetString(
PyTypeError::type_object_raw(py).cast(),
ffi::c_str!("exceptions must derive from BaseException").as_ptr(),
)
} else {
ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
}
}
}
#[cfg(test)]
mod tests {
use crate::{
exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
};
#[test]
#[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
fn test_reentrant_normalization() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
struct RecursiveArgs;
impl PyErrArguments for RecursiveArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
ERR.get(py)
.expect("is set just below")
.value(py)
.clone()
.into()
}
}
Python::with_gil(|py| {
ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
ERR.get(py).expect("is set just above").value(py);
})
}
#[test]
#[cfg(not(target_arch = "wasm32"))] fn test_no_deadlock_thread_switch() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
struct GILSwitchArgs;
impl PyErrArguments for GILSwitchArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
py.allow_threads(|| {
std::thread::sleep(std::time::Duration::from_millis(10));
});
py.None()
}
}
Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());
let handles = (0..10)
.map(|_| {
std::thread::spawn(|| {
Python::with_gil(|py| {
ERR.get(py).expect("is set just above").value(py);
});
})
})
.collect::<Vec<_>>();
for handle in handles {
handle.join().unwrap();
}
Python::with_gil(|py| {
assert!(ERR
.get(py)
.expect("is set above")
.is_instance_of::<PyValueError>(py))
});
}
}