use std::net::IpAddr;
use std::time::Instant;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use crate::{WafDecision, WafRequest};
const MAX_TRACKED_IPS: usize = 100_000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DdosConfig {
#[serde(default = "default_max_connections")]
pub max_connections_per_ip: u32,
#[serde(default = "default_max_new_connections")]
pub max_new_connections_per_second: u32,
#[serde(default = "default_max_body_size")]
pub max_request_body_size: usize,
#[serde(default = "default_slowloris_timeout")]
pub slowloris_timeout_ms: u64,
#[serde(default = "default_header_limit")]
pub header_count_limit: usize,
#[serde(default = "default_header_size_limit")]
pub header_size_limit: usize,
#[serde(default = "default_throttle_threshold_pct")]
pub throttle_threshold_pct: u32,
}
fn default_max_connections() -> u32 {
100
}
fn default_max_new_connections() -> u32 {
500
}
fn default_max_body_size() -> usize {
10 * 1024 * 1024 }
fn default_slowloris_timeout() -> u64 {
10_000
}
fn default_header_limit() -> usize {
100
}
fn default_header_size_limit() -> usize {
16_384 }
fn default_throttle_threshold_pct() -> u32 {
80
}
impl Default for DdosConfig {
fn default() -> Self {
Self {
max_connections_per_ip: default_max_connections(),
max_new_connections_per_second: default_max_new_connections(),
max_request_body_size: default_max_body_size(),
slowloris_timeout_ms: default_slowloris_timeout(),
header_count_limit: default_header_limit(),
header_size_limit: default_header_size_limit(),
throttle_threshold_pct: default_throttle_threshold_pct(),
}
}
}
struct IpConnectionInfo {
count: u32,
last_request: Instant,
}
struct GlobalRateInfo {
count: u32,
window_start: Instant,
}
pub struct DdosGuard {
config: DdosConfig,
connections: DashMap<IpAddr, IpConnectionInfo>,
global_rate: parking_lot::Mutex<GlobalRateInfo>,
}
impl DdosGuard {
pub fn new(config: DdosConfig) -> Self {
Self {
config,
connections: DashMap::new(),
global_rate: parking_lot::Mutex::new(GlobalRateInfo {
count: 0,
window_start: Instant::now(),
}),
}
}
pub fn record_connection(&self, ip: IpAddr) {
if !self.connections.contains_key(&ip) && self.connections.len() >= MAX_TRACKED_IPS {
self.cleanup(60);
if self.connections.len() >= MAX_TRACKED_IPS {
tracing::warn!(
ip = %ip,
tracked = self.connections.len(),
"DDoS guard: too many tracked IPs, skipping tracking for this IP"
);
return;
}
}
self.connections
.entry(ip)
.and_modify(|info| {
info.count += 1;
info.last_request = Instant::now();
})
.or_insert(IpConnectionInfo {
count: 1,
last_request: Instant::now(),
});
}
pub fn release_connection(&self, ip: IpAddr) {
if let Some(mut info) = self.connections.get_mut(&ip) {
info.count = info.count.saturating_sub(1);
if info.count == 0 {
drop(info);
self.connections.remove(&ip);
}
}
}
pub fn check(&self, req: &WafRequest) -> Option<WafDecision> {
if let Some(info) = self.connections.get(&req.client_ip) {
if info.count >= self.config.max_connections_per_ip {
return Some(WafDecision::Block {
status: 429,
reason: format!(
"too many concurrent connections from {} ({}/{})",
req.client_ip, info.count, self.config.max_connections_per_ip
),
rule: "ddos_connection_limit".into(),
});
}
}
{
let mut rate = self.global_rate.lock();
let now = Instant::now();
let elapsed = now.duration_since(rate.window_start);
if elapsed.as_secs() >= 1 {
rate.count = 1;
rate.window_start = now;
} else {
rate.count += 1;
if rate.count > self.config.max_new_connections_per_second {
let pct = (rate.count * 100) / self.config.max_new_connections_per_second;
if pct >= self.config.throttle_threshold_pct {
return Some(WafDecision::RateLimit { retry_after: 1 });
}
}
}
}
if let Some(content_length) = req
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
.and_then(|(_, v)| v.trim().parse::<usize>().ok())
{
if content_length > self.config.max_request_body_size {
return Some(WafDecision::Block {
status: 413,
reason: format!(
"request body too large ({} bytes via Content-Length, max {})",
content_length, self.config.max_request_body_size
),
rule: "ddos_body_size".into(),
});
}
}
if let Some(ref body) = req.body {
if body.len() > self.config.max_request_body_size {
return Some(WafDecision::Block {
status: 413,
reason: format!(
"request body too large ({} bytes, max {})",
body.len(),
self.config.max_request_body_size
),
rule: "ddos_body_size".into(),
});
}
}
if req.headers.len() > self.config.header_count_limit {
return Some(WafDecision::Block {
status: 431,
reason: format!(
"too many headers ({}, max {})",
req.headers.len(),
self.config.header_count_limit
),
rule: "ddos_header_count".into(),
});
}
let total_header_size: usize = req
.headers
.iter()
.map(|(k, v)| k.len() + v.len() + 4) .sum();
if total_header_size > self.config.header_size_limit {
return Some(WafDecision::Block {
status: 431,
reason: format!(
"headers too large ({total_header_size} bytes, max {})",
self.config.header_size_limit
),
rule: "ddos_header_size".into(),
});
}
None
}
pub fn cleanup(&self, max_age_secs: u64) {
let now = Instant::now();
self.connections
.retain(|_, info| now.duration_since(info.last_request).as_secs() < max_age_secs);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn make_req(ip: &str) -> WafRequest {
WafRequest {
client_ip: ip.parse().unwrap(),
method: "GET".into(),
path: "/".into(),
query: None,
headers: HashMap::new(),
body: None,
user_agent: Some("Mozilla/5.0".into()),
}
}
fn make_req_with_body(ip: &str, body: &str) -> WafRequest {
WafRequest {
client_ip: ip.parse().unwrap(),
method: "POST".into(),
path: "/api/data".into(),
query: None,
headers: HashMap::new(),
body: Some(body.into()),
user_agent: Some("Mozilla/5.0".into()),
}
}
fn make_req_with_headers(ip: &str, headers: Vec<(&str, &str)>) -> WafRequest {
WafRequest {
client_ip: ip.parse().unwrap(),
method: "GET".into(),
path: "/".into(),
query: None,
headers: headers
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect(),
body: None,
user_agent: Some("Mozilla/5.0".into()),
}
}
#[test]
fn clean_request_passes() {
let guard = DdosGuard::new(DdosConfig::default());
let req = make_req("10.0.0.1");
assert!(guard.check(&req).is_none());
}
#[test]
fn per_ip_connection_limit() {
let config = DdosConfig {
max_connections_per_ip: 3,
..Default::default()
};
let guard = DdosGuard::new(config);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
for _ in 0..3 {
guard.record_connection(ip);
}
let req = make_req("10.0.0.1");
let decision = guard.check(&req);
assert!(matches!(
decision,
Some(WafDecision::Block { status: 429, .. })
));
}
#[test]
fn per_ip_limit_released() {
let config = DdosConfig {
max_connections_per_ip: 2,
..Default::default()
};
let guard = DdosGuard::new(config);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
guard.record_connection(ip);
guard.record_connection(ip);
assert!(guard.check(&make_req("10.0.0.1")).is_some());
guard.release_connection(ip);
assert!(guard.check(&make_req("10.0.0.1")).is_none());
}
#[test]
fn body_size_limit() {
let config = DdosConfig {
max_request_body_size: 100,
..Default::default()
};
let guard = DdosGuard::new(config);
assert!(guard
.check(&make_req_with_body("10.0.0.1", "short"))
.is_none());
let big = "x".repeat(200);
let decision = guard.check(&make_req_with_body("10.0.0.1", &big));
assert!(matches!(
decision,
Some(WafDecision::Block { status: 413, .. })
));
}
#[test]
fn header_count_limit() {
let config = DdosConfig {
header_count_limit: 3,
..Default::default()
};
let guard = DdosGuard::new(config);
let headers: Vec<(&str, &str)> = (0..5)
.map(|i| match i {
0 => ("H0", "v0"),
1 => ("H1", "v1"),
2 => ("H2", "v2"),
3 => ("H3", "v3"),
_ => ("H4", "v4"),
})
.collect();
let req = make_req_with_headers("10.0.0.1", headers);
let decision = guard.check(&req);
assert!(matches!(
decision,
Some(WafDecision::Block { status: 431, .. })
));
}
#[test]
fn header_size_limit() {
let config = DdosConfig {
header_size_limit: 50,
..Default::default()
};
let guard = DdosGuard::new(config);
let big_value = "x".repeat(60);
let headers = vec![("X-Big", big_value.as_str())];
let req = make_req_with_headers("10.0.0.1", headers);
let decision = guard.check(&req);
assert!(matches!(
decision,
Some(WafDecision::Block { status: 431, .. })
));
}
#[test]
fn different_ips_independent() {
let config = DdosConfig {
max_connections_per_ip: 2,
..Default::default()
};
let guard = DdosGuard::new(config);
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
guard.record_connection(ip1);
guard.record_connection(ip1);
assert!(guard.check(&make_req("10.0.0.1")).is_some());
assert!(guard.check(&make_req("10.0.0.2")).is_none());
}
#[test]
fn cleanup_removes_stale() {
let guard = DdosGuard::new(DdosConfig::default());
let ip: IpAddr = "10.0.0.1".parse().unwrap();
guard.record_connection(ip);
assert!(guard.connections.contains_key(&ip));
guard.cleanup(0);
assert!(!guard.connections.contains_key(&ip));
}
#[test]
fn global_rate_within_limit() {
let config = DdosConfig {
max_new_connections_per_second: 1000,
..Default::default()
};
let guard = DdosGuard::new(config);
let req = make_req("10.0.0.1");
assert!(guard.check(&req).is_none());
}
}