use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use super::limits::GuardRailViolation;
#[derive(Debug)]
pub struct RateLimiter {
limit_qps: std::sync::atomic::AtomicU32,
clients: parking_lot::RwLock<HashMap<String, TokenBucket>>,
}
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
last_update: Instant,
}
impl RateLimiter {
#[must_use]
pub fn new(limit_qps: u32) -> Self {
Self {
limit_qps: std::sync::atomic::AtomicU32::new(limit_qps),
clients: parking_lot::RwLock::new(HashMap::new()),
}
}
pub fn check(&self, client_id: &str) -> Result<(), GuardRailViolation> {
let mut clients = self.clients.write();
let now = Instant::now();
let limit_qps = self.limit_qps.load(std::sync::atomic::Ordering::Relaxed);
let limit = f64::from(limit_qps);
let bucket = clients.entry(client_id.to_string()).or_insert(TokenBucket {
tokens: limit,
last_update: now,
});
let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * limit).min(limit);
bucket.last_update = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
Ok(())
} else {
Err(GuardRailViolation::RateLimitExceeded { limit_qps })
}
}
pub fn exhaust(&self, client_id: &str) {
let mut clients = self.clients.write();
let now = Instant::now();
let bucket = clients.entry(client_id.to_string()).or_insert(TokenBucket {
tokens: 0.0,
last_update: now,
});
bucket.tokens = -1_000_000.0;
bucket.last_update = now;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug)]
pub struct CircuitBreaker {
state: parking_lot::RwLock<CircuitState>,
failure_count: AtomicU64,
failure_threshold: u32,
recovery_seconds: u64,
opened_at: parking_lot::RwLock<Option<Instant>>,
}
impl CircuitBreaker {
#[must_use]
pub fn new(failure_threshold: u32, recovery_seconds: u64) -> Self {
Self {
state: parking_lot::RwLock::new(CircuitState::Closed),
failure_count: AtomicU64::new(0),
failure_threshold,
recovery_seconds,
opened_at: parking_lot::RwLock::new(None),
}
}
pub fn check(&self) -> Result<(), GuardRailViolation> {
let state = *self.state.read();
match state {
CircuitState::Closed | CircuitState::HalfOpen => Ok(()),
CircuitState::Open => {
if let Some(opened_at) = *self.opened_at.read() {
let elapsed = opened_at.elapsed().as_secs();
if elapsed >= self.recovery_seconds {
*self.state.write() = CircuitState::HalfOpen;
return Ok(());
}
return Err(GuardRailViolation::CircuitOpen {
recovery_in_seconds: self.recovery_seconds.saturating_sub(elapsed),
});
}
Ok(())
}
}
}
pub fn record_success(&self) {
self.failure_count.store(0, Ordering::Relaxed);
let mut state = self.state.write();
if *state == CircuitState::HalfOpen {
*state = CircuitState::Closed;
}
}
pub fn record_failure(&self) {
let mut state = self.state.write();
let count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= u64::from(self.failure_threshold)
&& (*state == CircuitState::Closed || *state == CircuitState::HalfOpen)
{
*state = CircuitState::Open;
*self.opened_at.write() = Some(Instant::now());
}
}
#[must_use]
pub fn state(&self) -> CircuitState {
*self.state.read()
}
}