Skip to main content

mockforge_bench/
wafbench.rs

1//! WAFBench YAML parser for importing CRS (Core Rule Set) attack patterns
2//!
3//! This module parses WAFBench YAML test files from the Microsoft WAFBench project
4//! (<https://github.com/microsoft/WAFBench>) and converts them into security test payloads
5//! compatible with MockForge's security testing framework.
6//!
7//! # WAFBench YAML Format
8//!
9//! WAFBench test files follow this structure:
10//! ```yaml
11//! meta:
12//!   author: "author-name"
13//!   description: "Tests for rule XXXXXX"
14//!   enabled: true
15//!   name: "XXXXXX.yaml"
16//!
17//! tests:
18//!   - desc: "Attack scenario description"
19//!     test_title: "XXXXXX-N"
20//!     stages:
21//!       - input:
22//!           dest_addr: "127.0.0.1"
23//!           headers:
24//!             Host: "localhost"
25//!             User-Agent: "Mozilla/5.0"
26//!           method: "GET"
27//!           port: 80
28//!           uri: "/path?param=<script>alert(1)</script>"
29//!         output:
30//!           status: [200, 403, 404]
31//! ```
32//!
33//! # Usage
34//!
35//! ```bash
36//! mockforge bench spec.yaml --wafbench-dir ./wafbench/REQUEST-941-*
37//! ```
38
39use crate::error::{BenchError, Result};
40use crate::security_payloads::{
41    PayloadLocation as SecurityPayloadLocation, SecurityCategory, SecurityPayload,
42};
43use glob::glob;
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46use std::path::Path;
47
48/// WAFBench test file metadata
49#[derive(Debug, Clone, Deserialize, Serialize)]
50pub struct WafBenchMeta {
51    /// Author of the test file
52    pub author: Option<String>,
53    /// Description of what the tests cover
54    pub description: Option<String>,
55    /// Whether the tests are enabled
56    #[serde(default = "default_enabled")]
57    pub enabled: bool,
58    /// Name of the test file
59    pub name: Option<String>,
60}
61
62fn default_enabled() -> bool {
63    true
64}
65
66/// A single WAFBench test case
67#[derive(Debug, Clone, Deserialize, Serialize)]
68pub struct WafBenchTest {
69    /// Description of the attack scenario
70    pub desc: Option<String>,
71    /// Unique test identifier (e.g., "941100-1")
72    pub test_title: String,
73    /// Test stages (request/response pairs)
74    #[serde(default)]
75    pub stages: Vec<WafBenchStage>,
76}
77
78/// A test stage containing input (request) and expected output (response)
79/// Supports both direct format and CRS v3.3 format with nested `stage:` wrapper
80#[derive(Debug, Clone, Deserialize, Serialize)]
81pub struct WafBenchStage {
82    /// The request configuration (direct format)
83    pub input: Option<WafBenchInput>,
84    /// Expected response (direct format)
85    pub output: Option<WafBenchOutput>,
86    /// Nested stage for CRS v3.3 format (stage: { input: ..., output: ... })
87    pub stage: Option<WafBenchStageInner>,
88}
89
90/// Inner stage structure for CRS v3.3 format
91#[derive(Debug, Clone, Deserialize, Serialize)]
92pub struct WafBenchStageInner {
93    /// The request configuration
94    pub input: WafBenchInput,
95    /// Expected response
96    pub output: Option<WafBenchOutput>,
97}
98
99impl WafBenchStage {
100    /// Get the input from either direct or nested format
101    pub fn get_input(&self) -> Option<&WafBenchInput> {
102        // Prefer nested stage format (CRS v3.3), fall back to direct format
103        if let Some(stage) = &self.stage {
104            Some(&stage.input)
105        } else {
106            self.input.as_ref()
107        }
108    }
109
110    /// Get the output from either direct or nested format
111    pub fn get_output(&self) -> Option<&WafBenchOutput> {
112        // Prefer nested stage format (CRS v3.3), fall back to direct format
113        if let Some(stage) = &self.stage {
114            stage.output.as_ref()
115        } else {
116            self.output.as_ref()
117        }
118    }
119}
120
121/// Request configuration for a WAFBench test
122#[derive(Debug, Clone, Deserialize, Serialize)]
123pub struct WafBenchInput {
124    /// Target address
125    pub dest_addr: Option<String>,
126    /// HTTP headers
127    #[serde(default)]
128    pub headers: HashMap<String, String>,
129    /// HTTP method
130    #[serde(default = "default_method")]
131    pub method: String,
132    /// Target port
133    #[serde(default = "default_port")]
134    pub port: u16,
135    /// Request URI (may contain attack payloads)
136    pub uri: Option<String>,
137    /// Request body data
138    pub data: Option<String>,
139    /// Protocol version
140    pub version: Option<String>,
141}
142
143fn default_method() -> String {
144    "GET".to_string()
145}
146
147fn default_port() -> u16 {
148    80
149}
150
151/// Expected response for a WAFBench test
152#[derive(Debug, Clone, Deserialize, Serialize)]
153pub struct WafBenchOutput {
154    /// Expected HTTP status codes (any match is valid)
155    #[serde(default)]
156    pub status: Vec<u16>,
157    /// Expected response headers
158    #[serde(default)]
159    pub response_headers: HashMap<String, String>,
160    /// Log contains patterns (can be string or array in different formats)
161    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
162    pub log_contains: Vec<String>,
163    /// Log does not contain patterns (can be string or array in different formats)
164    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
165    pub no_log_contains: Vec<String>,
166}
167
168/// Deserialize a field that can be either a single string or a Vec of strings
169fn deserialize_string_or_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
170where
171    D: serde::Deserializer<'de>,
172{
173    use serde::de::{self, Visitor};
174
175    struct StringOrVec;
176
177    impl<'de> Visitor<'de> for StringOrVec {
178        type Value = Vec<String>;
179
180        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
181            formatter.write_str("string or array of strings")
182        }
183
184        fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
185        where
186            E: de::Error,
187        {
188            Ok(vec![value.to_string()])
189        }
190
191        fn visit_string<E>(self, value: String) -> std::result::Result<Self::Value, E>
192        where
193            E: de::Error,
194        {
195            Ok(vec![value])
196        }
197
198        fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
199        where
200            A: de::SeqAccess<'de>,
201        {
202            let mut vec = Vec::new();
203            while let Some(value) = seq.next_element::<String>()? {
204                vec.push(value);
205            }
206            Ok(vec)
207        }
208
209        fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
210        where
211            E: de::Error,
212        {
213            Ok(Vec::new())
214        }
215
216        fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
217        where
218            E: de::Error,
219        {
220            Ok(Vec::new())
221        }
222    }
223
224    deserializer.deserialize_any(StringOrVec)
225}
226
227/// Complete WAFBench test file structure
228#[derive(Debug, Clone, Deserialize, Serialize)]
229pub struct WafBenchFile {
230    /// Test file metadata
231    pub meta: WafBenchMeta,
232    /// Test cases
233    #[serde(default)]
234    pub tests: Vec<WafBenchTest>,
235}
236
237/// A parsed WAFBench test case ready for use in security testing
238#[derive(Debug, Clone)]
239pub struct WafBenchTestCase {
240    /// Test identifier
241    pub test_id: String,
242    /// Description
243    pub description: String,
244    /// CRS rule ID (e.g., 941100)
245    pub rule_id: String,
246    /// Security category
247    pub category: SecurityCategory,
248    /// HTTP method
249    pub method: String,
250    /// Attack payloads extracted from the test
251    pub payloads: Vec<WafBenchPayload>,
252    /// Expected to be blocked (403)
253    pub expects_block: bool,
254}
255
256/// A specific payload from a WAFBench test
257#[derive(Debug, Clone)]
258pub struct WafBenchPayload {
259    /// The payload location (uri, header, body)
260    pub location: PayloadLocation,
261    /// The actual payload string
262    pub value: String,
263    /// Header name if location is Header
264    pub header_name: Option<String>,
265}
266
267/// Where the payload is injected
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub enum PayloadLocation {
270    /// Payload in URI/query string
271    Uri,
272    /// Payload in HTTP header
273    Header,
274    /// Payload in request body
275    Body,
276}
277
278impl std::fmt::Display for PayloadLocation {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        match self {
281            Self::Uri => write!(f, "uri"),
282            Self::Header => write!(f, "header"),
283            Self::Body => write!(f, "body"),
284        }
285    }
286}
287
288/// WAFBench loader and parser
289pub struct WafBenchLoader {
290    /// Loaded test cases
291    test_cases: Vec<WafBenchTestCase>,
292    /// Statistics
293    stats: WafBenchStats,
294}
295
296/// Statistics about loaded WAFBench tests
297#[derive(Debug, Clone, Default)]
298pub struct WafBenchStats {
299    /// Number of files processed
300    pub files_processed: usize,
301    /// Number of test cases loaded
302    pub test_cases_loaded: usize,
303    /// Number of payloads extracted
304    pub payloads_extracted: usize,
305    /// Tests by category
306    pub by_category: HashMap<SecurityCategory, usize>,
307    /// Files that failed to parse
308    pub parse_errors: Vec<String>,
309}
310
311impl WafBenchLoader {
312    /// Create a new empty loader
313    pub fn new() -> Self {
314        Self {
315            test_cases: Vec::new(),
316            stats: WafBenchStats::default(),
317        }
318    }
319
320    /// Load WAFBench tests from a directory pattern (supports glob)
321    ///
322    /// # Arguments
323    /// * `pattern` - Glob pattern like `./wafbench/REQUEST-941-*` or a direct path
324    ///
325    /// # Example
326    /// ```ignore
327    /// let loader = WafBenchLoader::new();
328    /// loader.load_from_pattern("./wafbench/REQUEST-941-APPLICATION-ATTACK-XSS/**/*.yaml")?;
329    /// ```
330    pub fn load_from_pattern(&mut self, pattern: &str) -> Result<()> {
331        // If pattern doesn't contain wildcards, check if it's a file or directory
332        if !pattern.contains('*') && !pattern.contains('?') {
333            let path = Path::new(pattern);
334            if path.is_file() {
335                // Load single file directly
336                return self.load_file(path);
337            } else if path.is_dir() {
338                return self.load_from_directory(path);
339            } else {
340                return Err(BenchError::Other(format!(
341                    "WAFBench path does not exist: {}",
342                    pattern
343                )));
344            }
345        }
346
347        // Use glob to find matching files
348        let entries = glob(pattern).map_err(|e| {
349            BenchError::Other(format!("Invalid WAFBench pattern '{}': {}", pattern, e))
350        })?;
351
352        for entry in entries {
353            match entry {
354                Ok(path) => {
355                    if path.is_file()
356                        && path.extension().is_some_and(|ext| ext == "yaml" || ext == "yml")
357                    {
358                        if let Err(e) = self.load_file(&path) {
359                            self.stats.parse_errors.push(format!("{}: {}", path.display(), e));
360                        }
361                    } else if path.is_dir() {
362                        if let Err(e) = self.load_from_directory(&path) {
363                            self.stats.parse_errors.push(format!("{}: {}", path.display(), e));
364                        }
365                    }
366                }
367                Err(e) => {
368                    self.stats.parse_errors.push(format!("Glob error: {}", e));
369                }
370            }
371        }
372
373        Ok(())
374    }
375
376    /// Load WAFBench tests from a directory (recursive)
377    pub fn load_from_directory(&mut self, dir: &Path) -> Result<()> {
378        if !dir.is_dir() {
379            return Err(BenchError::Other(format!(
380                "WAFBench path is not a directory: {}",
381                dir.display()
382            )));
383        }
384
385        self.load_directory_recursive(dir)?;
386        Ok(())
387    }
388
389    fn load_directory_recursive(&mut self, dir: &Path) -> Result<()> {
390        let entries = std::fs::read_dir(dir)
391            .map_err(|e| BenchError::Other(format!("Failed to read WAFBench directory: {}", e)))?;
392
393        for entry in entries.flatten() {
394            let path = entry.path();
395            if path.is_dir() {
396                // Recurse into subdirectories
397                self.load_directory_recursive(&path)?;
398            } else if path.extension().is_some_and(|ext| ext == "yaml" || ext == "yml") {
399                if let Err(e) = self.load_file(&path) {
400                    self.stats.parse_errors.push(format!("{}: {}", path.display(), e));
401                }
402            }
403        }
404
405        Ok(())
406    }
407
408    /// Load a single WAFBench YAML file
409    pub fn load_file(&mut self, path: &Path) -> Result<()> {
410        let content = std::fs::read_to_string(path).map_err(|e| {
411            BenchError::Other(format!("Failed to read WAFBench file {}: {}", path.display(), e))
412        })?;
413
414        let wafbench_file: WafBenchFile = serde_yaml::from_str(&content).map_err(|e| {
415            BenchError::Other(format!("Failed to parse WAFBench YAML {}: {}", path.display(), e))
416        })?;
417
418        // Skip disabled test files
419        if !wafbench_file.meta.enabled {
420            return Ok(());
421        }
422
423        self.stats.files_processed += 1;
424
425        // Determine the rule category from the file path or name
426        let category = self.detect_category(path, &wafbench_file.meta);
427
428        // Parse each test case
429        for test in wafbench_file.tests {
430            if let Some(test_case) = self.parse_test_case(&test, category) {
431                self.stats.payloads_extracted += test_case.payloads.len();
432                *self.stats.by_category.entry(category).or_insert(0) += 1;
433                self.test_cases.push(test_case);
434                self.stats.test_cases_loaded += 1;
435            }
436        }
437
438        Ok(())
439    }
440
441    /// Detect the security category from the file path
442    fn detect_category(&self, path: &Path, _meta: &WafBenchMeta) -> SecurityCategory {
443        let path_str = path.to_string_lossy().to_uppercase();
444
445        if path_str.contains("XSS") || path_str.contains("941") {
446            SecurityCategory::Xss
447        } else if path_str.contains("SQLI") || path_str.contains("942") {
448            SecurityCategory::SqlInjection
449        } else if path_str.contains("RCE") || path_str.contains("932") {
450            SecurityCategory::CommandInjection
451        } else if path_str.contains("LFI") || path_str.contains("930") {
452            SecurityCategory::PathTraversal
453        } else if path_str.contains("LDAP") {
454            SecurityCategory::LdapInjection
455        } else if path_str.contains("XXE") || path_str.contains("XML") {
456            SecurityCategory::Xxe
457        } else if path_str.contains("TEMPLATE") || path_str.contains("SSTI") {
458            SecurityCategory::Ssti
459        } else {
460            // Default to XSS as it's the most common in WAFBench
461            SecurityCategory::Xss
462        }
463    }
464
465    /// Parse a single test case into our format
466    fn parse_test_case(
467        &self,
468        test: &WafBenchTest,
469        category: SecurityCategory,
470    ) -> Option<WafBenchTestCase> {
471        // Extract rule ID from test_title (e.g., "941100-1" -> "941100")
472        let rule_id = test.test_title.split('-').next().unwrap_or(&test.test_title).to_string();
473
474        let mut payloads = Vec::new();
475        let mut method = "GET".to_string();
476        let mut expects_block = false;
477
478        for stage in &test.stages {
479            // Get input from either direct or nested format (CRS v3.3 compatibility)
480            let Some(input) = stage.get_input() else {
481                continue;
482            };
483
484            method = input.method.clone();
485
486            // Check if this test expects a block (403)
487            if let Some(output) = stage.get_output() {
488                if output.status.contains(&403) {
489                    expects_block = true;
490                }
491            }
492
493            // Extract payload from URI — CRS test files are attack payloads by
494            // definition, so we accept all values without filtering. Previously
495            // a narrow looks_like_attack() check discarded exotic payloads like
496            // VML, VBScript, UTF-7, JSFuck, and bracket-notation XSS.
497            if let Some(uri) = &input.uri {
498                if !uri.is_empty() {
499                    payloads.push(WafBenchPayload {
500                        location: PayloadLocation::Uri,
501                        value: uri.clone(),
502                        header_name: None,
503                    });
504                }
505            }
506
507            // Extract payloads from headers
508            for (header_name, header_value) in &input.headers {
509                if !header_value.is_empty() {
510                    payloads.push(WafBenchPayload {
511                        location: PayloadLocation::Header,
512                        value: header_value.clone(),
513                        header_name: Some(header_name.clone()),
514                    });
515                }
516            }
517
518            // Extract payload from body
519            if let Some(data) = &input.data {
520                if !data.is_empty() {
521                    payloads.push(WafBenchPayload {
522                        location: PayloadLocation::Body,
523                        value: data.clone(),
524                        header_name: None,
525                    });
526                }
527            }
528        }
529
530        // If no payloads found, still include the test but with full URI as payload
531        if payloads.is_empty() {
532            if let Some(stage) = test.stages.first() {
533                if let Some(input) = stage.get_input() {
534                    if let Some(uri) = &input.uri {
535                        payloads.push(WafBenchPayload {
536                            location: PayloadLocation::Uri,
537                            value: uri.clone(),
538                            header_name: None,
539                        });
540                    }
541                }
542            }
543        }
544
545        if payloads.is_empty() {
546            return None;
547        }
548
549        let description = test.desc.clone().unwrap_or_else(|| format!("CRS Rule {} test", rule_id));
550
551        Some(WafBenchTestCase {
552            test_id: test.test_title.clone(),
553            description,
554            rule_id,
555            category,
556            method,
557            payloads,
558            expects_block,
559        })
560    }
561
562    /// Check if a string looks like an attack payload (used in tests)
563    #[cfg(test)]
564    fn looks_like_attack(&self, s: &str) -> bool {
565        // Common attack patterns
566        let attack_patterns = [
567            "<script",
568            "javascript:",
569            "onerror=",
570            "onload=",
571            "onclick=",
572            "onfocus=",
573            "onmouseover=",
574            "eval(",
575            "alert(",
576            "document.",
577            "window.",
578            "'--",
579            "' OR ",
580            "' AND ",
581            "1=1",
582            "UNION SELECT",
583            "CONCAT(",
584            "CHAR(",
585            "../",
586            "..\\",
587            "/etc/passwd",
588            "cmd.exe",
589            "powershell",
590            "; ls",
591            "| cat",
592            "${",
593            "{{",
594            "<%",
595            "<?",
596            "<!ENTITY",
597            "SYSTEM \"",
598        ];
599
600        let lower = s.to_lowercase();
601        attack_patterns.iter().any(|p| lower.contains(&p.to_lowercase()))
602    }
603
604    /// Get all loaded test cases
605    pub fn test_cases(&self) -> &[WafBenchTestCase] {
606        &self.test_cases
607    }
608
609    /// Get statistics about loaded tests
610    pub fn stats(&self) -> &WafBenchStats {
611        &self.stats
612    }
613
614    /// Decode a form-URL-encoded body payload.
615    /// Replaces `+` with space (form-encoding convention), then decodes `%XX` sequences.
616    /// Strips form field name prefix (e.g., `var=;;dd foo bar` → `;;dd foo bar`)
617    /// since JSON injection puts the value in a field, not the form key.
618    fn decode_form_encoded_body(value: &str) -> String {
619        // Replace + with space first (form-encoding convention)
620        let plus_decoded = value.replace('+', " ");
621        // Then decode %XX sequences
622        let decoded = urlencoding::decode(&plus_decoded)
623            .map(|s| s.into_owned())
624            .unwrap_or(plus_decoded);
625        // Strip form field name prefix (e.g., "var=value" → "value")
626        // CRS test data like "var=;;dd foo bar" has the form key included,
627        // but we inject only the value into a JSON field.
628        Self::strip_form_key(&decoded)
629    }
630
631    /// Strip a single leading form key from a form-encoded value.
632    /// `"var=;;dd foo bar"` → `";;dd foo bar"`
633    /// `"pay=exec (@\n"` → `"exec (@\n"`
634    /// Values without `=` or starting with special chars are returned as-is.
635    fn strip_form_key(value: &str) -> String {
636        // Only strip if the prefix before the first = looks like a form field name
637        // (alphanumeric/underscore chars). Don't strip if the = is part of the attack.
638        if let Some(eq_pos) = value.find('=') {
639            let key = &value[..eq_pos];
640            // Form field names are alphanumeric with underscores
641            if !key.is_empty() && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
642                return value[eq_pos + 1..].to_string();
643            }
644        }
645        value.to_string()
646    }
647
648    /// Normalize a form body value to valid `application/x-www-form-urlencoded` format.
649    ///
650    /// CRS YAML `data` fields may be pre-encoded (`var=%3B%3Bdd+foo+bar`) or decoded
651    /// (`var=;;dd foo bar`). This function ensures the output is always properly encoded
652    /// so WAFs can parse it into ARGS and fire rules like 942432.
653    ///
654    /// Strategy: decode fully first (handling `+` as space and `%XX` sequences), then
655    /// re-encode. Pre-encoded input round-trips correctly; decoded input gets encoded.
656    fn ensure_form_encoded(value: &str) -> String {
657        value
658            .split('&')
659            .map(|pair| {
660                if let Some(eq_pos) = pair.find('=') {
661                    let key = &pair[..eq_pos];
662                    let val = &pair[eq_pos + 1..];
663                    // Decode: + → space, then %XX → chars
664                    let key_plus = key.replace('+', " ");
665                    let val_plus = val.replace('+', " ");
666                    let decoded_key = urlencoding::decode(&key_plus).unwrap_or(key.into());
667                    let decoded_val = urlencoding::decode(&val_plus).unwrap_or(val.into());
668                    // Re-encode with form-encoding (spaces as +)
669                    let enc_key = urlencoding::encode(&decoded_key).replace("%20", "+");
670                    let enc_val = urlencoding::encode(&decoded_val).replace("%20", "+");
671                    format!("{enc_key}={enc_val}")
672                } else {
673                    // No key=value structure — encode the whole thing
674                    let pair_plus = pair.replace('+', " ");
675                    let decoded = urlencoding::decode(&pair_plus).unwrap_or(pair.into());
676                    urlencoding::encode(&decoded).replace("%20", "+").to_string()
677                }
678            })
679            .collect::<Vec<_>>()
680            .join("&")
681    }
682
683    /// Convert loaded tests to SecurityPayload format for use with existing security testing
684    pub fn to_security_payloads(&self) -> Vec<SecurityPayload> {
685        let mut payloads = Vec::new();
686
687        for test_case in &self.test_cases {
688            // Assign group_id when a test case has multiple payloads
689            let group_id = if test_case.payloads.len() > 1 {
690                Some(test_case.test_id.clone())
691            } else {
692                None
693            };
694
695            for payload in &test_case.payloads {
696                // Extract just the attack payload part if possible
697                let payload_str = match payload.location {
698                    PayloadLocation::Body => {
699                        // Form-URL-decode body payloads so WAFs see the real characters
700                        Self::decode_form_encoded_body(&payload.value)
701                    }
702                    PayloadLocation::Uri => {
703                        // Extract attack payload from URI, URL-decode, strip path prefix
704                        self.extract_uri_payload(&payload.value)
705                    }
706                    PayloadLocation::Header => {
707                        // Headers are used as-is (Cookie values, User-Agent, etc.)
708                        payload.value.clone()
709                    }
710                };
711
712                // Convert local PayloadLocation to SecurityPayloadLocation
713                let location = match payload.location {
714                    PayloadLocation::Uri => SecurityPayloadLocation::Uri,
715                    PayloadLocation::Header => SecurityPayloadLocation::Header,
716                    PayloadLocation::Body => SecurityPayloadLocation::Body,
717                };
718
719                let mut sec_payload = SecurityPayload::new(
720                    payload_str,
721                    test_case.category,
722                    format!(
723                        "[WAFBench {}] {} ({})",
724                        test_case.rule_id, test_case.description, payload.location
725                    ),
726                )
727                .high_risk()
728                .with_location(location);
729
730                // Add header name for header payloads
731                if let Some(header_name) = &payload.header_name {
732                    sec_payload = sec_payload.with_header_name(header_name.clone());
733                }
734
735                // Add group ID for multi-part test cases
736                if let Some(gid) = &group_id {
737                    sec_payload = sec_payload.with_group_id(gid.clone());
738                }
739
740                // URI payloads without '?' are path-only attacks (e.g., 942101: POST /1234%20OR%201=1)
741                // These need to replace the request path so WAF inspects via REQUEST_FILENAME
742                if payload.location == PayloadLocation::Uri && !payload.value.contains('?') {
743                    sec_payload = sec_payload.with_inject_as_path();
744                }
745
746                // Body payloads: normalize to valid form-encoded format for WAF ARGS parsing
747                // (e.g., 942432: data "var=%3B%3Bdd+foo+bar" or decoded "var=;;dd foo bar")
748                if payload.location == PayloadLocation::Body {
749                    sec_payload = sec_payload
750                        .with_form_encoded_body(Self::ensure_form_encoded(&payload.value));
751                }
752
753                payloads.push(sec_payload);
754            }
755        }
756
757        payloads
758    }
759
760    /// Extract the actual attack payload from a URI.
761    ///
762    /// For URIs with query parameters (e.g., `/?var=EXECUTE%20IMMEDIATE%20%22`),
763    /// extracts and URL-decodes the first parameter value.
764    ///
765    /// For path-only URIs (e.g., `/1234%20OR%201=1`), URL-decodes the path and
766    /// strips the leading `/` which is a URI artifact, not part of the attack.
767    fn extract_uri_payload(&self, value: &str) -> String {
768        // If it's a URI with query params, extract the first parameter value
769        // (URL-decoded). CRS test files put the attack in query params.
770        if value.contains('?') {
771            if let Some(query) = value.split('?').nth(1) {
772                for param in query.split('&') {
773                    if let Some(val) = param.split('=').nth(1) {
774                        let decoded = urlencoding::decode(val).unwrap_or_else(|_| val.into());
775                        if !decoded.is_empty() {
776                            return decoded.to_string();
777                        }
778                    }
779                }
780            }
781        }
782
783        // For path-only URIs, URL-decode and strip leading /
784        // e.g., /1234%20OR%201=1 → 1234 OR 1=1
785        let decoded = urlencoding::decode(value)
786            .map(|s| s.into_owned())
787            .unwrap_or_else(|_| value.to_string());
788        let trimmed = decoded.trim_start_matches('/');
789        if trimmed.is_empty() {
790            // Don't return empty string for bare "/" paths
791            return decoded;
792        }
793        trimmed.to_string()
794    }
795}
796
797impl Default for WafBenchLoader {
798    fn default() -> Self {
799        Self::new()
800    }
801}
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806
807    #[test]
808    fn test_parse_wafbench_yaml() {
809        let yaml = r#"
810meta:
811  author: test
812  description: Test XSS rules
813  enabled: true
814  name: test.yaml
815
816tests:
817  - desc: "XSS in URI parameter"
818    test_title: "941100-1"
819    stages:
820      - input:
821          dest_addr: "127.0.0.1"
822          headers:
823            Host: "localhost"
824            User-Agent: "Mozilla/5.0"
825          method: "GET"
826          port: 80
827          uri: "/test?param=<script>alert(1)</script>"
828        output:
829          status: [403]
830"#;
831
832        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
833        assert!(file.meta.enabled);
834        assert_eq!(file.tests.len(), 1);
835        assert_eq!(file.tests[0].test_title, "941100-1");
836    }
837
838    #[test]
839    fn test_detect_category() {
840        let loader = WafBenchLoader::new();
841        let meta = WafBenchMeta {
842            author: None,
843            description: None,
844            enabled: true,
845            name: None,
846        };
847
848        assert_eq!(
849            loader.detect_category(Path::new("/wafbench/REQUEST-941-XSS/test.yaml"), &meta),
850            SecurityCategory::Xss
851        );
852
853        assert_eq!(
854            loader.detect_category(Path::new("/wafbench/REQUEST-942-SQLI/test.yaml"), &meta),
855            SecurityCategory::SqlInjection
856        );
857    }
858
859    #[test]
860    fn test_looks_like_attack() {
861        let loader = WafBenchLoader::new();
862
863        assert!(loader.looks_like_attack("<script>alert(1)</script>"));
864        assert!(loader.looks_like_attack("' OR '1'='1"));
865        assert!(loader.looks_like_attack("../../../etc/passwd"));
866        assert!(loader.looks_like_attack("; ls -la"));
867        assert!(!loader.looks_like_attack("normal text"));
868        assert!(!loader.looks_like_attack("hello world"));
869    }
870
871    #[test]
872    fn test_extract_uri_payload_with_query_params() {
873        let loader = WafBenchLoader::new();
874
875        // URI with query params: extracts and decodes the parameter value
876        let uri = "/test?param=%3Cscript%3Ealert(1)%3C/script%3E";
877        let payload = loader.extract_uri_payload(uri);
878        assert_eq!(payload, "<script>alert(1)</script>");
879    }
880
881    #[test]
882    fn test_extract_uri_payload_path_only() {
883        let loader = WafBenchLoader::new();
884
885        // Path-only URI: URL-decodes and strips leading /
886        let uri = "/1234%20OR%201=1";
887        let payload = loader.extract_uri_payload(uri);
888        assert_eq!(payload, "1234 OR 1=1");
889
890        // Path with quotes and special chars
891        let uri2 = "/foo')waitfor%20delay'5%3a0%3a20'--";
892        let payload2 = loader.extract_uri_payload(uri2);
893        assert_eq!(payload2, "foo')waitfor delay'5:0:20'--");
894
895        // Bare slash returns "/" (not empty)
896        let uri3 = "/";
897        let payload3 = loader.extract_uri_payload(uri3);
898        assert_eq!(payload3, "/");
899    }
900
901    #[test]
902    fn test_group_id_assigned_for_multi_part_test_cases() {
903        let yaml = r#"
904meta:
905  author: test
906  description: Multi-part test
907  enabled: true
908  name: test.yaml
909
910tests:
911  - desc: "Multi-part attack with URI and header"
912    test_title: "942290-1"
913    stages:
914      - input:
915          dest_addr: "127.0.0.1"
916          headers:
917            Host: "localhost"
918            User-Agent: "ModSecurity CRS 3 Tests"
919          method: "GET"
920          port: 80
921          uri: "/test?param=attack"
922        output:
923          status: [403]
924"#;
925
926        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
927        let mut loader = WafBenchLoader::new();
928        loader.stats.files_processed += 1;
929
930        let category = SecurityCategory::SqlInjection;
931        for test in &file.tests {
932            if let Some(test_case) = loader.parse_test_case(test, category) {
933                loader.test_cases.push(test_case);
934            }
935        }
936
937        let payloads = loader.to_security_payloads();
938        // This test has URI + 2 headers = 3 payloads, all should share a group_id
939        assert!(payloads.len() >= 2, "Should have at least 2 payloads");
940        let group_ids: Vec<_> = payloads.iter().map(|p| p.group_id.clone()).collect();
941        assert!(
942            group_ids.iter().all(|g| g.is_some()),
943            "All payloads in multi-part test should have group_id"
944        );
945        assert!(
946            group_ids.iter().all(|g| g.as_deref() == Some("942290-1")),
947            "All payloads should share the same group_id"
948        );
949    }
950
951    #[test]
952    fn test_single_payload_no_group_id() {
953        let yaml = r#"
954meta:
955  author: test
956  description: Single payload test
957  enabled: true
958  name: test.yaml
959
960tests:
961  - desc: "Simple XSS"
962    test_title: "941100-1"
963    stages:
964      - input:
965          dest_addr: "127.0.0.1"
966          headers: {}
967          method: "GET"
968          port: 80
969          uri: "/test?param=<script>alert(1)</script>"
970        output:
971          status: [403]
972"#;
973
974        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
975        let mut loader = WafBenchLoader::new();
976        loader.stats.files_processed += 1;
977
978        let category = SecurityCategory::Xss;
979        for test in &file.tests {
980            if let Some(test_case) = loader.parse_test_case(test, category) {
981                loader.test_cases.push(test_case);
982            }
983        }
984
985        let payloads = loader.to_security_payloads();
986        assert_eq!(payloads.len(), 1, "Should have exactly 1 payload");
987        assert!(payloads[0].group_id.is_none(), "Single-payload test should NOT have group_id");
988    }
989
990    #[test]
991    fn test_body_payload_form_url_decoded() {
992        let yaml = r#"
993meta:
994  author: test
995  description: Body payload test
996  enabled: true
997  name: test.yaml
998
999tests:
1000  - desc: "SQL injection in body"
1001    test_title: "942240-1"
1002    stages:
1003      - stage:
1004          input:
1005            dest_addr: 127.0.0.1
1006            headers:
1007              Host: localhost
1008            method: POST
1009            port: 80
1010            uri: "/"
1011            data: "%22+WAITFOR+DELAY+%270%3A0%3A5%27"
1012          output:
1013            log_contains: id "942240"
1014"#;
1015
1016        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
1017        let mut loader = WafBenchLoader::new();
1018        loader.stats.files_processed += 1;
1019
1020        let category = SecurityCategory::SqlInjection;
1021        for test in &file.tests {
1022            if let Some(test_case) = loader.parse_test_case(test, category) {
1023                loader.test_cases.push(test_case);
1024            }
1025        }
1026
1027        let payloads = loader.to_security_payloads();
1028        // Find the body payload
1029        let body_payload = payloads
1030            .iter()
1031            .find(|p| p.location == SecurityPayloadLocation::Body)
1032            .expect("Should have a body payload");
1033
1034        // The body payload should be form-URL-decoded
1035        assert!(
1036            body_payload.payload.contains('"'),
1037            "Body payload should have decoded %22 to double-quote: {}",
1038            body_payload.payload
1039        );
1040        assert!(
1041            body_payload.payload.contains(' '),
1042            "Body payload should have decoded + to space: {}",
1043            body_payload.payload
1044        );
1045        assert!(
1046            !body_payload.payload.contains("%22"),
1047            "Body payload should NOT contain literal %22: {}",
1048            body_payload.payload
1049        );
1050    }
1051
1052    #[test]
1053    fn test_decode_form_encoded_body() {
1054        // Basic decoding
1055        assert_eq!(
1056            WafBenchLoader::decode_form_encoded_body("%22+WAITFOR+DELAY+%27%0A"),
1057            "\" WAITFOR DELAY '\n"
1058        );
1059        assert_eq!(WafBenchLoader::decode_form_encoded_body("normal+text"), "normal text");
1060        assert_eq!(
1061            WafBenchLoader::decode_form_encoded_body("no+encoding+needed"),
1062            "no encoding needed"
1063        );
1064        // Form key stripping: var=value → value
1065        assert_eq!(
1066            WafBenchLoader::decode_form_encoded_body("var%3D%3B%3Bdd+foo+bar"),
1067            ";;dd foo bar"
1068        );
1069        // Form key stripping: pay=exec → exec
1070        assert_eq!(WafBenchLoader::decode_form_encoded_body("pay%3Dexec+%28%40%0A"), "exec (@\n");
1071        // No form key: starts with special char → returned as-is
1072        assert_eq!(WafBenchLoader::decode_form_encoded_body("%22+WAITFOR"), "\" WAITFOR");
1073    }
1074
1075    #[test]
1076    fn test_strip_form_key() {
1077        // Standard form key=value
1078        assert_eq!(WafBenchLoader::strip_form_key("var=;;dd foo bar"), ";;dd foo bar");
1079        assert_eq!(WafBenchLoader::strip_form_key("pay=exec (@\n"), "exec (@\n");
1080        assert_eq!(WafBenchLoader::strip_form_key("pay=DECLARE/**/@x\n"), "DECLARE/**/@x\n");
1081        // No form key (starts with special char)
1082        assert_eq!(WafBenchLoader::strip_form_key("\" WAITFOR DELAY '\n"), "\" WAITFOR DELAY '\n");
1083        // = inside attack payload, key is not alphanumeric
1084        assert_eq!(WafBenchLoader::strip_form_key("' OR 1=1"), "' OR 1=1");
1085        // Empty input
1086        assert_eq!(WafBenchLoader::strip_form_key(""), "");
1087        // Only key, no value
1088        assert_eq!(WafBenchLoader::strip_form_key("var="), "");
1089    }
1090
1091    #[test]
1092    fn test_ensure_form_encoded() {
1093        // Pre-encoded input round-trips correctly
1094        assert_eq!(
1095            WafBenchLoader::ensure_form_encoded("var=%3B%3Bdd+foo+bar"),
1096            "var=%3B%3Bdd+foo+bar"
1097        );
1098        // Decoded input gets properly encoded
1099        assert_eq!(WafBenchLoader::ensure_form_encoded("var=;;dd foo bar"), "var=%3B%3Bdd+foo+bar");
1100        // Multi-field form
1101        assert_eq!(
1102            WafBenchLoader::ensure_form_encoded("var=-------------------&var2=whatever"),
1103            "var=-------------------&var2=whatever"
1104        );
1105        // Already-encoded multi-field
1106        assert_eq!(
1107            WafBenchLoader::ensure_form_encoded("key=%22value%22&other=test+data"),
1108            "key=%22value%22&other=test+data"
1109        );
1110        // Decoded multi-field
1111        assert_eq!(
1112            WafBenchLoader::ensure_form_encoded("key=\"value\"&other=test data"),
1113            "key=%22value%22&other=test+data"
1114        );
1115        // No key=value structure
1116        assert_eq!(WafBenchLoader::ensure_form_encoded("plain text"), "plain+text");
1117        // Empty string
1118        assert_eq!(WafBenchLoader::ensure_form_encoded(""), "");
1119    }
1120
1121    #[test]
1122    fn test_uri_path_only_gets_inject_as_path() {
1123        let yaml = r#"
1124meta:
1125  author: test
1126  description: Path injection test
1127  enabled: true
1128  name: test.yaml
1129
1130tests:
1131  - desc: "Path-based SQL injection"
1132    test_title: "942101-1"
1133    stages:
1134      - stage:
1135          input:
1136            dest_addr: 127.0.0.1
1137            headers:
1138              Host: localhost
1139            method: POST
1140            port: 80
1141            uri: "/1234%20OR%201=1"
1142          output:
1143            log_contains: id "942101"
1144"#;
1145
1146        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
1147        let mut loader = WafBenchLoader::new();
1148        loader.stats.files_processed += 1;
1149
1150        let category = SecurityCategory::SqlInjection;
1151        for test in &file.tests {
1152            if let Some(test_case) = loader.parse_test_case(test, category) {
1153                loader.test_cases.push(test_case);
1154            }
1155        }
1156
1157        let payloads = loader.to_security_payloads();
1158        let uri_payload = payloads
1159            .iter()
1160            .find(|p| p.location == SecurityPayloadLocation::Uri)
1161            .expect("Should have URI payload");
1162
1163        assert_eq!(
1164            uri_payload.inject_as_path,
1165            Some(true),
1166            "Path-only URI should have inject_as_path=true"
1167        );
1168    }
1169
1170    #[test]
1171    fn test_uri_with_query_no_inject_as_path() {
1172        let yaml = r#"
1173meta:
1174  author: test
1175  description: Query param test
1176  enabled: true
1177  name: test.yaml
1178
1179tests:
1180  - desc: "Query-param SQL injection"
1181    test_title: "942100-1"
1182    stages:
1183      - stage:
1184          input:
1185            dest_addr: 127.0.0.1
1186            headers: {}
1187            method: GET
1188            port: 80
1189            uri: "/test?param=1+OR+1%3D1"
1190          output:
1191            log_contains: id "942100"
1192"#;
1193
1194        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
1195        let mut loader = WafBenchLoader::new();
1196        loader.stats.files_processed += 1;
1197
1198        let category = SecurityCategory::SqlInjection;
1199        for test in &file.tests {
1200            if let Some(test_case) = loader.parse_test_case(test, category) {
1201                loader.test_cases.push(test_case);
1202            }
1203        }
1204
1205        let payloads = loader.to_security_payloads();
1206        let uri_payload = payloads
1207            .iter()
1208            .find(|p| p.location == SecurityPayloadLocation::Uri)
1209            .expect("Should have URI payload");
1210
1211        assert!(
1212            uri_payload.inject_as_path.is_none(),
1213            "URI with query params should NOT have inject_as_path"
1214        );
1215    }
1216
1217    #[test]
1218    fn test_body_payload_gets_form_encoded_body() {
1219        let yaml = r#"
1220meta:
1221  author: test
1222  description: Form body test
1223  enabled: true
1224  name: test.yaml
1225
1226tests:
1227  - desc: "Form-encoded body attack"
1228    test_title: "942432-1"
1229    stages:
1230      - stage:
1231          input:
1232            dest_addr: 127.0.0.1
1233            headers:
1234              Host: localhost
1235            method: POST
1236            port: 80
1237            uri: "/"
1238            data: "var=%3B%3Bdd+foo+bar"
1239          output:
1240            log_contains: id "942432"
1241"#;
1242
1243        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
1244        let mut loader = WafBenchLoader::new();
1245        loader.stats.files_processed += 1;
1246
1247        let category = SecurityCategory::SqlInjection;
1248        for test in &file.tests {
1249            if let Some(test_case) = loader.parse_test_case(test, category) {
1250                loader.test_cases.push(test_case);
1251            }
1252        }
1253
1254        let payloads = loader.to_security_payloads();
1255        let body_payload = payloads
1256            .iter()
1257            .find(|p| p.location == SecurityPayloadLocation::Body)
1258            .expect("Should have body payload");
1259
1260        assert!(
1261            body_payload.form_encoded_body.is_some(),
1262            "Body payload should have form_encoded_body set"
1263        );
1264        // Pre-encoded CRS YAML value round-trips through ensure_form_encoded
1265        assert_eq!(
1266            body_payload.form_encoded_body.as_deref().unwrap(),
1267            "var=%3B%3Bdd+foo+bar",
1268            "form_encoded_body should be properly URL-encoded"
1269        );
1270    }
1271
1272    #[test]
1273    fn test_body_payload_decoded_yaml_gets_encoded() {
1274        // CRS YAML with already-decoded data value (some CRS distributions)
1275        let yaml = r#"
1276meta:
1277  author: test
1278  description: Form body test (decoded)
1279  enabled: true
1280  name: test.yaml
1281
1282tests:
1283  - desc: "Form-encoded body attack (decoded)"
1284    test_title: "942432-2"
1285    stages:
1286      - stage:
1287          input:
1288            dest_addr: 127.0.0.1
1289            headers:
1290              Host: localhost
1291            method: POST
1292            port: 80
1293            uri: "/"
1294            data: "var=;;dd foo bar"
1295          output:
1296            log_contains: id "942432"
1297"#;
1298
1299        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
1300        let mut loader = WafBenchLoader::new();
1301        loader.stats.files_processed += 1;
1302
1303        let category = SecurityCategory::SqlInjection;
1304        for test in &file.tests {
1305            if let Some(test_case) = loader.parse_test_case(test, category) {
1306                loader.test_cases.push(test_case);
1307            }
1308        }
1309
1310        let payloads = loader.to_security_payloads();
1311        let body_payload = payloads
1312            .iter()
1313            .find(|p| p.location == SecurityPayloadLocation::Body)
1314            .expect("Should have body payload");
1315
1316        assert!(
1317            body_payload.form_encoded_body.is_some(),
1318            "Body payload should have form_encoded_body set"
1319        );
1320        // Decoded input must be re-encoded for WAF ARGS parsing
1321        let encoded = body_payload.form_encoded_body.as_deref().unwrap();
1322        assert!(
1323            encoded.contains("%3B%3B") || encoded.contains("%3b%3b"),
1324            "Semicolons must be URL-encoded: {encoded}"
1325        );
1326        assert!(!encoded.contains(' '), "Spaces must be encoded as + in form body: {encoded}");
1327        assert!(encoded.starts_with("var="), "Form key must be preserved: {encoded}");
1328    }
1329
1330    #[test]
1331    fn test_parse_crs_v33_format() {
1332        // CRS v3.3/master uses a nested stage: wrapper
1333        let yaml = r#"
1334meta:
1335  author: "Christian Folini"
1336  description: Various SQL injection tests
1337  enabled: true
1338  name: 942100.yaml
1339
1340tests:
1341  - test_title: 942100-1
1342    desc: "Simple SQL Injection"
1343    stages:
1344      - stage:
1345          input:
1346            dest_addr: 127.0.0.1
1347            headers:
1348              Host: localhost
1349            method: POST
1350            port: 80
1351            uri: "/"
1352            data: "var=1234 OR 1=1"
1353            version: HTTP/1.0
1354          output:
1355            log_contains: id "942100"
1356"#;
1357
1358        let file: WafBenchFile = serde_yaml::from_str(yaml).unwrap();
1359        assert!(file.meta.enabled);
1360        assert_eq!(file.tests.len(), 1);
1361        assert_eq!(file.tests[0].test_title, "942100-1");
1362
1363        // Verify we can get the input from nested format
1364        let stage = &file.tests[0].stages[0];
1365        let input = stage.get_input().expect("Should have input");
1366        assert_eq!(input.method, "POST");
1367        assert_eq!(input.data.as_deref(), Some("var=1234 OR 1=1"));
1368    }
1369}