use std::collections::HashMap;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use axum::body::Body;
use axum::http::{Request, Response, StatusCode};
use axum::response::IntoResponse;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::RwLock;
use tower::{Layer, Service};
use tracing::{debug, info, warn};
#[derive(Error, Debug, Clone)]
pub enum RateLimitError {
#[error("Rate limit exceeded. Retry after {retry_after} seconds")]
RateLimitExceeded {
retry_after: u64,
limit: u32,
remaining: u32,
reset_at: u64,
},
#[error("Could not determine client IP address")]
IpExtractionFailed,
#[error("Internal rate limiter error: {0}")]
Internal(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub burst_size: u32,
pub cleanup_interval: Duration,
pub bucket_ttl: Duration,
pub trust_x_forwarded_for: bool,
pub trust_x_real_ip: bool,
pub whitelist: Vec<IpAddr>,
pub enabled: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 100,
burst_size: 10,
cleanup_interval: Duration::from_secs(60),
bucket_ttl: Duration::from_secs(300),
trust_x_forwarded_for: false,
trust_x_real_ip: false,
whitelist: Vec::new(),
enabled: true,
}
}
}
impl RateLimitConfig {
pub fn new(requests_per_minute: u32) -> Self {
Self {
requests_per_minute,
..Default::default()
}
}
pub fn with_burst_size(mut self, burst_size: u32) -> Self {
self.burst_size = burst_size;
self
}
pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
self.cleanup_interval = interval;
self
}
pub fn with_bucket_ttl(mut self, ttl: Duration) -> Self {
self.bucket_ttl = ttl;
self
}
pub fn trust_forwarded_for(mut self) -> Self {
self.trust_x_forwarded_for = true;
self
}
pub fn trust_real_ip(mut self) -> Self {
self.trust_x_real_ip = true;
self
}
pub fn whitelist_ip(mut self, ip: IpAddr) -> Self {
self.whitelist.push(ip);
self
}
pub fn disabled(mut self) -> Self {
self.enabled = false;
self
}
fn tokens_per_second(&self) -> f64 {
self.requests_per_minute as f64 / 60.0
}
}
#[derive(Debug, Clone)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64,
last_update: Instant,
last_access: Instant,
}
impl TokenBucket {
fn new(capacity: u32, refill_rate: f64) -> Self {
let now = Instant::now();
Self {
tokens: capacity as f64, capacity: capacity as f64,
refill_rate,
last_update: now,
last_access: now,
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_update).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
self.last_update = now;
self.last_access = now;
}
fn try_consume(&mut self) -> bool {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
fn remaining(&mut self) -> u32 {
self.refill();
self.tokens.floor() as u32
}
fn seconds_until_token(&self) -> u64 {
if self.tokens >= 1.0 {
return 0;
}
let tokens_needed = 1.0 - self.tokens;
(tokens_needed / self.refill_rate).ceil() as u64
}
fn is_stale(&self, ttl: Duration) -> bool {
self.last_access.elapsed() > ttl
}
}
pub struct RateLimiter {
buckets: RwLock<HashMap<IpAddr, TokenBucket>>,
config: RateLimitConfig,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
info!(
"Creating rate limiter: {} req/min, burst: {}",
config.requests_per_minute, config.burst_size
);
Self {
buckets: RwLock::new(HashMap::new()),
config,
}
}
pub fn default_limiter() -> Self {
Self::new(RateLimitConfig::default())
}
pub async fn check_limit(&self, ip: IpAddr) -> Result<(), RateLimitError> {
if !self.config.enabled {
return Ok(());
}
if self.config.whitelist.contains(&ip) {
debug!("IP {} is whitelisted, bypassing rate limit", ip);
return Ok(());
}
let mut buckets = self.buckets.write().await;
let bucket = buckets.entry(ip).or_insert_with(|| {
TokenBucket::new(self.config.burst_size, self.config.tokens_per_second())
});
if bucket.try_consume() {
debug!("Rate limit check passed for {}: {} tokens remaining", ip, bucket.remaining());
Ok(())
} else {
let retry_after = bucket.seconds_until_token();
let reset_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ retry_after;
warn!("Rate limit exceeded for {}: retry after {}s", ip, retry_after);
Err(RateLimitError::RateLimitExceeded {
retry_after,
limit: self.config.requests_per_minute,
remaining: 0,
reset_at,
})
}
}
pub async fn get_remaining(&self, ip: IpAddr) -> u32 {
if !self.config.enabled {
return u32::MAX;
}
if self.config.whitelist.contains(&ip) {
return u32::MAX;
}
let mut buckets = self.buckets.write().await;
if let Some(bucket) = buckets.get_mut(&ip) {
bucket.remaining()
} else {
self.config.burst_size
}
}
pub async fn reset(&self, ip: IpAddr) {
let mut buckets = self.buckets.write().await;
buckets.remove(&ip);
info!("Rate limit reset for {}", ip);
}
pub async fn reset_all(&self) {
let mut buckets = self.buckets.write().await;
buckets.clear();
info!("All rate limits reset");
}
pub async fn get_limit_info(&self, ip: IpAddr) -> RateLimitInfo {
let remaining = self.get_remaining(ip).await;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let reset_at = (now / 60 + 1) * 60;
RateLimitInfo {
limit: self.config.requests_per_minute,
remaining,
reset_at,
}
}
pub fn start_cleanup_task(self: Arc<Self>) {
let limiter = self.clone();
let interval = self.config.cleanup_interval;
let ttl = self.config.bucket_ttl;
tokio::spawn(async move {
info!("Starting rate limiter cleanup task (interval: {:?})", interval);
loop {
tokio::time::sleep(interval).await;
limiter.cleanup_stale_buckets(ttl).await;
}
});
}
async fn cleanup_stale_buckets(&self, ttl: Duration) {
let mut buckets = self.buckets.write().await;
let before_count = buckets.len();
buckets.retain(|ip, bucket| {
let keep = !bucket.is_stale(ttl);
if !keep {
debug!("Removing stale bucket for {}", ip);
}
keep
});
let removed = before_count - buckets.len();
if removed > 0 {
info!("Cleaned up {} stale rate limit buckets", removed);
}
}
pub async fn bucket_count(&self) -> usize {
self.buckets.read().await.len()
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
}
#[derive(Debug, Clone, Copy)]
pub struct RateLimitInfo {
pub limit: u32,
pub remaining: u32,
pub reset_at: u64,
}
#[derive(Clone)]
pub struct RateLimitLayer {
limiter: Arc<RateLimiter>,
}
impl RateLimitLayer {
pub fn new(limiter: Arc<RateLimiter>) -> Self {
Self { limiter }
}
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
limiter: self.limiter.clone(),
}
}
}
#[derive(Clone)]
pub struct RateLimitService<S> {
inner: S,
limiter: Arc<RateLimiter>,
}
impl<S> Service<Request<Body>> for RateLimitService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let limiter = self.limiter.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let ip = extract_client_ip(&req, &limiter.config);
match ip {
Some(ip) => {
match limiter.check_limit(ip).await {
Ok(()) => {
let response = inner.call(req).await?;
let info = limiter.get_limit_info(ip).await;
Ok(add_rate_limit_headers(response, &info))
}
Err(RateLimitError::RateLimitExceeded {
retry_after,
limit,
remaining,
reset_at,
}) => {
Ok(rate_limit_response(retry_after, limit, remaining, reset_at))
}
Err(e) => {
warn!("Rate limiter error: {}", e);
let response = inner.call(req).await?;
Ok(response)
}
}
}
None => {
warn!("Could not extract client IP for rate limiting");
let response = inner.call(req).await?;
Ok(response)
}
}
})
}
}
fn extract_client_ip(req: &Request<Body>, config: &RateLimitConfig) -> Option<IpAddr> {
if config.trust_x_forwarded_for {
if let Some(xff) = req.headers().get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
if let Some(first_ip) = xff_str.split(',').next() {
if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
return Some(ip);
}
}
}
}
}
if config.trust_x_real_ip {
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(ip_str) = real_ip.to_str() {
if let Ok(ip) = ip_str.trim().parse::<IpAddr>() {
return Some(ip);
}
}
}
}
req.extensions()
.get::<std::net::SocketAddr>()
.map(|addr| addr.ip())
}
fn add_rate_limit_headers(mut response: Response<Body>, info: &RateLimitInfo) -> Response<Body> {
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
info.limit.to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Remaining",
info.remaining.to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Reset",
info.reset_at.to_string().parse().unwrap(),
);
response
}
fn rate_limit_response(retry_after: u64, limit: u32, remaining: u32, reset_at: u64) -> Response<Body> {
let body = serde_json::json!({
"error": "rate_limit_exceeded",
"message": format!("Rate limit exceeded. Please retry after {} seconds.", retry_after),
"limit": limit,
"remaining": remaining,
"reset_at": reset_at,
"retry_after": retry_after
});
let mut response = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("Content-Type", "application/json")
.header("X-RateLimit-Limit", limit.to_string())
.header("X-RateLimit-Remaining", remaining.to_string())
.header("X-RateLimit-Reset", reset_at.to_string())
.header("Retry-After", retry_after.to_string())
.body(Body::from(body.to_string()))
.unwrap();
response
}
pub fn rate_limit_layer(config: RateLimitConfig) -> RateLimitLayer {
let limiter = Arc::new(RateLimiter::new(config));
RateLimitLayer::new(limiter)
}
pub fn rate_limit_layer_with_cleanup(config: RateLimitConfig) -> (RateLimitLayer, Arc<RateLimiter>) {
let limiter = Arc::new(RateLimiter::new(config));
limiter.clone().start_cleanup_task();
(RateLimitLayer::new(limiter.clone()), limiter)
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn test_ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
}
#[tokio::test]
async fn test_rate_limiter_allows_requests_under_limit() {
let config = RateLimitConfig::new(100).with_burst_size(10);
let limiter = RateLimiter::new(config);
for i in 0..10 {
let result = limiter.check_limit(test_ip()).await;
assert!(result.is_ok(), "Request {} should be allowed", i);
}
}
#[tokio::test]
async fn test_rate_limiter_blocks_when_exceeded() {
let config = RateLimitConfig::new(60).with_burst_size(5);
let limiter = RateLimiter::new(config);
let ip = test_ip();
for _ in 0..5 {
let _ = limiter.check_limit(ip).await;
}
let result = limiter.check_limit(ip).await;
assert!(result.is_err());
if let Err(RateLimitError::RateLimitExceeded { retry_after, .. }) = result {
assert!(retry_after > 0);
} else {
panic!("Expected RateLimitExceeded error");
}
}
#[tokio::test]
async fn test_rate_limiter_refills_over_time() {
let config = RateLimitConfig::new(600).with_burst_size(5); let limiter = RateLimiter::new(config);
let ip = test_ip();
for _ in 0..5 {
let _ = limiter.check_limit(ip).await;
}
tokio::time::sleep(Duration::from_millis(200)).await;
let result = limiter.check_limit(ip).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_rate_limiter_whitelist() {
let whitelisted_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
let config = RateLimitConfig::new(60)
.with_burst_size(1)
.whitelist_ip(whitelisted_ip);
let limiter = RateLimiter::new(config);
for _ in 0..100 {
let result = limiter.check_limit(whitelisted_ip).await;
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_rate_limiter_disabled() {
let config = RateLimitConfig::new(1).with_burst_size(1).disabled();
let limiter = RateLimiter::new(config);
for _ in 0..100 {
let result = limiter.check_limit(test_ip()).await;
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_get_remaining() {
let config = RateLimitConfig::new(60).with_burst_size(10);
let limiter = RateLimiter::new(config);
let ip = test_ip();
assert_eq!(limiter.get_remaining(ip).await, 10);
limiter.check_limit(ip).await.unwrap();
assert_eq!(limiter.get_remaining(ip).await, 9);
limiter.check_limit(ip).await.unwrap();
assert_eq!(limiter.get_remaining(ip).await, 8);
}
#[tokio::test]
async fn test_reset() {
let config = RateLimitConfig::new(60).with_burst_size(5);
let limiter = RateLimiter::new(config);
let ip = test_ip();
for _ in 0..5 {
let _ = limiter.check_limit(ip).await;
}
assert_eq!(limiter.get_remaining(ip).await, 0);
limiter.reset(ip).await;
assert_eq!(limiter.get_remaining(ip).await, 5);
}
#[tokio::test]
async fn test_multiple_ips() {
let config = RateLimitConfig::new(60).with_burst_size(2);
let limiter = RateLimiter::new(config);
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
limiter.check_limit(ip1).await.unwrap();
limiter.check_limit(ip1).await.unwrap();
assert!(limiter.check_limit(ip1).await.is_err());
assert!(limiter.check_limit(ip2).await.is_ok());
assert!(limiter.check_limit(ip2).await.is_ok());
}
#[tokio::test]
async fn test_cleanup_stale_buckets() {
let config = RateLimitConfig::new(60)
.with_burst_size(5)
.with_bucket_ttl(Duration::from_millis(100));
let limiter = RateLimiter::new(config);
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
limiter.check_limit(ip1).await.unwrap();
limiter.check_limit(ip2).await.unwrap();
assert_eq!(limiter.bucket_count().await, 2);
tokio::time::sleep(Duration::from_millis(150)).await;
limiter.cleanup_stale_buckets(Duration::from_millis(100)).await;
assert_eq!(limiter.bucket_count().await, 0);
}
#[test]
fn test_config_builder() {
let config = RateLimitConfig::new(200)
.with_burst_size(20)
.with_cleanup_interval(Duration::from_secs(120))
.with_bucket_ttl(Duration::from_secs(600))
.trust_forwarded_for()
.trust_real_ip();
assert_eq!(config.requests_per_minute, 200);
assert_eq!(config.burst_size, 20);
assert_eq!(config.cleanup_interval, Duration::from_secs(120));
assert_eq!(config.bucket_ttl, Duration::from_secs(600));
assert!(config.trust_x_forwarded_for);
assert!(config.trust_x_real_ip);
}
#[test]
fn test_tokens_per_second() {
let config = RateLimitConfig::new(60);
assert!((config.tokens_per_second() - 1.0).abs() < 0.001);
let config = RateLimitConfig::new(120);
assert!((config.tokens_per_second() - 2.0).abs() < 0.001);
let config = RateLimitConfig::new(30);
assert!((config.tokens_per_second() - 0.5).abs() < 0.001);
}
}