1pub mod crlf;
14pub mod custom;
15pub mod deserialization;
16pub mod log4shell;
17pub mod nosql;
18pub mod protocol;
19pub mod prototype;
20pub mod scanner;
21pub mod sensitive_path;
22pub mod shell;
23pub mod sqli;
24pub mod ssi;
25pub mod ssti;
26pub mod traversal;
27pub mod xss;
28pub mod xxe;
29
30use serde::{Deserialize, Serialize};
31
32use crate::{WafDecision, WafRequest};
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct RuleConfig {
37 #[serde(default = "default_true")]
38 pub sql_injection: bool,
39 #[serde(default = "default_true")]
40 pub xss: bool,
41 #[serde(default = "default_true")]
42 pub path_traversal: bool,
43 #[serde(default = "default_true")]
44 pub shell_injection: bool,
45 #[serde(default = "default_true")]
46 pub protocol_violation: bool,
47 #[serde(default = "default_true")]
48 pub scanner_detection: bool,
49 #[serde(default = "default_true")]
50 pub sensitive_path: bool,
51 #[serde(default = "default_true")]
52 pub crlf_injection: bool,
53 #[serde(default = "default_true")]
54 pub method_override: bool,
55 #[serde(default)]
57 pub log_only: bool,
58}
59
60fn default_true() -> bool {
61 true
62}
63
64impl Default for RuleConfig {
65 fn default() -> Self {
66 Self {
67 sql_injection: true,
68 xss: true,
69 path_traversal: true,
70 shell_injection: true,
71 protocol_violation: true,
72 scanner_detection: true,
73 sensitive_path: true,
74 crlf_injection: true,
75 method_override: true,
76 log_only: false,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct WafRuleMatch {
84 pub rule_name: String,
85 pub matched_pattern: String,
86 pub matched_input: String,
87}
88
89pub struct RuleEngine {
91 config: RuleConfig,
92 custom_rules: custom::CustomRuleSet,
93}
94
95impl RuleEngine {
96 pub fn new(config: RuleConfig, custom_rules: Vec<custom::CustomRule>) -> Self {
97 Self {
98 config,
99 custom_rules: custom::CustomRuleSet::new(custom_rules),
100 }
101 }
102
103 pub fn inspect(&self, req: &WafRequest) -> Option<WafDecision> {
106 if let Some(decision) = self.custom_rules.check(req) {
108 return Some(decision);
109 }
110
111 let inputs_to_check = self.collect_inputs(req);
113
114 if self.config.sql_injection {
116 for (label, input) in &inputs_to_check {
117 if let Some(desc) = sqli::check_sqli(input) {
118 return self.make_decision("sql_injection", &desc, label);
119 }
120 }
121 }
122
123 if self.config.xss {
125 for (label, input) in &inputs_to_check {
126 if let Some(desc) = xss::check_xss(input) {
127 return self.make_decision("xss", &desc, label);
128 }
129 }
130 }
131
132 if self.config.path_traversal {
134 if let Some(desc) = traversal::check_traversal(&req.path) {
135 return self.make_decision("path_traversal", &desc, "path");
136 }
137 if let Some(ref q) = req.query {
138 if let Some(desc) = traversal::check_traversal(q) {
139 return self.make_decision("path_traversal", &desc, "query");
140 }
141 }
142 }
143
144 if self.config.shell_injection {
146 for (label, input) in &inputs_to_check {
147 if let Some(desc) = shell::check_shell(input) {
148 return self.make_decision("shell_injection", &desc, label);
149 }
150 }
151 }
152
153 if self.config.protocol_violation {
155 if let Some(desc) = protocol::check_protocol(req) {
156 return self.make_decision("protocol_violation", &desc, "request");
157 }
158 }
159
160 if let Some(ref query) = req.query {
162 if let Some(finding) = check_ssrf(query) {
163 return self.make_decision("ssrf", &finding, "query");
164 }
165 }
166
167 if self.config.sensitive_path {
169 if let Some(desc) = sensitive_path::check_sensitive_path(&req.path) {
170 return self.make_decision("sensitive_path", &desc, "path");
171 }
172 }
173
174 if self.config.crlf_injection {
176 if req.path.contains('%') || req.path.contains('\r') || req.path.contains('\n') {
177 if let Some(desc) = crlf::check_crlf(&req.path) {
178 return self.make_decision("crlf_injection", &desc, "path");
179 }
180 }
181 if let Some(ref q) = req.query {
182 if q.contains('%') || q.contains('\r') || q.contains('\n') {
183 if let Some(desc) = crlf::check_crlf(q) {
184 return self.make_decision("crlf_injection", &desc, "query");
185 }
186 }
187 }
188 }
189
190 if self.config.method_override {
192 if let Some(desc) = check_method_override(req) {
193 return self.make_decision("method_override", &desc, "header");
194 }
195 }
196
197 for (label, input) in &inputs_to_check {
199 if input.contains("${") {
200 if let Some(desc) = log4shell::check_log4shell(input) {
201 return self.make_decision("log4shell", &desc, label);
202 }
203 }
204 }
205 if let Some(ref ua) = req.user_agent {
206 if ua.contains("${") {
207 if let Some(desc) = log4shell::check_log4shell(ua) {
208 return self.make_decision("log4shell", &desc, "user_agent");
209 }
210 }
211 }
212 for value in req.headers.values() {
213 if value.contains("${") {
214 if let Some(desc) = log4shell::check_log4shell(value) {
215 return self.make_decision("log4shell", &desc, "header");
216 }
217 }
218 }
219
220 if let Some(ref body) = req.body {
222 if body.contains("<!") {
223 if let Some(desc) = xxe::check_xxe(body) {
224 return self.make_decision("xxe", &desc, "body");
225 }
226 }
227 }
228
229 for (label, input) in &inputs_to_check {
231 if input.contains("{{")
232 || input.contains("${")
233 || input.contains("<%")
234 || input.contains("__")
235 {
236 if let Some(desc) = ssti::check_ssti(input) {
237 return self.make_decision("ssti", &desc, label);
238 }
239 }
240 }
241
242 for (label, input) in &inputs_to_check {
244 if input.contains("__proto__") || input.contains("constructor") {
245 if let Some(desc) = prototype::check_prototype(input) {
246 return self.make_decision("prototype_pollution", &desc, label);
247 }
248 }
249 }
250
251 for (label, input) in &inputs_to_check {
253 if input.contains("\"$") {
254 if let Some(desc) = nosql::check_nosql(input) {
255 return self.make_decision("nosql_injection", &desc, label);
256 }
257 }
258 }
259
260 for (label, input) in &inputs_to_check {
262 if input.contains("rO0AB")
263 || input.contains("aced")
264 || input.contains("AAEAAAD")
265 || (input.len() > 4
266 && input.as_bytes()[1] == b':'
267 && input.as_bytes()[0].is_ascii_uppercase())
268 {
269 if let Some(desc) = deserialization::check_deserialization(input) {
270 return self.make_decision("deserialization", &desc, label);
271 }
272 }
273 }
274
275 for (label, input) in &inputs_to_check {
277 if input.contains("<!--#") {
278 if let Some(desc) = ssi::check_ssi(input) {
279 return self.make_decision("ssi_injection", &desc, label);
280 }
281 }
282 }
283
284 if let Some(ref query) = req.query {
286 if let Some(desc) = check_open_redirect(query) {
287 return self.make_decision("open_redirect", &desc, "query");
288 }
289 }
290
291 if let Some(ref body) = req.body {
293 if body.contains("__schema") || body.contains("__type") {
294 return self.make_decision(
295 "graphql_introspection",
296 "GraphQL introspection query detected (__schema/__type)",
297 "body",
298 );
299 }
300 }
301
302 for (label, input) in &inputs_to_check {
304 if input.contains(")(") || input.contains("(|") || input.contains("(&") {
305 return self.make_decision(
306 "ldap_injection",
307 "LDAP filter injection pattern detected",
308 label,
309 );
310 }
311 }
312
313 for (name, value) in &req.headers {
317 if name.eq_ignore_ascii_case("authorization") && value.contains("eyJhbGciOiJub25lIi") {
318 return self.make_decision("jwt_attack", "JWT with alg:none detected", "header");
319 }
320 if name.eq_ignore_ascii_case("authorization") && value.starts_with("Bearer ") {
321 let token = &value[7..];
322 let parts: Vec<&str> = token.split('.').collect();
324 if parts.len() == 3 && parts[2].is_empty() {
325 return self.make_decision(
326 "jwt_attack",
327 "JWT with empty signature detected",
328 "header",
329 );
330 }
331 }
332 }
333
334 if self.config.scanner_detection {
336 if let Some(ref ua) = req.user_agent {
337 if let Some(desc) = scanner::check_scanner(ua) {
338 return self.make_decision("scanner_detection", &desc, "user_agent");
339 }
340 }
341 }
342
343 if let Some(decoded_path) = percent_decode(&req.path) {
346 if self.config.sql_injection {
347 if let Some(desc) = sqli::check_sqli(&decoded_path) {
348 return self.make_decision("sql_injection", &desc, "path(decoded)");
349 }
350 }
351 if self.config.xss {
352 if let Some(desc) = xss::check_xss(&decoded_path) {
353 return self.make_decision("xss", &desc, "path(decoded)");
354 }
355 }
356 if self.config.path_traversal {
357 if let Some(desc) = traversal::check_traversal(&decoded_path) {
358 return self.make_decision("path_traversal", &desc, "path(decoded)");
359 }
360 }
361 if let Some(double) = percent_decode(&decoded_path) {
363 if self.config.path_traversal {
364 if let Some(desc) = traversal::check_traversal(&double) {
365 return self.make_decision("path_traversal", &desc, "path(double-decoded)");
366 }
367 }
368 }
369 }
370 if let Some(ref q) = req.query {
371 if let Some(decoded_q) = percent_decode(q) {
372 if self.config.sql_injection {
373 if let Some(desc) = sqli::check_sqli(&decoded_q) {
374 return self.make_decision("sql_injection", &desc, "query(decoded)");
375 }
376 }
377 if self.config.xss {
378 if let Some(desc) = xss::check_xss(&decoded_q) {
379 return self.make_decision("xss", &desc, "query(decoded)");
380 }
381 }
382 if self.config.path_traversal {
383 if let Some(desc) = traversal::check_traversal(&decoded_q) {
384 return self.make_decision("path_traversal", &desc, "query(decoded)");
385 }
386 }
387 if self.config.shell_injection {
388 if let Some(desc) = shell::check_shell(&decoded_q) {
389 return self.make_decision("shell_injection", &desc, "query(decoded)");
390 }
391 }
392 if let Some(double_q) = percent_decode(&decoded_q) {
394 if self.config.xss {
395 if let Some(desc) = xss::check_xss(&double_q) {
396 return self.make_decision("xss", &desc, "query(double-decoded)");
397 }
398 }
399 if self.config.sql_injection {
400 if let Some(desc) = sqli::check_sqli(&double_q) {
401 return self.make_decision(
402 "sql_injection",
403 &desc,
404 "query(double-decoded)",
405 );
406 }
407 }
408 }
409 }
410 }
411
412 if let Some(ref body) = req.body {
414 if let Some(decoded_body) = percent_decode(body) {
415 if self.config.sql_injection {
416 if let Some(desc) = sqli::check_sqli(&decoded_body) {
417 return self.make_decision("sql_injection", &desc, "body(decoded)");
418 }
419 }
420 if self.config.xss {
421 if let Some(desc) = xss::check_xss(&decoded_body) {
422 return self.make_decision("xss", &desc, "body(decoded)");
423 }
424 }
425 if let Some(double_body) = percent_decode(&decoded_body) {
427 if self.config.sql_injection {
428 if let Some(desc) = sqli::check_sqli(&double_body) {
429 return self.make_decision("sql_injection", &desc, "body(double-decoded)");
430 }
431 }
432 if self.config.xss {
433 if let Some(desc) = xss::check_xss(&double_body) {
434 return self.make_decision("xss", &desc, "body(double-decoded)");
435 }
436 }
437 }
438 }
439 }
440
441 None
442 }
443
444 fn collect_inputs<'a>(&self, req: &'a WafRequest) -> Vec<(&'static str, &'a str)> {
446 let mut inputs = Vec::with_capacity(4);
447 inputs.push(("path", req.path.as_str()));
448 if let Some(ref q) = req.query {
449 inputs.push(("query", q.as_str()));
450 }
451 if let Some(ref body) = req.body {
452 inputs.push(("body", body.as_str()));
453 }
454 inputs
455 }
456
457 fn make_decision(&self, rule: &str, desc: &str, input_source: &str) -> Option<WafDecision> {
458 if self.config.log_only {
459 tracing::warn!(
460 rule = rule,
461 pattern = desc,
462 source = input_source,
463 "WAF rule match (log-only mode)"
464 );
465 None
466 } else {
467 Some(WafDecision::Block {
468 status: 403,
469 reason: format!("{rule}: {desc} (in {input_source})"),
470 rule: rule.into(),
471 })
472 }
473 }
474}
475
476fn check_open_redirect(query: &str) -> Option<String> {
478 use regex::RegexSet;
479 use std::sync::OnceLock;
480
481 static REDIRECT_PATTERNS: OnceLock<RegexSet> = OnceLock::new();
482 let patterns = REDIRECT_PATTERNS.get_or_init(|| {
483 RegexSet::new([
484 r"(?i)(redirect|redirect_uri|next|return_to|return_url|dest|destination|rurl|continue|login_to|logout|forward|goto|target_url|returnTo|RelayState)=//[a-zA-Z]",
486 r"(?i)(redirect|redirect_uri|next|return_to|return_url|dest|destination|rurl|continue|login_to|logout|forward|goto|target_url|returnTo|RelayState)=https?://[a-zA-Z]",
488 r"(?i)(redirect|redirect_uri|next|return_to|return_url|dest|destination|rurl|continue|login_to|logout|forward|goto|target_url|returnTo|RelayState)=/\\",
490 r"(?i)(redirect|redirect_uri|next|return_to|return_url|dest|destination|rurl|continue|login_to|logout|forward|goto|target_url|returnTo|RelayState)=https?://[^/]*@",
492 ])
493 .unwrap()
494 });
495
496 if patterns.is_match(query) {
497 Some("Open redirect: external URL in redirect parameter".into())
498 } else {
499 None
500 }
501}
502
503fn check_method_override(req: &WafRequest) -> Option<String> {
504 static OVERRIDE_HEADERS: &[&str] = &[
505 "x-http-method-override",
506 "x-http-method",
507 "x-method-override",
508 ];
509 static DANGEROUS_HEADERS: &[&str] = &[
510 "x-original-url",
511 "x-rewrite-url",
512 "x-forwarded-host",
513 "x-forwarded-scheme",
514 ];
515 for name in req.headers.keys() {
516 let lower = name.to_ascii_lowercase();
517 if OVERRIDE_HEADERS.contains(&lower.as_str()) {
518 return Some(format!("HTTP method override header detected: {name}"));
519 }
520 if DANGEROUS_HEADERS.contains(&lower.as_str()) {
521 return Some(format!(
522 "URL override header detected (cache poisoning): {name}"
523 ));
524 }
525 }
526 None
527}
528
529fn percent_decode(input: &str) -> Option<String> {
530 if !input.contains('%') {
531 return None;
532 }
533 let bytes = input.as_bytes();
534 let mut out = Vec::with_capacity(bytes.len());
535 let mut i = 0;
536 let mut changed = false;
537 while i < bytes.len() {
538 if bytes[i] == b'%' && i + 2 < bytes.len() {
539 if let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
540 out.push(hi << 4 | lo);
541 i += 3;
542 changed = true;
543 continue;
544 }
545 }
546 out.push(bytes[i]);
547 i += 1;
548 }
549 if changed {
550 Some(String::from_utf8_lossy(&out).into_owned())
551 } else {
552 None
553 }
554}
555
556fn hex_val(b: u8) -> Option<u8> {
557 match b {
558 b'0'..=b'9' => Some(b - b'0'),
559 b'a'..=b'f' => Some(b - b'a' + 10),
560 b'A'..=b'F' => Some(b - b'A' + 10),
561 _ => None,
562 }
563}
564
565fn check_ssrf(input: &str) -> Option<String> {
567 use regex::RegexSet;
568 use std::sync::OnceLock;
569
570 static SSRF_PATTERNS: OnceLock<RegexSet> = OnceLock::new();
571 let patterns = SSRF_PATTERNS.get_or_init(|| {
572 RegexSet::new([
573 r"(?i)(https?://|//)(localhost|127\.0\.0\.1|0\.0\.0\.0|\[::1\])",
574 r"(?i)(https?://|//)(10\.\d+\.\d+\.\d+|172\.(1[6-9]|2\d|3[01])\.\d+\.\d+|192\.168\.\d+\.\d+)",
575 r"(?i)(https?://|//)169\.254\.\d+\.\d+",
576 r"(?i)(file|gopher|dict|ftp)://",
577 r"(?i)(https?://|//)\d{8,10}(/|$|\s|:)",
578 r"(?i)(https?://|//)0\d+\.\d+\.\d+\.\d+",
579 r"(?i)(https?://|//)\[::ffff:",
580 r"(?i)(https?://|//)(127\.1|0\.0\.0\.0)(:|/|$)",
581 r"(?i)(https?://)\w+@(localhost|127\.|10\.|172\.(1[6-9]|2\d|3[01])\.|192\.168\.)",
582 r"(?i)(https?://|//)(metadata\.google\.internal|metadata\.azure\.com)",
583 ]).unwrap()
584 });
585
586 if patterns.is_match(input) {
587 Some("SSRF: private/internal URL detected in parameters".into())
588 } else {
589 None
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596 use std::collections::HashMap;
597
598 fn make_req(
599 method: &str,
600 path: &str,
601 query: Option<&str>,
602 body: Option<&str>,
603 ua: Option<&str>,
604 ) -> WafRequest {
605 WafRequest {
606 client_ip: "127.0.0.1".parse().unwrap(),
607 method: method.into(),
608 path: path.into(),
609 query: query.map(String::from),
610 headers: HashMap::new(),
611 body: body.map(String::from),
612 user_agent: ua.map(String::from),
613 }
614 }
615
616 #[test]
617 fn clean_request_passes() {
618 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
619 let req = make_req("GET", "/api/users", None, None, Some("Mozilla/5.0"));
620 assert!(engine.inspect(&req).is_none());
621 }
622
623 #[test]
624 fn detects_sqli_in_query() {
625 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
626 let req = make_req(
627 "GET",
628 "/search",
629 Some("q=1 UNION SELECT * FROM users"),
630 None,
631 None,
632 );
633 let decision = engine.inspect(&req);
634 assert!(
635 matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "sql_injection")
636 );
637 }
638
639 #[test]
640 fn detects_sqli_in_body() {
641 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
642 let req = make_req("POST", "/login", None, Some("user=admin' OR 1=1 --"), None);
643 let decision = engine.inspect(&req);
644 assert!(
645 matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "sql_injection")
646 );
647 }
648
649 #[test]
650 fn detects_xss_in_body() {
651 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
652 let req = make_req(
653 "POST",
654 "/comment",
655 None,
656 Some("<script>alert(1)</script>"),
657 None,
658 );
659 let decision = engine.inspect(&req);
660 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "xss"));
661 }
662
663 #[test]
664 fn detects_traversal_in_path() {
665 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
666 let req = make_req("GET", "/static/../../etc/passwd", None, None, None);
667 let decision = engine.inspect(&req);
668 assert!(
669 matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "path_traversal")
670 );
671 }
672
673 #[test]
674 fn detects_scanner_ua() {
675 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
676 let req = make_req("GET", "/", None, None, Some("sqlmap/1.5"));
677 let decision = engine.inspect(&req);
678 assert!(
679 matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "scanner_detection")
680 );
681 }
682
683 #[test]
684 fn detects_shell_injection_in_body() {
685 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
686 let req = make_req("POST", "/exec", None, Some("; cat /etc/passwd"), None);
687 let decision = engine.inspect(&req);
688 assert!(
689 matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "shell_injection")
690 );
691 }
692
693 #[test]
694 fn detects_shell_injection_in_query() {
695 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
696 let req = make_req("GET", "/search", Some("cmd=$(whoami)"), None, None);
697 let decision = engine.inspect(&req);
698 assert!(
699 matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "shell_injection")
700 );
701 }
702
703 #[test]
704 fn detects_protocol_violation_null_byte() {
705 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
706 let req = make_req("GET", "/file.php%00.jpg", None, None, None);
707 let decision = engine.inspect(&req);
708 assert!(decision.is_some());
710 }
711
712 #[test]
713 fn detects_protocol_violation_body_no_content_length() {
714 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
715 let req = make_req("POST", "/api/data", None, Some(r#"{"key":"val"}"#), None);
717 let decision = engine.inspect(&req);
718 assert!(
719 matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "protocol_violation")
720 );
721 }
722
723 #[test]
724 fn protocol_violation_passes_with_content_length() {
725 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
726 let req = WafRequest {
727 client_ip: "127.0.0.1".parse().unwrap(),
728 method: "POST".into(),
729 path: "/api/data".into(),
730 query: None,
731 headers: {
732 let mut h = HashMap::new();
733 h.insert("Content-Length".into(), "13".into());
734 h
735 },
736 body: Some(r#"{"key":"val"}"#.into()),
737 user_agent: None,
738 };
739 assert!(engine.inspect(&req).is_none());
740 }
741
742 #[test]
743 fn disabled_rules_skip() {
744 let config = RuleConfig {
745 sql_injection: false,
746 xss: false,
747 path_traversal: false,
748 shell_injection: false,
749 protocol_violation: false,
750 scanner_detection: false,
751 sensitive_path: false,
752 crlf_injection: false,
753 method_override: false,
754 log_only: false,
755 };
756 let engine = RuleEngine::new(config, vec![]);
757 let req = make_req(
758 "POST",
759 "/../../etc/passwd",
760 Some("q=UNION SELECT *"),
761 Some("<script>alert(1)</script>"),
762 Some("sqlmap/1.5"),
763 );
764 assert!(engine.inspect(&req).is_none());
765 }
766
767 #[test]
768 fn shell_injection_disabled_allows() {
769 let config = RuleConfig {
770 shell_injection: false,
771 ..Default::default()
772 };
773 let engine = RuleEngine::new(config, vec![]);
774 let req = make_req("POST", "/api", None, Some("; cat /etc/passwd"), None);
775 let decision = engine.inspect(&req);
777 if let Some(WafDecision::Block { rule, .. }) = &decision {
778 assert_ne!(rule, "shell_injection");
779 }
780 }
781
782 #[test]
785 fn detects_ssrf_localhost() {
786 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
787 let req = make_req(
788 "GET",
789 "/proxy",
790 Some("url=http://localhost/admin"),
791 None,
792 None,
793 );
794 let decision = engine.inspect(&req);
795 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
796 }
797
798 #[test]
799 fn detects_ssrf_127001() {
800 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
801 let req = make_req(
802 "GET",
803 "/fetch",
804 Some("url=http://127.0.0.1:8080/secret"),
805 None,
806 None,
807 );
808 let decision = engine.inspect(&req);
809 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
810 }
811
812 #[test]
813 fn detects_ssrf_private_10_range() {
814 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
815 let req = make_req(
816 "GET",
817 "/proxy",
818 Some("target=http://10.0.0.1/internal"),
819 None,
820 None,
821 );
822 let decision = engine.inspect(&req);
823 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
824 }
825
826 #[test]
827 fn detects_ssrf_private_192168() {
828 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
829 let req = make_req(
830 "GET",
831 "/proxy",
832 Some("target=http://192.168.1.1/admin"),
833 None,
834 None,
835 );
836 let decision = engine.inspect(&req);
837 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
838 }
839
840 #[test]
841 fn detects_ssrf_private_172() {
842 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
843 let req = make_req(
844 "GET",
845 "/proxy",
846 Some("url=http://172.16.0.1/meta"),
847 None,
848 None,
849 );
850 let decision = engine.inspect(&req);
851 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
852 }
853
854 #[test]
855 fn detects_ssrf_link_local_169254() {
856 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
857 let req = make_req(
858 "GET",
859 "/proxy",
860 Some("url=http://169.254.169.254/latest/meta-data/"),
861 None,
862 None,
863 );
864 let decision = engine.inspect(&req);
865 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
866 }
867
868 #[test]
869 fn detects_ssrf_file_scheme() {
870 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
871 let req = make_req(
873 "GET",
874 "/read",
875 Some("path=file:///tmp/data.txt"),
876 None,
877 None,
878 );
879 let decision = engine.inspect(&req);
880 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
881 }
882
883 #[test]
884 fn detects_ssrf_gopher_scheme() {
885 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
886 let req = make_req(
887 "GET",
888 "/fetch",
889 Some("url=gopher://127.0.0.1:25/"),
890 None,
891 None,
892 );
893 let decision = engine.inspect(&req);
894 assert!(matches!(decision, Some(WafDecision::Block { rule, .. }) if rule == "ssrf"));
895 }
896
897 #[test]
898 fn ssrf_allows_public_urls() {
899 let engine = RuleEngine::new(RuleConfig::default(), vec![]);
900 let req = make_req(
901 "GET",
902 "/proxy",
903 Some("url=https://example.com/page"),
904 None,
905 None,
906 );
907 assert!(engine.inspect(&req).is_none());
908 }
909
910 #[test]
911 fn ssrf_check_function_directly() {
912 assert!(check_ssrf("url=http://localhost/admin").is_some());
913 assert!(check_ssrf("url=http://0.0.0.0:8080").is_some());
914 assert!(check_ssrf("url=http://[::1]/secret").is_some());
915 assert!(check_ssrf("url=dict://evil.com").is_some());
916 assert!(check_ssrf("url=ftp://internal").is_some());
917 assert!(check_ssrf("url=https://google.com").is_none());
918 assert!(check_ssrf("search=hello+world").is_none());
919 }
920
921 #[test]
922 fn protocol_violation_disabled_allows() {
923 let config = RuleConfig {
924 protocol_violation: false,
925 ..Default::default()
926 };
927 let engine = RuleEngine::new(config, vec![]);
928 let req = make_req("POST", "/api/clean", None, Some("clean body"), None);
930 let decision = engine.inspect(&req);
931 if let Some(WafDecision::Block { rule, .. }) = &decision {
932 assert_ne!(rule, "protocol_violation");
933 }
934 }
935
936 #[test]
937 fn log_only_mode_allows() {
938 let config = RuleConfig {
939 log_only: true,
940 ..Default::default()
941 };
942 let engine = RuleEngine::new(config, vec![]);
943 let req = make_req("GET", "/search", Some("q=1 UNION SELECT *"), None, None);
944 assert!(engine.inspect(&req).is_none());
946 }
947
948 #[test]
949 fn custom_rules_take_priority() {
950 let custom = vec![custom::CustomRule {
951 name: "block-all-posts".into(),
952 match_config: custom::MatchConfig {
953 method: Some("POST".into()),
954 ..Default::default()
955 },
956 action: custom::CustomRuleAction::Block,
957 status: 405,
958 reason: Some("POST not allowed".into()),
959 }];
960 let engine = RuleEngine::new(RuleConfig::default(), custom);
961 let req = make_req("POST", "/api/data", None, Some(r#"{"key":"value"}"#), None);
963 let decision = engine.inspect(&req);
964 assert!(matches!(
965 decision,
966 Some(WafDecision::Block { status: 405, .. })
967 ));
968 }
969}