#[cfg(test)]
mod tests;
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant};
use crate::application::Application;
use crate::middleware::Middleware;
use crate::request::Request;
use crate::response::Response;
use crate::server::ConnectionInfo;
#[derive(Debug, Clone, PartialEq)]
pub enum BreakerState {
Closed,
Open,
HalfOpen,
}
struct BackendEntry {
state: BreakerState,
failures: u32,
opened_at: Option<Instant>,
}
impl BackendEntry {
fn new() -> Self {
Self { state: BreakerState::Closed, failures: 0, opened_at: None }
}
}
pub struct CircuitBreaker {
backends: HashMap<String, BackendEntry>,
failure_threshold: u32,
recovery: Duration,
}
impl CircuitBreaker {
pub fn new(failure_threshold: u32, recovery_secs: u64) -> Self {
Self {
backends: HashMap::new(),
failure_threshold,
recovery: Duration::from_secs(recovery_secs),
}
}
pub fn is_available(&mut self, backend: &str) -> bool {
let entry = self.backends.entry(backend.to_string()).or_insert_with(BackendEntry::new);
match entry.state {
BreakerState::Closed => true,
BreakerState::HalfOpen => true,
BreakerState::Open => {
if let Some(opened_at) = entry.opened_at {
if opened_at.elapsed() >= self.recovery {
entry.state = BreakerState::HalfOpen;
entry.opened_at = None;
return true;
}
}
false
}
}
}
pub fn record_success(&mut self, backend: &str) {
let entry = self.backends.entry(backend.to_string()).or_insert_with(BackendEntry::new);
entry.state = BreakerState::Closed;
entry.failures = 0;
entry.opened_at = None;
}
pub fn record_failure(&mut self, backend: &str) {
let threshold = self.failure_threshold;
let entry = self.backends.entry(backend.to_string()).or_insert_with(BackendEntry::new);
match entry.state {
BreakerState::Closed => {
entry.failures += 1;
if entry.failures >= threshold {
entry.state = BreakerState::Open;
entry.opened_at = Some(Instant::now());
}
}
BreakerState::HalfOpen => {
entry.state = BreakerState::Open;
entry.opened_at = Some(Instant::now());
}
BreakerState::Open => {
entry.opened_at = Some(Instant::now());
}
}
}
pub fn reset(&mut self, backend: &str) {
let entry = self.backends.entry(backend.to_string()).or_insert_with(BackendEntry::new);
entry.state = BreakerState::Closed;
entry.failures = 0;
entry.opened_at = None;
}
pub fn state(&self, backend: &str) -> BreakerState {
self.backends
.get(backend)
.map(|e| e.state.clone())
.unwrap_or(BreakerState::Closed)
}
}
static GLOBAL_BREAKER: OnceLock<Mutex<CircuitBreaker>> = OnceLock::new();
pub fn global() -> &'static Mutex<CircuitBreaker> {
GLOBAL_BREAKER.get_or_init(|| Mutex::new(CircuitBreaker::new(5, 30)))
}
pub struct RetryLayer {
max_retries: u32,
retry_on: Vec<i16>,
}
impl RetryLayer {
pub fn new() -> Self {
Self { max_retries: 3, retry_on: vec![502, 503, 504] }
}
pub fn max_retries(mut self, n: u32) -> Self {
self.max_retries = n;
self
}
pub fn retry_on(mut self, codes: Vec<i16>) -> Self {
self.retry_on = codes;
self
}
}
impl Default for RetryLayer {
fn default() -> Self {
Self::new()
}
}
impl Middleware for RetryLayer {
fn handle(
&self,
request: &Request,
connection: &ConnectionInfo,
next: &dyn Application,
) -> Result<Response, String> {
let mut response = next.execute(request, connection)?;
let mut attempts = 0u32;
while attempts < self.max_retries && self.retry_on.contains(&response.status_code) {
response = next.execute(request, connection)?;
attempts += 1;
}
Ok(response)
}
}