use std::fmt::Display;
use std::{error::Error as StdError, fmt::Debug};
#[derive(Debug)]
pub struct SimpleError(pub String);
impl Display for SimpleError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.0, f)
}
}
impl StdError for SimpleError {}
pub type Result<T, E = SimpleError> = std::result::Result<T, TError<E>>;
pub struct TError<E = SimpleError> {
phantom: std::marker::PhantomData<E>,
error: anyhow::Error,
}
impl<E> Debug for TError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.error, f)
}
}
impl<E> Display for TError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.error, f)
}
}
impl<T> From<TError<T>> for anyhow::Error {
fn from(err: TError<T>) -> Self {
err.error
}
}
impl<T> From<TError<T>> for Box<dyn StdError> {
fn from(err: TError<T>) -> Self {
err.error.into()
}
}
impl<E: Debug + Display + Send + Sync + 'static> TError<E> {
pub fn from_anyhow(error: anyhow::Error) -> Self {
Self {
phantom: std::marker::PhantomData,
error,
}
}
pub fn from_msg(msg: &str) -> Self {
Self {
phantom: std::marker::PhantomData,
error: anyhow::anyhow!("{}", msg),
}
}
pub fn try_get(self) -> Result<E, TError<E>> {
self.error.downcast().map_err(|e| TError {
phantom: std::marker::PhantomData,
error: e,
})
}
pub fn get_ref(&self) -> Option<&E> {
self.error.downcast_ref::<E>()
}
pub fn downcast_ref<T: Debug + Display + Send + Sync + 'static>(&self) -> Option<&T> {
self.error.downcast_ref::<T>()
}
pub fn downcast<T: Debug + Display + Send + Sync + 'static>(self) -> Result<T, Self> {
self.error.downcast::<T>().map_err(|e| TError {
phantom: std::marker::PhantomData,
error: e,
})
}
pub fn context<C>(self, context: C) -> TError<E>
where
C: Display + Send + Sync + 'static,
{
let error = self.error.context(context);
TError {
phantom: std::marker::PhantomData,
error,
}
}
pub fn with_context<F, R>(self, context: F) -> TError<E>
where
F: FnOnce() -> R,
R: Display + Send + Sync + 'static,
{
self.context(context())
}
pub fn change_err<T>(self) -> TError<T> {
TError::<T> {
phantom: std::marker::PhantomData,
error: self.error,
}
}
}
impl<E: Default + Debug + Display + Send + Sync + 'static> TError<E> {}
impl<SRC: StdError + Send + Sync + 'static, DST: StdError + 'static> From<SRC> for TError<DST> {
fn from(err: SRC) -> Self {
let error = anyhow::Error::new(err);
Self {
phantom: std::marker::PhantomData,
error,
}
}
}
pub(crate) mod private {
pub trait Sealed {}
}
pub trait Context<T, E, X: Display>: private::Sealed {
fn context<C>(self, context: C) -> std::result::Result<T, TError<X>>
where
C: Display + Send + Sync + 'static;
fn with_context<C, F>(self, f: F) -> std::result::Result<T, TError<X>>
where
C: Display + Send + Sync + 'static,
F: FnOnce() -> C;
}
impl<T, E> private::Sealed for std::result::Result<T, E> {}
impl<T, E: StdError + Send + Sync + 'static, X: StdError> Context<T, E, X>
for std::result::Result<T, E>
{
fn context<C>(self, context: C) -> std::result::Result<T, TError<X>>
where
C: Display + Send + Sync + 'static,
{
self.map_err(|err| {
let error = anyhow::Error::new(err);
let error = error.context(context.to_string());
TError {
phantom: std::marker::PhantomData,
error,
}
})
}
fn with_context<C, F>(self, f: F) -> std::result::Result<T, TError<X>>
where
C: Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
self.context(f())
}
}
pub trait DefaultError {
fn from_anyhow(err: anyhow::Error) -> Self;
}
impl<E: DefaultError + Debug + Display + Send + Sync + 'static> TError<E> {
pub fn get(self) -> E {
self.try_get()
.unwrap_or_else(|err| E::from_anyhow(err.error))
}
}
pub trait IntoTError<T, E>: private::Sealed {
fn terror(self) -> std::result::Result<T, TError<E>>;
}
impl<T, EIn, EOut> IntoTError<T, EOut> for std::result::Result<T, EIn>
where
EIn: Into<EOut>,
EOut: std::error::Error + Send + Sync + 'static,
{
fn terror(self) -> std::result::Result<T, TError<EOut>> {
self.map_err(|e| TError {
phantom: std::marker::PhantomData,
error: anyhow::Error::new(e.into()),
})
}
}
pub trait WrapTError<T, E>: private::Sealed {
fn change_err(self) -> std::result::Result<T, TError<E>>;
}
impl<T, EIn, EOut> WrapTError<T, EOut> for std::result::Result<T, TError<EIn>>
where
EIn: std::error::Error + Send + Sync + 'static,
EOut: std::error::Error + Send + Sync + 'static,
{
fn change_err(self) -> std::result::Result<T, TError<EOut>> {
self.map_err(|e| e.change_err())
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use super::*;
#[derive(Debug, thiserror::Error)]
enum MyError {
#[error("something went wrong")]
One,
#[error("Error two")]
Two(Box<dyn StdError + Send + Sync + 'static>),
#[error("io error: {0}")]
Three(#[from] std::io::Error),
}
impl DefaultError for MyError {
fn from_anyhow(err: anyhow::Error) -> Self {
MyError::Two(err.into())
}
}
#[derive(Debug, PartialEq)]
struct OtherError;
impl Display for OtherError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "OtherError")
}
}
impl StdError for OtherError {}
fn do_other_task(fail: bool) -> std::result::Result<(), OtherError> {
if fail {
Err(OtherError)
} else {
Ok(())
}
}
fn fallible_fn(other: bool) -> std::result::Result<(), TError<MyError>> {
do_other_task(other)?;
Err(MyError::One).context("failed")
}
#[test]
fn test_err() {
let err = fallible_fn(false).unwrap_err();
assert_matches!(err.get_ref(), Some(&MyError::One));
assert_eq!(format!("{err}"), "failed");
let e2 = err
.context("add more context")
.context("and even more context");
assert_matches!(e2.get_ref(), Some(&MyError::One));
let e3 = e2.context(MyError::Two(anyhow::anyhow!("other error").into()));
assert_matches!(e3.get_ref(), Some(&MyError::Two(_)));
let err = fallible_fn(true).unwrap_err();
assert_matches!(err.get_ref(), None);
assert_eq!(err.downcast_ref(), Some(&OtherError));
assert_matches!(err.get(), MyError::Two(_)); }
#[test]
fn test_terror() {
let path = std::path::Path::new("/invalid-dir-doesnt-exist");
let err: TError<MyError> = std::fs::read_to_string(path).terror().unwrap_err();
assert_matches!(err.get_ref(), Some(&MyError::Three(_)));
}
#[test]
fn test_change_err() {
let err = fallible_fn(true).unwrap_err();
let err: TError<OtherError> = err.change_err();
assert_eq!(err.try_get().unwrap(), OtherError);
}
#[test]
fn test_change_err_result() {
let err = fallible_fn(true);
let err: std::result::Result<(), TError<OtherError>> = err.change_err();
assert_eq!(err.unwrap_err().try_get().unwrap(), OtherError);
}
}