use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use std::time::{Duration, Instant};
use crate::error::{WorkerError, WorkerResult};
use crate::message::ReceivedMessage;
use crate::middleware::{MessageHandler, Middleware};
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
struct CircuitBreakerState {
state: CircuitState,
failure_count: u32,
max_failures: u32,
timeout: Duration,
opened_at: Option<Instant>,
success_threshold: u32,
half_open_successes: u32,
test_request_in_progress: bool,
}
impl CircuitBreakerState {
fn new(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
max_failures,
timeout,
opened_at: None,
success_threshold,
half_open_successes: 0,
test_request_in_progress: false, }
}
fn should_allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(opened_at) = self.opened_at
&& opened_at.elapsed() >= self.timeout {
self.state = CircuitState::HalfOpen;
self.half_open_successes = 0;
self.test_request_in_progress = true; return true;
}
false }
CircuitState::HalfOpen => {
if !self.test_request_in_progress {
self.test_request_in_progress = true;
true
} else {
false }
}
}
}
fn record_success(&mut self) {
match self.state {
CircuitState::Closed => {
self.failure_count = 0;
}
CircuitState::HalfOpen => {
self.half_open_successes += 1;
if self.half_open_successes >= self.success_threshold {
self.state = CircuitState::Closed;
self.failure_count = 0;
self.opened_at = None;
self.test_request_in_progress = false; }
}
CircuitState::Open => {
}
}
}
fn record_failure(&mut self) {
match self.state {
CircuitState::Closed => {
self.failure_count += 1;
if self.failure_count >= self.max_failures {
self.state = CircuitState::Open;
self.opened_at = Some(Instant::now());
}
}
CircuitState::HalfOpen => {
self.test_request_in_progress = false; self.state = CircuitState::Open;
self.opened_at = Some(Instant::now());
self.half_open_successes = 0;
}
CircuitState::Open => {
}
}
}
fn current_state(&self) -> &CircuitState {
&self.state
}
}
pub struct CircuitBreakerMiddleware {
state: Arc<Mutex<CircuitBreakerState>>,
name: String,
}
impl std::fmt::Debug for CircuitBreakerMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreakerMiddleware")
.field("name", &self.name)
.finish()
}
}
impl CircuitBreakerMiddleware {
pub fn new(max_failures: u32, timeout: Duration) -> Self {
Self {
state: Arc::new(Mutex::new(CircuitBreakerState::new(
max_failures,
timeout,
1, ))),
name: format!("circuit-breaker-{}failures", max_failures),
}
}
pub fn with_threshold(max_failures: u32, timeout: Duration, success_threshold: u32) -> Self {
Self {
state: Arc::new(Mutex::new(CircuitBreakerState::new(
max_failures,
timeout,
success_threshold,
))),
name: format!("circuit-breaker-{}failures", max_failures),
}
}
pub async fn get_state(&self) -> CircuitState {
let mut state = self.state.lock().await;
if state.current_state() == &CircuitState::Open
&& let Some(opened_at) = state.opened_at
&& opened_at.elapsed() >= state.timeout {
state.state = CircuitState::HalfOpen;
state.half_open_successes = 0;
}
state.current_state().clone()
}
}
#[async_trait]
impl Middleware for CircuitBreakerMiddleware {
fn name(&self) -> &str {
&self.name
}
async fn handle(
&self,
message: ReceivedMessage<serde_json::Value>,
next: Box<dyn MessageHandler>,
) -> WorkerResult<()> {
{
let mut state = self.state.lock().await;
if !state.should_allow_request() {
return Err(WorkerError::ProcessingFailed(format!(
"Circuit breaker '{}' is open, rejecting request",
self.name
)));
}
}
let result = next.handle(message).await;
{
let mut state = self.state.lock().await;
match result {
Ok(_) => state.record_success(),
Err(_) => state.record_failure(),
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time;
struct SuccessHandler;
#[async_trait]
impl MessageHandler for SuccessHandler {
async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
Ok(())
}
}
struct FailureHandler;
#[async_trait]
impl MessageHandler for FailureHandler {
async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
Err(WorkerError::ProcessingFailed("test failure".to_string()))
}
}
fn create_test_message() -> ReceivedMessage<serde_json::Value> {
use crate::message::{Message, MessageMetadata, AckHandle};
#[derive(Debug)]
struct MockAckHandle;
#[async_trait]
impl AckHandle for MockAckHandle {
async fn ack(&self) -> WorkerResult<()> {
Ok(())
}
async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
Ok(())
}
}
let message = Message {
id: "test-1".to_string(),
payload: serde_json::json!({"test": "data"}),
metadata: MessageMetadata::new("test-queue"),
};
ReceivedMessage::new(message, Arc::new(MockAckHandle))
}
#[tokio::test]
async fn test_circuit_closed_initially() {
let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
assert_eq!(middleware.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_opens_after_max_failures() {
let middleware = CircuitBreakerMiddleware::new(3, Duration::from_secs(1));
for _ in 0..3 {
let message = create_test_message();
let _ = middleware.handle(message, Box::new(FailureHandler)).await;
}
assert_eq!(middleware.get_state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_rejects_when_open() {
let middleware = CircuitBreakerMiddleware::new(2, Duration::from_secs(1));
for _ in 0..2 {
let message = create_test_message();
let _ = middleware.handle(message, Box::new(FailureHandler)).await;
}
let message = create_test_message();
let result = middleware.handle(message, Box::new(SuccessHandler)).await;
assert!(result.is_err());
assert!(matches!(result, Err(WorkerError::ProcessingFailed(_))));
}
#[tokio::test]
async fn test_circuit_transitions_to_half_open_and_allows_one_request() {
let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
for _ in 0..2 {
let message = create_test_message();
let _ = middleware.handle(message, Box::new(FailureHandler)).await;
}
assert_eq!(middleware.get_state().await, CircuitState::Open);
time::sleep(Duration::from_millis(150)).await;
let message1 = create_test_message();
let result1 = middleware.handle(message1, Box::new(SuccessHandler)).await;
assert!(result1.is_ok());
assert_eq!(middleware.get_state().await, CircuitState::Closed);
let middleware_half_open_test = CircuitBreakerMiddleware::with_threshold(2, Duration::from_millis(100), 2); for _ in 0..2 {
let message = create_test_message();
let _ = middleware_half_open_test.handle(message, Box::new(FailureHandler)).await;
}
time::sleep(Duration::from_millis(150)).await;
assert_eq!(middleware_half_open_test.get_state().await, CircuitState::HalfOpen);
let message_test_1 = create_test_message();
assert!(middleware_half_open_test.handle(message_test_1, Box::new(SuccessHandler)).await.is_ok());
assert_eq!(middleware_half_open_test.get_state().await, CircuitState::HalfOpen);
let message_test_2 = create_test_message();
let result_test_2 = middleware_half_open_test.handle(message_test_2, Box::new(SuccessHandler)).await;
assert!(result_test_2.is_err()); assert!(matches!(result_test_2, Err(WorkerError::ProcessingFailed(_))));
assert_eq!(middleware_half_open_test.get_state().await, CircuitState::HalfOpen); }
#[tokio::test]
async fn test_circuit_closes_after_success_in_half_open() {
let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
for _ in 0..2 {
let message = create_test_message();
let _ = middleware.handle(message, Box::new(FailureHandler)).await;
}
time::sleep(Duration::from_millis(150)).await;
let message = create_test_message();
middleware.handle(message, Box::new(SuccessHandler)).await.unwrap();
assert_eq!(middleware.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_reopens_on_failure_in_half_open() {
let middleware = CircuitBreakerMiddleware::new(2, Duration::from_millis(100));
for _ in 0..2 {
let message = create_test_message();
let _ = middleware.handle(message, Box::new(FailureHandler)).await;
}
time::sleep(Duration::from_millis(150)).await;
let message = create_test_message();
let _ = middleware.handle(message, Box::new(FailureHandler)).await;
assert_eq!(middleware.get_state().await, CircuitState::Open);
}
}