use stillwater::ContextError;
pub trait ResultExt<T, E> {
fn context(self, msg: impl Into<String>) -> Result<T, ContextError<E>>;
fn with_context<F>(self, f: F) -> Result<T, ContextError<E>>
where
F: FnOnce() -> String;
}
impl<T, E> ResultExt<T, E> for Result<T, E> {
fn context(self, msg: impl Into<String>) -> Result<T, ContextError<E>> {
self.map_err(|e| ContextError::new(e).context(msg))
}
fn with_context<F>(self, f: F) -> Result<T, ContextError<E>>
where
F: FnOnce() -> String,
{
self.map_err(|e| ContextError::new(e).context(f()))
}
}
pub type ContextResult<T, E> = Result<T, ContextError<E>>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_preservation() {
fn inner() -> Result<(), String> {
Err("base error".to_string())
}
fn middle() -> ContextResult<(), String> {
inner().context("middle operation")
}
fn outer() -> ContextResult<(), String> {
middle().map_err(|e| e.context("outer operation"))?;
Ok(())
}
let result = outer();
assert!(result.is_err());
let error = result.unwrap_err();
assert_eq!(error.inner(), "base error");
assert_eq!(error.context_trail().len(), 2);
assert!(error
.context_trail()
.contains(&"middle operation".to_string()));
assert!(error
.context_trail()
.contains(&"outer operation".to_string()));
}
#[test]
fn test_with_context_lazy_evaluation() {
let mut call_count = 0;
let success: Result<i32, String> = Ok(42);
let _ = success.with_context(|| {
call_count += 1;
"should not be called".to_string()
});
assert_eq!(call_count, 0);
let failure: Result<i32, String> = Err("error".to_string());
let _ = failure.with_context(|| {
call_count += 1;
"should be called".to_string()
});
assert_eq!(call_count, 1);
}
#[test]
fn test_multiple_context_layers() {
fn layer1() -> Result<(), String> {
Err("root cause".to_string())
}
fn layer2() -> ContextResult<(), String> {
layer1().context("layer 2 context")
}
fn layer3() -> ContextResult<(), String> {
layer2().map_err(|e| e.context("layer 3 context"))?;
Ok(())
}
fn layer4() -> ContextResult<(), String> {
layer3().map_err(|e| e.context("layer 4 context"))?;
Ok(())
}
let result = layer4();
assert!(result.is_err());
let error = result.unwrap_err();
assert_eq!(error.inner(), "root cause");
let trail = error.context_trail();
assert_eq!(trail.len(), 3);
assert!(trail.contains(&"layer 2 context".to_string()));
assert!(trail.contains(&"layer 3 context".to_string()));
assert!(trail.contains(&"layer 4 context".to_string()));
}
}