use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::types::error::VaultError;
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub reset_timeout: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
reset_timeout: Duration::from_secs(30),
}
}
}
enum State {
Closed { consecutive_failures: u32 },
Open { since: Instant },
HalfOpen,
}
pub(crate) struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Mutex<State>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Mutex::new(State::Closed {
consecutive_failures: 0,
}),
}
}
pub fn check(&self) -> Result<(), VaultError> {
let mut state = self.state.lock().map_err(|_| VaultError::LockPoisoned)?;
match *state {
State::Closed { .. } => Ok(()),
State::Open { since } => {
if since.elapsed() >= self.config.reset_timeout {
*state = State::HalfOpen;
Ok(())
} else {
Err(VaultError::CircuitOpen)
}
}
State::HalfOpen => {
Err(VaultError::CircuitOpen)
}
}
}
pub fn record_success(&self) {
if let Ok(mut state) = self.state.lock() {
*state = State::Closed {
consecutive_failures: 0,
};
}
}
pub fn record_failure(&self) {
if let Ok(mut state) = self.state.lock() {
match *state {
State::Closed {
consecutive_failures,
} => {
let new_count = consecutive_failures + 1;
if new_count >= self.config.failure_threshold {
*state = State::Open {
since: Instant::now(),
};
} else {
*state = State::Closed {
consecutive_failures: new_count,
};
}
}
State::HalfOpen => {
*state = State::Open {
since: Instant::now(),
};
}
State::Open { .. } => {
}
}
}
}
}