Skip to main content

bext_waf/rules/
mod.rs

1//! Attack-pattern detection engine — orchestrates 16 specialised rule modules
2//! with 20+ security checks.
3//!
4//! The [`RuleEngine`] runs each enabled detector against the request's path,
5//! query string, headers, and body.  Detection categories include SQL injection,
6//! XSS, path traversal, shell injection, CRLF injection, XXE, SSTI, prototype
7//! pollution, NoSQL injection, unsafe deserialisation, SSI, Log4Shell, SSRF,
8//! open redirect, GraphQL introspection, LDAP injection, HTTP method override,
9//! protocol violations, sensitive path access, scanner detection, and JWT
10//! `alg:none` attacks.  All regex patterns are compiled once via
11//! `OnceLock<RegexSet>` for zero per-request overhead.
12
13pub mod crlf;
14pub mod custom;
15pub mod deserialization;
16pub mod log4shell;
17pub mod nosql;
18pub mod protocol;
19pub mod prototype;
20pub mod scanner;
21pub mod sensitive_path;
22pub mod shell;
23pub mod sqli;
24pub mod ssi;
25pub mod ssti;
26pub mod traversal;
27pub mod xss;
28pub mod xxe;
29
30use serde::{Deserialize, Serialize};
31
32use crate::{WafDecision, WafRequest};
33
34/// Which rule categories are enabled.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct RuleConfig {
37    #[serde(default = "default_true")]
38    pub sql_injection: bool,
39    #[serde(default = "default_true")]
40    pub xss: bool,
41    #[serde(default = "default_true")]
42    pub path_traversal: bool,
43    #[serde(default = "default_true")]
44    pub shell_injection: bool,
45    #[serde(default = "default_true")]
46    pub protocol_violation: bool,
47    #[serde(default = "default_true")]
48    pub scanner_detection: bool,
49    #[serde(default = "default_true")]
50    pub sensitive_path: bool,
51    #[serde(default = "default_true")]
52    pub crlf_injection: bool,
53    #[serde(default = "default_true")]
54    pub method_override: bool,
55    /// If `log_only` is true, matches are logged but not blocked.
56    #[serde(default)]
57    pub log_only: bool,
58}
59
60fn default_true() -> bool {
61    true
62}
63
64impl Default for RuleConfig {
65    fn default() -> Self {
66        Self {
67            sql_injection: true,
68            xss: true,
69            path_traversal: true,
70            shell_injection: true,
71            protocol_violation: true,
72            scanner_detection: true,
73            sensitive_path: true,
74            crlf_injection: true,
75            method_override: true,
76            log_only: false,
77        }
78    }
79}
80
81/// Details of a rule match.
82#[derive(Debug, Clone)]
83pub struct WafRuleMatch {
84    pub rule_name: String,
85    pub matched_pattern: String,
86    pub matched_input: String,
87}
88
89/// The rule engine runs all enabled detectors against a request.
90pub struct RuleEngine {
91    config: RuleConfig,
92    custom_rules: custom::CustomRuleSet,
93}
94
95impl RuleEngine {
96    pub fn new(config: RuleConfig, custom_rules: Vec<custom::CustomRule>) -> Self {
97        Self {
98            config,
99            custom_rules: custom::CustomRuleSet::new(custom_rules),
100        }
101    }
102
103    /// Inspect a request against all enabled rule categories.
104    /// Returns the first match as a `WafDecision`, or `None` if clean.
105    pub fn inspect(&self, req: &WafRequest) -> Option<WafDecision> {
106        // 1. Custom rules first.
107        if let Some(decision) = self.custom_rules.check(req) {
108            return Some(decision);
109        }
110
111        // Collect raw inputs (zero-copy).
112        let inputs_to_check = self.collect_inputs(req);
113
114        // 2. SQL injection.
115        if self.config.sql_injection {
116            for (label, input) in &inputs_to_check {
117                if let Some(desc) = sqli::check_sqli(input) {
118                    return self.make_decision("sql_injection", &desc, label);
119                }
120            }
121        }
122
123        // 3. XSS.
124        if self.config.xss {
125            for (label, input) in &inputs_to_check {
126                if let Some(desc) = xss::check_xss(input) {
127                    return self.make_decision("xss", &desc, label);
128                }
129            }
130        }
131
132        // 4. Path traversal.
133        if self.config.path_traversal {
134            if let Some(desc) = traversal::check_traversal(&req.path) {
135                return self.make_decision("path_traversal", &desc, "path");
136            }
137            if let Some(ref q) = req.query {
138                if let Some(desc) = traversal::check_traversal(q) {
139                    return self.make_decision("path_traversal", &desc, "query");
140                }
141            }
142        }
143
144        // 5. Shell injection.
145        if self.config.shell_injection {
146            for (label, input) in &inputs_to_check {
147                if let Some(desc) = shell::check_shell(input) {
148                    return self.make_decision("shell_injection", &desc, label);
149                }
150            }
151        }
152
153        // 6. Protocol violation.
154        if self.config.protocol_violation {
155            if let Some(desc) = protocol::check_protocol(req) {
156                return self.make_decision("protocol_violation", &desc, "request");
157            }
158        }
159
160        // 6b. SSRF.
161        if let Some(ref query) = req.query {
162            if let Some(finding) = check_ssrf(query) {
163                return self.make_decision("ssrf", &finding, "query");
164            }
165        }
166
167        // 7. Sensitive path (cheap regex, only on path).
168        if self.config.sensitive_path {
169            if let Some(desc) = sensitive_path::check_sensitive_path(&req.path) {
170                return self.make_decision("sensitive_path", &desc, "path");
171            }
172        }
173
174        // 8. CRLF injection (only if path/query contains %).
175        if self.config.crlf_injection {
176            if req.path.contains('%') || req.path.contains('\r') || req.path.contains('\n') {
177                if let Some(desc) = crlf::check_crlf(&req.path) {
178                    return self.make_decision("crlf_injection", &desc, "path");
179                }
180            }
181            if let Some(ref q) = req.query {
182                if q.contains('%') || q.contains('\r') || q.contains('\n') {
183                    if let Some(desc) = crlf::check_crlf(q) {
184                        return self.make_decision("crlf_injection", &desc, "query");
185                    }
186                }
187            }
188        }
189
190        // 9. Method override headers.
191        if self.config.method_override {
192            if let Some(desc) = check_method_override(req) {
193                return self.make_decision("method_override", &desc, "header");
194            }
195        }
196
197        // 10. Log4Shell — only if input contains "${" (one branch, no alloc).
198        for (label, input) in &inputs_to_check {
199            if input.contains("${") {
200                if let Some(desc) = log4shell::check_log4shell(input) {
201                    return self.make_decision("log4shell", &desc, label);
202                }
203            }
204        }
205        if let Some(ref ua) = req.user_agent {
206            if ua.contains("${") {
207                if let Some(desc) = log4shell::check_log4shell(ua) {
208                    return self.make_decision("log4shell", &desc, "user_agent");
209                }
210            }
211        }
212        for value in req.headers.values() {
213            if value.contains("${") {
214                if let Some(desc) = log4shell::check_log4shell(value) {
215                    return self.make_decision("log4shell", &desc, "header");
216                }
217            }
218        }
219
220        // 11. XXE — only if body contains "<!" (one branch).
221        if let Some(ref body) = req.body {
222            if body.contains("<!") {
223                if let Some(desc) = xxe::check_xxe(body) {
224                    return self.make_decision("xxe", &desc, "body");
225                }
226            }
227        }
228
229        // 12. SSTI — only if input contains "{{" or "${" or "<%" or "__" (cheap).
230        for (label, input) in &inputs_to_check {
231            if input.contains("{{")
232                || input.contains("${")
233                || input.contains("<%")
234                || input.contains("__")
235            {
236                if let Some(desc) = ssti::check_ssti(input) {
237                    return self.make_decision("ssti", &desc, label);
238                }
239            }
240        }
241
242        // 13. Prototype pollution — only if input contains "__proto__" or "constructor" (cheap).
243        for (label, input) in &inputs_to_check {
244            if input.contains("__proto__") || input.contains("constructor") {
245                if let Some(desc) = prototype::check_prototype(input) {
246                    return self.make_decision("prototype_pollution", &desc, label);
247                }
248            }
249        }
250
251        // 14. NoSQL injection — only if input contains "$" in JSON context (cheap).
252        for (label, input) in &inputs_to_check {
253            if input.contains("\"$") {
254                if let Some(desc) = nosql::check_nosql(input) {
255                    return self.make_decision("nosql_injection", &desc, label);
256                }
257            }
258        }
259
260        // 15. Deserialization — only if input contains magic bytes (cheap).
261        for (label, input) in &inputs_to_check {
262            if input.contains("rO0AB")
263                || input.contains("aced")
264                || input.contains("AAEAAAD")
265                || (input.len() > 4
266                    && input.as_bytes()[1] == b':'
267                    && input.as_bytes()[0].is_ascii_uppercase())
268            {
269                if let Some(desc) = deserialization::check_deserialization(input) {
270                    return self.make_decision("deserialization", &desc, label);
271                }
272            }
273        }
274
275        // 16. SSI injection — only if input contains "<!--#" (cheap).
276        for (label, input) in &inputs_to_check {
277            if input.contains("<!--#") {
278                if let Some(desc) = ssi::check_ssi(input) {
279                    return self.make_decision("ssi_injection", &desc, label);
280                }
281            }
282        }
283
284        // 17. Open redirect — check query params for external URLs.
285        if let Some(ref query) = req.query {
286            if let Some(desc) = check_open_redirect(query) {
287                return self.make_decision("open_redirect", &desc, "query");
288            }
289        }
290
291        // 18. GraphQL introspection — only if body contains "__schema" or "__type".
292        if let Some(ref body) = req.body {
293            if body.contains("__schema") || body.contains("__type") {
294                return self.make_decision(
295                    "graphql_introspection",
296                    "GraphQL introspection query detected (__schema/__type)",
297                    "body",
298                );
299            }
300        }
301
302        // 19. LDAP injection — only if input contains ")(", "(|", "(&", or bare "*" auth.
303        for (label, input) in &inputs_to_check {
304            if input.contains(")(") || input.contains("(|") || input.contains("(&") {
305                return self.make_decision(
306                    "ldap_injection",
307                    "LDAP filter injection pattern detected",
308                    label,
309                );
310            }
311        }
312
313        // 20. JWT alg:none — check Authorization header for known bad base64 prefix.
314        // "eyJhbGciOiJub25lIi" = base64({"alg":"none")
315        // Also detect empty signature (JWT ending with '.')
316        for (name, value) in &req.headers {
317            if name.eq_ignore_ascii_case("authorization") && value.contains("eyJhbGciOiJub25lIi") {
318                return self.make_decision("jwt_attack", "JWT with alg:none detected", "header");
319            }
320            if name.eq_ignore_ascii_case("authorization") && value.starts_with("Bearer ") {
321                let token = &value[7..];
322                // JWT format: header.payload.signature — empty sig means trailing dot
323                let parts: Vec<&str> = token.split('.').collect();
324                if parts.len() == 3 && parts[2].is_empty() {
325                    return self.make_decision(
326                        "jwt_attack",
327                        "JWT with empty signature detected",
328                        "header",
329                    );
330                }
331            }
332        }
333
334        // 21. Scanner detection (user-agent).
335        if self.config.scanner_detection {
336            if let Some(ref ua) = req.user_agent {
337                if let Some(desc) = scanner::check_scanner(ua) {
338                    return self.make_decision("scanner_detection", &desc, "user_agent");
339                }
340            }
341        }
342
343        // --- Second pass: URL-decoded path+query only (not body). ---
344        // Only allocates when % is present. Catches encoding evasion.
345        if let Some(decoded_path) = percent_decode(&req.path) {
346            if self.config.sql_injection {
347                if let Some(desc) = sqli::check_sqli(&decoded_path) {
348                    return self.make_decision("sql_injection", &desc, "path(decoded)");
349                }
350            }
351            if self.config.xss {
352                if let Some(desc) = xss::check_xss(&decoded_path) {
353                    return self.make_decision("xss", &desc, "path(decoded)");
354                }
355            }
356            if self.config.path_traversal {
357                if let Some(desc) = traversal::check_traversal(&decoded_path) {
358                    return self.make_decision("path_traversal", &desc, "path(decoded)");
359                }
360            }
361            // Double-decode
362            if let Some(double) = percent_decode(&decoded_path) {
363                if self.config.path_traversal {
364                    if let Some(desc) = traversal::check_traversal(&double) {
365                        return self.make_decision("path_traversal", &desc, "path(double-decoded)");
366                    }
367                }
368            }
369        }
370        if let Some(ref q) = req.query {
371            if let Some(decoded_q) = percent_decode(q) {
372                if self.config.sql_injection {
373                    if let Some(desc) = sqli::check_sqli(&decoded_q) {
374                        return self.make_decision("sql_injection", &desc, "query(decoded)");
375                    }
376                }
377                if self.config.xss {
378                    if let Some(desc) = xss::check_xss(&decoded_q) {
379                        return self.make_decision("xss", &desc, "query(decoded)");
380                    }
381                }
382                if self.config.path_traversal {
383                    if let Some(desc) = traversal::check_traversal(&decoded_q) {
384                        return self.make_decision("path_traversal", &desc, "query(decoded)");
385                    }
386                }
387                if self.config.shell_injection {
388                    if let Some(desc) = shell::check_shell(&decoded_q) {
389                        return self.make_decision("shell_injection", &desc, "query(decoded)");
390                    }
391                }
392                // Double-decode query
393                if let Some(double_q) = percent_decode(&decoded_q) {
394                    if self.config.xss {
395                        if let Some(desc) = xss::check_xss(&double_q) {
396                            return self.make_decision("xss", &desc, "query(double-decoded)");
397                        }
398                    }
399                    if self.config.sql_injection {
400                        if let Some(desc) = sqli::check_sqli(&double_q) {
401                            return self.make_decision(
402                                "sql_injection",
403                                &desc,
404                                "query(double-decoded)",
405                            );
406                        }
407                    }
408                }
409            }
410        }
411
412        // --- Third pass: URL-decoded body (catches %25-encoded attacks in POST data). ---
413        if let Some(ref body) = req.body {
414            if let Some(decoded_body) = percent_decode(body) {
415                if self.config.sql_injection {
416                    if let Some(desc) = sqli::check_sqli(&decoded_body) {
417                        return self.make_decision("sql_injection", &desc, "body(decoded)");
418                    }
419                }
420                if self.config.xss {
421                    if let Some(desc) = xss::check_xss(&decoded_body) {
422                        return self.make_decision("xss", &desc, "body(decoded)");
423                    }
424                }
425                // Double-decode body
426                if let Some(double_body) = percent_decode(&decoded_body) {
427                    if self.config.sql_injection {
428                        if let Some(desc) = sqli::check_sqli(&double_body) {
429                            return self.make_decision("sql_injection", &desc, "body(double-decoded)");
430                        }
431                    }
432                    if self.config.xss {
433                        if let Some(desc) = xss::check_xss(&double_body) {
434                            return self.make_decision("xss", &desc, "body(double-decoded)");
435                        }
436                    }
437                }
438            }
439        }
440
441        None
442    }
443
444    /// Collect all text inputs from the request that should be scanned.
445    fn collect_inputs<'a>(&self, req: &'a WafRequest) -> Vec<(&'static str, &'a str)> {
446        let mut inputs = Vec::with_capacity(4);
447        inputs.push(("path", req.path.as_str()));
448        if let Some(ref q) = req.query {
449            inputs.push(("query", q.as_str()));
450        }
451        if let Some(ref body) = req.body {
452            inputs.push(("body", body.as_str()));
453        }
454        inputs
455    }
456
457    fn make_decision(&self, rule: &str, desc: &str, input_source: &str) -> Option<WafDecision> {
458        if self.config.log_only {
459            tracing::warn!(
460                rule = rule,
461                pattern = desc,
462                source = input_source,
463                "WAF rule match (log-only mode)"
464            );
465            None
466        } else {
467            Some(WafDecision::Block {
468                status: 403,
469                reason: format!("{rule}: {desc} (in {input_source})"),
470                rule: rule.into(),
471            })
472        }
473    }
474}
475
476/// Check query string for open redirect patterns (external URLs in redirect params).
477fn check_open_redirect(query: &str) -> Option<String> {
478    use regex::RegexSet;
479    use std::sync::OnceLock;
480
481    static REDIRECT_PATTERNS: OnceLock<RegexSet> = OnceLock::new();
482    let patterns = REDIRECT_PATTERNS.get_or_init(|| {
483        RegexSet::new([
484            // Protocol-relative: redirect=//evil.com
485            r"(?i)(redirect|redirect_uri|next|return_to|return_url|dest|destination|rurl|continue|login_to|logout|forward|goto|target_url|returnTo|RelayState)=//[a-zA-Z]",
486            // Full URL: redirect=https://evil.com
487            r"(?i)(redirect|redirect_uri|next|return_to|return_url|dest|destination|rurl|continue|login_to|logout|forward|goto|target_url|returnTo|RelayState)=https?://[a-zA-Z]",
488            // Backslash trick: redirect=/\evil.com
489            r"(?i)(redirect|redirect_uri|next|return_to|return_url|dest|destination|rurl|continue|login_to|logout|forward|goto|target_url|returnTo|RelayState)=/\\",
490            // Userinfo trick: redirect=http://legit@evil.com
491            r"(?i)(redirect|redirect_uri|next|return_to|return_url|dest|destination|rurl|continue|login_to|logout|forward|goto|target_url|returnTo|RelayState)=https?://[^/]*@",
492        ])
493        .unwrap()
494    });
495
496    if patterns.is_match(query) {
497        Some("Open redirect: external URL in redirect parameter".into())
498    } else {
499        None
500    }
501}
502
503fn check_method_override(req: &WafRequest) -> Option<String> {
504    static OVERRIDE_HEADERS: &[&str] = &[
505        "x-http-method-override",
506        "x-http-method",
507        "x-method-override",
508    ];
509    static DANGEROUS_HEADERS: &[&str] = &[
510        "x-original-url",
511        "x-rewrite-url",
512        "x-forwarded-host",
513        "x-forwarded-scheme",
514    ];
515    for name in req.headers.keys() {
516        let lower = name.to_ascii_lowercase();
517        if OVERRIDE_HEADERS.contains(&lower.as_str()) {
518            return Some(format!("HTTP method override header detected: {name}"));
519        }
520        if DANGEROUS_HEADERS.contains(&lower.as_str()) {
521            return Some(format!(
522                "URL override header detected (cache poisoning): {name}"
523            ));
524        }
525    }
526    None
527}
528
529fn percent_decode(input: &str) -> Option<String> {
530    if !input.contains('%') {
531        return None;
532    }
533    let bytes = input.as_bytes();
534    let mut out = Vec::with_capacity(bytes.len());
535    let mut i = 0;
536    let mut changed = false;
537    while i < bytes.len() {
538        if bytes[i] == b'%' && i + 2 < bytes.len() {
539            if let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
540                out.push(hi << 4 | lo);
541                i += 3;
542                changed = true;
543                continue;
544            }
545        }
546        out.push(bytes[i]);
547        i += 1;
548    }
549    if changed {
550        Some(String::from_utf8_lossy(&out).into_owned())
551    } else {
552        None
553    }
554}
555
556fn hex_val(b: u8) -> Option<u8> {
557    match b {
558        b'0'..=b'9' => Some(b - b'0'),
559        b'a'..=b'f' => Some(b - b'a' + 10),
560        b'A'..=b'F' => Some(b - b'A' + 10),
561        _ => None,
562    }
563}
564
565/// Check for SSRF patterns in input (private/internal URLs, dangerous schemes).
566fn check_ssrf(input: &str) -> Option<String> {
567    use regex::RegexSet;
568    use std::sync::OnceLock;
569
570    static SSRF_PATTERNS: OnceLock<RegexSet> = OnceLock::new();
571    let patterns = SSRF_PATTERNS.get_or_init(|| {
572        RegexSet::new([
573            r"(?i)(https?://|//)(localhost|127\.0\.0\.1|0\.0\.0\.0|\[::1\])",
574            r"(?i)(https?://|//)(10\.\d+\.\d+\.\d+|172\.(1[6-9]|2\d|3[01])\.\d+\.\d+|192\.168\.\d+\.\d+)",
575            r"(?i)(https?://|//)169\.254\.\d+\.\d+",
576            r"(?i)(file|gopher|dict|ftp)://",
577            r"(?i)(https?://|//)\d{8,10}(/|$|\s|:)",
578            r"(?i)(https?://|//)0\d+\.\d+\.\d+\.\d+",
579            r"(?i)(https?://|//)\[::ffff:",
580            r"(?i)(https?://|//)(127\.1|0\.0\.0\.0)(:|/|$)",
581            r"(?i)(https?://)\w+@(localhost|127\.|10\.|172\.(1[6-9]|2\d|3[01])\.|192\.168\.)",
582            r"(?i)(https?://|//)(metadata\.google\.internal|metadata\.azure\.com)",
583        ]).unwrap()
584    });
585
586    if patterns.is_match(input) {
587        Some("SSRF: private/internal URL detected in parameters".into())
588    } else {
589        None
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596    use std::collections::HashMap;
597
598    fn make_req(
599        method: &str,
600        path: &str,
601        query: Option<&str>,
602        body: Option<&str>,
603        ua: Option<&str>,
604    ) -> WafRequest {
605        WafRequest {
606            client_ip: "127.0.0.1".parse().unwrap(),
607            method: method.into(),
608            path: path.into(),
609            query: query.map(String::from),
610            headers: HashMap::new(),
611            body: body.map(String::from),
612            user_agent: ua.map(String::from),
613        }
614    }
615
616    #[test]
617    fn clean_request_passes() {
618        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
619        let req = make_req("GET", "/api/users", None, None, Some("Mozilla/5.0"));
620        assert!(engine.inspect(&req).is_none());
621    }
622
623    #[test]
624    fn detects_sqli_in_query() {
625        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
626        let req = make_req(
627            "GET",
628            "/search",
629            Some("q=1 UNION SELECT * FROM users"),
630            None,
631            None,
632        );
633        let decision = engine.inspect(&req);
634        assert!(
635            matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "sql_injection")
636        );
637    }
638
639    #[test]
640    fn detects_sqli_in_body() {
641        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
642        let req = make_req("POST", "/login", None, Some("user=admin' OR 1=1 --"), None);
643        let decision = engine.inspect(&req);
644        assert!(
645            matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "sql_injection")
646        );
647    }
648
649    #[test]
650    fn detects_xss_in_body() {
651        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
652        let req = make_req(
653            "POST",
654            "/comment",
655            None,
656            Some("<script>alert(1)</script>"),
657            None,
658        );
659        let decision = engine.inspect(&req);
660        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "xss"));
661    }
662
663    #[test]
664    fn detects_traversal_in_path() {
665        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
666        let req = make_req("GET", "/static/../../etc/passwd", None, None, None);
667        let decision = engine.inspect(&req);
668        assert!(
669            matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "path_traversal")
670        );
671    }
672
673    #[test]
674    fn detects_scanner_ua() {
675        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
676        let req = make_req("GET", "/", None, None, Some("sqlmap/1.5"));
677        let decision = engine.inspect(&req);
678        assert!(
679            matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "scanner_detection")
680        );
681    }
682
683    #[test]
684    fn detects_shell_injection_in_body() {
685        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
686        let req = make_req("POST", "/exec", None, Some("; cat /etc/passwd"), None);
687        let decision = engine.inspect(&req);
688        assert!(
689            matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "shell_injection")
690        );
691    }
692
693    #[test]
694    fn detects_shell_injection_in_query() {
695        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
696        let req = make_req("GET", "/search", Some("cmd=$(whoami)"), None, None);
697        let decision = engine.inspect(&req);
698        assert!(
699            matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "shell_injection")
700        );
701    }
702
703    #[test]
704    fn detects_protocol_violation_null_byte() {
705        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
706        let req = make_req("GET", "/file.php%00.jpg", None, None, None);
707        let decision = engine.inspect(&req);
708        // Null byte in path should be caught by either traversal or protocol violation.
709        assert!(decision.is_some());
710    }
711
712    #[test]
713    fn detects_protocol_violation_body_no_content_length() {
714        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
715        // POST with body but no Content-Length header.
716        let req = make_req("POST", "/api/data", None, Some(r#"{"key":"val"}"#), None);
717        let decision = engine.inspect(&req);
718        assert!(
719            matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "protocol_violation")
720        );
721    }
722
723    #[test]
724    fn protocol_violation_passes_with_content_length() {
725        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
726        let req = WafRequest {
727            client_ip: "127.0.0.1".parse().unwrap(),
728            method: "POST".into(),
729            path: "/api/data".into(),
730            query: None,
731            headers: {
732                let mut h = HashMap::new();
733                h.insert("Content-Length".into(), "13".into());
734                h
735            },
736            body: Some(r#"{"key":"val"}"#.into()),
737            user_agent: None,
738        };
739        assert!(engine.inspect(&req).is_none());
740    }
741
742    #[test]
743    fn disabled_rules_skip() {
744        let config = RuleConfig {
745            sql_injection: false,
746            xss: false,
747            path_traversal: false,
748            shell_injection: false,
749            protocol_violation: false,
750            scanner_detection: false,
751            sensitive_path: false,
752            crlf_injection: false,
753            method_override: false,
754            log_only: false,
755        };
756        let engine = RuleEngine::new(config, vec![]);
757        let req = make_req(
758            "POST",
759            "/../../etc/passwd",
760            Some("q=UNION SELECT *"),
761            Some("<script>alert(1)</script>"),
762            Some("sqlmap/1.5"),
763        );
764        assert!(engine.inspect(&req).is_none());
765    }
766
767    #[test]
768    fn shell_injection_disabled_allows() {
769        let config = RuleConfig {
770            shell_injection: false,
771            ..Default::default()
772        };
773        let engine = RuleEngine::new(config, vec![]);
774        let req = make_req("POST", "/api", None, Some("; cat /etc/passwd"), None);
775        // Should not trigger shell_injection when disabled (may still trigger traversal for /etc/passwd).
776        let decision = engine.inspect(&req);
777        if let Some(WafDecision::Block { rule, .. }) = &decision {
778            assert_ne!(rule, "shell_injection");
779        }
780    }
781
782    // ---- SSRF detection tests ----
783
784    #[test]
785    fn detects_ssrf_localhost() {
786        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
787        let req = make_req(
788            "GET",
789            "/proxy",
790            Some("url=http://localhost/admin"),
791            None,
792            None,
793        );
794        let decision = engine.inspect(&req);
795        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
796    }
797
798    #[test]
799    fn detects_ssrf_127001() {
800        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
801        let req = make_req(
802            "GET",
803            "/fetch",
804            Some("url=http://127.0.0.1:8080/secret"),
805            None,
806            None,
807        );
808        let decision = engine.inspect(&req);
809        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
810    }
811
812    #[test]
813    fn detects_ssrf_private_10_range() {
814        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
815        let req = make_req(
816            "GET",
817            "/proxy",
818            Some("target=http://10.0.0.1/internal"),
819            None,
820            None,
821        );
822        let decision = engine.inspect(&req);
823        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
824    }
825
826    #[test]
827    fn detects_ssrf_private_192168() {
828        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
829        let req = make_req(
830            "GET",
831            "/proxy",
832            Some("target=http://192.168.1.1/admin"),
833            None,
834            None,
835        );
836        let decision = engine.inspect(&req);
837        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
838    }
839
840    #[test]
841    fn detects_ssrf_private_172() {
842        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
843        let req = make_req(
844            "GET",
845            "/proxy",
846            Some("url=http://172.16.0.1/meta"),
847            None,
848            None,
849        );
850        let decision = engine.inspect(&req);
851        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
852    }
853
854    #[test]
855    fn detects_ssrf_link_local_169254() {
856        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
857        let req = make_req(
858            "GET",
859            "/proxy",
860            Some("url=http://169.254.169.254/latest/meta-data/"),
861            None,
862            None,
863        );
864        let decision = engine.inspect(&req);
865        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
866    }
867
868    #[test]
869    fn detects_ssrf_file_scheme() {
870        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
871        // Use a path that won't trigger traversal or shell rules before SSRF
872        let req = make_req(
873            "GET",
874            "/read",
875            Some("path=file:///tmp/data.txt"),
876            None,
877            None,
878        );
879        let decision = engine.inspect(&req);
880        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
881    }
882
883    #[test]
884    fn detects_ssrf_gopher_scheme() {
885        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
886        let req = make_req(
887            "GET",
888            "/fetch",
889            Some("url=gopher://127.0.0.1:25/"),
890            None,
891            None,
892        );
893        let decision = engine.inspect(&req);
894        assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
895    }
896
897    #[test]
898    fn ssrf_allows_public_urls() {
899        let engine = RuleEngine::new(RuleConfig::default(), vec![]);
900        let req = make_req(
901            "GET",
902            "/proxy",
903            Some("url=https://example.com/page"),
904            None,
905            None,
906        );
907        assert!(engine.inspect(&req).is_none());
908    }
909
910    #[test]
911    fn ssrf_check_function_directly() {
912        assert!(check_ssrf("url=http://localhost/admin").is_some());
913        assert!(check_ssrf("url=http://0.0.0.0:8080").is_some());
914        assert!(check_ssrf("url=http://[::1]/secret").is_some());
915        assert!(check_ssrf("url=dict://evil.com").is_some());
916        assert!(check_ssrf("url=ftp://internal").is_some());
917        assert!(check_ssrf("url=https://google.com").is_none());
918        assert!(check_ssrf("search=hello+world").is_none());
919    }
920
921    #[test]
922    fn protocol_violation_disabled_allows() {
923        let config = RuleConfig {
924            protocol_violation: false,
925            ..Default::default()
926        };
927        let engine = RuleEngine::new(config, vec![]);
928        // POST with body but no Content-Length — should pass if protocol_violation is disabled.
929        let req = make_req("POST", "/api/clean", None, Some("clean body"), None);
930        let decision = engine.inspect(&req);
931        if let Some(WafDecision::Block { rule, .. }) = &decision {
932            assert_ne!(rule, "protocol_violation");
933        }
934    }
935
936    #[test]
937    fn log_only_mode_allows() {
938        let config = RuleConfig {
939            log_only: true,
940            ..Default::default()
941        };
942        let engine = RuleEngine::new(config, vec![]);
943        let req = make_req("GET", "/search", Some("q=1 UNION SELECT *"), None, None);
944        // In log-only mode, no blocking.
945        assert!(engine.inspect(&req).is_none());
946    }
947
948    #[test]
949    fn custom_rules_take_priority() {
950        let custom = vec![custom::CustomRule {
951            name: "block-all-posts".into(),
952            match_config: custom::MatchConfig {
953                method: Some("POST".into()),
954                ..Default::default()
955            },
956            action: custom::CustomRuleAction::Block,
957            status: 405,
958            reason: Some("POST not allowed".into()),
959        }];
960        let engine = RuleEngine::new(RuleConfig::default(), custom);
961        // Even a clean POST request is blocked by custom rule.
962        let req = make_req("POST", "/api/data", None, Some(r#"{"key":"value"}"#), None);
963        let decision = engine.inspect(&req);
964        assert!(matches!(
965            decision,
966            Some(WafDecision::Block { status: 405, .. })
967        ));
968    }
969}