use std::sync::Arc;
pub trait FailureClassifier<Res, Err>: Send + Sync {
fn classify(&self, result: &Result<Res, Err>) -> bool;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DefaultClassifier;
impl<Res, Err> FailureClassifier<Res, Err> for DefaultClassifier {
fn classify(&self, result: &Result<Res, Err>) -> bool {
result.is_err()
}
}
#[derive(Clone)]
pub struct FnClassifier<F> {
f: Arc<F>,
}
impl<F> FnClassifier<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Res, Err> FailureClassifier<Res, Err> for FnClassifier<F>
where
F: Fn(&Result<Res, Err>) -> bool + Send + Sync,
{
fn classify(&self, result: &Result<Res, Err>) -> bool {
(self.f)(result)
}
}
impl<F> std::fmt::Debug for FnClassifier<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FnClassifier")
.field("f", &"<closure>")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_classifier_treats_errors_as_failures() {
let classifier = DefaultClassifier;
assert!(!FailureClassifier::<(), ()>::classify(&classifier, &Ok(())));
assert!(FailureClassifier::<(), ()>::classify(&classifier, &Err(())));
}
#[test]
fn default_classifier_works_with_any_types() {
let classifier = DefaultClassifier;
assert!(!FailureClassifier::<String, std::io::Error>::classify(
&classifier,
&Ok("ok".to_string())
));
assert!(FailureClassifier::<String, std::io::Error>::classify(
&classifier,
&Err(std::io::Error::other("fail"))
));
assert!(!FailureClassifier::<i32, &str>::classify(
&classifier,
&Ok(42)
));
assert!(FailureClassifier::<i32, &str>::classify(
&classifier,
&Err("error")
));
}
#[test]
fn fn_classifier_custom_logic() {
let classifier = FnClassifier::new(
|result: &Result<(), String>| matches!(result, Err(e) if e.contains("fatal")),
);
assert!(!classifier.classify(&Ok(())));
assert!(!classifier.classify(&Err("warning".to_string())));
assert!(classifier.classify(&Err("fatal error".to_string())));
}
#[test]
fn fn_classifier_can_treat_some_successes_as_failures() {
let classifier = FnClassifier::new(|result: &Result<u16, ()>| match result {
Ok(status) if *status >= 500 => true,
Err(_) => true,
_ => false,
});
assert!(!classifier.classify(&Ok(200)));
assert!(!classifier.classify(&Ok(404)));
assert!(classifier.classify(&Ok(500)));
assert!(classifier.classify(&Ok(503)));
assert!(classifier.classify(&Err(())));
}
}