use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context as TaskContext, Poll};
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tower::{Layer, Service};
#[derive(Debug)]
struct CircuitState {
phase: Phase,
consecutive_failures: u32,
consecutive_successes: u32,
half_open_in_flight: u32,
opened_at: Option<Instant>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Phase {
Closed,
Open,
HalfOpen,
}
impl CircuitState {
fn new() -> Self {
Self {
phase: Phase::Closed,
consecutive_failures: 0,
consecutive_successes: 0,
half_open_in_flight: 0,
opened_at: None,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub open_duration: Duration,
pub half_open_max_calls: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
open_duration: Duration::from_secs(30),
half_open_max_calls: 1,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
state: Arc<Mutex<CircuitState>>,
config: CircuitBreakerConfig,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: Arc::new(Mutex::new(CircuitState::new())),
config,
}
}
pub fn with_defaults() -> Self {
Self::new(CircuitBreakerConfig::default())
}
pub async fn phase(&self) -> Phase {
let state = self.state.lock().await;
self.effective_phase(&state)
}
fn effective_phase(&self, state: &CircuitState) -> Phase {
match state.phase {
Phase::Open => {
if let Some(opened_at) = state.opened_at {
if opened_at.elapsed() >= self.config.open_duration {
return Phase::HalfOpen;
}
}
Phase::Open
}
other => other,
}
}
pub async fn try_acquire(&self) -> bool {
let mut state = self.state.lock().await;
match self.effective_phase(&state) {
Phase::Closed => true,
Phase::Open => false,
Phase::HalfOpen => {
if state.phase == Phase::Open {
state.phase = Phase::HalfOpen;
state.consecutive_failures = 0;
state.consecutive_successes = 0;
state.half_open_in_flight = 0;
}
if state.half_open_in_flight < self.config.half_open_max_calls {
state.half_open_in_flight += 1;
true
} else {
false
}
}
}
}
pub async fn record_success(&self) {
let mut state = self.state.lock().await;
match state.phase {
Phase::Closed => {
state.consecutive_failures = 0;
}
Phase::HalfOpen => {
state.half_open_in_flight = state.half_open_in_flight.saturating_sub(1);
state.consecutive_successes += 1;
if state.consecutive_successes >= self.config.success_threshold {
state.phase = Phase::Closed;
state.consecutive_failures = 0;
state.consecutive_successes = 0;
state.opened_at = None;
}
}
Phase::Open => {
}
}
}
pub async fn record_failure(&self) {
let mut state = self.state.lock().await;
match state.phase {
Phase::Closed => {
state.consecutive_failures += 1;
if state.consecutive_failures >= self.config.failure_threshold {
state.phase = Phase::Open;
state.opened_at = Some(Instant::now());
state.consecutive_successes = 0;
}
}
Phase::HalfOpen => {
state.half_open_in_flight = state.half_open_in_flight.saturating_sub(1);
state.phase = Phase::Open;
state.opened_at = Some(Instant::now());
state.consecutive_failures = 0;
state.consecutive_successes = 0;
}
Phase::Open => {
state.opened_at = Some(Instant::now());
}
}
}
pub async fn force_reset(&self) {
let mut state = self.state.lock().await;
*state = CircuitState::new();
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::with_defaults()
}
}
fn default_is_failure<B>(resp: &http::Response<B>) -> bool {
resp.status().as_u16() >= 500
}
#[derive(Clone)]
pub struct CircuitBreakerLayer {
breaker: Arc<CircuitBreaker>,
is_failure: Arc<dyn Fn(&http::Response<bytes::Bytes>) -> bool + Send + Sync>,
}
impl std::fmt::Debug for CircuitBreakerLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreakerLayer")
.field("breaker", &self.breaker)
.finish()
}
}
impl CircuitBreakerLayer {
pub fn new(breaker: Arc<CircuitBreaker>) -> Self {
Self {
breaker,
is_failure: Arc::new(default_is_failure),
}
}
pub fn with_failure_predicate<F>(mut self, f: F) -> Self
where
F: Fn(&http::Response<bytes::Bytes>) -> bool + Send + Sync + 'static,
{
self.is_failure = Arc::new(f);
self
}
}
impl<S> Layer<S> for CircuitBreakerLayer {
type Service = CircuitBreakerService<S>;
fn layer(&self, inner: S) -> Self::Service {
CircuitBreakerService {
inner,
breaker: Arc::clone(&self.breaker),
is_failure: Arc::clone(&self.is_failure),
}
}
}
#[derive(Clone)]
pub struct CircuitBreakerService<S> {
inner: S,
breaker: Arc<CircuitBreaker>,
is_failure: Arc<dyn Fn(&http::Response<bytes::Bytes>) -> bool + Send + Sync>,
}
impl<S, B> Service<http::Request<B>> for CircuitBreakerService<S>
where
S: Service<http::Request<B>> + Clone + Send + 'static,
S::Response: Into<http::Response<bytes::Bytes>> + Send + 'static,
S::Error: Send + 'static,
S::Future: Send + 'static,
B: Send + 'static,
{
type Response = http::Response<bytes::Bytes>;
type Error = S::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let breaker = Arc::clone(&self.breaker);
let is_failure = Arc::clone(&self.is_failure);
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
Box::pin(async move {
let allowed = breaker.try_acquire().await;
if !allowed {
let resp = http::Response::builder()
.status(http::StatusCode::SERVICE_UNAVAILABLE)
.header(http::header::CONTENT_TYPE, "text/plain")
.body(bytes::Bytes::from_static(
b"Service Unavailable (circuit open)",
))
.expect("building 503 cannot fail");
return Ok(resp);
}
match inner.call(req).await {
Err(e) => {
breaker.record_failure().await;
Err(e)
}
Ok(resp) => {
let resp: http::Response<bytes::Bytes> = resp.into();
if (is_failure)(&resp) {
breaker.record_failure().await;
} else {
breaker.record_success().await;
}
Ok(resp)
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tower::{ServiceBuilder, ServiceExt};
fn fast_config(failures: u32) -> CircuitBreakerConfig {
CircuitBreakerConfig {
failure_threshold: failures,
success_threshold: 2,
open_duration: Duration::from_millis(50),
half_open_max_calls: 1,
}
}
#[tokio::test]
async fn test_breaker_starts_closed() {
let cb = CircuitBreaker::with_defaults();
assert_eq!(cb.phase().await, Phase::Closed);
}
#[tokio::test]
async fn test_breaker_opens_after_threshold_failures() {
let cb = CircuitBreaker::new(fast_config(3));
for _ in 0..2 {
cb.record_failure().await;
assert_eq!(cb.phase().await, Phase::Closed);
}
cb.record_failure().await;
assert_eq!(cb.phase().await, Phase::Open);
}
#[tokio::test]
async fn test_breaker_open_rejects_503_without_calling_inner() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let inner = tower::service_fn(move |_req: http::Request<String>| {
cc.fetch_add(1, Ordering::SeqCst);
async {
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(200)
.body(bytes::Bytes::new())
.unwrap(),
)
}
});
let cb = Arc::new(CircuitBreaker::new(fast_config(1)));
cb.record_failure().await;
assert_eq!(cb.phase().await, Phase::Open);
let layer = CircuitBreakerLayer::new(Arc::clone(&cb));
let mut svc = ServiceBuilder::new().layer(layer).service(inner);
let resp = svc
.ready()
.await
.unwrap()
.call(http::Request::builder().body(String::new()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), http::StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(call_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_breaker_open_to_half_open_after_duration() {
tokio::time::pause();
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 1,
open_duration: Duration::from_millis(100),
half_open_max_calls: 1,
});
cb.record_failure().await;
assert_eq!(cb.phase().await, Phase::Open);
tokio::time::advance(Duration::from_millis(110)).await;
assert_eq!(cb.phase().await, Phase::HalfOpen);
}
#[tokio::test]
async fn test_breaker_half_open_on_success_closes() {
tokio::time::pause();
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 2,
open_duration: Duration::from_millis(100),
half_open_max_calls: 2,
});
cb.record_failure().await;
assert_eq!(cb.phase().await, Phase::Open);
tokio::time::advance(Duration::from_millis(110)).await;
assert_eq!(cb.phase().await, Phase::HalfOpen);
assert!(cb.try_acquire().await);
cb.record_success().await;
assert_eq!(cb.phase().await, Phase::HalfOpen);
assert!(cb.try_acquire().await);
cb.record_success().await;
assert_eq!(cb.phase().await, Phase::Closed);
}
#[tokio::test]
async fn test_breaker_half_open_on_failure_reopens() {
tokio::time::pause();
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 2,
open_duration: Duration::from_millis(100),
half_open_max_calls: 1,
});
cb.record_failure().await;
tokio::time::advance(Duration::from_millis(110)).await;
assert_eq!(cb.phase().await, Phase::HalfOpen);
assert!(cb.try_acquire().await);
cb.record_failure().await;
assert_eq!(cb.phase().await, Phase::Open);
}
#[tokio::test]
async fn test_breaker_layer_closed_forwards_request() {
let inner = tower::service_fn(|_req: http::Request<String>| async {
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(200)
.body(bytes::Bytes::from_static(b"ok"))
.unwrap(),
)
});
let cb = Arc::new(CircuitBreaker::with_defaults());
let layer = CircuitBreakerLayer::new(Arc::clone(&cb));
let mut svc = ServiceBuilder::new().layer(layer).service(inner);
let resp = svc
.ready()
.await
.unwrap()
.call(http::Request::builder().body(String::new()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn test_breaker_layer_5xx_trips_breaker() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let inner = tower::service_fn(move |_req: http::Request<String>| {
cc.fetch_add(1, Ordering::SeqCst);
async {
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(500)
.body(bytes::Bytes::new())
.unwrap(),
)
}
});
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
}));
let layer = CircuitBreakerLayer::new(Arc::clone(&cb));
let mut svc = ServiceBuilder::new().layer(layer).service(inner);
for _ in 0..3 {
let _ = svc
.ready()
.await
.unwrap()
.call(http::Request::builder().body(String::new()).unwrap())
.await;
}
assert_eq!(cb.phase().await, Phase::Open);
let n_before = call_count.load(Ordering::SeqCst);
let resp = svc
.ready()
.await
.unwrap()
.call(http::Request::builder().body(String::new()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), http::StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(call_count.load(Ordering::SeqCst), n_before); }
}