use std::{borrow::Cow, error::Error, fmt, io};
use pyo3::{exceptions::PyException, intern, prelude::*, sync::PyOnceLock, types::IntoPyDict};
pub struct PyErrChain {
err: PyErr,
cause: Option<Box<Self>>,
}
impl PyErrChain {
#[must_use]
#[inline]
pub fn new<T: Into<Box<dyn Error + 'static>>>(py: Python, err: T) -> Self {
Self::from_pyerr(py, Self::pyerr_from_err(py, err))
}
#[must_use]
#[inline]
pub fn new_with_translator<
E: Into<Box<dyn Error + 'static>>,
T: AnyErrorToPyErr,
M: MapErrorToPyErr,
>(
py: Python,
err: E,
) -> Self {
Self::from_pyerr(py, Self::pyerr_from_err_with_translator::<E, T, M>(py, err))
}
#[must_use]
#[inline]
pub fn pyerr_from_err<T: Into<Box<dyn Error + 'static>>>(py: Python, err: T) -> PyErr {
Self::pyerr_from_err_with_translator::<T, ErrorNoPyErr, DowncastToPyErr>(py, err)
}
#[must_use]
pub fn pyerr_from_err_with_translator<
E: Into<Box<dyn Error + 'static>>,
T: AnyErrorToPyErr,
M: MapErrorToPyErr,
>(
py: Python,
err: E,
) -> PyErr {
let err: Box<dyn Error + 'static> = err.into();
let err = match M::try_map(py, err, |err: Box<Self>| err.into_pyerr()) {
Ok(err) => return err,
Err(err) => err,
};
let err = match M::try_map(py, err, |err: Box<PyErr>| *err) {
Ok(err) => return err,
Err(err) => err,
};
let mut chain = Vec::new();
let mut source = err.source();
let mut cause = None;
while let Some(err) = source.take() {
if let Some(err) = M::try_map_ref(py, err, |err: &Self| err.as_pyerr().clone_ref(py)) {
cause = err.cause(py);
chain.push(err);
break;
}
if let Some(err) = M::try_map_ref(py, err, |err: &PyErr| err.clone_ref(py)) {
cause = err.cause(py);
chain.push(err);
break;
}
source = err.source();
#[allow(clippy::option_if_let_else)]
chain.push(match T::try_from_err_ref::<M>(py, err) {
Some(err) => err,
None => PyException::new_err(format!("{err}")),
});
}
while let Some(err) = chain.pop() {
err.set_cause(py, cause.take());
cause = Some(err);
}
let err = match T::try_from_err::<M>(py, err) {
Ok(err) => err,
Err(err) => PyException::new_err(format!("{err}")),
};
err.set_cause(py, cause);
err
}
#[must_use]
pub fn from_pyerr(py: Python, err: PyErr) -> Self {
let mut chain = Vec::new();
let mut cause = err.cause(py);
while let Some(err) = cause.take() {
cause = err.cause(py);
chain.push(Self { err, cause: None });
}
let mut cause = None;
while let Some(mut err) = chain.pop() {
err.cause = cause.take();
cause = Some(Box::new(err));
}
Self { err, cause }
}
#[must_use]
pub fn into_pyerr(self) -> PyErr {
self.err
}
#[must_use]
pub const fn as_pyerr(&self) -> &PyErr {
&self.err
}
#[must_use]
pub fn cause(&self) -> Option<&PyErr> {
self.cause.as_deref().map(Self::as_pyerr)
}
#[must_use]
pub fn clone_ref(&self, py: Python) -> Self {
Self {
err: self.err.clone_ref(py),
cause: self
.cause
.as_ref()
.map(|cause| Box::new(cause.clone_ref(py))),
}
}
}
impl fmt::Debug for PyErrChain {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
Python::attach(|py| {
let traceback = self.err.traceback(py).map(|tb| {
tb.format()
.map_or(Cow::Borrowed("<traceback str() failed>"), |tb| {
Cow::Owned(tb)
})
});
fmt.debug_struct("PyErrChain")
.field("type", &self.err.get_type(py))
.field("value", self.err.value(py))
.field("traceback", &traceback)
.field("cause", &self.cause)
.finish()
})
}
}
impl fmt::Display for PyErrChain {
#[inline]
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.err, fmt)
}
}
impl Error for PyErrChain {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.cause.as_deref().map(|cause| cause as &dyn Error)
}
}
impl From<PyErr> for PyErrChain {
fn from(err: PyErr) -> Self {
Python::attach(|py| Self::from_pyerr(py, err))
}
}
impl From<PyErrChain> for PyErr {
fn from(err: PyErrChain) -> Self {
err.into_pyerr()
}
}
pub trait AnyErrorToPyErr {
fn try_from_err<T: MapErrorToPyErr>(
py: Python,
err: Box<dyn Error + 'static>,
) -> Result<PyErr, Box<dyn Error + 'static>>;
fn try_from_err_ref<T: MapErrorToPyErr>(
py: Python,
err: &(dyn Error + 'static),
) -> Option<PyErr>;
}
pub trait MapErrorToPyErr {
fn try_map<T: Error + 'static>(
py: Python,
err: Box<dyn Error + 'static>,
map: impl FnOnce(Box<T>) -> PyErr,
) -> Result<PyErr, Box<dyn Error + 'static>>;
fn try_map_send_sync<T: Error + 'static>(
py: Python,
err: Box<dyn Error + Send + Sync + 'static>,
map: impl FnOnce(Box<T>) -> PyErr,
) -> Result<PyErr, Box<dyn Error + Send + Sync + 'static>>;
fn try_map_ref<T: Error + 'static>(
py: Python,
err: &(dyn Error + 'static),
map: impl FnOnce(&T) -> PyErr,
) -> Option<PyErr>;
}
pub struct ErrorNoPyErr;
impl AnyErrorToPyErr for ErrorNoPyErr {
#[inline]
fn try_from_err<T: MapErrorToPyErr>(
_py: Python,
err: Box<dyn Error + 'static>,
) -> Result<PyErr, Box<dyn Error + 'static>> {
Err(err)
}
#[inline]
fn try_from_err_ref<T: MapErrorToPyErr>(
_py: Python,
_err: &(dyn Error + 'static),
) -> Option<PyErr> {
None
}
}
pub struct IoErrorToPyErr;
impl AnyErrorToPyErr for IoErrorToPyErr {
fn try_from_err<T: MapErrorToPyErr>(
py: Python,
err: Box<dyn Error + 'static>,
) -> Result<PyErr, Box<dyn Error + 'static>> {
T::try_map(py, err, |err: Box<io::Error>| {
let kind = err.kind();
if err.get_ref().is_some() {
#[allow(clippy::unwrap_used)] let err = err.into_inner().unwrap();
let err = match T::try_map_send_sync(py, err, |err: Box<PyErr>| *err) {
Ok(err) => return err,
Err(err) => err,
};
let err =
match T::try_map_send_sync(py, err, |err: Box<PyErrChain>| err.into_pyerr()) {
Ok(err) => return err,
Err(err) => err,
};
return PyErr::from(io::Error::new(kind, err));
}
PyErr::from(*err)
})
}
fn try_from_err_ref<T: MapErrorToPyErr>(
py: Python,
err: &(dyn Error + 'static),
) -> Option<PyErr> {
T::try_map_ref(py, err, |err: &io::Error| {
if let Some(err) = err.get_ref() {
if let Some(err) = T::try_map_ref(py, err, |err: &PyErr| err.clone_ref(py)) {
return err;
}
if let Some(err) =
T::try_map_ref(py, err, |err: &PyErrChain| err.as_pyerr().clone_ref(py))
{
return err;
}
}
PyErr::from(io::Error::new(err.kind(), format!("{err}")))
})
}
}
pub struct DowncastToPyErr;
impl MapErrorToPyErr for DowncastToPyErr {
fn try_map<T: Error + 'static>(
_py: Python,
err: Box<dyn Error + 'static>,
map: impl FnOnce(Box<T>) -> PyErr,
) -> Result<PyErr, Box<dyn Error + 'static>> {
err.downcast().map(map)
}
fn try_map_send_sync<T: Error + 'static>(
_py: Python,
err: Box<dyn Error + Send + Sync + 'static>,
map: impl FnOnce(Box<T>) -> PyErr,
) -> Result<PyErr, Box<dyn Error + Send + Sync + 'static>> {
err.downcast().map(map)
}
fn try_map_ref<T: Error + 'static>(
_py: Python,
err: &(dyn Error + 'static),
map: impl FnOnce(&T) -> PyErr,
) -> Option<PyErr> {
err.downcast_ref().map(map)
}
}
#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn err_with_location(py: Python, err: PyErr, file: &str, line: u32, column: u32) -> PyErr {
const RAISE: &str = "raise err";
static COMPILE: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
static EXEC: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
let _ = column;
#[allow(clippy::expect_used)] let compile = COMPILE
.import(py, "builtins", "compile")
.expect("Python does not provide a compile() function");
#[allow(clippy::expect_used)] let exec = EXEC
.import(py, "builtins", "exec")
.expect("Python does not provide an exec() function");
let mut code = String::with_capacity((line as usize) + RAISE.len());
for _ in 1..line {
code.push('\n');
}
code.push_str(RAISE);
#[allow(clippy::expect_used)] let code = compile
.call1((code, file, intern!(py, "exec")))
.expect("failed to compile PyErr location helper");
#[allow(clippy::expect_used)] let globals = [(intern!(py, "err"), err)]
.into_py_dict(py)
.expect("failed to create a dict(err=...)");
#[allow(clippy::expect_used)] let err = exec.call1((code, globals)).expect_err("raise must raise");
err
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn python_cause() {
Python::attach(|py| {
let err = py
.run(
&std::ffi::CString::new(
r#"
try:
try:
raise Exception("source")
except Exception as err:
raise IndexError("middle") from err
except Exception as err:
raise LookupError("top") from err
"#,
)
.unwrap(),
None,
None,
)
.expect_err("raise must raise");
let err = PyErrChain::new(py, err);
assert_eq!(format!("{err}"), "LookupError: top");
let err = err.source().expect("must have source");
assert_eq!(format!("{err}"), "IndexError: middle");
let err = err.source().expect("must have source");
assert_eq!(format!("{err}"), "Exception: source");
assert!(err.source().is_none());
});
}
#[test]
fn rust_source() {
#[derive(Debug)]
struct MyErr {
msg: &'static str,
source: Option<Box<Self>>,
}
impl fmt::Display for MyErr {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(self.msg)
}
}
impl Error for MyErr {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|source| &**source as &dyn Error)
}
}
Python::attach(|py| {
let err = PyErrChain::new(
py,
MyErr {
msg: "top",
source: Some(Box::new(MyErr {
msg: "middle",
source: Some(Box::new(MyErr {
msg: "source",
source: None,
})),
})),
},
);
let source = err.source().expect("must have source");
let source = source.source().expect("must have source");
assert!(source.source().is_none());
let err = PyErr::from(err);
assert_eq!(format!("{err}"), "Exception: top");
let err = err.cause(py).expect("must have cause");
assert_eq!(format!("{err}"), "Exception: middle");
let err = err.cause(py).expect("must have cause");
assert_eq!(format!("{err}"), "Exception: source");
assert!(err.cause(py).is_none());
});
}
#[test]
fn err_location() {
Python::attach(|py| {
let err = err_with_location(py, PyException::new_err("oh no"), "foo.rs", 27, 15);
assert_eq!(format!("{err}"), "Exception: oh no");
assert_eq!(
err.traceback(py)
.expect("must have traceback")
.format()
.expect("traceback must be formattable"),
r#"Traceback (most recent call last):
File "foo.rs", line 27, in <module>
"#,
);
assert!(err.cause(py).is_none());
let err = err_with_location(py, err, "bar.rs", 24, 18);
let top = PyException::new_err("oh yes");
top.set_cause(py, Some(err));
let err = err_with_location(py, top, "baz.rs", 41, 1);
assert_eq!(format!("{err}"), "Exception: oh yes");
assert_eq!(
err.traceback(py)
.expect("must have traceback")
.format()
.expect("traceback must be formattable"),
r#"Traceback (most recent call last):
File "baz.rs", line 41, in <module>
"#,
);
let cause = err.cause(py).expect("must have a cause");
assert_eq!(format!("{cause}"), "Exception: oh no");
assert_eq!(
cause
.traceback(py)
.expect("must have traceback")
.format()
.expect("traceback must be formattable"),
r#"Traceback (most recent call last):
File "bar.rs", line 24, in <module>
File "foo.rs", line 27, in <module>
"#,
);
assert!(cause.cause(py).is_none());
});
}
#[test]
fn anyhow() {
Python::attach(|py| {
let err = anyhow::anyhow!("source").context("middle").context("top");
let err = PyErrChain::new(py, err);
assert_eq!(format!("{err}"), "Exception: top");
let err = err.source().expect("must have source");
assert_eq!(format!("{err}"), "Exception: middle");
let err = err.source().expect("must have source");
assert_eq!(format!("{err}"), "Exception: source");
assert!(err.source().is_none());
});
}
}