foxtive_worker/middleware/
circuit_breaker.rs1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::Mutex;
5
6use crate::error::WorkerError;
7use crate::message::ReceivedMessage;
8use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum CircuitState {
13 Closed,
15
16 Open,
18
19 HalfOpen,
21}
22
23struct CircuitBreakerState {
25 state: CircuitState,
27 failure_count: u32,
29 max_failures: u32,
31 timeout: Duration,
33 opened_at: Option<Instant>,
35 success_threshold: u32,
37 half_open_successes: u32,
39 test_request_in_progress: bool,
41}
42
43impl CircuitBreakerState {
44 fn new(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
45 Self {
46 state: CircuitState::Closed,
47 failure_count: 0,
48 max_failures,
49 timeout,
50 opened_at: None,
51 success_threshold,
52 half_open_successes: 0,
53 test_request_in_progress: false, }
55 }
56
57 fn should_allow_request(&mut self) -> bool {
58 match self.state {
59 CircuitState::Closed => true,
60 CircuitState::Open => {
61 if let Some(opened_at) = self.opened_at
63 && opened_at.elapsed() >= self.timeout
64 {
65 self.state = CircuitState::HalfOpen;
67 self.half_open_successes = 0;
68 self.test_request_in_progress = true; return true;
70 }
71 false }
73 CircuitState::HalfOpen => {
74 if !self.test_request_in_progress {
76 self.test_request_in_progress = true;
77 true
78 } else {
79 false }
81 }
82 }
83 }
84
85 fn record_success(&mut self) {
86 match self.state {
87 CircuitState::Closed => {
88 self.failure_count = 0;
90 }
91 CircuitState::HalfOpen => {
92 self.half_open_successes += 1;
94 if self.half_open_successes >= self.success_threshold {
95 self.state = CircuitState::Closed;
97 self.failure_count = 0;
98 self.opened_at = None;
99 self.test_request_in_progress = false; }
101 }
102 CircuitState::Open => {
103 }
105 }
106 }
107
108 fn record_failure(&mut self) {
109 match self.state {
110 CircuitState::Closed => {
111 self.failure_count += 1;
112 if self.failure_count >= self.max_failures {
113 self.state = CircuitState::Open;
115 self.opened_at = Some(Instant::now());
116 }
117 }
118 CircuitState::HalfOpen => {
119 self.test_request_in_progress = false; self.state = CircuitState::Open;
122 self.opened_at = Some(Instant::now());
123 self.half_open_successes = 0;
124 }
125 CircuitState::Open => {
126 }
128 }
129 }
130
131 fn current_state(&self) -> &CircuitState {
132 &self.state
133 }
134}
135
136pub struct CircuitBreakerMiddleware {
153 state: Arc<Mutex<CircuitBreakerState>>,
154 name: String,
155}
156
157impl std::fmt::Debug for CircuitBreakerMiddleware {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("CircuitBreakerMiddleware")
160 .field("name", &self.name)
161 .finish()
162 }
163}
164
165impl CircuitBreakerMiddleware {
166 pub fn new(max_failures: u32, timeout: Duration) -> Self {
172 Self {
173 state: Arc::new(Mutex::new(CircuitBreakerState::new(
174 max_failures,
175 timeout,
176 1, ))),
178 name: format!("circuit-breaker-{}failures", max_failures),
179 }
180 }
181
182 pub fn with_threshold(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
184 Self {
185 state: Arc::new(Mutex::new(CircuitBreakerState::new(
186 max_failures,
187 timeout,
188 success_threshold,
189 ))),
190 name: format!("circuit-breaker-{}failures", max_failures),
191 }
192 }
193
194 pub async fn get_state(&self) -> CircuitState {
196 let mut state = self.state.lock().await;
197 if state.current_state() == &CircuitState::Open
199 && let Some(opened_at) = state.opened_at
200 && opened_at.elapsed() >= state.timeout
201 {
202 state.state = CircuitState::HalfOpen;
203 state.half_open_successes = 0;
204 }
206 state.current_state().clone()
207 }
208}
209
210#[async_trait]
211impl Middleware for CircuitBreakerMiddleware {
212 fn name(&self) -> &str {
213 &self.name
214 }
215
216 async fn handle(
217 &self,
218 message: ReceivedMessage<serde_json::Value>,
219 next: Box<dyn MessageHandler>,
220 ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
221 {
223 let mut state = self.state.lock().await;
224 if !state.should_allow_request() {
225 return Err(WorkerError::ProcessingFailed(format!(
226 "Circuit breaker '{}' is open, rejecting request",
227 self.name
228 )));
229 }
230 }
231
232 let result = next.handle(message).await;
234
235 {
237 let mut state = self.state.lock().await;
238 match &result {
239 Ok(MiddlewareResult::Continue) | Ok(MiddlewareResult::Acknowledged) => {
240 state.record_success()
241 }
242 Err(_) => state.record_failure(),
243 }
244 }
245
246 result
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use tokio::time;
254
255 struct SuccessHandler;
256
257 #[async_trait]
258 impl MessageHandler for SuccessHandler {
259 async fn handle(
260 &self,
261 _message: ReceivedMessage<serde_json::Value>,
262 ) -> Result<MiddlewareResult, WorkerError> {
263 Ok(MiddlewareResult::Continue)
264 }
265 }
266
267 struct FailureHandler;
268
269 #[async_trait]
270 impl MessageHandler for FailureHandler {
271 async fn handle(
272 &self,
273 _message: ReceivedMessage<serde_json::Value>,
274 ) -> Result<MiddlewareResult, WorkerError> {
275 Err(WorkerError::ProcessingFailed("test failure".to_string()))
276 }
277 }
278
279 fn create_test_message() -> ReceivedMessage<serde_json::Value> {
280 use crate::message::{AckHandle, Message, MessageMetadata};
281
282 #[derive(Debug)]
283 struct MockAckHandle;
284
285 #[async_trait]
286 impl AckHandle for MockAckHandle {
287 async fn ack(&self) -> crate::WorkerResult<()> {
288 Ok(())
289 }
290
291 async fn nack(&self, _requeue: bool) -> crate::WorkerResult<()> {
292 Ok(())
293 }
294 }
295
296 let message = Message {
297 id: "test-1".to_string(),
298 payload: serde_json::json!({"test": "data"}),
299 metadata: MessageMetadata::new("test-queue"),
300 };
301 ReceivedMessage::new(message, Arc::new(MockAckHandle))
302 }
303
304 #[tokio::test]
305 async fn test_circuit_closed_initially() {
306 let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
307 assert_eq!(middleware.get_state().await, CircuitState::Closed);
308 }
309
310 #[tokio::test]
311 async fn test_circuit_opens_after_max_failures() {
312 let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
313
314 for _ in 0..3 {
316 let message = create_test_message();
317 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
318 }
319
320 assert_eq!(middleware.get_state().await, CircuitState::Open);
321 }
322
323 #[tokio::test]
324 async fn test_circuit_rejects_when_open() {
325 let middleware = CircuitBreakerMiddleware::new(2, Duration::from_secs(1));
326
327 for _ in 0..2 {
329 let message = create_test_message();
330 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
331 }
332
333 let message = create_test_message();
335 let result = middleware.handle(message, Box::new(SuccessHandler)).await;
336 assert!(result.is_err());
337 assert!(matches!(result, Err(WorkerError::ProcessingFailed(_))));
338 }
339
340 #[tokio::test]
341 async fn test_circuit_transitions_to_half_open_and_allows_one_request() {
342 let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
343
344 for _ in 0..2 {
346 let message = create_test_message();
347 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
348 }
349
350 assert_eq!(middleware.get_state().await, CircuitState::Open);
351
352 time::sleep(Duration::from_millis(150)).await;
354
355 let message1 = create_test_message();
357 let result1 = middleware.handle(message1, Box::new(SuccessHandler)).await;
358 assert!(result1.is_ok());
359 assert_eq!(middleware.get_state().await, CircuitState::Closed); let middleware_half_open_test =
363 CircuitBreakerMiddleware::with_threshold(2, Duration::from_millis(100), 2); for _ in 0..2 {
365 let message = create_test_message();
366 let _ = middleware_half_open_test
367 .handle(message, Box::new(FailureHandler))
368 .await;
369 }
370 time::sleep(Duration::from_millis(150)).await;
371 assert_eq!(
372 middleware_half_open_test.get_state().await,
373 CircuitState::HalfOpen
374 );
375
376 let message_test_1 = create_test_message();
377 assert!(
378 middleware_half_open_test
379 .handle(message_test_1, Box::new(SuccessHandler))
380 .await
381 .is_ok()
382 );
383 assert_eq!(
384 middleware_half_open_test.get_state().await,
385 CircuitState::HalfOpen
386 ); let message_test_2 = create_test_message();
389 let result_test_2 = middleware_half_open_test
390 .handle(message_test_2, Box::new(SuccessHandler))
391 .await;
392 assert!(result_test_2.is_err()); assert!(matches!(
394 result_test_2,
395 Err(WorkerError::ProcessingFailed(_))
396 ));
397 assert_eq!(
398 middleware_half_open_test.get_state().await,
399 CircuitState::HalfOpen
400 ); }
402
403 #[tokio::test]
404 async fn test_circuit_closes_after_success_in_half_open() {
405 let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
406
407 for _ in 0..2 {
409 let message = create_test_message();
410 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
411 }
412
413 time::sleep(Duration::from_millis(150)).await;
415
416 let message = create_test_message();
418 middleware
419 .handle(message, Box::new(SuccessHandler))
420 .await
421 .unwrap();
422
423 assert_eq!(middleware.get_state().await, CircuitState::Closed);
424 }
425
426 #[tokio::test]
427 async fn test_circuit_reopens_on_failure_in_half_open() {
428 let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
429
430 for _ in 0..2 {
432 let message = create_test_message();
433 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
434 }
435
436 time::sleep(Duration::from_millis(150)).await;
438
439 let message = create_test_message();
441 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
442
443 assert_eq!(middleware.get_state().await, CircuitState::Open);
444 }
445}