Skip to main content

foxtive_worker/middleware/
circuit_breaker.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::Mutex;
5
6use crate::error::WorkerError;
7use crate::message::ReceivedMessage;
8use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
9
10/// Circuit breaker state.
11#[derive(Debug, Clone, PartialEq)]
12pub enum CircuitState {
13    /// Circuit is closed, requests flow normally
14    Closed,
15
16    /// Circuit is open, requests are rejected
17    Open,
18
19    /// Circuit is half-open, testing if service recovered
20    HalfOpen,
21}
22
23/// Circuit breaker configuration and state.
24struct CircuitBreakerState {
25    /// Current state of the circuit breaker
26    state: CircuitState,
27    /// Number of consecutive failures
28    failure_count: u32,
29    /// Maximum failures before opening circuit
30    max_failures: u32,
31    /// Time to wait before transitioning from Open to HalfOpen
32    timeout: Duration,
33    /// When the circuit was opened
34    opened_at: Option<Instant>,
35    /// Number of successes in HalfOpen state needed to close
36    success_threshold: u32,
37    /// Current successes in HalfOpen state
38    half_open_successes: u32,
39    /// Flag to ensure only one request is allowed through in HalfOpen state
40    test_request_in_progress: bool,
41}
42
43impl CircuitBreakerState {
44    fn new(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
45        Self {
46            state: CircuitState::Closed,
47            failure_count: 0,
48            max_failures,
49            timeout,
50            opened_at: None,
51            success_threshold,
52            half_open_successes: 0,
53            test_request_in_progress: false, // Initialize the new flag
54        }
55    }
56
57    fn should_allow_request(&mut self) -> bool {
58        match self.state {
59            CircuitState::Closed => true,
60            CircuitState::Open => {
61                // Check if timeout has elapsed
62                if let Some(opened_at) = self.opened_at
63                    && opened_at.elapsed() >= self.timeout
64                {
65                    // Transition to HalfOpen and allow one test request
66                    self.state = CircuitState::HalfOpen;
67                    self.half_open_successes = 0;
68                    self.test_request_in_progress = true; // This request is the test request
69                    return true;
70                }
71                false // Still in Open state, timeout not elapsed
72            }
73            CircuitState::HalfOpen => {
74                // Only allow one request through in HalfOpen state
75                if !self.test_request_in_progress {
76                    self.test_request_in_progress = true;
77                    true
78                } else {
79                    false // Another request is already testing
80                }
81            }
82        }
83    }
84
85    fn record_success(&mut self) {
86        match self.state {
87            CircuitState::Closed => {
88                // Reset failure count on success
89                self.failure_count = 0;
90            }
91            CircuitState::HalfOpen => {
92                // Don't reset test_request_in_progress here - keep blocking additional requests
93                self.half_open_successes += 1;
94                if self.half_open_successes >= self.success_threshold {
95                    // Transition to Closed
96                    self.state = CircuitState::Closed;
97                    self.failure_count = 0;
98                    self.opened_at = None;
99                    self.test_request_in_progress = false; // Reset flag when closing circuit
100                }
101            }
102            CircuitState::Open => {
103                // Should not happen if `should_allow_request` is respected
104            }
105        }
106    }
107
108    fn record_failure(&mut self) {
109        match self.state {
110            CircuitState::Closed => {
111                self.failure_count += 1;
112                if self.failure_count >= self.max_failures {
113                    // Open the circuit
114                    self.state = CircuitState::Open;
115                    self.opened_at = Some(Instant::now());
116                }
117            }
118            CircuitState::HalfOpen => {
119                // Any failure in HalfOpen reopens the circuit
120                self.test_request_in_progress = false; // Test request completed (with failure)
121                self.state = CircuitState::Open;
122                self.opened_at = Some(Instant::now());
123                self.half_open_successes = 0;
124            }
125            CircuitState::Open => {
126                // Already open, do nothing
127            }
128        }
129    }
130
131    fn current_state(&self) -> &CircuitState {
132        &self.state
133    }
134}
135
136/// Middleware that implements the circuit breaker pattern.
137///
138/// The circuit breaker protects downstream services from being overwhelmed
139/// when they're failing. It has three states:
140/// - **Closed**: Normal operation, requests pass through
141/// - **Open**: Requests are rejected immediately (fail fast)
142/// - **HalfOpen**: Testing if service recovered with limited requests
143///
144/// # Example
145/// ```rust,no_run
146/// use foxtive_worker::CircuitBreakerMiddleware;
147/// use std::time::Duration;
148///
149/// // Open circuit after 5 failures, retry after 30 seconds
150/// let middleware = CircuitBreakerMiddleware::new(5, Duration::from_secs(30));
151/// ```
152pub struct CircuitBreakerMiddleware {
153    state: Arc<Mutex<CircuitBreakerState>>,
154    name: String,
155}
156
157impl std::fmt::Debug for CircuitBreakerMiddleware {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("CircuitBreakerMiddleware")
160            .field("name", &self.name)
161            .finish()
162    }
163}
164
165impl CircuitBreakerMiddleware {
166    /// Create a new circuit breaker middleware.
167    ///
168    /// # Arguments
169    /// * `max_failures` - Number of consecutive failures before opening circuit
170    /// * `timeout` - Time to wait before transitioning from Open to HalfOpen
171    pub fn new(max_failures: u32, timeout: Duration) -> Self {
172        Self {
173            state: Arc::new(Mutex::new(CircuitBreakerState::new(
174                max_failures,
175                timeout,
176                1, // Default: 1 success to close
177            ))),
178            name: format!("circuit-breaker-{}failures", max_failures),
179        }
180    }
181
182    /// Create a circuit breaker with custom success threshold.
183    pub fn with_threshold(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
184        Self {
185            state: Arc::new(Mutex::new(CircuitBreakerState::new(
186                max_failures,
187                timeout,
188                success_threshold,
189            ))),
190            name: format!("circuit-breaker-{}failures", max_failures),
191        }
192    }
193
194    /// Get the current circuit state.
195    pub async fn get_state(&self) -> CircuitState {
196        let mut state = self.state.lock().await;
197        // Check if Open state should transition to HalfOpen
198        if state.current_state() == &CircuitState::Open
199            && let Some(opened_at) = state.opened_at
200            && opened_at.elapsed() >= state.timeout
201        {
202            state.state = CircuitState::HalfOpen;
203            state.half_open_successes = 0;
204            // Don't reset test_request_in_progress here - it will be set by should_allow_request
205        }
206        state.current_state().clone()
207    }
208}
209
210#[async_trait]
211impl Middleware for CircuitBreakerMiddleware {
212    fn name(&self) -> &str {
213        &self.name
214    }
215
216    async fn handle(
217        &self,
218        message: ReceivedMessage<serde_json::Value>,
219        next: Box<dyn MessageHandler>,
220    ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
221        // Check if request should be allowed
222        {
223            let mut state = self.state.lock().await;
224            if !state.should_allow_request() {
225                return Err(WorkerError::ProcessingFailed(format!(
226                    "Circuit breaker '{}' is open, rejecting request",
227                    self.name
228                )));
229            }
230        }
231
232        // Process the message
233        let result = next.handle(message).await;
234
235        // Record success or failure
236        {
237            let mut state = self.state.lock().await;
238            match &result {
239                Ok(MiddlewareResult::Continue) | Ok(MiddlewareResult::Acknowledged) => state.record_success(),
240                Err(_) => state.record_failure(),
241            }
242        }
243
244        result
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use tokio::time;
252
253    struct SuccessHandler;
254
255    #[async_trait]
256    impl MessageHandler for SuccessHandler {
257        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
258            Ok(MiddlewareResult::Continue)
259        }
260    }
261
262    struct FailureHandler;
263
264    #[async_trait]
265    impl MessageHandler for FailureHandler {
266        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
267            Err(WorkerError::ProcessingFailed("test failure".to_string()))
268        }
269    }
270
271    fn create_test_message() -> ReceivedMessage<serde_json::Value> {
272        use crate::message::{AckHandle, Message, MessageMetadata};
273
274        #[derive(Debug)]
275        struct MockAckHandle;
276
277        #[async_trait]
278        impl AckHandle for MockAckHandle {
279            async fn ack(&self) -> crate::WorkerResult<()> {
280                Ok(())
281            }
282
283            async fn nack(&self, _requeue: bool) -> crate::WorkerResult<()> {
284                Ok(())
285            }
286        }
287
288        let message = Message {
289            id: "test-1".to_string(),
290            payload: serde_json::json!({"test": "data"}),
291            metadata: MessageMetadata::new("test-queue"),
292        };
293        ReceivedMessage::new(message, Arc::new(MockAckHandle))
294    }
295
296    #[tokio::test]
297    async fn test_circuit_closed_initially() {
298        let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
299        assert_eq!(middleware.get_state().await, CircuitState::Closed);
300    }
301
302    #[tokio::test]
303    async fn test_circuit_opens_after_max_failures() {
304        let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
305
306        // Cause 3 failures
307        for _ in 0..3 {
308            let message = create_test_message();
309            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
310        }
311
312        assert_eq!(middleware.get_state().await, CircuitState::Open);
313    }
314
315    #[tokio::test]
316    async fn test_circuit_rejects_when_open() {
317        let middleware = CircuitBreakerMiddleware::new(2, Duration::from_secs(1));
318
319        // Open the circuit
320        for _ in 0..2 {
321            let message = create_test_message();
322            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
323        }
324
325        // Next request should be rejected
326        let message = create_test_message();
327        let result = middleware.handle(message, Box::new(SuccessHandler)).await;
328        assert!(result.is_err());
329        assert!(matches!(result, Err(WorkerError::ProcessingFailed(_))));
330    }
331
332    #[tokio::test]
333    async fn test_circuit_transitions_to_half_open_and_allows_one_request() {
334        let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
335
336        // Open the circuit
337        for _ in 0..2 {
338            let message = create_test_message();
339            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
340        }
341
342        assert_eq!(middleware.get_state().await, CircuitState::Open);
343
344        // Wait for timeout
345        time::sleep(Duration::from_millis(150)).await;
346
347        // Should transition to HalfOpen and allow the first request
348        let message1 = create_test_message();
349        let result1 = middleware.handle(message1, Box::new(SuccessHandler)).await;
350        assert!(result1.is_ok());
351        assert_eq!(middleware.get_state().await, CircuitState::Closed); // Should close after 1 success with default threshold
352
353        // If it were still HalfOpen, a second request should be rejected
354        let middleware_half_open_test =
355            CircuitBreakerMiddleware::with_threshold(2, Duration::from_millis(100), 2); // Need 2 successes to close
356        for _ in 0..2 {
357            let message = create_test_message();
358            let _ = middleware_half_open_test
359                .handle(message, Box::new(FailureHandler))
360                .await;
361        }
362        time::sleep(Duration::from_millis(150)).await;
363        assert_eq!(
364            middleware_half_open_test.get_state().await,
365            CircuitState::HalfOpen
366        );
367
368        let message_test_1 = create_test_message();
369        assert!(
370            middleware_half_open_test
371                .handle(message_test_1, Box::new(SuccessHandler))
372                .await
373                .is_ok()
374        );
375        assert_eq!(
376            middleware_half_open_test.get_state().await,
377            CircuitState::HalfOpen
378        ); // Still HalfOpen, 1 success recorded
379
380        let message_test_2 = create_test_message();
381        let result_test_2 = middleware_half_open_test
382            .handle(message_test_2, Box::new(SuccessHandler))
383            .await;
384        assert!(result_test_2.is_err()); // Second request should be rejected
385        assert!(matches!(
386            result_test_2,
387            Err(WorkerError::ProcessingFailed(_))
388        ));
389        assert_eq!(
390            middleware_half_open_test.get_state().await,
391            CircuitState::HalfOpen
392        ); // Still HalfOpen
393    }
394
395    #[tokio::test]
396    async fn test_circuit_closes_after_success_in_half_open() {
397        let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
398
399        // Open the circuit
400        for _ in 0..2 {
401            let message = create_test_message();
402            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
403        }
404
405        // Wait for timeout
406        time::sleep(Duration::from_millis(150)).await;
407
408        // Success in HalfOpen should close circuit (with default success_threshold = 1)
409        let message = create_test_message();
410        middleware
411            .handle(message, Box::new(SuccessHandler))
412            .await
413            .unwrap();
414
415        assert_eq!(middleware.get_state().await, CircuitState::Closed);
416    }
417
418    #[tokio::test]
419    async fn test_circuit_reopens_on_failure_in_half_open() {
420        let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
421
422        // Open the circuit
423        for _ in 0..2 {
424            let message = create_test_message();
425            let _ = middleware.handle(message, Box::new(FailureHandler)).await;
426        }
427
428        // Wait for timeout
429        time::sleep(Duration::from_millis(150)).await;
430
431        // Failure in HalfOpen should reopen circuit
432        let message = create_test_message();
433        let _ = middleware.handle(message, Box::new(FailureHandler)).await;
434
435        assert_eq!(middleware.get_state().await, CircuitState::Open);
436    }
437}