foxtive_worker/middleware/
circuit_breaker.rs1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::Mutex;
5
6use crate::error::WorkerError;
7use crate::message::ReceivedMessage;
8use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum CircuitState {
13 Closed,
15
16 Open,
18
19 HalfOpen,
21}
22
23struct CircuitBreakerState {
25 state: CircuitState,
27 failure_count: u32,
29 max_failures: u32,
31 timeout: Duration,
33 opened_at: Option<Instant>,
35 success_threshold: u32,
37 half_open_successes: u32,
39 test_request_in_progress: bool,
41}
42
43impl CircuitBreakerState {
44 fn new(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
45 Self {
46 state: CircuitState::Closed,
47 failure_count: 0,
48 max_failures,
49 timeout,
50 opened_at: None,
51 success_threshold,
52 half_open_successes: 0,
53 test_request_in_progress: false, }
55 }
56
57 fn should_allow_request(&mut self) -> bool {
58 match self.state {
59 CircuitState::Closed => true,
60 CircuitState::Open => {
61 if let Some(opened_at) = self.opened_at
63 && opened_at.elapsed() >= self.timeout
64 {
65 self.state = CircuitState::HalfOpen;
67 self.half_open_successes = 0;
68 self.test_request_in_progress = true; return true;
70 }
71 false }
73 CircuitState::HalfOpen => {
74 if !self.test_request_in_progress {
76 self.test_request_in_progress = true;
77 true
78 } else {
79 false }
81 }
82 }
83 }
84
85 fn record_success(&mut self) {
86 match self.state {
87 CircuitState::Closed => {
88 self.failure_count = 0;
90 }
91 CircuitState::HalfOpen => {
92 self.half_open_successes += 1;
94 if self.half_open_successes >= self.success_threshold {
95 self.state = CircuitState::Closed;
97 self.failure_count = 0;
98 self.opened_at = None;
99 self.test_request_in_progress = false; }
101 }
102 CircuitState::Open => {
103 }
105 }
106 }
107
108 fn record_failure(&mut self) {
109 match self.state {
110 CircuitState::Closed => {
111 self.failure_count += 1;
112 if self.failure_count >= self.max_failures {
113 self.state = CircuitState::Open;
115 self.opened_at = Some(Instant::now());
116 }
117 }
118 CircuitState::HalfOpen => {
119 self.test_request_in_progress = false; self.state = CircuitState::Open;
122 self.opened_at = Some(Instant::now());
123 self.half_open_successes = 0;
124 }
125 CircuitState::Open => {
126 }
128 }
129 }
130
131 fn current_state(&self) -> &CircuitState {
132 &self.state
133 }
134}
135
136pub struct CircuitBreakerMiddleware {
153 state: Arc<Mutex<CircuitBreakerState>>,
154 name: String,
155}
156
157impl std::fmt::Debug for CircuitBreakerMiddleware {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("CircuitBreakerMiddleware")
160 .field("name", &self.name)
161 .finish()
162 }
163}
164
165impl CircuitBreakerMiddleware {
166 pub fn new(max_failures: u32, timeout: Duration) -> Self {
172 Self {
173 state: Arc::new(Mutex::new(CircuitBreakerState::new(
174 max_failures,
175 timeout,
176 1, ))),
178 name: format!("circuit-breaker-{}failures", max_failures),
179 }
180 }
181
182 pub fn with_threshold(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
184 Self {
185 state: Arc::new(Mutex::new(CircuitBreakerState::new(
186 max_failures,
187 timeout,
188 success_threshold,
189 ))),
190 name: format!("circuit-breaker-{}failures", max_failures),
191 }
192 }
193
194 pub async fn get_state(&self) -> CircuitState {
196 let mut state = self.state.lock().await;
197 if state.current_state() == &CircuitState::Open
199 && let Some(opened_at) = state.opened_at
200 && opened_at.elapsed() >= state.timeout
201 {
202 state.state = CircuitState::HalfOpen;
203 state.half_open_successes = 0;
204 }
206 state.current_state().clone()
207 }
208}
209
210#[async_trait]
211impl Middleware for CircuitBreakerMiddleware {
212 fn name(&self) -> &str {
213 &self.name
214 }
215
216 async fn handle(
217 &self,
218 message: ReceivedMessage<serde_json::Value>,
219 next: Box<dyn MessageHandler>,
220 ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
221 {
223 let mut state = self.state.lock().await;
224 if !state.should_allow_request() {
225 return Err(WorkerError::ProcessingFailed(format!(
226 "Circuit breaker '{}' is open, rejecting request",
227 self.name
228 )));
229 }
230 }
231
232 let result = next.handle(message).await;
234
235 {
237 let mut state = self.state.lock().await;
238 match &result {
239 Ok(MiddlewareResult::Continue) | Ok(MiddlewareResult::Acknowledged) => state.record_success(),
240 Err(_) => state.record_failure(),
241 }
242 }
243
244 result
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use tokio::time;
252
253 struct SuccessHandler;
254
255 #[async_trait]
256 impl MessageHandler for SuccessHandler {
257 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
258 Ok(MiddlewareResult::Continue)
259 }
260 }
261
262 struct FailureHandler;
263
264 #[async_trait]
265 impl MessageHandler for FailureHandler {
266 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
267 Err(WorkerError::ProcessingFailed("test failure".to_string()))
268 }
269 }
270
271 fn create_test_message() -> ReceivedMessage<serde_json::Value> {
272 use crate::message::{AckHandle, Message, MessageMetadata};
273
274 #[derive(Debug)]
275 struct MockAckHandle;
276
277 #[async_trait]
278 impl AckHandle for MockAckHandle {
279 async fn ack(&self) -> crate::WorkerResult<()> {
280 Ok(())
281 }
282
283 async fn nack(&self, _requeue: bool) -> crate::WorkerResult<()> {
284 Ok(())
285 }
286 }
287
288 let message = Message {
289 id: "test-1".to_string(),
290 payload: serde_json::json!({"test": "data"}),
291 metadata: MessageMetadata::new("test-queue"),
292 };
293 ReceivedMessage::new(message, Arc::new(MockAckHandle))
294 }
295
296 #[tokio::test]
297 async fn test_circuit_closed_initially() {
298 let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
299 assert_eq!(middleware.get_state().await, CircuitState::Closed);
300 }
301
302 #[tokio::test]
303 async fn test_circuit_opens_after_max_failures() {
304 let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
305
306 for _ in 0..3 {
308 let message = create_test_message();
309 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
310 }
311
312 assert_eq!(middleware.get_state().await, CircuitState::Open);
313 }
314
315 #[tokio::test]
316 async fn test_circuit_rejects_when_open() {
317 let middleware = CircuitBreakerMiddleware::new(2, Duration::from_secs(1));
318
319 for _ in 0..2 {
321 let message = create_test_message();
322 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
323 }
324
325 let message = create_test_message();
327 let result = middleware.handle(message, Box::new(SuccessHandler)).await;
328 assert!(result.is_err());
329 assert!(matches!(result, Err(WorkerError::ProcessingFailed(_))));
330 }
331
332 #[tokio::test]
333 async fn test_circuit_transitions_to_half_open_and_allows_one_request() {
334 let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
335
336 for _ in 0..2 {
338 let message = create_test_message();
339 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
340 }
341
342 assert_eq!(middleware.get_state().await, CircuitState::Open);
343
344 time::sleep(Duration::from_millis(150)).await;
346
347 let message1 = create_test_message();
349 let result1 = middleware.handle(message1, Box::new(SuccessHandler)).await;
350 assert!(result1.is_ok());
351 assert_eq!(middleware.get_state().await, CircuitState::Closed); let middleware_half_open_test =
355 CircuitBreakerMiddleware::with_threshold(2, Duration::from_millis(100), 2); for _ in 0..2 {
357 let message = create_test_message();
358 let _ = middleware_half_open_test
359 .handle(message, Box::new(FailureHandler))
360 .await;
361 }
362 time::sleep(Duration::from_millis(150)).await;
363 assert_eq!(
364 middleware_half_open_test.get_state().await,
365 CircuitState::HalfOpen
366 );
367
368 let message_test_1 = create_test_message();
369 assert!(
370 middleware_half_open_test
371 .handle(message_test_1, Box::new(SuccessHandler))
372 .await
373 .is_ok()
374 );
375 assert_eq!(
376 middleware_half_open_test.get_state().await,
377 CircuitState::HalfOpen
378 ); let message_test_2 = create_test_message();
381 let result_test_2 = middleware_half_open_test
382 .handle(message_test_2, Box::new(SuccessHandler))
383 .await;
384 assert!(result_test_2.is_err()); assert!(matches!(
386 result_test_2,
387 Err(WorkerError::ProcessingFailed(_))
388 ));
389 assert_eq!(
390 middleware_half_open_test.get_state().await,
391 CircuitState::HalfOpen
392 ); }
394
395 #[tokio::test]
396 async fn test_circuit_closes_after_success_in_half_open() {
397 let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
398
399 for _ in 0..2 {
401 let message = create_test_message();
402 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
403 }
404
405 time::sleep(Duration::from_millis(150)).await;
407
408 let message = create_test_message();
410 middleware
411 .handle(message, Box::new(SuccessHandler))
412 .await
413 .unwrap();
414
415 assert_eq!(middleware.get_state().await, CircuitState::Closed);
416 }
417
418 #[tokio::test]
419 async fn test_circuit_reopens_on_failure_in_half_open() {
420 let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
421
422 for _ in 0..2 {
424 let message = create_test_message();
425 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
426 }
427
428 time::sleep(Duration::from_millis(150)).await;
430
431 let message = create_test_message();
433 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
434
435 assert_eq!(middleware.get_state().await, CircuitState::Open);
436 }
437}