use std::net::IpAddr;
use std::num::NonZeroU32;
use std::sync::Arc;
use governor::clock::DefaultClock;
use governor::middleware::NoOpMiddleware;
use governor::state::keyed::DashMapStateStore;
use governor::{Quota, RateLimiter};
use crate::{ProxyError, RateLimitConfig, Result};
type InnerLimiter = RateLimiter<IpAddr, DashMapStateStore<IpAddr>, DefaultClock, NoOpMiddleware>;
#[derive(Debug, Clone)]
pub struct IpRateLimiter {
inner: Arc<InnerLimiter>,
}
impl IpRateLimiter {
pub fn from_config(config: &RateLimitConfig) -> Result<Self> {
let rps = NonZeroU32::new(config.requests_per_second)
.ok_or_else(|| ProxyError::Internal("requests_per_second must be non-zero".into()))?;
let burst = NonZeroU32::new(config.burst)
.ok_or_else(|| ProxyError::Internal("burst must be non-zero".into()))?;
let quota = Quota::per_second(rps).allow_burst(burst);
let limiter = RateLimiter::dashmap(quota);
Ok(Self {
inner: Arc::new(limiter),
})
}
pub fn check(&self, ip: &IpAddr) -> std::result::Result<(), u64> {
self.inner.check_key(ip).map_err(|not_until| {
not_until
.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()))
.as_millis() as u64
})
}
pub fn retain_recent(&self) {
self.inner.retain_recent();
}
pub fn tracked_ip_count(&self) -> usize {
self.inner.len()
}
}