use crate::config::{NetworkConfig, RateLimitConfig, RateLimitAlgorithm};
use crate::error::{ServerError, ServerResult};
use crate::auth::TokenClaims;
use axum::{
extract::{Request, State},
http::{header, HeaderValue, StatusCode},
middleware::Next,
response::Response,
};
use dashmap::DashMap;
use serde::Serialize;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tower_http::{
cors::CorsLayer,
timeout::TimeoutLayer,
trace::TraceLayer,
};
use tower_http::classify::{SharedClassifier, ServerErrorsAsFailures};
use tracing::{info, warn, error};
#[derive(Clone)]
pub struct AdvancedRateLimiter {
config: RateLimitConfig,
algorithm: RateLimitAlgorithm,
token_bucket_storage: Arc<DashMap<String, TokenBucketState>>,
sliding_window_storage: Arc<DashMap<String, SlidingWindowState>>,
fixed_window_storage: Arc<DashMap<String, FixedWindowState>>,
leaky_bucket_storage: Arc<DashMap<String, LeakyBucketState>>,
ddos_protection: Arc<DdosProtection>,
metrics: Arc<RateLimitMetrics>,
}
#[derive(Clone, Debug)]
struct TokenBucketState {
tokens: u32,
last_refill: Instant,
burst_capacity: u32,
}
#[derive(Clone, Debug)]
struct SlidingWindowState {
request_timestamps: Vec<Instant>,
window_size: Duration,
}
#[derive(Clone, Debug)]
struct FixedWindowState {
request_count: u32,
window_start: Instant,
window_duration: Duration,
}
#[derive(Clone, Debug)]
struct LeakyBucketState {
queue_size: u32,
last_leak: Instant,
max_queue_size: u32,
leak_rate: u32,
}
#[derive(Clone)]
struct DdosProtection {
suspicious_ips: Arc<DashMap<String, SuspiciousIpInfo>>,
global_requests: Arc<DashMap<Instant, u32>>,
config: DdosConfig,
}
#[derive(Clone, Debug)]
struct SuspiciousIpInfo {
requests_per_minute: u32,
failed_requests: u32,
last_activity: Instant,
reputation_score: u8,
blocked: bool,
block_expiry: Option<Instant>,
}
#[derive(Debug, Clone)]
struct DdosConfig {
enabled: bool,
global_rps_threshold: u32,
ip_rps_threshold: u32,
auto_block_threshold: u32,
block_duration: Duration,
reputation_decay_rate: u8,
}
#[derive(Clone, Debug)]
struct RateLimitMetrics {
total_requests: Arc<std::sync::atomic::AtomicU64>,
allowed_requests: Arc<std::sync::atomic::AtomicU64>,
blocked_requests: Arc<std::sync::atomic::AtomicU64>,
ddos_blocks: Arc<std::sync::atomic::AtomicU64>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RateLimitMetricsSnapshot {
pub total_requests: u64,
pub allowed_requests: u64,
pub blocked_requests: u64,
pub ddos_blocks: u64,
}
impl RateLimitMetrics {
fn new() -> Self {
Self {
total_requests: Arc::new(std::sync::atomic::AtomicU64::new(0)),
allowed_requests: Arc::new(std::sync::atomic::AtomicU64::new(0)),
blocked_requests: Arc::new(std::sync::atomic::AtomicU64::new(0)),
ddos_blocks: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
}
impl DdosProtection {
fn new(config: DdosConfig) -> Self {
Self {
suspicious_ips: Arc::new(DashMap::new()),
global_requests: Arc::new(DashMap::new()),
config,
}
}
async fn check_ddos_protection(&self, client_id: &str) -> ServerResult<()> {
if !self.config.enabled {
return Ok(());
}
let now = Instant::now();
let ip = self.extract_ip_from_client_id(client_id);
if let Some(ip_info) = self.suspicious_ips.get(&ip) {
if ip_info.blocked {
if let Some(expiry) = ip_info.block_expiry {
if now < expiry {
warn!(
ip = %ip,
expiry = ?expiry,
"Blocked IP attempted request"
);
return Err(ServerError::DdosBlocked);
} else {
let mut info = ip_info.clone();
info.blocked = false;
info.block_expiry = None;
self.suspicious_ips.insert(ip.clone(), info);
}
}
}
}
self.cleanup_global_requests(now);
let current_rps = self.global_requests.iter().map(|entry| *entry.value()).sum::<u32>();
if current_rps > self.config.global_rps_threshold {
warn!(
current_rps = current_rps,
threshold = self.config.global_rps_threshold,
"Global request rate threshold exceeded"
);
return Err(ServerError::DdosBlocked);
}
let mut ip_info = self.suspicious_ips
.entry(ip.clone())
.or_insert_with(|| SuspiciousIpInfo {
requests_per_minute: 0,
failed_requests: 0,
last_activity: now,
reputation_score: 100, blocked: false,
block_expiry: None,
});
ip_info.requests_per_minute += 1;
ip_info.last_activity = now;
if ip_info.requests_per_minute > self.config.ip_rps_threshold * 60 {
warn!(
ip = %ip,
requests_per_minute = ip_info.requests_per_minute,
threshold = self.config.ip_rps_threshold * 60,
"IP request rate threshold exceeded"
);
ip_info.reputation_score = ip_info.reputation_score.saturating_sub(20);
if ip_info.requests_per_minute > self.config.auto_block_threshold {
ip_info.blocked = true;
ip_info.block_expiry = Some(now + self.config.block_duration);
warn!(
ip = %ip,
duration = ?self.config.block_duration,
"IP auto-blocked due to excessive requests"
);
return Err(ServerError::DdosBlocked);
}
}
if ip_info.reputation_score < 30 {
if ip_info.requests_per_minute > (self.config.ip_rps_threshold * 30) {
warn!(
ip = %ip,
reputation = ip_info.reputation_score,
"Low reputation IP exceeded strict rate limit"
);
return Err(ServerError::RateLimit);
}
}
let hours_since_last_activity = now.duration_since(ip_info.last_activity).as_secs() / 3600;
if hours_since_last_activity > 0 {
ip_info.reputation_score = (ip_info.reputation_score +
(hours_since_last_activity as u8 * self.config.reputation_decay_rate)).min(100);
}
if now.duration_since(ip_info.last_activity) >= Duration::from_secs(60) {
ip_info.requests_per_minute = 0;
}
Ok(())
}
async fn record_failed_request(&self, client_id: &str) {
let ip = self.extract_ip_from_client_id(client_id);
if let Some(mut ip_info) = self.suspicious_ips.get_mut(&ip) {
ip_info.failed_requests += 1;
ip_info.reputation_score = ip_info.reputation_score.saturating_sub(5);
if ip_info.failed_requests > 10 {
ip_info.blocked = true;
ip_info.block_expiry = Some(Instant::now() + self.config.block_duration);
warn!(
ip = %ip,
failed_requests = ip_info.failed_requests,
"IP blocked due to excessive failed requests"
);
}
}
}
async fn cleanup(&self) {
let now = Instant::now();
let cleanup_threshold = Duration::from_secs(3600);
self.suspicious_ips.retain(|_, info| {
now.duration_since(info.last_activity) < cleanup_threshold
});
self.cleanup_global_requests(now);
}
fn cleanup_global_requests(&self, now: Instant) {
let cutoff = now - Duration::from_secs(60); self.global_requests.retain(|×tamp, _| timestamp > cutoff);
}
fn extract_ip_from_client_id(&self, client_id: &str) -> String {
if client_id.starts_with("ip:") {
client_id[3..].to_string()
} else if client_id.starts_with("user:") {
"unknown_user".to_string()
} else {
client_id.to_string()
}
}
}
fn track_global_request(ddos_protection: &Arc<DdosProtection>) {
let now = Instant::now();
let mut counter = ddos_protection.global_requests.entry(now).or_insert(0);
*counter += 1;
}
impl AdvancedRateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
let ddos_config = DdosConfig {
enabled: config.ddos_protection.enabled,
global_rps_threshold: config.ddos_protection.global_rps_threshold
.unwrap_or_else(|| config.requests_per_minute / 60),
ip_rps_threshold: config.ddos_protection.ip_rps_threshold
.unwrap_or_else(|| config.requests_per_minute / 600), auto_block_threshold: config.ddos_protection.auto_block_threshold
.unwrap_or_else(|| config.requests_per_minute * 2),
block_duration: Duration::from_secs(config.ddos_protection.block_duration_seconds),
reputation_decay_rate: config.ddos_protection.reputation_decay_rate,
};
Self {
algorithm: config.algorithm.clone(),
config,
token_bucket_storage: Arc::new(DashMap::new()),
sliding_window_storage: Arc::new(DashMap::new()),
fixed_window_storage: Arc::new(DashMap::new()),
leaky_bucket_storage: Arc::new(DashMap::new()),
ddos_protection: Arc::new(DdosProtection::new(ddos_config)),
metrics: Arc::new(RateLimitMetrics::new()),
}
}
pub fn with_algorithm(config: RateLimitConfig, algorithm: RateLimitAlgorithm) -> Self {
let mut limiter = Self::new(config);
limiter.algorithm = algorithm;
limiter
}
pub async fn check_rate_limit(&self, client_id: &str) -> ServerResult<()> {
if !self.config.enabled {
return Ok(());
}
self.metrics.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
track_global_request(&self.ddos_protection);
if let Err(e) = self.ddos_protection.check_ddos_protection(client_id).await {
self.metrics.ddos_blocks.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
return Err(e);
}
let result = match self.algorithm {
RateLimitAlgorithm::TokenBucket => self.check_token_bucket(client_id).await,
RateLimitAlgorithm::SlidingWindow => self.check_sliding_window(client_id).await,
RateLimitAlgorithm::FixedWindow => self.check_fixed_window(client_id).await,
RateLimitAlgorithm::LeakyBucket => self.check_leaky_bucket(client_id).await,
};
match result {
Ok(()) => {
self.metrics.allowed_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(())
}
Err(e) => {
self.metrics.blocked_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.ddos_protection.record_failed_request(client_id).await;
Err(e)
}
}
}
async fn check_token_bucket(&self, client_id: &str) -> ServerResult<()> {
let now = Instant::now();
let mut state = self.token_bucket_storage
.entry(client_id.to_string())
.or_insert_with(|| TokenBucketState {
tokens: self.config.burst_size,
last_refill: now,
burst_capacity: self.config.burst_size,
});
let time_elapsed = now.duration_since(state.last_refill);
let tokens_to_add = (time_elapsed.as_secs() as u32 * self.config.requests_per_minute) / 60;
if tokens_to_add > 0 {
state.tokens = (state.tokens + tokens_to_add).min(state.burst_capacity);
state.last_refill = now;
}
if state.tokens == 0 {
return Err(ServerError::RateLimit);
}
state.tokens -= 1;
Ok(())
}
async fn check_sliding_window(&self, client_id: &str) -> ServerResult<()> {
let now = Instant::now();
let window_size = Duration::from_secs(60);
let mut state = self.sliding_window_storage
.entry(client_id.to_string())
.or_insert_with(|| SlidingWindowState {
request_timestamps: Vec::new(),
window_size,
});
state.request_timestamps.retain(|×tamp| {
now.duration_since(timestamp) < window_size
});
if state.request_timestamps.len() as u32 >= self.config.requests_per_minute {
return Err(ServerError::RateLimit);
}
state.request_timestamps.push(now);
Ok(())
}
async fn check_fixed_window(&self, client_id: &str) -> ServerResult<()> {
let now = Instant::now();
let window_duration = Duration::from_secs(60);
let mut state = self.fixed_window_storage
.entry(client_id.to_string())
.or_insert_with(|| FixedWindowState {
request_count: 0,
window_start: now,
window_duration,
});
if now.duration_since(state.window_start) >= window_duration {
state.request_count = 0;
state.window_start = now;
}
if state.request_count >= self.config.requests_per_minute {
return Err(ServerError::RateLimit);
}
state.request_count += 1;
Ok(())
}
async fn check_leaky_bucket(&self, client_id: &str) -> ServerResult<()> {
let now = Instant::now();
let leak_rate = self.config.requests_per_minute / 60;
let max_queue_size = self.config.burst_size;
let mut state = self.leaky_bucket_storage
.entry(client_id.to_string())
.or_insert_with(|| LeakyBucketState {
queue_size: 0,
last_leak: now,
max_queue_size,
leak_rate,
});
let time_elapsed = now.duration_since(state.last_leak);
let tokens_to_leak = (time_elapsed.as_secs() as u32 * leak_rate).min(state.queue_size);
if tokens_to_leak > 0 {
state.queue_size -= tokens_to_leak;
state.last_leak = now;
}
if state.queue_size >= max_queue_size {
return Err(ServerError::RateLimit);
}
state.queue_size += 1;
Ok(())
}
pub async fn get_rate_limit_headers(&self, client_id: &str) -> Option<(HeaderValue, HeaderValue, HeaderValue)> {
if !self.config.enabled {
return None;
}
let remaining = match self.algorithm {
RateLimitAlgorithm::TokenBucket => {
if let Some(state) = self.token_bucket_storage.get(client_id) {
state.tokens
} else {
self.config.burst_size
}
}
RateLimitAlgorithm::SlidingWindow => {
if let Some(state) = self.sliding_window_storage.get(client_id) {
let now = Instant::now();
let count = state.request_timestamps.iter()
.filter(|&×tamp| now.duration_since(timestamp) < state.window_size)
.count() as u32;
self.config.requests_per_minute.saturating_sub(count)
} else {
self.config.requests_per_minute
}
}
RateLimitAlgorithm::FixedWindow => {
if let Some(state) = self.fixed_window_storage.get(client_id) {
self.config.requests_per_minute.saturating_sub(state.request_count)
} else {
self.config.requests_per_minute
}
}
RateLimitAlgorithm::LeakyBucket => {
if let Some(state) = self.leaky_bucket_storage.get(client_id) {
state.max_queue_size.saturating_sub(state.queue_size)
} else {
self.config.burst_size
}
}
};
let reset_timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() + 60;
Some((
HeaderValue::from_str(&remaining.to_string()).ok()?,
HeaderValue::from_str(&self.config.requests_per_minute.to_string()).ok()?,
HeaderValue::from_str(&reset_timestamp.to_string()).ok()?,
))
}
pub fn get_metrics(&self) -> RateLimitMetricsSnapshot {
RateLimitMetricsSnapshot {
total_requests: self.metrics.total_requests.load(std::sync::atomic::Ordering::Relaxed),
allowed_requests: self.metrics.allowed_requests.load(std::sync::atomic::Ordering::Relaxed),
blocked_requests: self.metrics.blocked_requests.load(std::sync::atomic::Ordering::Relaxed),
ddos_blocks: self.metrics.ddos_blocks.load(std::sync::atomic::Ordering::Relaxed),
}
}
pub async fn cleanup(&self) {
let now = Instant::now();
let cleanup_threshold = Duration::from_secs(300);
self.token_bucket_storage.retain(|_, state| {
now.duration_since(state.last_refill) < cleanup_threshold
});
self.sliding_window_storage.retain(|_, state| {
!state.request_timestamps.is_empty() &&
now.duration_since(*state.request_timestamps.last().unwrap()) < cleanup_threshold
});
self.fixed_window_storage.retain(|_, state| {
now.duration_since(state.window_start) < cleanup_threshold
});
self.leaky_bucket_storage.retain(|_, state| {
now.duration_since(state.last_leak) < cleanup_threshold
});
self.ddos_protection.cleanup().await;
}
}
pub async fn advanced_rate_limit_middleware(
State(rate_limiter): State<Arc<AdvancedRateLimiter>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let client_id = extract_client_id(&request);
match rate_limiter.check_rate_limit(&client_id).await {
Ok(()) => {
let mut response = next.run(request).await;
if let Some((remaining, limit, reset)) = rate_limiter.get_rate_limit_headers(&client_id).await {
response.headers_mut().insert("X-RateLimit-Remaining", remaining);
response.headers_mut().insert("X-RateLimit-Limit", limit);
response.headers_mut().insert("X-RateLimit-Reset", reset);
}
Ok(response)
}
Err(ServerError::RateLimit) => {
warn!(
client_id = %client_id,
"Rate limit exceeded"
);
Err(StatusCode::TOO_MANY_REQUESTS)
}
Err(ServerError::DdosBlocked) => {
warn!(
client_id = %client_id,
"Request blocked by DDoS protection"
);
Err(StatusCode::TOO_MANY_REQUESTS)
}
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
pub async fn rate_limit_middleware(
State(rate_limiter): State<Arc<AdvancedRateLimiter>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
advanced_rate_limit_middleware(State(rate_limiter), request, next).await
}
pub async fn request_logging_middleware(
request: Request,
next: Next,
) -> Response {
let start = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let user_agent = request
.headers()
.get(header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown")
.to_string();
let user_id = request
.extensions()
.get::<TokenClaims>()
.map(|claims| claims.sub.clone());
let response = next.run(request).await;
let duration = start.elapsed();
let status = response.status();
if status.is_server_error() {
error!(
method = %method,
uri = %uri,
status = %status,
duration_ms = duration.as_millis(),
user_agent = %user_agent,
user_id = ?user_id,
"Request completed with server error"
);
} else if status.is_client_error() {
warn!(
method = %method,
uri = %uri,
status = %status,
duration_ms = duration.as_millis(),
user_agent = %user_agent,
user_id = ?user_id,
"Request completed with client error"
);
} else {
info!(
method = %method,
uri = %uri,
status = %status,
duration_ms = duration.as_millis(),
user_agent = %user_agent,
user_id = ?user_id,
"Request completed successfully"
);
}
response
}
pub async fn request_size_middleware(
State(network_config): State<NetworkConfig>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
if let Some(content_length) = request.headers().get(header::CONTENT_LENGTH) {
if let Ok(length) = content_length.to_str() {
if let Ok(size) = length.parse::<usize>() {
if size > network_config.max_body_size {
warn!(
request_size = size,
max_size = network_config.max_body_size,
"Request payload too large"
);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
}
}
}
Ok(next.run(request).await)
}
pub async fn tenant_isolation_middleware(
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let tenant_id = request
.extensions()
.get::<TokenClaims>()
.and_then(|claims| claims.tenant_id.clone())
.or_else(|| {
request
.headers()
.get("X-Tenant-ID")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
});
if let Some(tenant_id) = tenant_id {
let mut request = request;
request.extensions_mut().insert(tenant_id);
return Ok(next.run(request).await);
}
Ok(next.run(request).await)
}
pub async fn security_headers_middleware(
request: Request,
next: Next,
) -> Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert("X-Content-Type-Options", HeaderValue::from_static("nosniff"));
headers.insert("X-Frame-Options", HeaderValue::from_static("DENY"));
headers.insert("X-XSS-Protection", HeaderValue::from_static("1; mode=block"));
headers.insert("Strict-Transport-Security", HeaderValue::from_static("max-age=31536000; includeSubDomains"));
headers.insert("Referrer-Policy", HeaderValue::from_static("strict-origin-when-cross-origin"));
headers.insert("Content-Security-Policy", HeaderValue::from_static("default-src 'self'"));
response
}
pub fn create_cors_layer(config: &crate::config::CorsConfig) -> CorsLayer {
use tower_http::cors::CorsLayer;
use http::{HeaderName, Method};
let origins: Vec<_> = config.allowed_origins.iter()
.filter_map(|s| s.as_str().parse().ok())
.collect();
let methods: Vec<Method> = config.allowed_methods.iter()
.filter_map(|s| s.as_str().parse().ok())
.collect();
let headers: Vec<HeaderName> = config.allowed_headers.iter()
.filter_map(|s| s.as_str().parse().ok())
.collect();
let mut cors_layer = CorsLayer::new()
.allow_origin(origins)
.allow_methods(methods)
.allow_headers(headers);
if config.allow_credentials {
cors_layer = cors_layer.allow_credentials(true);
}
cors_layer
}
pub fn create_timeout_layer(timeout_seconds: u64) -> TimeoutLayer {
TimeoutLayer::new(Duration::from_secs(timeout_seconds))
}
pub fn create_trace_layer() -> TraceLayer<SharedClassifier<ServerErrorsAsFailures>> {
TraceLayer::new_for_http()
}
fn extract_client_id(request: &Request) -> String {
if let Some(claims) = request.extensions().get::<TokenClaims>() {
return format!("user:{}", claims.sub);
}
request
.headers()
.get("X-Forwarded-For")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| format!("ip:{}", s.trim()))
.or_else(|| {
request
.headers()
.get("X-Real-IP")
.and_then(|h| h.to_str().ok())
.map(|s| format!("ip:{}", s))
})
.unwrap_or_else(|| "unknown".to_string())
}
pub struct MiddlewareStack;
impl MiddlewareStack {
pub fn new() -> Self {
Self
}
}
impl Default for MiddlewareStack {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Method;
use crate::config::DdosProtectionConfig;
#[test]
fn test_client_id_extraction() {
let claims = TokenClaims {
sub: "user123".to_string(),
username: "testuser".to_string(),
email: None,
roles: vec![],
tenant_id: None,
iat: 0,
exp: 0,
jti: "test".to_string(),
};
let mut request = Request::builder()
.method(Method::GET)
.uri("/test")
.body(Empty::new())
.unwrap();
request.extensions_mut().insert(claims);
let client_id = extract_client_id(&request);
assert_eq!(client_id, "user:user123");
}
#[tokio::test]
async fn test_rate_limiter() {
let config = RateLimitConfig {
enabled: true,
requests_per_minute: 2,
requests_per_hour: 10,
burst_size: 2,
algorithm: RateLimitAlgorithm::TokenBucket,
ddos_protection: DdosProtectionConfig::default(),
};
let rate_limiter = AdvancedRateLimiter::new(config);
rate_limiter.check_rate_limit("client1").await.unwrap();
rate_limiter.check_rate_limit("client1").await.unwrap();
assert!(matches!(
rate_limiter.check_rate_limit("client1").await,
Err(ServerError::RateLimit)
));
}
#[test]
fn test_cors_layer_creation() {
let config = crate::config::CorsConfig::default();
let _cors_layer = create_cors_layer(&config);
assert!(true);
}
}