Skip to main content

foxtive_worker/middleware/
circuit_breaker.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::Mutex;
5
6use crate::error::WorkerError;
7use crate::message::ReceivedMessage;
8use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
9
10/// Circuit breaker state.
11#[derive(Debug, Clone, PartialEq)]
12pub enum CircuitState {
13    /// Circuit is closed, requests flow normally
14    Closed,
15
16    /// Circuit is open, requests are rejected
17    Open,
18
19    /// Circuit is half-open, testing if service recovered
20    HalfOpen,
21}
22
23/// Circuit breaker configuration and state.
24struct CircuitBreakerState {
25    /// Current state of the circuit breaker
26    state: CircuitState,
27    /// Number of consecutive failures
28    failure_count: u32,
29    /// Maximum failures before opening circuit
30    max_failures: u32,
31    /// Time to wait before transitioning from Open to HalfOpen
32    timeout: Duration,
33    /// When the circuit was opened
34    opened_at: Option<Instant>,
35    /// Number of successes in HalfOpen state needed to close
36    success_threshold: u32,
37    /// Current successes in HalfOpen state
38    half_open_successes: u32,
39    /// Flag to ensure only one request is allowed through in HalfOpen state
40    test_request_in_progress: bool,
41}
42
43impl CircuitBreakerState {
44    fn new(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
45        Self {
46            state: CircuitState::Closed,
47            failure_count: 0,
48            max_failures,
49            timeout,
50            opened_at: None,
51            success_threshold,
52            half_open_successes: 0,
53            test_request_in_progress: false, // Initialize the new flag
54        }
55    }
56
57    fn should_allow_request(&mut self) -> bool {
58        match self.state {
59            CircuitState::Closed => true,
60            CircuitState::Open => {
61                // Check if timeout has elapsed
62                if let Some(opened_at) = self.opened_at
63                    && opened_at.elapsed() >= self.timeout
64                {
65                    // Transition to HalfOpen and allow one test request
66                    self.state = CircuitState::HalfOpen;
67                    self.half_open_successes = 0;
68                    self.test_request_in_progress = true; // This request is the test request
69                    return true;
70                }
71                false // Still in Open state, timeout not elapsed
72            }
73            CircuitState::HalfOpen => {
74                // Only allow one request through in HalfOpen state
75                if !self.test_request_in_progress {
76                    self.test_request_in_progress = true;
77                    true
78                } else {
79                    false // Another request is already testing
80                }
81            }
82        }
83    }
84
85    fn record_success(&mut self) {
86        match self.state {
87            CircuitState::Closed => {
88                // Reset failure count on success
89                self.failure_count = 0;
90            }
91            CircuitState::HalfOpen => {
92                // Don't reset test_request_in_progress here - keep blocking additional requests
93                self.half_open_successes += 1;
94                if self.half_open_successes >= self.success_threshold {
95                    // Transition to Closed
96                    self.state = CircuitState::Closed;
97                    self.failure_count = 0;
98                    self.opened_at = None;
99                    self.test_request_in_progress = false; // Reset flag when closing circuit
100                }
101            }
102            CircuitState::Open => {
103                // Should not happen if `should_allow_request` is respected
104            }
105        }
106    }
107
108    fn record_failure(&mut self) {
109        match self.state {
110            CircuitState::Closed => {
111                self.failure_count += 1;
112                if self.failure_count >= self.max_failures {
113                    // Open the circuit
114                    self.state = CircuitState::Open;
115                    self.opened_at = Some(Instant::now());
116                }
117            }
118            CircuitState::HalfOpen => {
119                // Any failure in HalfOpen reopens the circuit
120                self.test_request_in_progress = false; // Test request completed (with failure)
121                self.state = CircuitState::Open;
122                self.opened_at = Some(Instant::now());
123                self.half_open_successes = 0;
124            }
125            CircuitState::Open => {
126                // Already open, do nothing
127            }
128        }
129    }
130
131    fn current_state(&self) -> &CircuitState {
132        &self.state
133    }
134}
135
136/// Middleware that implements the circuit breaker pattern.
137///
138/// The circuit breaker protects downstream services from being overwhelmed
139/// when they're failing. It has three states:
140/// - **Closed**: Normal operation, requests pass through
141/// - **Open**: Requests are rejected immediately (fail fast)
142/// - **HalfOpen**: Testing if service recovered with limited requests
143///
144/// # Example
145/// ```rust,no_run
146/// use foxtive_worker::CircuitBreakerMiddleware;
147/// use std::time::Duration;
148///
149/// // Open circuit after 5 failures, retry after 30 seconds
150/// let middleware = CircuitBreakerMiddleware::new(5, Duration::from_secs(30));
151/// ```
152pub struct CircuitBreakerMiddleware {
153    state: Arc<Mutex<CircuitBreakerState>>,
154    name: String,
155}
156
157impl std::fmt::Debug for CircuitBreakerMiddleware {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("CircuitBreakerMiddleware")
160            .field("name", &self.name)
161            .finish()
162    }
163}
164
165impl CircuitBreakerMiddleware {
166    /// Create a new circuit breaker middleware.
167    ///
168    /// # Arguments
169    /// * `max_failures` - Number of consecutive failures before opening circuit
170    /// * `timeout` - Time to wait before transitioning from Open to HalfOpen
171    pub fn new(max_failures: u32, timeout: Duration) -> Self {
172        Self {
173            state: Arc::new(Mutex::new(CircuitBreakerState::new(
174                max_failures,
175                timeout,
176                1, // Default: 1 success to close
177            ))),
178            name: format!("circuit-breaker-{}failures", max_failures),
179        }
180    }
181
182    /// Create a circuit breaker with custom success threshold.
183    pub fn with_threshold(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
184        Self {
185            state: Arc::new(Mutex::new(CircuitBreakerState::new(
186                max_failures,
187                timeout,
188                success_threshold,
189            ))),
190            name: format!("circuit-breaker-{}failures", max_failures),
191        }
192    }
193
194    /// Get the current circuit state.
195    pub async fn get_state(&self) -> CircuitState {
196        let mut state = self.state.lock().await;
197        // Check if Open state should transition to HalfOpen
198        if state.current_state() == &CircuitState::Open
199            && let Some(opened_at) = state.opened_at
200            && opened_at.elapsed() >= state.timeout
201        {
202            state.state = CircuitState::HalfOpen;
203            state.half_open_successes = 0;
204            // Don't reset test_request_in_progress here - it will be set by should_allow_request
205        }
206        state.current_state().clone()
207    }
208}
209
210#[async_trait]
211impl Middleware for CircuitBreakerMiddleware {
212    fn name(&self) -> &str {
213        &self.name
214    }
215
216    async fn handle(
217        &self,
218        message: ReceivedMessage<serde_json::Value>,
219        next: Box<dyn MessageHandler>,
220    ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
221        // Check if request should be allowed
222        {
223            let mut state = self.state.lock().await;
224            if !state.should_allow_request() {
225                return Err(WorkerError::ProcessingFailed(format!(
226                    "Circuit breaker '{}' is open, rejecting request",
227                    self.name
228                )));
229            }
230        }
231
232        // Process the message
233        let result = next.handle(message).await;
234
235        // Record success or failure
236        {
237            let mut state = self.state.lock().await;
238            match &result {
239                Ok(MiddlewareResult::Continue) | Ok(MiddlewareResult::Acknowledged) => {
240                    state.record_success()
241                }
242                Err(_) => state.record_failure(),
243            }
244        }
245
246        result
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use tokio::time;
254
255    struct SuccessHandler;
256
257    #[async_trait]
258    impl MessageHandler for SuccessHandler {
259        async fn handle(
260            &self,
261            _message: ReceivedMessage<serde_json::Value>,
262        ) -> Result<MiddlewareResult, WorkerError> {
263            Ok(MiddlewareResult::Continue)
264        }
265    }
266
267    struct FailureHandler;
268
269    #[async_trait]
270    impl MessageHandler for FailureHandler {
271        async fn handle(
272            &self,
273            _message: ReceivedMessage<serde_json::Value>,
274        ) -> Result<MiddlewareResult, WorkerError> {
275            Err(WorkerError::ProcessingFailed("test failure".to_string()))
276        }
277    }
278
279    fn create_test_message() -> ReceivedMessage<serde_json::Value> {
280        use crate::message::{AckHandle, Message, MessageMetadata};
281
282        #[derive(Debug)]
283        struct MockAckHandle;
284
285        #[async_trait]
286        impl AckHandle for MockAckHandle {
287            async fn ack(&self) -> crate::WorkerResult<()> {
288                Ok(())
289            }
290
291            async fn nack(&self, _requeue: bool) -> crate::WorkerResult<()> {
292                Ok(())
293            }
294        }
295
296        let message = Message {
297            id: "test-1".to_string(),
298            payload: serde_json::json!({"test": "data"}),
299            metadata: MessageMetadata::new("test-queue"),
300        };
301        ReceivedMessage::new(message, Arc::new(MockAckHandle))
302    }
303
304    #[tokio::test]
305    async fn test_circuit_closed_initially() {
306        let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
307        assert_eq!(middleware.get_state().await, CircuitState::Closed);
308    }
309
310    #[tokio::test]
311    async fn test_circuit_opens_after_max_failures() {
312        let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
313
314        // Cause 3 failures
315        for _ in 0..3 {
316            let message = create_test_message();
317            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
318        }
319
320        assert_eq!(middleware.get_state().await, CircuitState::Open);
321    }
322
323    #[tokio::test]
324    async fn test_circuit_rejects_when_open() {
325        let middleware = CircuitBreakerMiddleware::new(2, Duration::from_secs(1));
326
327        // Open the circuit
328        for _ in 0..2 {
329            let message = create_test_message();
330            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
331        }
332
333        // Next request should be rejected
334        let message = create_test_message();
335        let result = middleware.handle(message, Box::new(SuccessHandler)).await;
336        assert!(result.is_err());
337        assert!(matches!(result, Err(WorkerError::ProcessingFailed(_))));
338    }
339
340    #[tokio::test]
341    async fn test_circuit_transitions_to_half_open_and_allows_one_request() {
342        let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
343
344        // Open the circuit
345        for _ in 0..2 {
346            let message = create_test_message();
347            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
348        }
349
350        assert_eq!(middleware.get_state().await, CircuitState::Open);
351
352        // Wait for timeout
353        time::sleep(Duration::from_millis(150)).await;
354
355        // Should transition to HalfOpen and allow the first request
356        let message1 = create_test_message();
357        let result1 = middleware.handle(message1, Box::new(SuccessHandler)).await;
358        assert!(result1.is_ok());
359        assert_eq!(middleware.get_state().await, CircuitState::Closed); // Should close after 1 success with default threshold
360
361        // If it were still HalfOpen, a second request should be rejected
362        let middleware_half_open_test =
363            CircuitBreakerMiddleware::with_threshold(2, Duration::from_millis(100), 2); // Need 2 successes to close
364        for _ in 0..2 {
365            let message = create_test_message();
366            let _ = middleware_half_open_test
367                .handle(message, Box::new(FailureHandler))
368                .await;
369        }
370        time::sleep(Duration::from_millis(150)).await;
371        assert_eq!(
372            middleware_half_open_test.get_state().await,
373            CircuitState::HalfOpen
374        );
375
376        let message_test_1 = create_test_message();
377        assert!(
378            middleware_half_open_test
379                .handle(message_test_1, Box::new(SuccessHandler))
380                .await
381                .is_ok()
382        );
383        assert_eq!(
384            middleware_half_open_test.get_state().await,
385            CircuitState::HalfOpen
386        ); // Still HalfOpen, 1 success recorded
387
388        let message_test_2 = create_test_message();
389        let result_test_2 = middleware_half_open_test
390            .handle(message_test_2, Box::new(SuccessHandler))
391            .await;
392        assert!(result_test_2.is_err()); // Second request should be rejected
393        assert!(matches!(
394            result_test_2,
395            Err(WorkerError::ProcessingFailed(_))
396        ));
397        assert_eq!(
398            middleware_half_open_test.get_state().await,
399            CircuitState::HalfOpen
400        ); // Still HalfOpen
401    }
402
403    #[tokio::test]
404    async fn test_circuit_closes_after_success_in_half_open() {
405        let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
406
407        // Open the circuit
408        for _ in 0..2 {
409            let message = create_test_message();
410            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
411        }
412
413        // Wait for timeout
414        time::sleep(Duration::from_millis(150)).await;
415
416        // Success in HalfOpen should close circuit (with default success_threshold = 1)
417        let message = create_test_message();
418        middleware
419            .handle(message, Box::new(SuccessHandler))
420            .await
421            .unwrap();
422
423        assert_eq!(middleware.get_state().await, CircuitState::Closed);
424    }
425
426    #[tokio::test]
427    async fn test_circuit_reopens_on_failure_in_half_open() {
428        let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
429
430        // Open the circuit
431        for _ in 0..2 {
432            let message = create_test_message();
433            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
434        }
435
436        // Wait for timeout
437        time::sleep(Duration::from_millis(150)).await;
438
439        // Failure in HalfOpen should reopen circuit
440        let message = create_test_message();
441        let _ = middleware.handle(message, Box::new(FailureHandler)).await;
442
443        assert_eq!(middleware.get_state().await, CircuitState::Open);
444    }
445}