use std::any::Any;
#[derive(Debug)]
pub struct FailureContext<'a> {
pub circuit_name: &'a str,
pub error: &'a dyn Any,
pub duration: f64,
}
pub trait FailureClassifier: Send + Sync + std::fmt::Debug {
fn should_trip(&self, ctx: &FailureContext<'_>) -> bool;
}
#[derive(Debug, Clone, Copy)]
pub struct DefaultClassifier;
impl FailureClassifier for DefaultClassifier {
fn should_trip(&self, _ctx: &FailureContext<'_>) -> bool {
true }
}
impl Default for DefaultClassifier {
fn default() -> Self {
Self
}
}
pub struct PredicateClassifier<F>
where
F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
{
predicate: F,
}
impl<F> PredicateClassifier<F>
where
F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
{
pub fn new(predicate: F) -> Self {
Self { predicate }
}
}
impl<F> FailureClassifier for PredicateClassifier<F>
where
F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
{
fn should_trip(&self, ctx: &FailureContext<'_>) -> bool {
(self.predicate)(ctx)
}
}
impl<F> std::fmt::Debug for PredicateClassifier<F>
where
F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PredicateClassifier")
.field("predicate", &"<closure>")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_classifier_trips_all() {
let classifier = DefaultClassifier;
let ctx = FailureContext {
circuit_name: "test",
error: &"any error" as &dyn Any,
duration: 0.1,
};
assert!(classifier.should_trip(&ctx));
}
#[test]
fn test_predicate_classifier() {
let classifier = PredicateClassifier::new(|ctx| ctx.duration > 1.0);
let fast_ctx = FailureContext {
circuit_name: "test",
error: &"fast error" as &dyn Any,
duration: 0.5,
};
let slow_ctx = FailureContext {
circuit_name: "test",
error: &"slow error" as &dyn Any,
duration: 2.0,
};
assert!(!classifier.should_trip(&fast_ctx));
assert!(classifier.should_trip(&slow_ctx));
}
#[test]
fn test_error_type_downcast() {
#[derive(Debug)]
struct MyError {
is_server_error: bool,
}
let server_error = MyError {
is_server_error: true,
};
let client_error = MyError {
is_server_error: false,
};
let classifier = PredicateClassifier::new(|ctx| {
ctx.error
.downcast_ref::<MyError>()
.map(|e| e.is_server_error)
.unwrap_or(true) });
let server_ctx = FailureContext {
circuit_name: "test",
error: &server_error as &dyn Any,
duration: 0.1,
};
let client_ctx = FailureContext {
circuit_name: "test",
error: &client_error as &dyn Any,
duration: 0.1,
};
assert!(classifier.should_trip(&server_ctx));
assert!(!classifier.should_trip(&client_ctx));
}
}