use crate::response_fingerprint::{FingerprintDrift, ResponseFingerprint, compare, fingerprint};
use crate::waf_detect::DetectedWaf;
#[derive(Debug, Clone)]
pub struct ProbePayload {
pub label: &'static str,
pub payload: &'static str,
pub category: ProbeCategory,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProbeCategory {
Xss,
Sqli,
PathTraversal,
}
#[derive(Debug, Clone)]
pub struct ProbeResult {
pub payload: &'static ProbePayload,
pub baseline: ResponseFingerprint,
pub probed: ResponseFingerprint,
pub drift: FingerprintDrift,
}
#[must_use]
pub fn probe_set() -> &'static [ProbePayload] {
&[
ProbePayload {
label: "xss_probe",
payload: "<script>alert('waf_probe')</script>",
category: ProbeCategory::Xss,
},
ProbePayload {
label: "sqli_probe",
payload: "' OR '1'='1' -- waf_probe",
category: ProbeCategory::Sqli,
},
ProbePayload {
label: "path_traversal_probe",
payload: "../../../etc/passwd?waf_probe=1",
category: ProbeCategory::PathTraversal,
},
]
}
#[must_use]
pub fn active_probe(
payload: &'static ProbePayload,
baseline_status: u16,
baseline_headers: &[(String, String)],
baseline_body: &[u8],
probed_status: u16,
probed_headers: &[(String, String)],
probed_body: &[u8],
) -> ProbeResult {
let baseline = fingerprint(baseline_status, baseline_headers, baseline_body);
let probed = fingerprint(probed_status, probed_headers, probed_body);
let drift = compare(&baseline, &probed);
ProbeResult {
payload,
baseline,
probed,
drift,
}
}
const ALL_BLOCKED_WEIGHT: f64 = 0.50;
const MAJORITY_BLOCKED_WEIGHT: f64 = 0.30;
const MAJORITY_BLOCKED_THRESHOLD: f64 = 0.50;
const HIGH_DRIFT_WEIGHT: f64 = 0.30;
const MODERATE_DRIFT_WEIGHT: f64 = 0.20;
const HIGH_DRIFT_THRESHOLD: f64 = 0.60;
const MODERATE_DRIFT_THRESHOLD: f64 = 0.40;
const UNIFORM_BLOCK_STATUS_WEIGHT: f64 = 0.10;
const UNIVERSAL_TITLE_CHANGE_WEIGHT: f64 = 0.20;
#[must_use]
pub fn classify_drift(results: &[ProbeResult]) -> Vec<DetectedWaf> {
let mut score: f64 = 0.0;
let mut indicators = Vec::new();
let blocked_count = results.iter().filter(|r| r.drift.likely_blocked).count();
let total = results.len().max(1);
let block_rate = blocked_count as f64 / total as f64;
if block_rate >= 1.0 {
score += ALL_BLOCKED_WEIGHT;
indicators.push("all probes blocked".into());
} else if block_rate >= MAJORITY_BLOCKED_THRESHOLD {
score += MAJORITY_BLOCKED_WEIGHT;
indicators.push("majority of probes blocked".into());
}
let avg_drift: f64 = results.iter().map(|r| r.drift.score).sum::<f64>() / total as f64;
if avg_drift >= HIGH_DRIFT_THRESHOLD {
score += HIGH_DRIFT_WEIGHT;
indicators.push(format!("high avg drift {:.0}%", avg_drift * 100.0));
} else if avg_drift >= MODERATE_DRIFT_THRESHOLD {
score += MODERATE_DRIFT_WEIGHT;
indicators.push(format!("moderate avg drift {:.0}%", avg_drift * 100.0));
}
let statuses: std::collections::HashSet<u16> =
results.iter().map(|r| r.probed.status).collect();
if let Some(status) = (statuses.len() == 1)
.then(|| statuses.iter().next().copied())
.flatten()
&& status >= 400
{
score += UNIFORM_BLOCK_STATUS_WEIGHT;
indicators.push(format!("uniform block status {status}"));
}
let title_changes = results
.iter()
.filter(|r| r.drift.changed.contains(&"title_tag"))
.count();
if title_changes == total {
score += UNIVERSAL_TITLE_CHANGE_WEIGHT;
indicators.push("title changed on every probe".into());
}
if score > 0.0 {
vec![DetectedWaf {
name: "Active-Probe-Generic".into(),
confidence: score.min(1.0),
indicators,
}]
} else {
Vec::new()
}
}