Skip to main content

bext_waf/
ddos.rs

1//! DDoS mitigation — connection, body size, and header abuse guards.
2//!
3//! Enforces per-IP concurrent connection limits, global new-connection rate
4//! throttling, request body size caps, and header count/size limits.  Tracked
5//! connections are bounded at 100 000 entries to prevent the guard itself from
6//! becoming an OOM vector.
7
8use std::net::IpAddr;
9use std::time::Instant;
10
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13
14use crate::{WafDecision, WafRequest};
15
16const MAX_TRACKED_IPS: usize = 100_000;
17
18/// Configuration for DDoS mitigation.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct DdosConfig {
21    /// Maximum concurrent connections per IP.
22    #[serde(default = "default_max_connections")]
23    pub max_connections_per_ip: u32,
24    /// Maximum new connections per second (global).
25    #[serde(default = "default_max_new_connections")]
26    pub max_new_connections_per_second: u32,
27    /// Maximum request body size in bytes.
28    #[serde(default = "default_max_body_size")]
29    pub max_request_body_size: usize,
30    /// Slowloris timeout in milliseconds (not enforced at WAF layer, but exposed for server config).
31    #[serde(default = "default_slowloris_timeout")]
32    pub slowloris_timeout_ms: u64,
33    /// Maximum number of headers per request.
34    #[serde(default = "default_header_limit")]
35    pub header_count_limit: usize,
36    /// Maximum total header size in bytes.
37    #[serde(default = "default_header_size_limit")]
38    pub header_size_limit: usize,
39    /// Connection rate threshold to start auto-throttling (percentage of max, 0-100).
40    #[serde(default = "default_throttle_threshold_pct")]
41    pub throttle_threshold_pct: u32,
42}
43
44fn default_max_connections() -> u32 {
45    100
46}
47fn default_max_new_connections() -> u32 {
48    500
49}
50fn default_max_body_size() -> usize {
51    10 * 1024 * 1024 // 10 MB
52}
53fn default_slowloris_timeout() -> u64 {
54    10_000
55}
56fn default_header_limit() -> usize {
57    100
58}
59fn default_header_size_limit() -> usize {
60    16_384 // 16 KB
61}
62fn default_throttle_threshold_pct() -> u32 {
63    80
64}
65
66impl Default for DdosConfig {
67    fn default() -> Self {
68        Self {
69            max_connections_per_ip: default_max_connections(),
70            max_new_connections_per_second: default_max_new_connections(),
71            max_request_body_size: default_max_body_size(),
72            slowloris_timeout_ms: default_slowloris_timeout(),
73            header_count_limit: default_header_limit(),
74            header_size_limit: default_header_size_limit(),
75            throttle_threshold_pct: default_throttle_threshold_pct(),
76        }
77    }
78}
79
80struct IpConnectionInfo {
81    count: u32,
82    last_request: Instant,
83}
84
85struct GlobalRateInfo {
86    count: u32,
87    window_start: Instant,
88}
89
90/// DDoS mitigation guard.
91pub struct DdosGuard {
92    config: DdosConfig,
93    connections: DashMap<IpAddr, IpConnectionInfo>,
94    global_rate: parking_lot::Mutex<GlobalRateInfo>,
95}
96
97impl DdosGuard {
98    pub fn new(config: DdosConfig) -> Self {
99        Self {
100            config,
101            connections: DashMap::new(),
102            global_rate: parking_lot::Mutex::new(GlobalRateInfo {
103                count: 0,
104                window_start: Instant::now(),
105            }),
106        }
107    }
108
109    /// Record a new connection from an IP. Call this when a request arrives.
110    pub fn record_connection(&self, ip: IpAddr) {
111        // Guard against unbounded DashMap growth
112        if !self.connections.contains_key(&ip) && self.connections.len() >= MAX_TRACKED_IPS {
113            self.cleanup(60);
114            if self.connections.len() >= MAX_TRACKED_IPS {
115                tracing::warn!(
116                    ip = %ip,
117                    tracked = self.connections.len(),
118                    "DDoS guard: too many tracked IPs, skipping tracking for this IP"
119                );
120                return;
121            }
122        }
123
124        self.connections
125            .entry(ip)
126            .and_modify(|info| {
127                info.count += 1;
128                info.last_request = Instant::now();
129            })
130            .or_insert(IpConnectionInfo {
131                count: 1,
132                last_request: Instant::now(),
133            });
134    }
135
136    /// Release a connection from an IP. Call this when a request completes.
137    pub fn release_connection(&self, ip: IpAddr) {
138        if let Some(mut info) = self.connections.get_mut(&ip) {
139            info.count = info.count.saturating_sub(1);
140            if info.count == 0 {
141                drop(info);
142                self.connections.remove(&ip);
143            }
144        }
145    }
146
147    /// Check a request against DDoS limits.
148    pub fn check(&self, req: &WafRequest) -> Option<WafDecision> {
149        // 1. Check per-IP connection limit.
150        if let Some(info) = self.connections.get(&req.client_ip) {
151            if info.count >= self.config.max_connections_per_ip {
152                return Some(WafDecision::Block {
153                    status: 429,
154                    reason: format!(
155                        "too many concurrent connections from {} ({}/{})",
156                        req.client_ip, info.count, self.config.max_connections_per_ip
157                    ),
158                    rule: "ddos_connection_limit".into(),
159                });
160            }
161        }
162
163        // 2. Check global new-connection rate.
164        {
165            let mut rate = self.global_rate.lock();
166            let now = Instant::now();
167            let elapsed = now.duration_since(rate.window_start);
168            if elapsed.as_secs() >= 1 {
169                // Reset window.
170                rate.count = 1;
171                rate.window_start = now;
172            } else {
173                rate.count += 1;
174                if rate.count > self.config.max_new_connections_per_second {
175                    // Auto-throttle: check if we're above threshold.
176                    let pct = (rate.count * 100) / self.config.max_new_connections_per_second;
177                    if pct >= self.config.throttle_threshold_pct {
178                        return Some(WafDecision::RateLimit { retry_after: 1 });
179                    }
180                }
181            }
182        }
183
184        // 3a. Check Content-Length header against body size limit.
185        if let Some(content_length) = req
186            .headers
187            .iter()
188            .find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
189            .and_then(|(_, v)| v.trim().parse::<usize>().ok())
190        {
191            if content_length > self.config.max_request_body_size {
192                return Some(WafDecision::Block {
193                    status: 413,
194                    reason: format!(
195                        "request body too large ({} bytes via Content-Length, max {})",
196                        content_length, self.config.max_request_body_size
197                    ),
198                    rule: "ddos_body_size".into(),
199                });
200            }
201        }
202
203        // 3b. Check actual body size.
204        if let Some(ref body) = req.body {
205            if body.len() > self.config.max_request_body_size {
206                return Some(WafDecision::Block {
207                    status: 413,
208                    reason: format!(
209                        "request body too large ({} bytes, max {})",
210                        body.len(),
211                        self.config.max_request_body_size
212                    ),
213                    rule: "ddos_body_size".into(),
214                });
215            }
216        }
217
218        // 4. Check header count.
219        if req.headers.len() > self.config.header_count_limit {
220            return Some(WafDecision::Block {
221                status: 431,
222                reason: format!(
223                    "too many headers ({}, max {})",
224                    req.headers.len(),
225                    self.config.header_count_limit
226                ),
227                rule: "ddos_header_count".into(),
228            });
229        }
230
231        // 5. Check total header size.
232        let total_header_size: usize = req
233            .headers
234            .iter()
235            .map(|(k, v)| k.len() + v.len() + 4) // +4 for ": " and "\r\n"
236            .sum();
237        if total_header_size > self.config.header_size_limit {
238            return Some(WafDecision::Block {
239                status: 431,
240                reason: format!(
241                    "headers too large ({total_header_size} bytes, max {})",
242                    self.config.header_size_limit
243                ),
244                rule: "ddos_header_size".into(),
245            });
246        }
247
248        None
249    }
250
251    /// Clean up stale connection entries (connections older than `max_age_secs`).
252    pub fn cleanup(&self, max_age_secs: u64) {
253        let now = Instant::now();
254        self.connections
255            .retain(|_, info| now.duration_since(info.last_request).as_secs() < max_age_secs);
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use std::collections::HashMap;
263
264    fn make_req(ip: &str) -> WafRequest {
265        WafRequest {
266            client_ip: ip.parse().unwrap(),
267            method: "GET".into(),
268            path: "/".into(),
269            query: None,
270            headers: HashMap::new(),
271            body: None,
272            user_agent: Some("Mozilla/5.0".into()),
273        }
274    }
275
276    fn make_req_with_body(ip: &str, body: &str) -> WafRequest {
277        WafRequest {
278            client_ip: ip.parse().unwrap(),
279            method: "POST".into(),
280            path: "/api/data".into(),
281            query: None,
282            headers: HashMap::new(),
283            body: Some(body.into()),
284            user_agent: Some("Mozilla/5.0".into()),
285        }
286    }
287
288    fn make_req_with_headers(ip: &str, headers: Vec<(&str, &str)>) -> WafRequest {
289        WafRequest {
290            client_ip: ip.parse().unwrap(),
291            method: "GET".into(),
292            path: "/".into(),
293            query: None,
294            headers: headers
295                .into_iter()
296                .map(|(k, v)| (k.into(), v.into()))
297                .collect(),
298            body: None,
299            user_agent: Some("Mozilla/5.0".into()),
300        }
301    }
302
303    #[test]
304    fn clean_request_passes() {
305        let guard = DdosGuard::new(DdosConfig::default());
306        let req = make_req("10.0.0.1");
307        assert!(guard.check(&req).is_none());
308    }
309
310    #[test]
311    fn per_ip_connection_limit() {
312        let config = DdosConfig {
313            max_connections_per_ip: 3,
314            ..Default::default()
315        };
316        let guard = DdosGuard::new(config);
317        let ip: IpAddr = "10.0.0.1".parse().unwrap();
318
319        // Record 3 connections.
320        for _ in 0..3 {
321            guard.record_connection(ip);
322        }
323
324        let req = make_req("10.0.0.1");
325        let decision = guard.check(&req);
326        assert!(matches!(
327            decision,
328            Some(WafDecision::Block { status: 429, .. })
329        ));
330    }
331
332    #[test]
333    fn per_ip_limit_released() {
334        let config = DdosConfig {
335            max_connections_per_ip: 2,
336            ..Default::default()
337        };
338        let guard = DdosGuard::new(config);
339        let ip: IpAddr = "10.0.0.1".parse().unwrap();
340
341        guard.record_connection(ip);
342        guard.record_connection(ip);
343        // At limit.
344        assert!(guard.check(&make_req("10.0.0.1")).is_some());
345
346        // Release one.
347        guard.release_connection(ip);
348        assert!(guard.check(&make_req("10.0.0.1")).is_none());
349    }
350
351    #[test]
352    fn body_size_limit() {
353        let config = DdosConfig {
354            max_request_body_size: 100,
355            ..Default::default()
356        };
357        let guard = DdosGuard::new(config);
358
359        // Small body: OK.
360        assert!(guard
361            .check(&make_req_with_body("10.0.0.1", "short"))
362            .is_none());
363
364        // Large body: blocked.
365        let big = "x".repeat(200);
366        let decision = guard.check(&make_req_with_body("10.0.0.1", &big));
367        assert!(matches!(
368            decision,
369            Some(WafDecision::Block { status: 413, .. })
370        ));
371    }
372
373    #[test]
374    fn header_count_limit() {
375        let config = DdosConfig {
376            header_count_limit: 3,
377            ..Default::default()
378        };
379        let guard = DdosGuard::new(config);
380
381        let headers: Vec<(&str, &str)> = (0..5)
382            .map(|i| match i {
383                0 => ("H0", "v0"),
384                1 => ("H1", "v1"),
385                2 => ("H2", "v2"),
386                3 => ("H3", "v3"),
387                _ => ("H4", "v4"),
388            })
389            .collect();
390        let req = make_req_with_headers("10.0.0.1", headers);
391        let decision = guard.check(&req);
392        assert!(matches!(
393            decision,
394            Some(WafDecision::Block { status: 431, .. })
395        ));
396    }
397
398    #[test]
399    fn header_size_limit() {
400        let config = DdosConfig {
401            header_size_limit: 50,
402            ..Default::default()
403        };
404        let guard = DdosGuard::new(config);
405
406        // Create headers that exceed 50 bytes total.
407        let big_value = "x".repeat(60);
408        let headers = vec![("X-Big", big_value.as_str())];
409        let req = make_req_with_headers("10.0.0.1", headers);
410        let decision = guard.check(&req);
411        assert!(matches!(
412            decision,
413            Some(WafDecision::Block { status: 431, .. })
414        ));
415    }
416
417    #[test]
418    fn different_ips_independent() {
419        let config = DdosConfig {
420            max_connections_per_ip: 2,
421            ..Default::default()
422        };
423        let guard = DdosGuard::new(config);
424
425        let ip1: IpAddr = "10.0.0.1".parse().unwrap();
426        guard.record_connection(ip1);
427        guard.record_connection(ip1);
428        // ip1 at limit.
429        assert!(guard.check(&make_req("10.0.0.1")).is_some());
430        // ip2 is fine.
431        assert!(guard.check(&make_req("10.0.0.2")).is_none());
432    }
433
434    #[test]
435    fn cleanup_removes_stale() {
436        let guard = DdosGuard::new(DdosConfig::default());
437        let ip: IpAddr = "10.0.0.1".parse().unwrap();
438        guard.record_connection(ip);
439
440        assert!(guard.connections.contains_key(&ip));
441        // Cleanup with 0 age removes everything.
442        guard.cleanup(0);
443        assert!(!guard.connections.contains_key(&ip));
444    }
445
446    #[test]
447    fn global_rate_within_limit() {
448        let config = DdosConfig {
449            max_new_connections_per_second: 1000,
450            ..Default::default()
451        };
452        let guard = DdosGuard::new(config);
453        let req = make_req("10.0.0.1");
454        // Single request should be fine.
455        assert!(guard.check(&req).is_none());
456    }
457}