use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, BTreeSet};
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct PayloadClass(pub String);
impl PayloadClass {
#[must_use]
pub fn new(raw: &str) -> Self {
Self(raw.trim().to_ascii_lowercase())
}
#[must_use]
pub fn from_payload(payload: &str) -> Self {
let lower = payload.to_ascii_lowercase();
if lower.contains("select")
|| lower.contains("union")
|| lower.contains("insert")
|| lower.contains("update")
|| lower.contains("delete")
|| lower.contains("drop")
|| lower.contains("' or ")
|| lower.contains("or 1=1")
{
return Self::new("sql");
}
if lower.contains("<script")
|| lower.contains("onerror")
|| lower.contains("onload")
|| lower.contains("javascript:")
|| lower.contains("alert(")
{
return Self::new("xss");
}
if lower.contains("../")
|| lower.contains("..\\")
|| lower.contains("%2e%2e")
|| lower.contains("etc/passwd")
{
return Self::new("path");
}
if lower.contains("$(")
|| lower.contains("`")
|| lower.contains("|bash")
|| lower.contains("cmd.exe")
|| lower.contains("/bin/sh")
{
return Self::new("cmdi");
}
if lower.contains("{{")
|| lower.contains("{%")
|| lower.contains("#{")
|| lower.contains("${'")
{
return Self::new("ssti");
}
if lower.contains("ldap://") || lower.contains("(uid=") || lower.contains("(cn=") {
return Self::new("ldap");
}
if lower.contains("http://") || lower.contains("https://") || lower.contains("ssrf") {
return Self::new("ssrf");
}
if lower.contains("<!entity") || lower.contains("<!doctype") || lower.contains("xxe") {
return Self::new("xxe");
}
if lower.contains("${jndi:") || lower.contains("log4j") || lower.contains("log4shell") {
return Self::new("log4shell");
}
Self::new("unknown")
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct RuleId(pub String);
impl RuleId {
#[must_use]
pub fn new(raw: &str) -> Self {
let s = raw.trim().to_ascii_lowercase();
let s = s
.strip_prefix("rule_")
.or_else(|| s.strip_prefix("rule-"))
.unwrap_or(&s);
Self(s.to_string())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for RuleId {
fn default() -> Self {
Self(String::new())
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct RuleCoverage {
pub by_rule: BTreeMap<RuleId, BTreeSet<String>>,
pub by_class: BTreeMap<PayloadClass, BTreeSet<RuleId>>,
}
impl RuleCoverage {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn record(&mut self, payload: &str, rule_id: Option<&str>) {
let cls = PayloadClass::from_payload(payload);
let fp: String = payload
.chars()
.filter(|c| !c.is_control())
.take(64)
.collect();
if let Some(rid_raw) = rule_id {
let rid = RuleId::new(rid_raw);
self.by_rule.entry(rid.clone()).or_default().insert(fp);
self.by_class.entry(cls).or_default().insert(rid);
} else {
let sentinel = RuleId::new("__unblocked__");
self.by_class.entry(cls).or_default().insert(sentinel);
}
}
#[must_use]
pub fn coverage_report(&self) -> String {
let mut lines = Vec::with_capacity(self.by_rule.len() + 4);
lines.push(format!(
"# wafrift rule-coverage report — {} distinct rules triggered",
self.by_rule.len()
));
lines.push(format!(
"# payload classes observed: {}",
self.by_class.len()
));
lines.push("# rule_id\tpayloads_observed".to_string());
for (rule_id, payloads) in &self.by_rule {
lines.push(format!("{}\t{}", rule_id.as_str(), payloads.len()));
}
lines.push("# per-class summary".to_string());
for (cls, rules) in &self.by_class {
lines.push(format!(
"# {}: {} rule(s) — {}",
cls.as_str(),
rules.len(),
rules
.iter()
.map(|r| r.as_str())
.collect::<Vec<_>>()
.join(", ")
));
}
lines.join("\n")
}
#[must_use]
pub fn rule_count(&self) -> usize {
self.by_rule
.keys()
.filter(|r| r.0 != "__unblocked__")
.count()
}
#[must_use]
pub fn triggered_rules(&self) -> Vec<&RuleId> {
self.by_rule
.keys()
.filter(|r| r.0 != "__unblocked__")
.collect()
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
}
#[must_use]
pub fn map_elites_descriptor(
payload: &str,
rule_id: Option<&str>,
) -> (PayloadClass, Option<RuleId>) {
let cls = PayloadClass::from_payload(payload);
let rid = rule_id.map(RuleId::new);
(cls, rid)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_coverage_has_no_rules() {
let cov = RuleCoverage::new();
assert_eq!(cov.rule_count(), 0);
assert!(cov.by_rule.is_empty());
assert!(cov.by_class.is_empty());
}
#[test]
fn single_rule_id_recorded_correctly() {
let mut cov = RuleCoverage::new();
cov.record("' OR 1=1--", Some("942100"));
assert_eq!(cov.rule_count(), 1);
let rid = RuleId::new("942100");
assert!(cov.by_rule.contains_key(&rid));
let cls = PayloadClass::new("sql");
assert!(cov.by_class.contains_key(&cls));
}
#[test]
fn mixed_classes_produce_distinct_cells() {
let mut cov = RuleCoverage::new();
cov.record("' OR 1=1--", Some("942100")); cov.record("<script>alert(1)</script>", Some("941100")); cov.record("../../../etc/passwd", Some("930100"));
assert_eq!(cov.rule_count(), 3);
assert!(cov.by_class.contains_key(&PayloadClass::new("sql")));
assert!(cov.by_class.contains_key(&PayloadClass::new("xss")));
assert!(cov.by_class.contains_key(&PayloadClass::new("path")));
}
#[test]
fn descriptor_is_stable_for_same_input() {
let payload = "' UNION SELECT 1,2,3--";
let rule = Some("942190");
let d1 = map_elites_descriptor(payload, rule);
let d2 = map_elites_descriptor(payload, rule);
assert_eq!(d1, d2);
}
#[test]
fn descriptor_without_rule_id_has_none_dimension() {
let (cls, rid) = map_elites_descriptor("' OR 1=1--", None);
assert_eq!(cls, PayloadClass::new("sql"));
assert!(rid.is_none());
}
#[test]
fn json_roundtrip_preserves_coverage() {
let mut cov = RuleCoverage::new();
cov.record("' OR 1=1--", Some("942100"));
cov.record("<script>alert(1)</script>", Some("941100"));
let json = cov.to_json().expect("serialization must not fail");
let restored = RuleCoverage::from_json(&json).expect("deserialization must not fail");
assert_eq!(restored.rule_count(), cov.rule_count());
assert_eq!(restored.by_rule.len(), cov.by_rule.len());
assert_eq!(restored.by_class.len(), cov.by_class.len());
}
#[test]
fn rule_id_case_folding_normalises() {
let r1 = RuleId::new("942100");
let r2 = RuleId::new("RULE_942100");
let r3 = RuleId::new("rule-942100");
assert_eq!(r1, r2);
assert_eq!(r1, r3);
}
#[test]
fn rule_id_without_rule_prefix_preserved() {
let r = RuleId::new("sql_942100");
assert_eq!(r.as_str(), "sql_942100");
}
#[test]
fn payload_class_detects_sql() {
let cls = PayloadClass::from_payload("' UNION SELECT username, password FROM users--");
assert_eq!(cls, PayloadClass::new("sql"));
}
#[test]
fn payload_class_detects_xss() {
let cls = PayloadClass::from_payload("<script>alert(document.cookie)</script>");
assert_eq!(cls, PayloadClass::new("xss"));
}
#[test]
fn same_rule_accumulates_multiple_payloads() {
let mut cov = RuleCoverage::new();
cov.record("' OR 1=1--", Some("942100"));
cov.record("' OR 'x'='x'--", Some("942100"));
cov.record("1 AND 1=1--", Some("942100"));
let rid = RuleId::new("942100");
assert_eq!(cov.by_rule[&rid].len(), 3);
assert_eq!(cov.rule_count(), 1);
}
#[test]
fn coverage_report_contains_triggered_rule() {
let mut cov = RuleCoverage::new();
cov.record("' OR 1=1--", Some("942100"));
let report = cov.coverage_report();
assert!(report.contains("942100"), "report must mention rule 942100");
assert!(report.contains("1"), "report must show payload count");
}
}