use std::fmt::{Debug, Display};
use crate::{ANNError, ANNResult};
pub trait TransientError<T>: Sized + std::fmt::Debug + Send + Sync {
fn acknowledge<D>(self, why: D)
where
D: Display;
#[track_caller]
fn acknowledge_with<F, D>(self, why: F)
where
F: FnOnce() -> D,
D: Display,
{
self.acknowledge(why())
}
fn escalate<D>(self, why: D) -> T
where
D: Display;
#[track_caller]
fn escalate_with<F, D>(self, why: F) -> T
where
F: FnOnce() -> D,
D: Display,
{
self.escalate(why())
}
}
pub trait ToRanked {
type Transient: TransientError<Self::Error>;
type Error: Into<ANNError> + std::fmt::Debug + Send + Sync;
fn to_ranked(self) -> RankedError<Self::Transient, Self::Error>;
fn from_transient(transient: Self::Transient) -> Self;
fn from_error(error: Self::Error) -> Self;
}
#[must_use]
#[derive(Debug)]
pub enum RankedError<R, E>
where
R: TransientError<E>,
{
Transient(R),
Error(E),
}
impl<R, E> ToRanked for RankedError<R, E>
where
R: TransientError<E>,
E: Into<ANNError> + std::fmt::Debug + Send + Sync,
{
type Transient = R;
type Error = E;
fn to_ranked(self) -> Self {
self
}
fn from_transient(transient: <Self as ToRanked>::Transient) -> Self {
Self::Transient(transient)
}
fn from_error(error: <Self as ToRanked>::Error) -> Self {
Self::Error(error)
}
}
#[derive(Debug)]
pub enum NeverTransient {}
#[macro_export]
macro_rules! always_escalate {
($T:ty) => {
impl $crate::error::TransientError<$T> for $crate::error::NeverTransient {
fn acknowledge<D>(self, _: D)
where
D: std::fmt::Display,
{
unreachable!("NeverTransient is an unconstructable type");
}
fn acknowledge_with<F, D>(self, _: F)
where
F: FnOnce() -> D,
D: std::fmt::Display,
{
unreachable!("NeverTransient is an unconstructable type");
}
fn escalate<D>(self, _: D) -> $T
where
D: std::fmt::Display,
{
unreachable!("NeverTransient is an unconstructable type");
}
fn escalate_with<F, D>(self, _: F) -> $T
where
F: FnOnce() -> D,
D: std::fmt::Display,
{
unreachable!("NeverTransient is an unconstructable type");
}
}
impl $crate::error::ToRanked for $T {
type Transient = $crate::error::NeverTransient;
type Error = Self;
fn to_ranked(self) -> $crate::error::RankedError<Self::Transient, Self::Error> {
$crate::error::RankedError::Error(self)
}
fn from_transient(_: $crate::error::NeverTransient) -> Self {
unreachable!("NeverTransient is an unconstructable type");
}
fn from_error(error: Self) -> Self {
error
}
}
};
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Infallible {}
impl From<Infallible> for ANNError {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
impl std::fmt::Display for Infallible {
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
unreachable!()
}
}
impl std::error::Error for Infallible {}
impl Infallible {
pub fn match_infallible<T>(x: Result<T, Infallible>) -> T {
x.unwrap_or_else(|inf| match inf {})
}
}
always_escalate!(Infallible);
pub trait ErrorExt<T> {
fn allow_transient<D>(self, why: D) -> ANNResult<Option<T>>
where
D: Display;
fn allow_transient_with<F, D>(self, why: F) -> ANNResult<Option<T>>
where
F: FnOnce() -> D,
D: Display;
fn escalate<D>(self, why: D) -> ANNResult<T>
where
D: Display;
fn escalate_with<F, D>(self, why: F) -> ANNResult<T>
where
F: FnOnce() -> D,
D: Display;
}
impl<T, E> ErrorExt<T> for Result<T, E>
where
E: ToRanked,
{
#[track_caller]
fn allow_transient<D>(self, why: D) -> ANNResult<Option<T>>
where
D: Display,
{
match self {
Ok(v) => Ok(Some(v)),
Err(err) => match err.to_ranked() {
RankedError::Transient(transient) => {
transient.acknowledge(why);
Ok(None)
}
RankedError::Error(err) => Err(err.into()),
},
}
}
#[track_caller]
fn allow_transient_with<F, D>(self, why: F) -> ANNResult<Option<T>>
where
F: FnOnce() -> D,
D: Display,
{
match self {
Ok(v) => Ok(Some(v)),
Err(err) => match err.to_ranked() {
RankedError::Transient(transient) => {
transient.acknowledge_with(why);
Ok(None)
}
RankedError::Error(err) => Err(err.into()),
},
}
}
#[track_caller]
fn escalate<D>(self, why: D) -> ANNResult<T>
where
D: Display,
{
match self {
Ok(v) => Ok(v),
Err(err) => match err.to_ranked() {
RankedError::Transient(transient) => Err(transient.escalate(why).into()),
RankedError::Error(err) => Err(err.into()),
},
}
}
#[track_caller]
fn escalate_with<F, D>(self, why: F) -> ANNResult<T>
where
F: FnOnce() -> D,
D: Display,
{
match self {
Ok(v) => Ok(v),
Err(err) => match err.to_ranked() {
RankedError::Transient(transient) => Err(transient.escalate_with(why).into()),
RankedError::Error(err) => Err(err.into()),
},
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use thiserror::Error;
use super::*;
#[derive(Debug, Clone, Copy, Error)]
#[error("generic error message: {0}")]
struct AlwaysEscalate(usize);
impl From<AlwaysEscalate> for ANNError {
fn from(value: AlwaysEscalate) -> ANNError {
ANNError::log_index_error(value)
}
}
always_escalate!(AlwaysEscalate);
#[test]
fn test_always_escalate() {
assert_eq!(
std::mem::size_of::<RankedError<NeverTransient, AlwaysEscalate>>(),
std::mem::size_of::<AlwaysEscalate>()
);
let r = AlwaysEscalate(10).to_ranked();
assert!(matches!(r, RankedError::Error(AlwaysEscalate(10))));
}
#[derive(Debug, Error)]
#[error(
"Bomb: value = {}, ack = {}, escalated = {}",
value,
acknowledged,
escalated
)]
struct Bomb<'a> {
messages: &'a Mutex<Vec<(String, u32)>>,
acknowledged: bool,
escalated: bool,
value: u64,
}
impl<'a> Bomb<'a> {
fn new(messages: &'a Mutex<Vec<(String, u32)>>, value: u64) -> Self {
Self {
messages,
acknowledged: false,
escalated: false,
value,
}
}
}
impl Drop for Bomb<'_> {
fn drop(&mut self) {
if !self.acknowledged && !self.escalated {
panic!("Bomb error was neither acknowledged nor escalated");
}
if self.acknowledged && self.escalated {
panic!("Bomb error was both acknowledged and escalated");
}
}
}
#[derive(Debug, Error)]
#[error("Disarmed: value = {}", value)]
struct Disarmed<'a> {
messages: &'a Mutex<Vec<(String, u32)>>,
value: u64,
}
impl<'a> Disarmed<'a> {
fn new(messages: &'a Mutex<Vec<(String, u32)>>, value: u64) -> Self {
Self { messages, value }
}
}
impl<'a> TransientError<Disarmed<'a>> for Bomb<'a> {
#[track_caller]
fn acknowledge<D>(mut self, why: D)
where
D: Display,
{
self.acknowledged = true;
let mut v = self.messages.lock().unwrap();
let location = std::panic::Location::caller();
v.push((format!("acknowledged: {}", why), location.line()))
}
#[track_caller]
fn escalate<D>(mut self, why: D) -> Disarmed<'a>
where
D: Display,
{
self.escalated = true;
let mut v = self.messages.lock().unwrap();
let location = std::panic::Location::caller();
v.push((format!("escalated: {}", why), location.line()));
Disarmed {
messages: self.messages,
value: self.value,
}
}
}
impl From<Disarmed<'_>> for ANNError {
#[track_caller]
fn from(value: Disarmed<'_>) -> ANNError {
ANNError::log_index_error(&value)
}
}
struct MaybeTransient<'a> {
messages: &'a Mutex<Vec<(String, u32)>>,
value: u64,
transient: bool,
}
impl<'a> MaybeTransient<'a> {
fn new(messages: &'a Mutex<Vec<(String, u32)>>, value: u64, transient: bool) -> Self {
Self {
messages,
value,
transient,
}
}
}
impl<'a> ToRanked for MaybeTransient<'a> {
type Transient = Bomb<'a>;
type Error = Disarmed<'a>;
fn to_ranked(self) -> RankedError<Self::Transient, Self::Error> {
if self.transient {
RankedError::Transient(Bomb::new(self.messages, self.value))
} else {
RankedError::Error(Disarmed::new(self.messages, self.value))
}
}
fn from_transient(mut transient: Bomb<'a>) -> Self {
transient.acknowledged = true;
Self::new(transient.messages, transient.value, true)
}
fn from_error(error: Disarmed<'a>) -> Self {
Self::new(error.messages, error.value, false)
}
}
#[test]
fn to_ranked_idempotent() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let v = MaybeTransient::new(&messages, 10, true).to_ranked();
assert!(matches!(v, RankedError::Transient(..)));
match v.to_ranked() {
RankedError::Transient(v) => v.acknowledge(""),
_ => panic!("wrong variant"),
}
let v = MaybeTransient::new(&messages, 10, false).to_ranked();
assert!(matches!(v, RankedError::Error(..)));
let v = v.to_ranked();
assert!(matches!(v, RankedError::Error(..)));
}
#[test]
fn error_ext_allow_transient_ok() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let v: usize = Result::<usize, MaybeTransient<'_>>::Ok(10)
.allow_transient("hello")
.unwrap()
.unwrap();
assert_eq!(v, 10);
assert!(messages.lock().unwrap().is_empty());
}
#[test]
fn error_ext_allow_transient_with_ok() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let v: usize = Result::<usize, MaybeTransient<'_>>::Ok(10)
.allow_transient_with(|| -> &str {
panic!("this should not be called!");
})
.unwrap()
.unwrap();
assert_eq!(v, 10);
assert!(messages.lock().unwrap().is_empty());
}
#[test]
fn error_ext_escalate_ok() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let v: usize = Result::<usize, MaybeTransient<'_>>::Ok(10)
.escalate("hello")
.unwrap();
assert_eq!(v, 10);
assert!(messages.lock().unwrap().is_empty());
}
#[test]
fn error_ext_escalate_with_ok() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let v: usize = Result::<usize, MaybeTransient<'_>>::Ok(10)
.escalate_with(|| -> &str {
panic!("this should not be called");
})
.unwrap();
assert_eq!(v, 10);
assert!(messages.lock().unwrap().is_empty());
}
#[test]
fn error_ext_allow_transient_transient() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let why = "foo";
let line = line!();
assert!(
Result::<usize, MaybeTransient<'_>>::Err(MaybeTransient::new(&messages, 10, true))
.allow_transient(why)
.unwrap()
.is_none()
);
let m = messages.lock().unwrap();
assert_eq!(m.len(), 1);
assert_eq!(m[0].1, line + 3);
assert_eq!(m[0].0, format!("acknowledged: {}", why));
}
#[test]
fn error_ext_allow_transient_with_transient() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let why = "foo";
let mut called: bool = false;
let line = line!();
assert!(
Result::<usize, MaybeTransient<'_>>::Err(MaybeTransient::new(&messages, 10, true))
.allow_transient_with(|| {
called = true;
why
})
.unwrap()
.is_none()
);
assert!(called);
let m = messages.lock().unwrap();
assert_eq!(m.len(), 1);
assert_eq!(m[0].1, line + 3);
assert_eq!(m[0].0, format!("acknowledged: {}", why));
}
#[test]
fn error_ext_escalate_transient() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let why = "foo";
let line = line!();
assert!(
Result::<usize, MaybeTransient<'_>>::Err(MaybeTransient::new(&messages, 10, true))
.escalate(why)
.is_err()
);
let m = messages.lock().unwrap();
assert_eq!(m.len(), 1);
assert_eq!(m[0].1, line + 3);
assert_eq!(m[0].0, format!("escalated: {}", why));
}
#[test]
fn error_ext_escalate_with_transient() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let why = "foo";
let mut called: bool = false;
let line = line!();
assert!(
Result::<usize, MaybeTransient<'_>>::Err(MaybeTransient::new(&messages, 10, true))
.escalate_with(|| {
called = true;
why
})
.is_err()
);
assert!(called);
let m = messages.lock().unwrap();
assert_eq!(m.len(), 1);
assert_eq!(m[0].1, line + 3);
assert_eq!(m[0].0, format!("escalated: {}", why));
}
#[test]
fn error_ext_allow_transient_error() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let why = "foo";
assert!(
Result::<usize, MaybeTransient<'_>>::Err(MaybeTransient::new(&messages, 10, false))
.allow_transient(why)
.is_err()
);
assert!(messages.lock().unwrap().is_empty());
}
#[test]
fn error_ext_allow_transient_with_error() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
assert!(
Result::<usize, MaybeTransient<'_>>::Err(MaybeTransient::new(&messages, 10, false))
.allow_transient_with(|| -> &str {
panic!("should not be called");
})
.is_err()
);
assert!(messages.lock().unwrap().is_empty());
}
#[test]
fn error_ext_escalate_error() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
let why = "foo";
assert!(
Result::<usize, MaybeTransient<'_>>::Err(MaybeTransient::new(&messages, 10, false))
.escalate(why)
.is_err()
);
assert!(messages.lock().unwrap().is_empty());
}
#[test]
fn error_ext_escalate_with_error() {
let messages = Mutex::new(Vec::<(String, u32)>::new());
assert!(
Result::<usize, MaybeTransient<'_>>::Err(MaybeTransient::new(&messages, 10, false))
.escalate_with(|| -> &str {
panic!("should not be called");
})
.is_err()
);
assert!(messages.lock().unwrap().is_empty());
}
#[test]
fn test_infallible() {
let result: Result<i32, Infallible> = Ok(42);
let value = Infallible::match_infallible(result);
assert_eq!(value, 42);
let result: Result<String, Infallible> = Ok("hello".to_string());
let value = Infallible::match_infallible(result);
assert_eq!(value, "hello");
fn _test_infallible_into_ann_error(_: Infallible) -> ANNError {
ANNError::log_index_error("This should never be called")
}
}
}