Skip to main content

breaker_machines/
classifier.rs

1//! Failure classification for exception filtering
2//!
3//! This module provides traits and types for determining which errors
4//! should trip the circuit breaker vs. being ignored.
5
6use std::any::Any;
7
8/// Context provided to failure classifiers for error evaluation
9#[derive(Debug)]
10pub struct FailureContext<'a> {
11    /// Circuit name
12    pub circuit_name: &'a str,
13    /// The error that occurred (can be downcast to specific types)
14    pub error: &'a dyn Any,
15    /// Duration of the failed call in seconds
16    pub duration: f64,
17}
18
19/// Trait for classifying failures - determines if an error should trip the circuit
20///
21/// Implementors can inspect the error type and context to decide whether
22/// this particular failure should count toward opening the circuit.
23///
24/// # Examples
25///
26/// ```rust
27/// use breaker_machines::{FailureClassifier, FailureContext};
28/// use std::sync::Arc;
29///
30/// #[derive(Debug)]
31/// struct ServerErrorClassifier;
32///
33/// impl FailureClassifier for ServerErrorClassifier {
34///     fn should_trip(&self, ctx: &FailureContext<'_>) -> bool {
35///         // Only trip on server errors (5xx), not client errors (4xx)
36///         // This would require your error type to be downcast-able
37///         true // Default: trip on all errors
38///     }
39/// }
40/// ```
41pub trait FailureClassifier: Send + Sync + std::fmt::Debug {
42    /// Determine if this error should count as a failure for circuit breaker logic
43    ///
44    /// Returns `true` if the error should trip the circuit, `false` to ignore it.
45    fn should_trip(&self, ctx: &FailureContext<'_>) -> bool;
46}
47
48/// Default classifier that trips on all errors
49#[derive(Debug, Clone, Copy)]
50pub struct DefaultClassifier;
51
52impl FailureClassifier for DefaultClassifier {
53    fn should_trip(&self, _ctx: &FailureContext<'_>) -> bool {
54        true // All errors trip the circuit (backward compatible)
55    }
56}
57
58impl Default for DefaultClassifier {
59    fn default() -> Self {
60        Self
61    }
62}
63
64/// Predicate-based classifier using a closure
65///
66/// Allows using simple closures for common filtering patterns.
67pub struct PredicateClassifier<F>
68where
69    F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
70{
71    predicate: F,
72}
73
74impl<F> PredicateClassifier<F>
75where
76    F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
77{
78    /// Create a new predicate-based classifier
79    pub fn new(predicate: F) -> Self {
80        Self { predicate }
81    }
82}
83
84impl<F> FailureClassifier for PredicateClassifier<F>
85where
86    F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
87{
88    fn should_trip(&self, ctx: &FailureContext<'_>) -> bool {
89        (self.predicate)(ctx)
90    }
91}
92
93impl<F> std::fmt::Debug for PredicateClassifier<F>
94where
95    F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
96{
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        f.debug_struct("PredicateClassifier")
99            .field("predicate", &"<closure>")
100            .finish()
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn test_default_classifier_trips_all() {
110        let classifier = DefaultClassifier;
111        let ctx = FailureContext {
112            circuit_name: "test",
113            error: &"any error" as &dyn Any,
114            duration: 0.1,
115        };
116
117        assert!(classifier.should_trip(&ctx));
118    }
119
120    #[test]
121    fn test_predicate_classifier() {
122        // Classifier that only trips on slow errors
123        let classifier = PredicateClassifier::new(|ctx| ctx.duration > 1.0);
124
125        let fast_ctx = FailureContext {
126            circuit_name: "test",
127            error: &"fast error" as &dyn Any,
128            duration: 0.5,
129        };
130
131        let slow_ctx = FailureContext {
132            circuit_name: "test",
133            error: &"slow error" as &dyn Any,
134            duration: 2.0,
135        };
136
137        assert!(!classifier.should_trip(&fast_ctx));
138        assert!(classifier.should_trip(&slow_ctx));
139    }
140
141    #[test]
142    fn test_error_type_downcast() {
143        #[derive(Debug)]
144        struct MyError {
145            is_server_error: bool,
146        }
147
148        let server_error = MyError {
149            is_server_error: true,
150        };
151        let client_error = MyError {
152            is_server_error: false,
153        };
154
155        let classifier = PredicateClassifier::new(|ctx| {
156            ctx.error
157                .downcast_ref::<MyError>()
158                .map(|e| e.is_server_error)
159                .unwrap_or(true) // Trip on unknown errors
160        });
161
162        let server_ctx = FailureContext {
163            circuit_name: "test",
164            error: &server_error as &dyn Any,
165            duration: 0.1,
166        };
167
168        let client_ctx = FailureContext {
169            circuit_name: "test",
170            error: &client_error as &dyn Any,
171            duration: 0.1,
172        };
173
174        assert!(classifier.should_trip(&server_ctx));
175        assert!(!classifier.should_trip(&client_ctx));
176    }
177}