use reinhardt_http::Handler;
use reinhardt_http::{Request, Response};
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitStrategy {
FixedWindow,
SlidingWindow,
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: usize,
pub window_duration: Duration,
pub strategy: RateLimitStrategy,
pub trusted_proxies: Vec<String>,
}
impl RateLimitConfig {
pub fn new(
max_requests: usize,
window_duration: Duration,
strategy: RateLimitStrategy,
) -> Self {
Self {
max_requests,
window_duration,
strategy,
trusted_proxies: Vec::new(),
}
}
pub fn per_minute(max_requests: usize) -> Self {
Self::new(
max_requests,
Duration::from_secs(60),
RateLimitStrategy::FixedWindow,
)
}
pub fn per_hour(max_requests: usize) -> Self {
Self::new(
max_requests,
Duration::from_secs(3600),
RateLimitStrategy::FixedWindow,
)
}
pub fn with_trusted_proxies(mut self, proxies: Vec<String>) -> Self {
self.trusted_proxies = proxies;
self
}
}
#[derive(Debug, Clone)]
struct RateLimitEntry {
count: usize,
window_start: Instant,
}
#[derive(Debug, Clone)]
struct SlidingWindowEntry {
timestamps: Vec<Instant>,
}
pub struct RateLimitHandler {
inner: Arc<dyn Handler>,
config: RateLimitConfig,
limits: Arc<RwLock<HashMap<IpAddr, RateLimitEntry>>>,
sliding_limits: Arc<RwLock<HashMap<IpAddr, SlidingWindowEntry>>>,
}
impl RateLimitHandler {
pub fn new(inner: Arc<dyn Handler>, config: RateLimitConfig) -> Self {
Self {
inner,
config,
limits: Arc::new(RwLock::new(HashMap::new())),
sliding_limits: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn is_allowed(&self, ip: IpAddr) -> bool {
match self.config.strategy {
RateLimitStrategy::FixedWindow => self.is_allowed_fixed_window(ip).await,
RateLimitStrategy::SlidingWindow => self.is_allowed_sliding_window(ip).await,
}
}
async fn is_allowed_fixed_window(&self, ip: IpAddr) -> bool {
let now = Instant::now();
let mut limits = self.limits.write().await;
if limits.len() > 1024 {
limits.retain(|_, entry| {
now.duration_since(entry.window_start) < self.config.window_duration * 2
});
}
let entry = limits.entry(ip).or_insert(RateLimitEntry {
count: 0,
window_start: now,
});
if now.duration_since(entry.window_start) >= self.config.window_duration {
entry.count = 0;
entry.window_start = now;
}
if entry.count < self.config.max_requests {
entry.count += 1;
true
} else {
false
}
}
async fn is_allowed_sliding_window(&self, ip: IpAddr) -> bool {
let now = Instant::now();
let window = self.config.window_duration;
let mut limits = self.sliding_limits.write().await;
if limits.len() > 1024 {
limits.retain(|_, entry| {
entry
.timestamps
.last()
.is_some_and(|&last| now.duration_since(last) < window * 2)
});
}
let entry = limits.entry(ip).or_insert(SlidingWindowEntry {
timestamps: Vec::new(),
});
entry
.timestamps
.retain(|&ts| now.duration_since(ts) < window);
if entry.timestamps.len() < self.config.max_requests {
entry.timestamps.push(now);
true
} else {
false
}
}
fn extract_client_ip(&self, request: &Request) -> IpAddr {
let peer_ip = request.remote_addr.map(|addr| addr.ip());
let from_trusted_proxy = peer_ip.map(|ip| self.is_trusted_proxy(ip)).unwrap_or(false);
if from_trusted_proxy {
if let Some(xff) = request.headers.get("X-Forwarded-For")
&& let Ok(xff_str) = xff.to_str()
&& let Some(first_ip) = xff_str.split(',').next()
&& let Ok(ip) = first_ip.trim().parse()
{
return ip;
}
if let Some(xri) = request.headers.get("X-Real-IP")
&& let Ok(ip_str) = xri.to_str()
&& let Ok(ip) = ip_str.parse()
{
return ip;
}
}
if let Some(ip) = peer_ip {
return ip;
}
"127.0.0.1".parse().unwrap()
}
fn is_trusted_proxy(&self, ip: IpAddr) -> bool {
self.config.trusted_proxies.iter().any(|proxy| {
if let Ok(network) = proxy.parse::<ipnet::IpNet>() {
return network.contains(&ip);
}
if let Ok(proxy_ip) = proxy.parse::<IpAddr>() {
return proxy_ip == ip;
}
false
})
}
}
#[async_trait::async_trait]
impl Handler for RateLimitHandler {
async fn handle(&self, request: Request) -> reinhardt_core::exception::Result<Response> {
let client_ip = self.extract_client_ip(&request);
if self.is_allowed(client_ip).await {
self.inner.handle(request).await
} else {
Ok(Response::new(http::StatusCode::TOO_MANY_REQUESTS).with_body("Rate limit exceeded"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
async fn poll_until<F, Fut>(
timeout: std::time::Duration,
interval: std::time::Duration,
mut condition: F,
) -> Result<(), String>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = bool>,
{
let start = std::time::Instant::now();
while start.elapsed() < timeout {
if condition().await {
return Ok(());
}
tokio::time::sleep(interval).await;
}
Err(format!("Timeout after {:?} waiting for condition", timeout))
}
struct TestHandler;
#[async_trait::async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
Ok(Response::ok().with_body("Success"))
}
}
#[tokio::test]
async fn test_rate_limit_config_creation() {
let config = RateLimitConfig::per_minute(60);
assert_eq!(config.max_requests, 60);
assert_eq!(config.window_duration, Duration::from_secs(60));
let config = RateLimitConfig::per_hour(1000);
assert_eq!(config.max_requests, 1000);
assert_eq!(config.window_duration, Duration::from_secs(3600));
}
#[tokio::test]
async fn test_rate_limit_handler_creation() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::per_minute(10);
let _rate_limit_handler = RateLimitHandler::new(handler, config);
}
#[tokio::test]
async fn test_requests_within_limit() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::per_minute(5);
let rate_limit_handler = RateLimitHandler::new(handler, config);
for _ in 0..5 {
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let response = rate_limit_handler.handle(request).await.unwrap();
assert_eq!(response.status, http::StatusCode::OK);
}
}
#[tokio::test]
async fn test_requests_exceed_limit() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::per_minute(3);
let rate_limit_handler = RateLimitHandler::new(handler, config);
for _ in 0..3 {
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let response = rate_limit_handler.handle(request).await.unwrap();
assert_eq!(response.status, http::StatusCode::OK);
}
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let response = rate_limit_handler.handle(request).await.unwrap();
assert_eq!(response.status, http::StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn test_rate_limit_window_reset() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::new(
2,
Duration::from_millis(100),
RateLimitStrategy::FixedWindow,
);
let rate_limit_handler = RateLimitHandler::new(handler, config);
for _ in 0..2 {
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let response = rate_limit_handler.handle(request).await.unwrap();
assert_eq!(response.status, http::StatusCode::OK);
}
poll_until(
Duration::from_millis(200),
Duration::from_millis(10),
|| async {
let test_request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let test_response = rate_limit_handler.handle(test_request).await.unwrap();
test_response.status == http::StatusCode::OK
},
)
.await
.expect("Window should reset within 200ms");
}
#[tokio::test]
async fn test_sliding_window_requests_within_limit() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::new(
3,
Duration::from_millis(200),
RateLimitStrategy::SlidingWindow,
);
let rate_limit_handler = RateLimitHandler::new(handler, config);
for _ in 0..3 {
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let response = rate_limit_handler.handle(request).await.unwrap();
assert_eq!(response.status, http::StatusCode::OK);
}
}
#[tokio::test]
async fn test_sliding_window_requests_exceed_limit() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::new(
2,
Duration::from_millis(200),
RateLimitStrategy::SlidingWindow,
);
let rate_limit_handler = RateLimitHandler::new(handler, config);
for _ in 0..2 {
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let response = rate_limit_handler.handle(request).await.unwrap();
assert_eq!(response.status, http::StatusCode::OK);
}
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let response = rate_limit_handler.handle(request).await.unwrap();
assert_eq!(response.status, http::StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn test_sliding_window_expires_old_requests() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::new(
2,
Duration::from_millis(100),
RateLimitStrategy::SlidingWindow,
);
let rate_limit_handler = RateLimitHandler::new(handler, config);
for _ in 0..2 {
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let response = rate_limit_handler.handle(request).await.unwrap();
assert_eq!(response.status, http::StatusCode::OK);
}
poll_until(
Duration::from_millis(200),
Duration::from_millis(10),
|| async {
let test_request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(http::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap();
let test_response = rate_limit_handler.handle(test_request).await.unwrap();
test_response.status == http::StatusCode::OK
},
)
.await
.expect("Sliding window should allow requests after old timestamps expire");
}
#[test]
fn test_extract_client_ip_from_trusted_xff() {
let handler = Arc::new(TestHandler);
let config =
RateLimitConfig::per_minute(10).with_trusted_proxies(vec!["10.0.0.1".to_string()]);
let rate_limit_handler = RateLimitHandler::new(handler, config);
let mut headers = http::HeaderMap::new();
headers.insert(
"X-Forwarded-For",
"192.168.1.100, 10.0.0.1, 172.16.0.1".parse().unwrap(),
);
let mut request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(headers)
.body(bytes::Bytes::new())
.build()
.unwrap();
request.remote_addr = Some("10.0.0.1:12345".parse().unwrap());
let ip = rate_limit_handler.extract_client_ip(&request);
assert_eq!(ip, "192.168.1.100".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_client_ip_ignores_untrusted_xff() {
let handler = Arc::new(TestHandler);
let config =
RateLimitConfig::per_minute(10).with_trusted_proxies(vec!["10.0.0.1".to_string()]);
let rate_limit_handler = RateLimitHandler::new(handler, config);
let mut headers = http::HeaderMap::new();
headers.insert("X-Forwarded-For", "192.168.1.100".parse().unwrap());
let mut request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(headers)
.body(bytes::Bytes::new())
.build()
.unwrap();
request.remote_addr = Some("203.0.113.42:54321".parse().unwrap());
let ip = rate_limit_handler.extract_client_ip(&request);
assert_eq!(ip, "203.0.113.42".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_client_ip_from_trusted_x_real_ip() {
let handler = Arc::new(TestHandler);
let config =
RateLimitConfig::per_minute(10).with_trusted_proxies(vec!["10.0.0.0/8".to_string()]);
let rate_limit_handler = RateLimitHandler::new(handler, config);
let mut headers = http::HeaderMap::new();
headers.insert("X-Real-IP", "203.0.113.42".parse().unwrap());
let mut request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(headers)
.body(bytes::Bytes::new())
.build()
.unwrap();
request.remote_addr = Some("10.0.0.5:8080".parse().unwrap());
let ip = rate_limit_handler.extract_client_ip(&request);
assert_eq!(ip, "203.0.113.42".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_client_ip_fallback_to_localhost() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::per_minute(10);
let rate_limit_handler = RateLimitHandler::new(handler, config);
let headers = http::HeaderMap::new();
let request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(headers)
.body(bytes::Bytes::new())
.build()
.unwrap();
let ip = rate_limit_handler.extract_client_ip(&request);
assert_eq!(ip, "127.0.0.1".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_client_ip_no_trusted_proxies() {
let handler = Arc::new(TestHandler);
let config = RateLimitConfig::per_minute(10);
let rate_limit_handler = RateLimitHandler::new(handler, config);
let mut headers = http::HeaderMap::new();
headers.insert("X-Forwarded-For", "192.168.1.100".parse().unwrap());
let mut request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(headers)
.body(bytes::Bytes::new())
.build()
.unwrap();
request.remote_addr = Some("203.0.113.1:8080".parse().unwrap());
let ip = rate_limit_handler.extract_client_ip(&request);
assert_eq!(ip, "203.0.113.1".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_client_ip_with_invalid_header() {
let handler = Arc::new(TestHandler);
let config =
RateLimitConfig::per_minute(10).with_trusted_proxies(vec!["10.0.0.1".to_string()]);
let rate_limit_handler = RateLimitHandler::new(handler, config);
let mut headers = http::HeaderMap::new();
headers.insert("X-Forwarded-For", "invalid-ip".parse().unwrap());
let mut request = Request::builder()
.method(http::Method::GET)
.uri("/")
.version(http::Version::HTTP_11)
.headers(headers)
.body(bytes::Bytes::new())
.build()
.unwrap();
request.remote_addr = Some("10.0.0.1:8080".parse().unwrap());
let ip = rate_limit_handler.extract_client_ip(&request);
assert_eq!(ip, "10.0.0.1".parse::<IpAddr>().unwrap());
}
}