use std::{
error::Error,
fmt::{Debug, Display},
};
pub trait Context<T, E>
where
E: Error + 'static,
{
fn context<V, RE>(self, value: V) -> Result<T, Contextual<RE>>
where
V: Into<ContextValue>,
RE: From<E> + Error + 'static;
}
#[derive(Debug)]
pub struct Contextual<E: Error> {
inner: E,
context: Vec<ContextValue>,
}
impl<E: Error> Contextual<E> {
pub fn source(&self) -> &E {
&self.inner
}
}
#[derive(Debug)]
pub enum ContextValue {
Str(&'static str),
}
impl From<&'static str> for ContextValue {
fn from(value: &'static str) -> Self {
Self::Str(value)
}
}
impl<T, E> Context<T, E> for Result<T, E>
where
E: Error + 'static,
{
fn context<V, RE>(self, value: V) -> Result<T, Contextual<RE>>
where
V: Into<ContextValue>,
RE: From<E> + Error + 'static,
{
match self {
Ok(v) => Ok(v),
Err(e) => Err(Contextual {
inner: e.into(),
context: vec![value.into()],
}),
}
}
}
impl<T, E> Context<T, E> for Result<T, Contextual<E>>
where
E: Error + 'static,
{
fn context<V, RE>(self, value: V) -> Result<T, Contextual<RE>>
where
V: Into<ContextValue>,
RE: From<E> + Error + 'static,
{
match self {
Ok(v) => Ok(v),
Err(e) => {
let mut context = e.context;
context.push(value.into());
Err(Contextual {
inner: e.inner.into(),
context,
})
}
}
}
}
impl<E, R> From<E> for Contextual<R>
where
E: Error + 'static,
R: From<E> + Error + 'static,
{
fn from(value: E) -> Self {
Contextual {
inner: value.into(),
context: vec![],
}
}
}
impl<E> Display for Contextual<E>
where
E: Error,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
const CAUSED_INDENT: &str = " ";
let Some(e) = self.context.last() else {
write!(f, "Error: {}", self.inner)?;
return Ok(());
};
writeln!(f, "Error: {e}")?;
writeln!(f, "Caused by:")?;
for c in self.context.iter().rev().skip(1) {
writeln!(f, "{CAUSED_INDENT}{c}")?;
}
write!(f, "{CAUSED_INDENT}{}", self.inner)?;
Ok(())
}
}
impl Display for ContextValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Str(s) => write!(f, "{s}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_display() {
#[derive(Debug, thiserror::Error)]
enum Error {
#[error("test: {0}")]
Test(&'static str),
}
#[expect(clippy::needless_question_mark)]
fn a() -> Result<(), Contextual<Error>> {
Ok(source().context("failed to test")?)
}
fn b() -> Result<(), Contextual<Error>> {
source().context("failed to test")
}
fn source() -> Result<(), Error> {
Err(Error::Test("value"))
}
fn ab() -> Result<(), Contextual<Error>> {
a().context("failed to call a")
}
assert_eq!(
a().unwrap_err().to_string(),
"Error: failed to test\nCaused by:\n test: value"
);
assert_eq!(
b().unwrap_err().to_string(),
"Error: failed to test\nCaused by:\n test: value"
);
assert_eq!(
ab().unwrap_err().to_string(),
"Error: failed to call a\nCaused by:\n failed to test\n test: value"
);
}
}