breaker_machines/
classifier.rs1use std::any::Any;
7
8#[derive(Debug)]
10pub struct FailureContext<'a> {
11 pub circuit_name: &'a str,
13 pub error: &'a dyn Any,
15 pub duration: f64,
17}
18
19pub trait FailureClassifier: Send + Sync + std::fmt::Debug {
42 fn should_trip(&self, ctx: &FailureContext<'_>) -> bool;
46}
47
48#[derive(Debug, Clone, Copy)]
50pub struct DefaultClassifier;
51
52impl FailureClassifier for DefaultClassifier {
53 fn should_trip(&self, _ctx: &FailureContext<'_>) -> bool {
54 true }
56}
57
58impl Default for DefaultClassifier {
59 fn default() -> Self {
60 Self
61 }
62}
63
64pub struct PredicateClassifier<F>
68where
69 F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
70{
71 predicate: F,
72}
73
74impl<F> PredicateClassifier<F>
75where
76 F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
77{
78 pub fn new(predicate: F) -> Self {
80 Self { predicate }
81 }
82}
83
84impl<F> FailureClassifier for PredicateClassifier<F>
85where
86 F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
87{
88 fn should_trip(&self, ctx: &FailureContext<'_>) -> bool {
89 (self.predicate)(ctx)
90 }
91}
92
93impl<F> std::fmt::Debug for PredicateClassifier<F>
94where
95 F: Fn(&FailureContext<'_>) -> bool + Send + Sync,
96{
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("PredicateClassifier")
99 .field("predicate", &"<closure>")
100 .finish()
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_default_classifier_trips_all() {
110 let classifier = DefaultClassifier;
111 let ctx = FailureContext {
112 circuit_name: "test",
113 error: &"any error" as &dyn Any,
114 duration: 0.1,
115 };
116
117 assert!(classifier.should_trip(&ctx));
118 }
119
120 #[test]
121 fn test_predicate_classifier() {
122 let classifier = PredicateClassifier::new(|ctx| ctx.duration > 1.0);
124
125 let fast_ctx = FailureContext {
126 circuit_name: "test",
127 error: &"fast error" as &dyn Any,
128 duration: 0.5,
129 };
130
131 let slow_ctx = FailureContext {
132 circuit_name: "test",
133 error: &"slow error" as &dyn Any,
134 duration: 2.0,
135 };
136
137 assert!(!classifier.should_trip(&fast_ctx));
138 assert!(classifier.should_trip(&slow_ctx));
139 }
140
141 #[test]
142 fn test_error_type_downcast() {
143 #[derive(Debug)]
144 struct MyError {
145 is_server_error: bool,
146 }
147
148 let server_error = MyError {
149 is_server_error: true,
150 };
151 let client_error = MyError {
152 is_server_error: false,
153 };
154
155 let classifier = PredicateClassifier::new(|ctx| {
156 ctx.error
157 .downcast_ref::<MyError>()
158 .map(|e| e.is_server_error)
159 .unwrap_or(true) });
161
162 let server_ctx = FailureContext {
163 circuit_name: "test",
164 error: &server_error as &dyn Any,
165 duration: 0.1,
166 };
167
168 let client_ctx = FailureContext {
169 circuit_name: "test",
170 error: &client_error as &dyn Any,
171 duration: 0.1,
172 };
173
174 assert!(classifier.should_trip(&server_ctx));
175 assert!(!classifier.should_trip(&client_ctx));
176 }
177}