use bytes::Bytes;
use dashmap::DashMap;
use http::StatusCode;
use http_body_util::Full;
use rustapi_core::middleware::{BoxedNext, MiddlewareLayer};
use rustapi_core::{Request, Response, ResponseBody};
use std::collections::VecDeque;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
enum RateLimitEntry {
FixedWindow { count: u32, window_start: Instant },
SlidingWindow { requests: VecDeque<Instant> },
TokenBucket { tokens: f64, last_refill: Instant },
}
#[derive(Debug, Clone, Copy)]
struct RateLimitDecision {
is_allowed: bool,
remaining: u32,
retry_after: Duration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitStrategy {
FixedWindow,
SlidingWindow,
TokenBucket,
}
#[derive(Debug)]
struct RateLimitStore {
entries: DashMap<IpAddr, RateLimitEntry>,
}
impl RateLimitStore {
fn new() -> Self {
Self {
entries: DashMap::new(),
}
}
fn check_and_update(
&self,
ip: IpAddr,
max_requests: u32,
window: Duration,
strategy: RateLimitStrategy,
) -> (bool, u32, u32, u64) {
let now = Instant::now();
let mut entry = self
.entries
.entry(ip)
.or_insert_with(|| RateLimitStore::new_entry(strategy, max_requests, now));
let decision = match (&mut *entry, strategy) {
(
RateLimitEntry::FixedWindow {
count,
window_start,
},
RateLimitStrategy::FixedWindow,
) => {
if now.duration_since(*window_start) >= window {
*count = 0;
*window_start = now;
}
*count += 1;
RateLimitDecision {
is_allowed: *count <= max_requests,
remaining: max_requests.saturating_sub(*count),
retry_after: window.saturating_sub(now.duration_since(*window_start)),
}
}
(RateLimitEntry::SlidingWindow { requests }, RateLimitStrategy::SlidingWindow) => {
while let Some(oldest) = requests.front() {
if now.duration_since(*oldest) >= window {
requests.pop_front();
} else {
break;
}
}
let is_allowed = requests.len() < max_requests as usize;
if is_allowed {
requests.push_back(now);
}
let retry_after = requests
.front()
.map(|oldest| window.saturating_sub(now.duration_since(*oldest)))
.unwrap_or(Duration::ZERO);
RateLimitDecision {
is_allowed,
remaining: max_requests.saturating_sub(requests.len() as u32),
retry_after,
}
}
(
RateLimitEntry::TokenBucket {
tokens,
last_refill,
},
RateLimitStrategy::TokenBucket,
) => {
let refill_rate = max_requests as f64 / window.as_secs_f64().max(f64::EPSILON);
let elapsed = now.duration_since(*last_refill).as_secs_f64();
*tokens = (*tokens + elapsed * refill_rate).min(max_requests as f64);
*last_refill = now;
let is_allowed = *tokens >= 1.0;
if is_allowed {
*tokens -= 1.0;
}
let remaining = tokens.floor().max(0.0).min(max_requests as f64) as u32;
let retry_after = next_token_after(*tokens, max_requests, refill_rate);
RateLimitDecision {
is_allowed,
remaining,
retry_after,
}
}
(entry, _) => {
*entry = RateLimitStore::new_entry(strategy, max_requests, now);
let _ = entry;
return self.check_and_update(ip, max_requests, window, strategy);
}
};
let reset = unix_timestamp_after(decision.retry_after);
(
decision.is_allowed,
max_requests.saturating_sub(decision.remaining),
decision.remaining,
reset,
)
}
#[allow(dead_code)]
fn get_info(
&self,
ip: IpAddr,
max_requests: u32,
window: Duration,
strategy: RateLimitStrategy,
) -> Option<RateLimitInfo> {
let now = Instant::now();
self.entries
.get(&ip)
.map(|entry| match (&*entry, strategy) {
(
RateLimitEntry::FixedWindow {
count,
window_start,
},
RateLimitStrategy::FixedWindow,
) => {
let current_count = if now.duration_since(*window_start) >= window {
0
} else {
*count
};
RateLimitInfo {
limit: max_requests,
remaining: max_requests.saturating_sub(current_count),
reset: unix_timestamp_after(
window.saturating_sub(now.duration_since(*window_start)),
),
}
}
(RateLimitEntry::SlidingWindow { requests }, RateLimitStrategy::SlidingWindow) => {
let active = requests
.iter()
.copied()
.filter(|timestamp| now.duration_since(*timestamp) < window)
.collect::<Vec<_>>();
let retry_after = active
.first()
.map(|oldest| window.saturating_sub(now.duration_since(*oldest)))
.unwrap_or(Duration::ZERO);
RateLimitInfo {
limit: max_requests,
remaining: max_requests.saturating_sub(active.len() as u32),
reset: unix_timestamp_after(retry_after),
}
}
(
RateLimitEntry::TokenBucket {
tokens,
last_refill,
},
RateLimitStrategy::TokenBucket,
) => {
let refill_rate = max_requests as f64 / window.as_secs_f64().max(f64::EPSILON);
let elapsed = now.duration_since(*last_refill).as_secs_f64();
let available = (*tokens + elapsed * refill_rate).min(max_requests as f64);
let retry_after = next_token_after(available, max_requests, refill_rate);
RateLimitInfo {
limit: max_requests,
remaining: available.floor().max(0.0).min(max_requests as f64) as u32,
reset: unix_timestamp_after(retry_after),
}
}
_ => RateLimitInfo {
limit: max_requests,
remaining: max_requests,
reset: unix_timestamp_after(Duration::ZERO),
},
})
}
fn new_entry(strategy: RateLimitStrategy, max_requests: u32, now: Instant) -> RateLimitEntry {
match strategy {
RateLimitStrategy::FixedWindow => RateLimitEntry::FixedWindow {
count: 0,
window_start: now,
},
RateLimitStrategy::SlidingWindow => RateLimitEntry::SlidingWindow {
requests: VecDeque::new(),
},
RateLimitStrategy::TokenBucket => RateLimitEntry::TokenBucket {
tokens: max_requests as f64,
last_refill: now,
},
}
}
}
fn next_token_after(tokens: f64, max_requests: u32, refill_rate: f64) -> Duration {
if refill_rate <= f64::EPSILON || tokens >= max_requests as f64 {
return Duration::ZERO;
}
let fractional = tokens.fract();
let needed = if fractional <= f64::EPSILON {
1.0
} else {
1.0 - fractional
};
Duration::from_secs_f64((needed / refill_rate).max(0.0))
}
fn unix_timestamp_after(duration: Duration) -> u64 {
unix_now_secs() + duration_to_header_secs(duration)
}
fn unix_now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn duration_to_header_secs(duration: Duration) -> u64 {
if duration.is_zero() {
0
} else {
duration.as_secs().max(1)
}
}
#[derive(Clone)]
pub struct RateLimitLayer {
requests: u32,
window: Duration,
strategy: RateLimitStrategy,
store: Arc<RateLimitStore>,
}
impl RateLimitLayer {
pub fn new(requests: u32, window: Duration) -> Self {
Self {
requests,
window,
strategy: RateLimitStrategy::FixedWindow,
store: Arc::new(RateLimitStore::new()),
}
}
pub fn sliding_window(requests: u32, window: Duration) -> Self {
Self {
requests,
window,
strategy: RateLimitStrategy::SlidingWindow,
store: Arc::new(RateLimitStore::new()),
}
}
pub fn token_bucket(capacity: u32, refill_window: Duration) -> Self {
Self {
requests: capacity,
window: refill_window,
strategy: RateLimitStrategy::TokenBucket,
store: Arc::new(RateLimitStore::new()),
}
}
pub fn requests(&self) -> u32 {
self.requests
}
pub fn window(&self) -> Duration {
self.window
}
pub fn strategy(&self) -> RateLimitStrategy {
self.strategy
}
#[cfg(test)]
#[allow(dead_code)]
#[allow(private_interfaces)]
pub(crate) fn store(&self) -> &Arc<RateLimitStore> {
&self.store
}
fn extract_client_ip(req: &Request) -> IpAddr {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(forwarded_str) = forwarded.to_str() {
if let Some(first_ip) = forwarded_str.split(',').next() {
if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
return ip;
}
}
}
}
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(ip_str) = real_ip.to_str() {
if let Ok(ip) = ip_str.trim().parse::<IpAddr>() {
return ip;
}
}
}
"127.0.0.1".parse().unwrap()
}
}
impl MiddlewareLayer for RateLimitLayer {
fn call(
&self,
req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
let store = self.store.clone();
let max_requests = self.requests;
let window = self.window;
let strategy = self.strategy;
Box::pin(async move {
let client_ip = RateLimitLayer::extract_client_ip(&req);
let (is_allowed, _count, remaining, reset) =
store.check_and_update(client_ip, max_requests, window, strategy);
if !is_allowed {
let now_secs = unix_now_secs();
let retry_after = reset.saturating_sub(now_secs);
let error_body = serde_json::json!({
"error": {
"type": "rate_limit_exceeded",
"message": "Too many requests",
"retry_after": retry_after
}
});
let body = serde_json::to_vec(&error_body).unwrap_or_default();
return http::Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header(http::header::CONTENT_TYPE, "application/json")
.header("X-RateLimit-Limit", max_requests.to_string())
.header("X-RateLimit-Remaining", "0")
.header("X-RateLimit-Reset", reset.to_string())
.header("Retry-After", retry_after.to_string())
.body(ResponseBody::Full(Full::new(Bytes::from(body))))
.unwrap();
}
let mut response = next(req).await;
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
max_requests.to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Remaining",
remaining.to_string().parse().unwrap(),
);
headers.insert("X-RateLimit-Reset", reset.to_string().parse().unwrap());
response
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RateLimitInfo {
pub limit: u32,
pub remaining: u32,
pub reset: u64,
}
#[cfg(test)]
#[allow(dead_code)]
fn create_rate_limit_response(limit: u32, reset: u64, retry_after: u64) -> Response {
let error_body = serde_json::json!({
"error": {
"type": "rate_limit_exceeded",
"message": "Too many requests",
"retry_after": retry_after
}
});
let body = serde_json::to_vec(&error_body).unwrap_or_default();
http::Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header(http::header::CONTENT_TYPE, "application/json")
.header("X-RateLimit-Limit", limit.to_string())
.header("X-RateLimit-Remaining", "0")
.header("X-RateLimit-Reset", reset.to_string())
.header("Retry-After", retry_after.to_string())
.body(ResponseBody::Full(Full::new(Bytes::from(body))))
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http::{Method, StatusCode};
use proptest::prelude::*;
use proptest::test_runner::TestCaseError;
use rustapi_core::middleware::LayerStack;
use std::sync::Arc;
fn create_test_request(ip: Option<&str>) -> Request {
let uri: http::Uri = "/test".parse().unwrap();
let mut builder = http::Request::builder().method(Method::GET).uri(uri);
if let Some(ip_str) = ip {
builder = builder.header("X-Forwarded-For", ip_str);
}
let req = builder.body(()).unwrap();
Request::from_http_request(req, Bytes::new())
}
fn create_success_handler() -> BoxedNext {
Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(StatusCode::OK)
.body(ResponseBody::Full(Full::new(Bytes::from("success"))))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
})
}
fn ipv4_strategy() -> impl Strategy<Value = String> {
(1u8..255, 0u8..255, 0u8..255, 1u8..255)
.prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d))
}
fn rate_limit_config_strategy() -> impl Strategy<Value = (u32, u64)> {
(1u32..100, 1u64..60)
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_rate_limit_state_tracking(
(max_requests, window_secs) in rate_limit_config_strategy(),
num_requests in 1u32..50,
ip in ipv4_strategy(),
) {
let num_requests = num_requests.min(max_requests);
let rt = tokio::runtime::Runtime::new().unwrap();
let result: std::result::Result<(), TestCaseError> = rt.block_on(async {
let layer = RateLimitLayer::new(max_requests, Duration::from_secs(window_secs));
let mut stack = LayerStack::new();
stack.push(Box::new(layer));
for k in 1..=num_requests {
let handler = create_success_handler();
let request = create_test_request(Some(&ip));
let response = stack.execute(request, handler).await;
prop_assert_eq!(
response.status(),
StatusCode::OK,
"Request {} of {} should be allowed",
k,
num_requests
);
let limit_header = response.headers().get("X-RateLimit-Limit");
prop_assert!(limit_header.is_some(), "X-RateLimit-Limit header should be present");
let limit_value: u32 = limit_header.unwrap().to_str().unwrap().parse().unwrap();
prop_assert_eq!(
limit_value,
max_requests,
"X-RateLimit-Limit should equal configured limit"
);
let remaining_header = response.headers().get("X-RateLimit-Remaining");
prop_assert!(remaining_header.is_some(), "X-RateLimit-Remaining header should be present");
let remaining_value: u32 = remaining_header.unwrap().to_str().unwrap().parse().unwrap();
let expected_remaining = max_requests.saturating_sub(k);
prop_assert_eq!(
remaining_value,
expected_remaining,
"X-RateLimit-Remaining should be {} after {} requests (limit: {})",
expected_remaining,
k,
max_requests
);
let reset_header = response.headers().get("X-RateLimit-Reset");
prop_assert!(reset_header.is_some(), "X-RateLimit-Reset header should be present");
}
Ok(())
});
result?;
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_rate_limit_enforcement(
max_requests in 1u32..20,
window_secs in 10u64..120,
ip in ipv4_strategy(),
) {
let rt = tokio::runtime::Runtime::new().unwrap();
let result: std::result::Result<(), TestCaseError> = rt.block_on(async {
let layer = RateLimitLayer::new(max_requests, Duration::from_secs(window_secs));
let mut stack = LayerStack::new();
stack.push(Box::new(layer));
for k in 1..=max_requests {
let handler = create_success_handler();
let request = create_test_request(Some(&ip));
let response = stack.execute(request, handler).await;
prop_assert_eq!(
response.status(),
StatusCode::OK,
"Request {} of {} should be allowed",
k,
max_requests
);
}
let handler = create_success_handler();
let request = create_test_request(Some(&ip));
let response = stack.execute(request, handler).await;
prop_assert_eq!(
response.status(),
StatusCode::TOO_MANY_REQUESTS,
"Request {} should be rejected with 429",
max_requests + 1
);
let retry_after = response.headers().get("Retry-After");
prop_assert!(retry_after.is_some(), "Retry-After header should be present on 429 response");
let remaining = response.headers().get("X-RateLimit-Remaining");
prop_assert!(remaining.is_some(), "X-RateLimit-Remaining should be present");
let remaining_value: u32 = remaining.unwrap().to_str().unwrap().parse().unwrap();
prop_assert_eq!(remaining_value, 0, "X-RateLimit-Remaining should be 0 when limit exceeded");
let body_bytes = {
use http_body_util::BodyExt;
let body = response.into_body();
body.collect().await.unwrap().to_bytes()
};
let body_str = String::from_utf8_lossy(&body_bytes);
prop_assert!(
body_str.contains("\"type\":\"rate_limit_exceeded\"") ||
body_str.contains("\"type\": \"rate_limit_exceeded\""),
"Response body should contain error type 'rate_limit_exceeded', got: {}",
body_str
);
Ok(())
});
result?;
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn prop_rate_limit_window_reset(
max_requests in 1u32..10,
ip in ipv4_strategy(),
) {
let rt = tokio::runtime::Runtime::new().unwrap();
let result: std::result::Result<(), TestCaseError> = rt.block_on(async {
let window = Duration::from_millis(10);
let layer = RateLimitLayer::new(max_requests, window);
let mut stack = LayerStack::new();
stack.push(Box::new(layer));
for _ in 0..max_requests {
let handler = create_success_handler();
let request = create_test_request(Some(&ip));
let response = stack.execute(request, handler).await;
prop_assert_eq!(response.status(), StatusCode::OK);
}
let handler = create_success_handler();
let request = create_test_request(Some(&ip));
let response = stack.execute(request, handler).await;
prop_assert_eq!(
response.status(),
StatusCode::TOO_MANY_REQUESTS,
"Should be rate limited after exhausting limit"
);
tokio::time::sleep(window + Duration::from_millis(5)).await;
let handler = create_success_handler();
let request = create_test_request(Some(&ip));
let response = stack.execute(request, handler).await;
prop_assert_eq!(
response.status(),
StatusCode::OK,
"Request should be allowed after window reset"
);
let remaining = response.headers().get("X-RateLimit-Remaining");
prop_assert!(remaining.is_some());
let remaining_value: u32 = remaining.unwrap().to_str().unwrap().parse().unwrap();
prop_assert_eq!(
remaining_value,
max_requests - 1,
"Remaining should be {} after window reset and one request",
max_requests - 1
);
Ok(())
});
result?;
}
}
#[test]
fn test_rate_limit_layer_creation() {
let layer = RateLimitLayer::new(100, Duration::from_secs(60));
assert_eq!(layer.requests(), 100);
assert_eq!(layer.window(), Duration::from_secs(60));
assert_eq!(layer.strategy(), RateLimitStrategy::FixedWindow);
}
#[test]
fn test_sliding_window_layer_creation() {
let layer = RateLimitLayer::sliding_window(100, Duration::from_secs(60));
assert_eq!(layer.requests(), 100);
assert_eq!(layer.window(), Duration::from_secs(60));
assert_eq!(layer.strategy(), RateLimitStrategy::SlidingWindow);
}
#[test]
fn test_token_bucket_layer_creation() {
let layer = RateLimitLayer::token_bucket(5, Duration::from_secs(10));
assert_eq!(layer.requests(), 5);
assert_eq!(layer.window(), Duration::from_secs(10));
assert_eq!(layer.strategy(), RateLimitStrategy::TokenBucket);
}
#[test]
fn test_extract_client_ip_from_x_forwarded_for() {
let request = create_test_request(Some("192.168.1.1, 10.0.0.1"));
let ip = RateLimitLayer::extract_client_ip(&request);
assert_eq!(ip, "192.168.1.1".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_client_ip_single_ip() {
let request = create_test_request(Some("192.168.1.100"));
let ip = RateLimitLayer::extract_client_ip(&request);
assert_eq!(ip, "192.168.1.100".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_client_ip_default() {
let request = create_test_request(None);
let ip = RateLimitLayer::extract_client_ip(&request);
assert_eq!(ip, "127.0.0.1".parse::<IpAddr>().unwrap());
}
#[test]
fn test_different_ips_have_separate_limits() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let layer = RateLimitLayer::new(2, Duration::from_secs(60));
let mut stack = LayerStack::new();
stack.push(Box::new(layer));
for _ in 0..2 {
let handler = create_success_handler();
let request = create_test_request(Some("192.168.1.1"));
let response = stack.execute(request, handler).await;
assert_eq!(response.status(), StatusCode::OK);
}
let handler = create_success_handler();
let request = create_test_request(Some("192.168.1.1"));
let response = stack.execute(request, handler).await;
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
let handler = create_success_handler();
let request = create_test_request(Some("192.168.1.2"));
let response = stack.execute(request, handler).await;
assert_eq!(response.status(), StatusCode::OK);
});
}
#[test]
fn test_rate_limit_response_body_format() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let layer = RateLimitLayer::new(1, Duration::from_secs(60));
let mut stack = LayerStack::new();
stack.push(Box::new(layer));
let handler = create_success_handler();
let request = create_test_request(Some("10.0.0.1"));
let response = stack.execute(request, handler).await;
assert_eq!(response.status(), StatusCode::OK);
let handler = create_success_handler();
let request = create_test_request(Some("10.0.0.1"));
let response = stack.execute(request, handler).await;
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
use http_body_util::BodyExt;
let body = response.into_body();
let body_bytes = body.collect().await.unwrap().to_bytes();
let body_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body_json["error"]["type"], "rate_limit_exceeded");
assert_eq!(body_json["error"]["message"], "Too many requests");
assert!(body_json["error"]["retry_after"].is_number());
});
}
#[test]
fn test_sliding_window_keeps_recent_requests_in_window() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let layer = RateLimitLayer::sliding_window(2, Duration::from_millis(40));
let mut stack = LayerStack::new();
stack.push(Box::new(layer));
let request = create_test_request(Some("10.0.0.2"));
let response = stack.execute(request, create_success_handler()).await;
assert_eq!(response.status(), StatusCode::OK);
tokio::time::sleep(Duration::from_millis(35)).await;
let request = create_test_request(Some("10.0.0.2"));
let response = stack.execute(request, create_success_handler()).await;
assert_eq!(response.status(), StatusCode::OK);
tokio::time::sleep(Duration::from_millis(10)).await;
let request = create_test_request(Some("10.0.0.2"));
let response = stack.execute(request, create_success_handler()).await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("X-RateLimit-Remaining").unwrap(),
"0"
);
});
}
#[test]
fn test_token_bucket_refills_after_wait() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let layer = RateLimitLayer::token_bucket(2, Duration::from_millis(40));
let mut stack = LayerStack::new();
stack.push(Box::new(layer));
for _ in 0..2 {
let request = create_test_request(Some("10.0.0.3"));
let response = stack.execute(request, create_success_handler()).await;
assert_eq!(response.status(), StatusCode::OK);
}
let request = create_test_request(Some("10.0.0.3"));
let response = stack.execute(request, create_success_handler()).await;
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
tokio::time::sleep(Duration::from_millis(25)).await;
let request = create_test_request(Some("10.0.0.3"));
let response = stack.execute(request, create_success_handler()).await;
assert_eq!(response.status(), StatusCode::OK);
});
}
}