use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
use http::{Request, Response};
use parking_lot::Mutex;
use tower::{Layer, Service};
use crate::Body;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Clone, Debug)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout: Duration,
pub half_open_max_probes: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self::new(5)
}
}
impl CircuitBreakerConfig {
pub fn new(failure_threshold: u32) -> Self {
Self {
failure_threshold,
recovery_timeout: Duration::from_secs(30),
half_open_max_probes: 1,
}
}
pub fn recovery_timeout(mut self, timeout: Duration) -> Self {
self.recovery_timeout = timeout;
self
}
pub fn half_open_max_probes(mut self, probes: u32) -> Self {
self.half_open_max_probes = probes;
self
}
}
struct CircuitBreakerState {
state: CircuitState,
failure_count: u32,
success_count: u32,
last_failure_time: Option<Instant>,
}
impl CircuitBreakerState {
fn new() -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
}
}
fn should_allow_request(&mut self, config: &CircuitBreakerConfig) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(last_failure) = self.last_failure_time {
if last_failure.elapsed() >= config.recovery_timeout {
self.state = CircuitState::HalfOpen;
self.success_count = 0;
true
} else {
false
}
} else {
false
}
}
CircuitState::HalfOpen => self.success_count < config.half_open_max_probes,
}
}
fn record_success(&mut self, config: &CircuitBreakerConfig) {
match self.state {
CircuitState::Closed => {
self.failure_count = 0;
}
CircuitState::HalfOpen => {
self.success_count += 1;
if self.success_count >= config.half_open_max_probes {
self.state = CircuitState::Closed;
self.failure_count = 0;
}
}
CircuitState::Open => {}
}
}
fn record_failure(&mut self, config: &CircuitBreakerConfig) {
self.last_failure_time = Some(Instant::now());
match self.state {
CircuitState::Closed => {
self.failure_count += 1;
if self.failure_count >= config.failure_threshold {
self.state = CircuitState::Open;
}
}
CircuitState::HalfOpen => {
self.state = CircuitState::Open;
}
CircuitState::Open => {}
}
}
}
#[derive(Clone)]
pub struct CircuitBreakerLayer {
config: CircuitBreakerConfig,
}
impl CircuitBreakerLayer {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self { config }
}
}
impl<S> Layer<S> for CircuitBreakerLayer {
type Service = CircuitBreakerService<S>;
fn layer(&self, inner: S) -> Self::Service {
CircuitBreakerService {
inner,
config: self.config.clone(),
state: Arc::new(Mutex::new(CircuitBreakerState::new())),
}
}
}
#[derive(Debug)]
pub struct CircuitOpenError {
pub retry_after: Duration,
}
impl std::fmt::Display for CircuitOpenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"circuit breaker is open, retry after {:?}",
self.retry_after
)
}
}
impl std::error::Error for CircuitOpenError {}
#[derive(Clone)]
pub struct CircuitBreakerService<S> {
inner: S,
config: CircuitBreakerConfig,
state: Arc<Mutex<CircuitBreakerState>>,
}
type BoxFut<T> = Pin<Box<dyn Future<Output = T> + Send>>;
impl<S, ResBody> Service<Request<Body>> for CircuitBreakerService<S>
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Error: Into<crate::error::BoxError> + Send,
S::Future: Send + 'static,
ResBody: Send + 'static,
{
type Response = Response<ResBody>;
type Error = crate::error::BoxError;
type Future = BoxFut<Result<Response<ResBody>, crate::error::BoxError>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
{
let mut state = self.state.lock();
if !state.should_allow_request(&self.config) {
let retry_after = self.config.recovery_timeout;
trace!(
"Circuit breaker is open for {} {}, failing fast (retry after {:?})",
req.method(),
req.uri(),
retry_after
);
return Box::pin(async move { Err(CircuitOpenError { retry_after }.into()) });
}
}
trace!("Circuit breaker request: {} {}", req.method(), req.uri());
let mut inner = self.inner.clone();
let state = self.state.clone();
let config = self.config.clone();
Box::pin(async move {
let result = inner.call(req).await;
match &result {
Ok(response) => {
let status = response.status();
if status.is_server_error() {
trace!("Circuit breaker recording server error: {}", status);
state.lock().record_failure(&config);
} else {
state.lock().record_success(&config);
}
}
Err(_) => {
state.lock().record_failure(&config);
}
}
result.map_err(Into::into)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker_config_defaults() {
let config = CircuitBreakerConfig::new(3);
assert_eq!(config.failure_threshold, 3);
assert_eq!(config.recovery_timeout, Duration::from_secs(30));
assert_eq!(config.half_open_max_probes, 1);
}
#[test]
fn test_circuit_breaker_state_closed() {
let config = CircuitBreakerConfig::new(3);
let mut state = CircuitBreakerState::new();
assert!(state.should_allow_request(&config));
assert_eq!(state.state, CircuitState::Closed);
state.record_failure(&config);
assert_eq!(state.state, CircuitState::Closed);
assert_eq!(state.failure_count, 1);
state.record_failure(&config);
assert_eq!(state.failure_count, 2);
state.record_failure(&config);
assert_eq!(state.state, CircuitState::Open);
}
#[test]
fn test_circuit_breaker_state_success_resets() {
let config = CircuitBreakerConfig::new(3);
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
state.record_failure(&config);
assert_eq!(state.failure_count, 2);
state.record_success(&config);
assert_eq!(state.failure_count, 0);
assert_eq!(state.state, CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_open_blocks_requests() {
let config = CircuitBreakerConfig::new(2);
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
state.record_failure(&config);
assert_eq!(state.state, CircuitState::Open);
assert!(!state.should_allow_request(&config));
}
}