use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::policy::Policy;
use crate::error::DoOverError;
#[derive(Clone, Copy)]
enum State {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
failure_threshold: usize,
reset_timeout: Duration,
failures: AtomicUsize,
opened_at: RwLock<Option<Instant>>,
state: RwLock<State>,
}
impl Clone for CircuitBreaker {
fn clone(&self) -> Self {
Self {
failure_threshold: self.failure_threshold,
reset_timeout: self.reset_timeout,
failures: AtomicUsize::new(self.failures.load(Ordering::Relaxed)),
opened_at: RwLock::new(*self.opened_at.blocking_read()),
state: RwLock::new(*self.state.blocking_read()),
}
}
}
impl CircuitBreaker {
pub fn new(failure_threshold: usize, reset_timeout: Duration) -> Self {
Self {
failure_threshold,
reset_timeout,
failures: AtomicUsize::new(0),
opened_at: RwLock::new(None),
state: RwLock::new(State::Closed),
}
}
}
#[async_trait::async_trait]
impl<E> Policy<DoOverError<E>> for CircuitBreaker
where
E: Send + Sync,
{
async fn execute<F, Fut, T>(&self, f: F) -> Result<T, DoOverError<E>>
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, DoOverError<E>>> + Send,
T: Send,
{
{
let state = self.state.read().await;
if matches!(*state, State::Open) {
let opened = self.opened_at.read().await;
if let Some(t) = *opened {
if t.elapsed() >= self.reset_timeout {
drop(opened);
*self.state.write().await = State::HalfOpen;
} else {
return Err(DoOverError::CircuitOpen);
}
}
}
}
match f().await {
Ok(v) => {
self.failures.store(0, Ordering::Relaxed);
*self.state.write().await = State::Closed;
Ok(v)
}
Err(e) => {
let count = self.failures.fetch_add(1, Ordering::Relaxed) + 1;
if count >= self.failure_threshold {
*self.state.write().await = State::Open;
*self.opened_at.write().await = Some(Instant::now());
}
Err(e)
}
}
}
}