use std::{
collections::HashMap,
future::{ready, Ready},
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::{Duration, Instant},
};
use actix_service::{Service, Transform};
use actix_web::{
body::MessageBody,
dev::{forward_ready, ServiceRequest, ServiceResponse},
http::header,
Error, HttpResponse,
};
use parking_lot::Mutex;
#[derive(Clone, Debug)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window: Duration,
}
impl RateLimitConfig {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
max_requests,
window,
}
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
}
}
}
#[derive(Clone)]
pub struct RateLimit {
config: RateLimitConfig,
}
impl RateLimit {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
config: RateLimitConfig::new(max_requests, window),
}
}
pub fn with_config(config: RateLimitConfig) -> Self {
Self { config }
}
}
impl<S, B> Transform<S, ServiceRequest> for RateLimit
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Transform = RateLimitMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RateLimitMiddleware {
service,
config: self.config.clone(),
buckets: Arc::new(Mutex::new(HashMap::new())),
sweep_counter: Arc::new(AtomicUsize::new(0)),
}))
}
}
pub struct RateLimitMiddleware<S> {
service: S,
config: RateLimitConfig,
buckets: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
sweep_counter: Arc<AtomicUsize>,
}
struct CheckResult {
allowed: bool,
remaining: u32,
}
impl<S> RateLimitMiddleware<S> {
fn check(&self, key: &str) -> CheckResult {
let now = Instant::now();
let window = self.config.window;
let max = self.config.max_requests;
let mut buckets = self.buckets.lock();
if self.sweep_counter.fetch_add(1, Ordering::Relaxed) % 100 == 0 {
buckets.retain(|_, timestamps| {
timestamps.retain(|&t| now.duration_since(t) < window);
!timestamps.is_empty()
});
}
let timestamps = buckets.entry(key.to_string()).or_default();
timestamps.retain(|&t| now.duration_since(t) < window);
let count = timestamps.len() as u32;
let remaining = max.saturating_sub(count);
let allowed = remaining > 0;
if allowed {
timestamps.push(now);
}
let is_empty = timestamps.is_empty();
let result = CheckResult {
allowed,
remaining: remaining.saturating_sub(if allowed { 1 } else { 0 }),
};
if is_empty {
buckets.remove(key);
}
result
}
}
impl<S, B> Service<ServiceRequest> for RateLimitMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Future =
std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let key = req
.connection_info()
.peer_addr()
.unwrap_or("unknown")
.to_string();
let result = self.check(&key);
if !result.allowed {
let (req, _) = req.into_parts();
let response = HttpResponse::TooManyRequests()
.insert_header(("x-ratelimit-remaining", "0"))
.insert_header(("x-ratelimit-limit", self.config.max_requests.to_string()))
.body("Rate limit exceeded");
return Box::pin(ready(Ok(ServiceResponse::new(req, response))));
}
let remaining = result.remaining;
let max = self.config.max_requests;
let fut = self.service.call(req);
Box::pin(async move {
let mut res = fut.await?.map_into_boxed_body();
res.headers_mut().insert(
header::HeaderName::from_static("x-ratelimit-remaining"),
header::HeaderValue::from(remaining),
);
res.headers_mut().insert(
header::HeaderName::from_static("x-ratelimit-limit"),
header::HeaderValue::from(max),
);
Ok(res)
})
}
}