use std::collections::HashMap;
use std::fmt;
use std::net::IpAddr;
use std::sync::{
atomic::{AtomicU64, Ordering},
RwLock,
};
use crate::internal::common::epoch;
use crate::Status;
use crate::{
middleware::{MiddleResult, Middleware},
Content, Request, Response,
};
type Handler = Box<dyn Fn(&Request) -> Option<Response> + Send + Sync>;
pub struct RateLimiter {
req_limit: u64,
last_reset: AtomicU64,
req_timeout: u64,
requests: RwLock<HashMap<IpAddr, u64>>,
handler: Handler,
}
impl RateLimiter {
pub fn new() -> RateLimiter {
RateLimiter {
last_reset: AtomicU64::new(0),
req_limit: 10,
req_timeout: 60,
requests: RwLock::new(HashMap::new()),
handler: Box::new(|_| {
Some(
Response::new()
.status(Status::TooManyRequests)
.text("Too Many Requests")
.content(Content::TXT),
)
}),
}
}
pub fn limit(self, limit: u64) -> RateLimiter {
RateLimiter {
req_limit: limit,
..self
}
}
pub fn timeout(self, timeout: u64) -> RateLimiter {
RateLimiter {
req_timeout: timeout,
..self
}
}
pub fn handler(self, handler: Handler) -> RateLimiter {
RateLimiter { handler, ..self }
}
fn add_request(&self, ip: IpAddr) {
let mut req = self.requests.write().unwrap();
let count = req.get(&ip).unwrap_or(&0) + 1;
req.insert(ip, count);
}
fn check_reset(&self) {
let time = epoch().as_secs();
if self.last_reset.load(Ordering::Acquire) + self.req_timeout <= time {
self.requests.write().unwrap().clear();
self.last_reset.store(time, Ordering::Release);
}
}
fn is_over_limit(&self, ip: IpAddr) -> bool {
self.requests.read().unwrap().get(&ip).unwrap_or(&0) >= &self.req_limit
}
}
impl Middleware for RateLimiter {
fn pre(&self, req: &mut Request) -> MiddleResult {
if self.is_over_limit(req.address.ip()) {
if let Some(i) = (self.handler)(req) {
return MiddleResult::Send(i);
}
}
MiddleResult::Continue
}
fn end(&self, req: &Request, _res: &Response) {
self.check_reset();
self.add_request(req.address.ip());
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("RateLimiter")
.field("req_limit", &self.req_limit)
.field("req_timeout", &self.req_timeout)
.field("last_reset", &self.last_reset)
.field("requests", &self.requests)
.finish()
}
}