1use std::net::IpAddr;
9use std::time::Instant;
10
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13
14use crate::{WafDecision, WafRequest};
15
16const MAX_TRACKED_IPS: usize = 100_000;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct DdosConfig {
21 #[serde(default = "default_max_connections")]
23 pub max_connections_per_ip: u32,
24 #[serde(default = "default_max_new_connections")]
26 pub max_new_connections_per_second: u32,
27 #[serde(default = "default_max_body_size")]
29 pub max_request_body_size: usize,
30 #[serde(default = "default_slowloris_timeout")]
32 pub slowloris_timeout_ms: u64,
33 #[serde(default = "default_header_limit")]
35 pub header_count_limit: usize,
36 #[serde(default = "default_header_size_limit")]
38 pub header_size_limit: usize,
39 #[serde(default = "default_throttle_threshold_pct")]
41 pub throttle_threshold_pct: u32,
42}
43
44fn default_max_connections() -> u32 {
45 100
46}
47fn default_max_new_connections() -> u32 {
48 500
49}
50fn default_max_body_size() -> usize {
51 10 * 1024 * 1024 }
53fn default_slowloris_timeout() -> u64 {
54 10_000
55}
56fn default_header_limit() -> usize {
57 100
58}
59fn default_header_size_limit() -> usize {
60 16_384 }
62fn default_throttle_threshold_pct() -> u32 {
63 80
64}
65
66impl Default for DdosConfig {
67 fn default() -> Self {
68 Self {
69 max_connections_per_ip: default_max_connections(),
70 max_new_connections_per_second: default_max_new_connections(),
71 max_request_body_size: default_max_body_size(),
72 slowloris_timeout_ms: default_slowloris_timeout(),
73 header_count_limit: default_header_limit(),
74 header_size_limit: default_header_size_limit(),
75 throttle_threshold_pct: default_throttle_threshold_pct(),
76 }
77 }
78}
79
80struct IpConnectionInfo {
81 count: u32,
82 last_request: Instant,
83}
84
85struct GlobalRateInfo {
86 count: u32,
87 window_start: Instant,
88}
89
90pub struct DdosGuard {
92 config: DdosConfig,
93 connections: DashMap<IpAddr, IpConnectionInfo>,
94 global_rate: parking_lot::Mutex<GlobalRateInfo>,
95}
96
97impl DdosGuard {
98 pub fn new(config: DdosConfig) -> Self {
99 Self {
100 config,
101 connections: DashMap::new(),
102 global_rate: parking_lot::Mutex::new(GlobalRateInfo {
103 count: 0,
104 window_start: Instant::now(),
105 }),
106 }
107 }
108
109 pub fn record_connection(&self, ip: IpAddr) {
111 if !self.connections.contains_key(&ip) && self.connections.len() >= MAX_TRACKED_IPS {
113 self.cleanup(60);
114 if self.connections.len() >= MAX_TRACKED_IPS {
115 tracing::warn!(
116 ip = %ip,
117 tracked = self.connections.len(),
118 "DDoS guard: too many tracked IPs, skipping tracking for this IP"
119 );
120 return;
121 }
122 }
123
124 self.connections
125 .entry(ip)
126 .and_modify(|info| {
127 info.count += 1;
128 info.last_request = Instant::now();
129 })
130 .or_insert(IpConnectionInfo {
131 count: 1,
132 last_request: Instant::now(),
133 });
134 }
135
136 pub fn release_connection(&self, ip: IpAddr) {
138 if let Some(mut info) = self.connections.get_mut(&ip) {
139 info.count = info.count.saturating_sub(1);
140 if info.count == 0 {
141 drop(info);
142 self.connections.remove(&ip);
143 }
144 }
145 }
146
147 pub fn check(&self, req: &WafRequest) -> Option<WafDecision> {
149 if let Some(info) = self.connections.get(&req.client_ip) {
151 if info.count >= self.config.max_connections_per_ip {
152 return Some(WafDecision::Block {
153 status: 429,
154 reason: format!(
155 "too many concurrent connections from {} ({}/{})",
156 req.client_ip, info.count, self.config.max_connections_per_ip
157 ),
158 rule: "ddos_connection_limit".into(),
159 });
160 }
161 }
162
163 {
165 let mut rate = self.global_rate.lock();
166 let now = Instant::now();
167 let elapsed = now.duration_since(rate.window_start);
168 if elapsed.as_secs() >= 1 {
169 rate.count = 1;
171 rate.window_start = now;
172 } else {
173 rate.count += 1;
174 if rate.count > self.config.max_new_connections_per_second {
175 let pct = (rate.count * 100) / self.config.max_new_connections_per_second;
177 if pct >= self.config.throttle_threshold_pct {
178 return Some(WafDecision::RateLimit { retry_after: 1 });
179 }
180 }
181 }
182 }
183
184 if let Some(content_length) = req
186 .headers
187 .iter()
188 .find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
189 .and_then(|(_, v)| v.trim().parse::<usize>().ok())
190 {
191 if content_length > self.config.max_request_body_size {
192 return Some(WafDecision::Block {
193 status: 413,
194 reason: format!(
195 "request body too large ({} bytes via Content-Length, max {})",
196 content_length, self.config.max_request_body_size
197 ),
198 rule: "ddos_body_size".into(),
199 });
200 }
201 }
202
203 if let Some(ref body) = req.body {
205 if body.len() > self.config.max_request_body_size {
206 return Some(WafDecision::Block {
207 status: 413,
208 reason: format!(
209 "request body too large ({} bytes, max {})",
210 body.len(),
211 self.config.max_request_body_size
212 ),
213 rule: "ddos_body_size".into(),
214 });
215 }
216 }
217
218 if req.headers.len() > self.config.header_count_limit {
220 return Some(WafDecision::Block {
221 status: 431,
222 reason: format!(
223 "too many headers ({}, max {})",
224 req.headers.len(),
225 self.config.header_count_limit
226 ),
227 rule: "ddos_header_count".into(),
228 });
229 }
230
231 let total_header_size: usize = req
233 .headers
234 .iter()
235 .map(|(k, v)| k.len() + v.len() + 4) .sum();
237 if total_header_size > self.config.header_size_limit {
238 return Some(WafDecision::Block {
239 status: 431,
240 reason: format!(
241 "headers too large ({total_header_size} bytes, max {})",
242 self.config.header_size_limit
243 ),
244 rule: "ddos_header_size".into(),
245 });
246 }
247
248 None
249 }
250
251 pub fn cleanup(&self, max_age_secs: u64) {
253 let now = Instant::now();
254 self.connections
255 .retain(|_, info| now.duration_since(info.last_request).as_secs() < max_age_secs);
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use std::collections::HashMap;
263
264 fn make_req(ip: &str) -> WafRequest {
265 WafRequest {
266 client_ip: ip.parse().unwrap(),
267 method: "GET".into(),
268 path: "/".into(),
269 query: None,
270 headers: HashMap::new(),
271 body: None,
272 user_agent: Some("Mozilla/5.0".into()),
273 }
274 }
275
276 fn make_req_with_body(ip: &str, body: &str) -> WafRequest {
277 WafRequest {
278 client_ip: ip.parse().unwrap(),
279 method: "POST".into(),
280 path: "/api/data".into(),
281 query: None,
282 headers: HashMap::new(),
283 body: Some(body.into()),
284 user_agent: Some("Mozilla/5.0".into()),
285 }
286 }
287
288 fn make_req_with_headers(ip: &str, headers: Vec<(&str, &str)>) -> WafRequest {
289 WafRequest {
290 client_ip: ip.parse().unwrap(),
291 method: "GET".into(),
292 path: "/".into(),
293 query: None,
294 headers: headers
295 .into_iter()
296 .map(|(k, v)| (k.into(), v.into()))
297 .collect(),
298 body: None,
299 user_agent: Some("Mozilla/5.0".into()),
300 }
301 }
302
303 #[test]
304 fn clean_request_passes() {
305 let guard = DdosGuard::new(DdosConfig::default());
306 let req = make_req("10.0.0.1");
307 assert!(guard.check(&req).is_none());
308 }
309
310 #[test]
311 fn per_ip_connection_limit() {
312 let config = DdosConfig {
313 max_connections_per_ip: 3,
314 ..Default::default()
315 };
316 let guard = DdosGuard::new(config);
317 let ip: IpAddr = "10.0.0.1".parse().unwrap();
318
319 for _ in 0..3 {
321 guard.record_connection(ip);
322 }
323
324 let req = make_req("10.0.0.1");
325 let decision = guard.check(&req);
326 assert!(matches!(
327 decision,
328 Some(WafDecision::Block { status: 429, .. })
329 ));
330 }
331
332 #[test]
333 fn per_ip_limit_released() {
334 let config = DdosConfig {
335 max_connections_per_ip: 2,
336 ..Default::default()
337 };
338 let guard = DdosGuard::new(config);
339 let ip: IpAddr = "10.0.0.1".parse().unwrap();
340
341 guard.record_connection(ip);
342 guard.record_connection(ip);
343 assert!(guard.check(&make_req("10.0.0.1")).is_some());
345
346 guard.release_connection(ip);
348 assert!(guard.check(&make_req("10.0.0.1")).is_none());
349 }
350
351 #[test]
352 fn body_size_limit() {
353 let config = DdosConfig {
354 max_request_body_size: 100,
355 ..Default::default()
356 };
357 let guard = DdosGuard::new(config);
358
359 assert!(guard
361 .check(&make_req_with_body("10.0.0.1", "short"))
362 .is_none());
363
364 let big = "x".repeat(200);
366 let decision = guard.check(&make_req_with_body("10.0.0.1", &big));
367 assert!(matches!(
368 decision,
369 Some(WafDecision::Block { status: 413, .. })
370 ));
371 }
372
373 #[test]
374 fn header_count_limit() {
375 let config = DdosConfig {
376 header_count_limit: 3,
377 ..Default::default()
378 };
379 let guard = DdosGuard::new(config);
380
381 let headers: Vec<(&str, &str)> = (0..5)
382 .map(|i| match i {
383 0 => ("H0", "v0"),
384 1 => ("H1", "v1"),
385 2 => ("H2", "v2"),
386 3 => ("H3", "v3"),
387 _ => ("H4", "v4"),
388 })
389 .collect();
390 let req = make_req_with_headers("10.0.0.1", headers);
391 let decision = guard.check(&req);
392 assert!(matches!(
393 decision,
394 Some(WafDecision::Block { status: 431, .. })
395 ));
396 }
397
398 #[test]
399 fn header_size_limit() {
400 let config = DdosConfig {
401 header_size_limit: 50,
402 ..Default::default()
403 };
404 let guard = DdosGuard::new(config);
405
406 let big_value = "x".repeat(60);
408 let headers = vec![("X-Big", big_value.as_str())];
409 let req = make_req_with_headers("10.0.0.1", headers);
410 let decision = guard.check(&req);
411 assert!(matches!(
412 decision,
413 Some(WafDecision::Block { status: 431, .. })
414 ));
415 }
416
417 #[test]
418 fn different_ips_independent() {
419 let config = DdosConfig {
420 max_connections_per_ip: 2,
421 ..Default::default()
422 };
423 let guard = DdosGuard::new(config);
424
425 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
426 guard.record_connection(ip1);
427 guard.record_connection(ip1);
428 assert!(guard.check(&make_req("10.0.0.1")).is_some());
430 assert!(guard.check(&make_req("10.0.0.2")).is_none());
432 }
433
434 #[test]
435 fn cleanup_removes_stale() {
436 let guard = DdosGuard::new(DdosConfig::default());
437 let ip: IpAddr = "10.0.0.1".parse().unwrap();
438 guard.record_connection(ip);
439
440 assert!(guard.connections.contains_key(&ip));
441 guard.cleanup(0);
443 assert!(!guard.connections.contains_key(&ip));
444 }
445
446 #[test]
447 fn global_rate_within_limit() {
448 let config = DdosConfig {
449 max_new_connections_per_second: 1000,
450 ..Default::default()
451 };
452 let guard = DdosGuard::new(config);
453 let req = make_req("10.0.0.1");
454 assert!(guard.check(&req).is_none());
456 }
457}