use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use axum::http::{HeaderName, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::Json;
use dashmap::DashMap;
use parking_lot::Mutex;
use serde::Serialize;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub enabled: bool,
pub per_ip_requests: u32,
pub per_ip_burst: u32,
pub per_tenant_requests: u32,
pub per_tenant_burst: u32,
pub auth_fail_limit: u32,
pub auth_fail_ban_duration: Duration,
pub trust_proxy_headers: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enabled: true,
per_ip_requests: 1000, per_ip_burst: 100, per_tenant_requests: 10000, per_tenant_burst: 1000, auth_fail_limit: 10, auth_fail_ban_duration: Duration::from_secs(300), trust_proxy_headers: false, }
}
}
impl RateLimitConfig {
pub fn builder() -> RateLimitConfigBuilder {
RateLimitConfigBuilder::default()
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn strict() -> Self {
Self {
enabled: true,
per_ip_requests: 100,
per_ip_burst: 20,
per_tenant_requests: 1000,
per_tenant_burst: 100,
auth_fail_limit: 5,
auth_fail_ban_duration: Duration::from_secs(600), trust_proxy_headers: false, }
}
pub fn from_file_config(file_config: &crate::config::RateLimitConfigFile) -> Option<Self> {
if !file_config.enabled {
return Some(Self::disabled());
}
Some(Self {
enabled: true,
per_ip_requests: file_config.requests_per_second * 60, per_ip_burst: file_config.burst_size,
per_tenant_requests: file_config.requests_per_second * 60 * 10, per_tenant_burst: file_config.burst_size * 10,
auth_fail_limit: file_config.max_auth_failures,
auth_fail_ban_duration: Duration::from_secs(file_config.lockout_duration_seconds),
trust_proxy_headers: file_config.trust_proxy_headers,
})
}
}
#[derive(Default)]
pub struct RateLimitConfigBuilder {
config: RateLimitConfig,
}
impl RateLimitConfigBuilder {
pub fn enabled(mut self, enabled: bool) -> Self {
self.config.enabled = enabled;
self
}
pub fn per_ip_requests(mut self, requests: u32) -> Self {
self.config.per_ip_requests = requests;
self
}
pub fn per_ip_burst(mut self, burst: u32) -> Self {
self.config.per_ip_burst = burst;
self
}
pub fn per_tenant_requests(mut self, requests: u32) -> Self {
self.config.per_tenant_requests = requests;
self
}
pub fn per_tenant_burst(mut self, burst: u32) -> Self {
self.config.per_tenant_burst = burst;
self
}
pub fn auth_fail_limit(mut self, limit: u32) -> Self {
self.config.auth_fail_limit = limit;
self
}
pub fn auth_fail_ban_duration(mut self, duration: Duration) -> Self {
self.config.auth_fail_ban_duration = duration;
self
}
pub fn trust_proxy_headers(mut self, trust: bool) -> Self {
self.config.trust_proxy_headers = trust;
self
}
pub fn build(self) -> RateLimitConfig {
self.config
}
}
#[derive(Debug)]
struct TokenBucket {
capacity: u32,
tokens: f64,
refill_rate: f64,
last_refill: std::time::Instant,
}
impl TokenBucket {
fn new(capacity: u32, requests_per_minute: u32) -> Self {
Self {
capacity,
tokens: capacity as f64,
refill_rate: requests_per_minute as f64 / 60.0,
last_refill: std::time::Instant::now(),
}
}
fn try_acquire(&mut self) -> bool {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
fn refill(&mut self) {
let now = std::time::Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity as f64);
self.last_refill = now;
}
fn remaining(&self) -> u32 {
self.tokens as u32
}
fn retry_after(&self) -> u64 {
if self.tokens >= 1.0 {
0
} else {
let needed = 1.0 - self.tokens;
(needed / self.refill_rate).ceil() as u64
}
}
}
#[derive(Debug)]
struct AuthFailureEntry {
failures: u32,
first_failure: std::time::Instant,
ban_until: Option<std::time::Instant>,
}
impl AuthFailureEntry {
fn new() -> Self {
Self {
failures: 0,
first_failure: std::time::Instant::now(),
ban_until: None,
}
}
fn is_banned(&self) -> bool {
self.ban_until
.map(|until| std::time::Instant::now() < until)
.unwrap_or(false)
}
fn ban_remaining_secs(&self) -> Option<u64> {
self.ban_until.and_then(|until| {
let now = std::time::Instant::now();
if now < until {
Some((until - now).as_secs())
} else {
None
}
})
}
}
pub struct RateLimiter {
config: RateLimitConfig,
ip_limiters: DashMap<IpAddr, Mutex<TokenBucket>>,
tenant_limiters: DashMap<String, Mutex<TokenBucket>>,
auth_failures: DashMap<IpAddr, Mutex<AuthFailureEntry>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
ip_limiters: DashMap::new(),
tenant_limiters: DashMap::new(),
auth_failures: DashMap::new(),
}
}
pub fn default_limiter() -> Arc<Self> {
Arc::new(Self::new(RateLimitConfig::default()))
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn trust_proxy_headers(&self) -> bool {
self.config.trust_proxy_headers
}
pub fn is_ip_banned(&self, ip: &IpAddr) -> bool {
self.auth_failures
.get(ip)
.map(|entry| entry.lock().is_banned())
.unwrap_or(false)
}
pub fn ip_ban_remaining(&self, ip: &IpAddr) -> Option<u64> {
self.auth_failures
.get(ip)
.and_then(|entry| entry.lock().ban_remaining_secs())
}
pub fn record_auth_failure(&self, ip: &IpAddr) {
if !self.config.enabled {
return;
}
let entry = self
.auth_failures
.entry(*ip)
.or_insert_with(|| Mutex::new(AuthFailureEntry::new()));
let mut entry = entry.lock();
if entry.first_failure.elapsed() > Duration::from_secs(3600) {
*entry = AuthFailureEntry::new();
}
entry.failures = entry.failures.saturating_add(1);
if entry.failures >= self.config.auth_fail_limit {
entry.ban_until = Some(std::time::Instant::now() + self.config.auth_fail_ban_duration);
tracing::warn!(
ip = %ip,
failures = entry.failures,
ban_duration_secs = self.config.auth_fail_ban_duration.as_secs(),
"IP banned due to excessive auth failures"
);
}
}
pub fn record_auth_success(&self, ip: &IpAddr) {
if !self.config.enabled {
return;
}
self.auth_failures.remove(ip);
}
pub fn check_ip_limit(&self, ip: &IpAddr) -> Result<RateLimitInfo, RateLimitExceeded> {
if !self.config.enabled {
return Ok(RateLimitInfo::unlimited());
}
if self.is_ip_banned(ip) {
let retry_after = self.ip_ban_remaining(ip).unwrap_or(60);
return Err(RateLimitExceeded {
limit: self.config.per_ip_requests,
remaining: 0,
retry_after,
limit_type: LimitType::IpBanned,
});
}
let entry = self.ip_limiters.entry(*ip).or_insert_with(|| {
Mutex::new(TokenBucket::new(
self.config.per_ip_burst,
self.config.per_ip_requests,
))
});
let mut bucket = entry.lock();
if bucket.try_acquire() {
Ok(RateLimitInfo {
limit: self.config.per_ip_requests,
remaining: bucket.remaining(),
reset_secs: 60, limit_type: LimitType::PerIp,
})
} else {
Err(RateLimitExceeded {
limit: self.config.per_ip_requests,
remaining: 0,
retry_after: bucket.retry_after(),
limit_type: LimitType::PerIp,
})
}
}
pub fn check_tenant_limit(&self, tenant_id: &str) -> Result<RateLimitInfo, RateLimitExceeded> {
if !self.config.enabled {
return Ok(RateLimitInfo::unlimited());
}
let entry = self
.tenant_limiters
.entry(tenant_id.to_string())
.or_insert_with(|| {
Mutex::new(TokenBucket::new(
self.config.per_tenant_burst,
self.config.per_tenant_requests,
))
});
let mut bucket = entry.lock();
if bucket.try_acquire() {
Ok(RateLimitInfo {
limit: self.config.per_tenant_requests,
remaining: bucket.remaining(),
reset_secs: 60,
limit_type: LimitType::PerTenant,
})
} else {
Err(RateLimitExceeded {
limit: self.config.per_tenant_requests,
remaining: 0,
retry_after: bucket.retry_after(),
limit_type: LimitType::PerTenant,
})
}
}
pub fn check_limits(
&self,
ip: &IpAddr,
tenant_id: Option<&str>,
) -> Result<RateLimitInfo, RateLimitExceeded> {
let ip_result = self.check_ip_limit(ip)?;
if let Some(tenant_id) = tenant_id {
let tenant_result = self.check_tenant_limit(tenant_id)?;
if tenant_result.remaining < ip_result.remaining {
return Ok(tenant_result);
}
}
Ok(ip_result)
}
pub fn peek_ip_limit(&self, ip: &IpAddr) -> Option<RateLimitInfo> {
if !self.config.enabled {
return Some(RateLimitInfo::unlimited());
}
self.ip_limiters.get(ip).map(|entry| {
let bucket = entry.lock();
RateLimitInfo {
limit: self.config.per_ip_requests,
remaining: bucket.remaining(),
reset_secs: 60,
limit_type: LimitType::PerIp,
}
})
}
pub fn cleanup(&self) {
self.auth_failures.retain(|_, entry| {
let entry = entry.lock();
entry.is_banned() || entry.first_failure.elapsed() < Duration::from_secs(3600)
});
let stale_threshold = Duration::from_secs(300);
self.ip_limiters.retain(|_, bucket| {
let bucket = bucket.lock();
bucket.last_refill.elapsed() < stale_threshold
});
self.tenant_limiters.retain(|_, bucket| {
let bucket = bucket.lock();
bucket.last_refill.elapsed() < stale_threshold
});
tracing::debug!(
ip_limiters = self.ip_limiters.len(),
tenant_limiters = self.tenant_limiters.len(),
auth_failures = self.auth_failures.len(),
"Rate limiter cleanup completed"
);
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct RateLimitInfo {
pub limit: u32,
pub remaining: u32,
pub reset_secs: u64,
pub limit_type: LimitType,
}
impl RateLimitInfo {
fn unlimited() -> Self {
Self {
limit: u32::MAX,
remaining: u32::MAX,
reset_secs: 0,
limit_type: LimitType::None,
}
}
pub fn headers(&self) -> Vec<(HeaderName, HeaderValue)> {
if matches!(self.limit_type, LimitType::None) {
return vec![];
}
vec![
(
HeaderName::from_static("x-ratelimit-limit"),
HeaderValue::from_str(&self.limit.to_string()).unwrap(),
),
(
HeaderName::from_static("x-ratelimit-remaining"),
HeaderValue::from_str(&self.remaining.to_string()).unwrap(),
),
(
HeaderName::from_static("x-ratelimit-reset"),
HeaderValue::from_str(&self.reset_secs.to_string()).unwrap(),
),
]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LimitType {
None,
PerIp,
PerTenant,
IpBanned,
}
#[derive(Debug, Clone)]
pub struct RateLimitExceeded {
pub limit: u32,
pub remaining: u32,
pub retry_after: u64,
pub limit_type: LimitType,
}
impl RateLimitExceeded {
pub fn headers(&self) -> Vec<(HeaderName, HeaderValue)> {
vec![
(
HeaderName::from_static("x-ratelimit-limit"),
HeaderValue::from_str(&self.limit.to_string()).unwrap(),
),
(
HeaderName::from_static("x-ratelimit-remaining"),
HeaderValue::from_str(&self.remaining.to_string()).unwrap(),
),
(
HeaderName::from_static("retry-after"),
HeaderValue::from_str(&self.retry_after.to_string()).unwrap(),
),
]
}
}
#[derive(Debug, Serialize)]
pub struct RateLimitErrorResponse {
pub error: RateLimitErrorBody,
}
#[derive(Debug, Serialize)]
pub struct RateLimitErrorBody {
pub code: u16,
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub retry_after: u64,
}
impl IntoResponse for RateLimitExceeded {
fn into_response(self) -> Response {
let message = match self.limit_type {
LimitType::IpBanned => format!(
"IP temporarily banned due to excessive authentication failures. Retry after {} seconds.",
self.retry_after
),
LimitType::PerIp => format!(
"Rate limit exceeded for IP address. Limit: {} requests/minute. Retry after {} seconds.",
self.limit, self.retry_after
),
LimitType::PerTenant => format!(
"Rate limit exceeded for tenant. Limit: {} requests/minute. Retry after {} seconds.",
self.limit, self.retry_after
),
LimitType::None => "Rate limit exceeded".to_string(),
};
let body = RateLimitErrorResponse {
error: RateLimitErrorBody {
code: StatusCode::TOO_MANY_REQUESTS.as_u16(),
message,
error_type: "RateLimitExceededException".to_string(),
retry_after: self.retry_after,
},
};
let mut response = (StatusCode::TOO_MANY_REQUESTS, Json(body)).into_response();
for (name, value) in self.headers() {
response.headers_mut().insert(name, value);
}
response
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
fn test_ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
}
#[test]
fn test_rate_limiter_disabled() {
let limiter = RateLimiter::new(RateLimitConfig::disabled());
let ip = test_ip();
for _ in 0..10000 {
assert!(limiter.check_ip_limit(&ip).is_ok());
}
}
#[test]
fn test_rate_limiter_basic() {
let config = RateLimitConfig::builder()
.enabled(true)
.per_ip_requests(10)
.per_ip_burst(5)
.build();
let limiter = RateLimiter::new(config);
let ip = test_ip();
for i in 0..5 {
let result = limiter.check_ip_limit(&ip);
assert!(result.is_ok(), "Request {} should succeed", i);
}
let result = limiter.check_ip_limit(&ip);
assert!(result.is_err(), "Request 6 should be rate limited");
}
#[test]
fn test_rate_limit_info_headers() {
let info = RateLimitInfo {
limit: 1000,
remaining: 500,
reset_secs: 60,
limit_type: LimitType::PerIp,
};
let headers = info.headers();
assert_eq!(headers.len(), 3);
}
#[test]
fn test_auth_failure_tracking() {
let config = RateLimitConfig::builder()
.enabled(true)
.auth_fail_limit(3)
.auth_fail_ban_duration(Duration::from_secs(10))
.build();
let limiter = RateLimiter::new(config);
let ip = test_ip();
limiter.record_auth_failure(&ip);
assert!(!limiter.is_ip_banned(&ip));
limiter.record_auth_failure(&ip);
assert!(!limiter.is_ip_banned(&ip));
limiter.record_auth_failure(&ip);
assert!(
limiter.is_ip_banned(&ip),
"IP should be banned after 3 failures"
);
let result = limiter.check_ip_limit(&ip);
assert!(result.is_err());
assert_eq!(result.unwrap_err().limit_type, LimitType::IpBanned);
}
#[test]
fn test_auth_success_resets_failures() {
let config = RateLimitConfig::builder()
.enabled(true)
.auth_fail_limit(5)
.build();
let limiter = RateLimiter::new(config);
let ip = test_ip();
limiter.record_auth_failure(&ip);
limiter.record_auth_failure(&ip);
limiter.record_auth_success(&ip);
assert!(!limiter.is_ip_banned(&ip));
}
#[test]
fn test_tenant_rate_limiting() {
let config = RateLimitConfig::builder()
.enabled(true)
.per_tenant_requests(5)
.per_tenant_burst(3)
.build();
let limiter = RateLimiter::new(config);
let tenant = "test-tenant";
for i in 0..3 {
let result = limiter.check_tenant_limit(tenant);
assert!(result.is_ok(), "Request {} should succeed", i);
}
let result = limiter.check_tenant_limit(tenant);
assert!(result.is_err());
assert_eq!(result.unwrap_err().limit_type, LimitType::PerTenant);
}
#[test]
fn test_combined_limits() {
let config = RateLimitConfig::builder()
.enabled(true)
.per_ip_requests(100)
.per_ip_burst(50)
.per_tenant_requests(10)
.per_tenant_burst(5)
.build();
let limiter = RateLimiter::new(config);
let ip = test_ip();
let tenant = "test-tenant";
for _ in 0..5 {
let result = limiter.check_limits(&ip, Some(tenant));
assert!(result.is_ok());
}
let result = limiter.check_limits(&ip, Some(tenant));
assert!(result.is_err());
assert_eq!(result.unwrap_err().limit_type, LimitType::PerTenant);
}
#[test]
fn test_rate_limit_response() {
let exceeded = RateLimitExceeded {
limit: 1000,
remaining: 0,
retry_after: 60,
limit_type: LimitType::PerIp,
};
let response = exceeded.into_response();
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
assert!(response.headers().contains_key("retry-after"));
assert!(response.headers().contains_key("x-ratelimit-limit"));
}
#[test]
fn test_config_builder() {
let config = RateLimitConfig::builder()
.enabled(true)
.per_ip_requests(500)
.per_ip_burst(50)
.per_tenant_requests(5000)
.per_tenant_burst(500)
.auth_fail_limit(10)
.auth_fail_ban_duration(Duration::from_secs(600))
.build();
assert!(config.enabled);
assert_eq!(config.per_ip_requests, 500);
assert_eq!(config.per_ip_burst, 50);
assert_eq!(config.per_tenant_requests, 5000);
assert_eq!(config.auth_fail_limit, 10);
}
#[test]
fn test_strict_config() {
let config = RateLimitConfig::strict();
assert!(config.enabled);
assert_eq!(config.per_ip_requests, 100);
assert_eq!(config.auth_fail_limit, 5);
}
#[test]
fn test_cleanup_removes_stale_entries() {
let config = RateLimitConfig::builder()
.enabled(true)
.per_ip_requests(1000)
.per_ip_burst(100)
.per_tenant_requests(1000)
.per_tenant_burst(100)
.build();
let limiter = RateLimiter::new(config);
let ip1 = "192.168.1.1".parse().unwrap();
let ip2 = "192.168.1.2".parse().unwrap();
limiter.check_ip_limit(&ip1).unwrap();
limiter.check_ip_limit(&ip2).unwrap();
limiter.check_tenant_limit("tenant1").unwrap();
limiter.check_tenant_limit("tenant2").unwrap();
assert_eq!(limiter.ip_limiters.len(), 2);
assert_eq!(limiter.tenant_limiters.len(), 2);
limiter.cleanup();
assert_eq!(limiter.ip_limiters.len(), 2);
assert_eq!(limiter.tenant_limiters.len(), 2);
}
#[test]
fn test_peek_ip_limit_does_not_consume_token() {
let config = RateLimitConfig::builder()
.enabled(true)
.per_ip_requests(10)
.per_ip_burst(10)
.build();
let limiter = RateLimiter::new(config);
let ip: std::net::IpAddr = "10.0.0.1".parse().unwrap();
let info = limiter.check_ip_limit(&ip).unwrap();
let remaining_after_check = info.remaining;
let peek_info = limiter.peek_ip_limit(&ip).expect("bucket should exist");
assert_eq!(peek_info.remaining, remaining_after_check);
let peek_info2 = limiter.peek_ip_limit(&ip).expect("bucket should exist");
assert_eq!(peek_info2.remaining, remaining_after_check);
}
#[test]
fn test_peek_ip_limit_returns_none_for_unknown_ip() {
let config = RateLimitConfig::builder()
.enabled(true)
.per_ip_requests(10)
.per_ip_burst(10)
.build();
let limiter = RateLimiter::new(config);
let ip: std::net::IpAddr = "10.0.0.99".parse().unwrap();
assert!(limiter.peek_ip_limit(&ip).is_none());
}
}