use rustapi_core::{
middleware::{BoxedNext, MiddlewareLayer},
Request, Response, ResponseBody,
};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub timeout: Duration,
pub success_threshold: usize,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
timeout: Duration::from_secs(60),
success_threshold: 2,
}
}
}
struct CircuitBreakerState {
state: CircuitState,
failure_count: usize,
success_count: usize,
last_failure_time: Option<Instant>,
total_requests: u64,
total_failures: u64,
total_successes: u64,
}
impl Default for CircuitBreakerState {
fn default() -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
total_requests: 0,
total_failures: 0,
total_successes: 0,
}
}
}
#[derive(Clone)]
pub struct CircuitBreakerLayer {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitBreakerState>>,
}
impl CircuitBreakerLayer {
pub fn new() -> Self {
Self {
config: CircuitBreakerConfig::default(),
state: Arc::new(RwLock::new(CircuitBreakerState::default())),
}
}
pub fn failure_threshold(mut self, threshold: usize) -> Self {
self.config.failure_threshold = threshold;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = timeout;
self
}
pub fn success_threshold(mut self, threshold: usize) -> Self {
self.config.success_threshold = threshold;
self
}
pub async fn get_state(&self) -> CircuitState {
self.state.read().await.state
}
pub async fn get_stats(&self) -> CircuitBreakerStats {
let state = self.state.read().await;
CircuitBreakerStats {
state: state.state,
total_requests: state.total_requests,
total_failures: state.total_failures,
total_successes: state.total_successes,
failure_count: state.failure_count,
success_count: state.success_count,
}
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
*state = CircuitBreakerState::default();
}
}
impl Default for CircuitBreakerLayer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerStats {
pub state: CircuitState,
pub total_requests: u64,
pub total_failures: u64,
pub total_successes: u64,
pub failure_count: usize,
pub success_count: usize,
}
impl MiddlewareLayer for CircuitBreakerLayer {
fn call(
&self,
req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
let config = self.config.clone();
let state = self.state.clone();
Box::pin(async move {
let mut state_guard = state.write().await;
state_guard.total_requests += 1;
match state_guard.state {
CircuitState::Open => {
if let Some(last_failure) = state_guard.last_failure_time {
if last_failure.elapsed() >= config.timeout {
tracing::info!("Circuit breaker transitioning to HalfOpen");
state_guard.state = CircuitState::HalfOpen;
state_guard.success_count = 0;
} else {
drop(state_guard);
return http::Response::builder()
.status(503)
.header("Content-Type", "application/json")
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from(
serde_json::json!({
"error": {
"type": "service_unavailable",
"message": "Circuit breaker is OPEN"
}
})
.to_string(),
),
)))
.unwrap();
}
}
}
CircuitState::HalfOpen => {
}
CircuitState::Closed => {
}
}
drop(state_guard);
let response = next(req).await;
let mut state_guard = state.write().await;
if response.status().is_success() {
state_guard.total_successes += 1;
match state_guard.state {
CircuitState::HalfOpen => {
state_guard.success_count += 1;
if state_guard.success_count >= config.success_threshold {
tracing::info!("Circuit breaker transitioning to Closed");
state_guard.state = CircuitState::Closed;
state_guard.failure_count = 0;
state_guard.success_count = 0;
}
}
CircuitState::Closed => {
state_guard.failure_count = 0;
}
_ => {}
}
} else {
record_failure(&mut state_guard, &config);
}
response
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}
fn record_failure(state: &mut CircuitBreakerState, config: &CircuitBreakerConfig) {
state.total_failures += 1;
state.failure_count += 1;
state.last_failure_time = Some(Instant::now());
match state.state {
CircuitState::Closed if state.failure_count >= config.failure_threshold => {
tracing::warn!(
"Circuit breaker OPENING after {} failures",
state.failure_count
);
state.state = CircuitState::Open;
}
CircuitState::Closed => {}
CircuitState::HalfOpen => {
tracing::warn!("Circuit breaker returning to OPEN state");
state.state = CircuitState::Open;
state.success_count = 0;
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use std::sync::Arc;
#[tokio::test]
async fn circuit_breaker_opens_after_threshold() {
let breaker = CircuitBreakerLayer::new()
.failure_threshold(3)
.timeout(Duration::from_secs(1));
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(500)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("Error"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
for _ in 0..3 {
let req = http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let _ = breaker.call(req, next.clone()).await;
}
let state = breaker.get_state().await;
assert_eq!(state, CircuitState::Open);
let req = http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let response = breaker.call(req, next.clone()).await;
assert_eq!(response.status(), 503);
}
#[tokio::test]
async fn circuit_breaker_recovers() {
let breaker = CircuitBreakerLayer::new()
.failure_threshold(2)
.timeout(Duration::from_millis(100))
.success_threshold(2);
let fail_next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(500)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("Error"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
for _ in 0..2 {
let req = http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let _ = breaker.call(req, fail_next.clone()).await;
}
assert_eq!(breaker.get_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(150)).await;
let success_next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(200)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("OK"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
for _ in 0..2 {
let req = http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let result = breaker.call(req, success_next.clone()).await;
assert!(result.status().is_success());
}
let state = breaker.get_state().await;
assert_eq!(state, CircuitState::Closed);
}
}