use std::{future::Future, sync::Arc, time::Duration};
use thiserror::Error;
use tokio::sync::Mutex;
use crate::resil::{WindowConfig, WindowSnapshot, breaker_state::CircuitBreakerState};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BreakerState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BreakerConfig {
pub failure_threshold: u32,
pub reset_timeout: Duration,
}
impl Default for BreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
reset_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BreakerPolicyConfig {
pub window: WindowConfig,
pub min_request_count: u64,
pub failure_ratio_percent: u8,
pub drop_ratio_percent: u8,
pub half_open_max_calls: u32,
pub force_pass_interval: Duration,
pub sre_rejection_enabled: bool,
pub sre_k_millis: u32,
pub sre_protection: u64,
}
impl Default for BreakerPolicyConfig {
fn default() -> Self {
Self {
window: WindowConfig::default(),
min_request_count: 20,
failure_ratio_percent: 50,
drop_ratio_percent: 20,
half_open_max_calls: 1,
force_pass_interval: Duration::from_secs(5),
sre_rejection_enabled: false,
sre_k_millis: 1500,
sre_protection: 5,
}
}
}
impl BreakerPolicyConfig {
pub fn google_sre() -> Self {
Self {
sre_rejection_enabled: true,
drop_ratio_percent: 0,
failure_ratio_percent: 100,
..Self::default()
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BreakerSnapshot {
pub state: BreakerState,
pub consecutive_failures: u32,
pub half_open_in_flight: u32,
pub window: WindowSnapshot,
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
pub enum BreakerError {
#[error("circuit breaker is open")]
Open,
#[error("circuit breaker dropped request")]
Dropped,
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum BreakerCallError<E> {
#[error(transparent)]
Rejected(#[from] BreakerError),
#[error("protected call failed: {0}")]
Inner(E),
}
#[derive(Debug)]
pub struct CircuitBreaker {
state: CircuitBreakerState,
}
impl CircuitBreaker {
pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
Self {
state: CircuitBreakerState::new(
BreakerConfig {
failure_threshold,
reset_timeout,
},
BreakerPolicyConfig::default(),
),
}
}
pub fn state(&mut self) -> BreakerState {
self.state.state()
}
pub fn allow(&mut self) -> bool {
self.state.allow().is_ok()
}
pub fn record_success(&mut self) {
self.state.record_success();
}
pub fn record_failure(&mut self) {
self.state.record_failure();
}
}
#[derive(Debug, Clone)]
pub struct SharedCircuitBreaker {
state: Arc<Mutex<CircuitBreakerState>>,
}
impl SharedCircuitBreaker {
pub fn new(config: BreakerConfig) -> Self {
Self::with_policy(config, BreakerPolicyConfig::default())
}
pub fn with_policy(config: BreakerConfig, policy: BreakerPolicyConfig) -> Self {
Self {
state: Arc::new(Mutex::new(CircuitBreakerState::new(config, policy))),
}
}
pub async fn allow(&self) -> Result<BreakerGuard, BreakerError> {
self.state.lock().await.allow()?;
Ok(BreakerGuard {
breaker: self.clone(),
completed: false,
})
}
pub async fn do_request<F, Fut, T, E>(&self, request: F) -> Result<T, BreakerCallError<E>>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
self.do_with_acceptable(request, |_| false).await
}
pub async fn do_with_fallback<F, Fut, Fb, FbFut, T, E>(
&self,
request: F,
fallback: Fb,
) -> Result<T, E>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, E>>,
Fb: FnOnce(BreakerError) -> FbFut,
FbFut: Future<Output = Result<T, E>>,
{
let guard = match self.allow().await {
Ok(guard) => guard,
Err(error) => return fallback(error).await,
};
match request().await {
Ok(value) => {
guard.record_success().await;
Ok(value)
}
Err(error) => {
guard.record_failure().await;
Err(error)
}
}
}
pub async fn do_with_acceptable<F, Fut, T, E, A>(
&self,
request: F,
acceptable: A,
) -> Result<T, BreakerCallError<E>>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, E>>,
A: Fn(&E) -> bool,
{
let guard = self.allow().await?;
match request().await {
Ok(value) => {
guard.record_success().await;
Ok(value)
}
Err(error) if acceptable(&error) => {
guard.record_success().await;
Err(BreakerCallError::Inner(error))
}
Err(error) => {
guard.record_failure().await;
Err(BreakerCallError::Inner(error))
}
}
}
pub async fn state(&self) -> BreakerState {
self.state.lock().await.state()
}
pub async fn snapshot(&self) -> BreakerSnapshot {
self.state.lock().await.snapshot()
}
async fn record_success(&self) {
self.state.lock().await.record_success();
}
async fn record_failure(&self) {
self.state.lock().await.record_failure();
}
}
#[derive(Debug)]
pub struct BreakerGuard {
breaker: SharedCircuitBreaker,
completed: bool,
}
impl BreakerGuard {
pub async fn record_success(mut self) {
self.breaker.record_success().await;
self.completed = true;
}
pub async fn record_failure(mut self) {
self.breaker.record_failure().await;
self.completed = true;
}
}
impl Drop for BreakerGuard {
fn drop(&mut self) {
if !self.completed {
let breaker = self.breaker.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
breaker.record_failure().await;
});
}
}
}
}