use std::collections::HashMap;
const MAX_FEATURES_PER_PAYLOAD: usize = 100;
const WEIGHT_STEP: f64 = 1.0;
fn byte_ngrams(data: &[u8], limit: usize) -> Vec<String> {
let mut seen: HashMap<[u8; 4], bool> = HashMap::new();
let mut out = Vec::new();
for size in [2usize, 3, 4] {
if data.len() < size {
continue;
}
for window in data.windows(size) {
if out.len() >= limit {
return out;
}
let mut key = [0u8; 4];
key[..window.len()].copy_from_slice(window);
if seen.insert(key, true).is_none() {
out.push(format!("ng:{}", String::from_utf8_lossy(window)));
}
}
}
out
}
fn whitespace_tokens(payload: &str) -> impl Iterator<Item = String> + '_ {
payload
.split(|c: char| {
c.is_ascii_whitespace()
|| matches!(
c,
'\'' | '"'
| '`'
| ';'
| ','
| '('
| ')'
| '['
| ']'
| '{'
| '}'
| '<'
| '>'
| '='
| '!'
| '&'
| '|'
| '+'
| '-'
| '*'
| '/'
| '\\'
| '?'
| '@'
| '#'
| '$'
| '%'
| '^'
| '~'
)
})
.filter(|s| !s.is_empty())
.map(|s| format!("tok:{s}"))
}
fn extract_features(payload: &str) -> Vec<String> {
let mut seen: HashMap<String, ()> = HashMap::new();
let mut features = Vec::with_capacity(MAX_FEATURES_PER_PAYLOAD);
for tok in whitespace_tokens(payload) {
if features.len() >= MAX_FEATURES_PER_PAYLOAD {
return features;
}
if seen.insert(tok.clone(), ()).is_none() {
features.push(tok);
}
}
let remaining = MAX_FEATURES_PER_PAYLOAD.saturating_sub(features.len());
if remaining > 0 {
for ngram in byte_ngrams(payload.as_bytes(), remaining) {
if features.len() >= MAX_FEATURES_PER_PAYLOAD {
break;
}
if seen.insert(ngram.clone(), ()).is_none() {
features.push(ngram);
}
}
}
features
}
#[derive(Debug, Clone)]
pub struct WafBoosterScorer {
feature_weights: HashMap<String, f64>,
pub decay: f64,
}
impl WafBoosterScorer {
#[must_use]
pub fn new(decay: f64) -> Self {
debug_assert!(
decay > 0.0 && decay <= 1.0,
"decay must be in (0.0, 1.0]; got {decay}"
);
Self {
feature_weights: HashMap::new(),
decay: decay.clamp(f64::MIN_POSITIVE, 1.0),
}
}
#[must_use]
pub fn no_decay() -> Self {
Self::new(1.0)
}
fn apply_decay(&mut self) {
if (self.decay - 1.0).abs() < f64::EPSILON {
return; }
for w in self.feature_weights.values_mut() {
*w *= self.decay;
}
}
#[must_use]
pub fn feature_count(&self) -> usize {
self.feature_weights.len()
}
#[must_use]
pub fn weight_of(&self, feature: &str) -> f64 {
self.feature_weights.get(feature).copied().unwrap_or(0.0)
}
pub fn observe_block(&mut self, payload: &str, _rule_id: Option<&str>) {
self.apply_decay();
for feat in extract_features(payload) {
*self.feature_weights.entry(feat).or_insert(0.0) += WEIGHT_STEP;
}
}
pub fn observe_pass(&mut self, payload: &str) {
self.apply_decay();
for feat in extract_features(payload) {
*self.feature_weights.entry(feat).or_insert(0.0) -= WEIGHT_STEP;
}
}
#[must_use]
pub fn score_candidate(&self, payload: &str) -> f64 {
if self.feature_weights.is_empty() {
return 0.0;
}
let mut total = 0.0_f64;
for feat in extract_features(payload) {
if let Some(w) = self.feature_weights.get(&feat) {
total += w;
}
}
total
}
#[must_use]
pub fn rank_candidates(&self, candidates: &[String]) -> Vec<(String, f64)> {
let mut scored: Vec<(String, f64)> = candidates
.iter()
.map(|c| (c.clone(), self.score_candidate(c)))
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
scored
}
}
impl Default for WafBoosterScorer {
fn default() -> Self {
Self::no_decay()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn feature_extraction_non_empty_for_attack_payload() {
let feats = extract_features("' UNION SELECT 1,2--");
assert!(
!feats.is_empty(),
"must extract features from a non-empty payload"
);
}
#[test]
fn feature_extraction_empty_payload_returns_empty() {
let feats = extract_features("");
assert!(feats.is_empty());
}
#[test]
fn feature_count_capped_at_max() {
let long = "A".repeat(500);
let feats = extract_features(&long);
assert!(
feats.len() <= MAX_FEATURES_PER_PAYLOAD,
"expected ≤ {MAX_FEATURES_PER_PAYLOAD}, got {}",
feats.len()
);
}
#[test]
fn ngram_boundary_two_byte_minimum() {
let ngrams = byte_ngrams(b"A", 100);
assert!(ngrams.is_empty());
}
#[test]
fn ngram_two_byte_ok() {
let ngrams = byte_ngrams(b"AB", 100);
assert!(!ngrams.is_empty());
assert!(ngrams.iter().any(|n| n.contains("AB")));
}
#[test]
fn feature_extraction_deduplicates() {
let feats = extract_features("select select select");
let count = feats.iter().filter(|f| f.contains("select")).count();
assert_eq!(count, 1, "duplicate token extracted");
}
#[test]
fn observe_block_raises_score() {
let mut scorer = WafBoosterScorer::no_decay();
let payload = "' UNION SELECT--";
let before = scorer.score_candidate(payload);
scorer.observe_block(payload, None);
let after = scorer.score_candidate(payload);
assert!(
after > before,
"block observation must raise score: {before} → {after}"
);
}
#[test]
fn observe_pass_lowers_score() {
let mut scorer = WafBoosterScorer::no_decay();
let payload = "hello world";
scorer.observe_block(payload, None); let before = scorer.score_candidate(payload);
scorer.observe_pass(payload);
let after = scorer.score_candidate(payload);
assert!(
after < before,
"pass observation must lower score: {before} → {after}"
);
}
#[test]
fn empty_scorer_returns_zero() {
let scorer = WafBoosterScorer::no_decay();
assert_eq!(scorer.score_candidate("anything"), 0.0);
assert_eq!(scorer.score_candidate(""), 0.0);
}
#[test]
fn score_zero_for_unseen_features() {
let mut scorer = WafBoosterScorer::no_decay();
scorer.observe_block("totally different payload xyz", None);
let score = scorer.score_candidate("12345678");
let _ = score; }
#[test]
fn decay_shrinks_old_weights() {
let mut scorer = WafBoosterScorer::new(0.5);
let payload = "' UNION SELECT--";
scorer.observe_block(payload, None);
let after_block = scorer.score_candidate(payload);
scorer.observe_block("unrelated thing", None);
let after_decay = scorer.score_candidate(payload);
assert!(
after_decay < after_block,
"decay must shrink old weights: {after_block} → {after_decay}"
);
}
#[test]
fn decay_one_means_no_shrinkage() {
let mut scorer = WafBoosterScorer::new(1.0);
let payload = "' UNION SELECT--";
scorer.observe_block(payload, None);
let s1 = scorer.score_candidate(payload);
scorer.observe_block("completely unrelated qwerty", None);
let s2 = scorer.score_candidate(payload);
assert!(
s2 >= s1,
"decay=1.0 must not shrink existing weights: {s1} → {s2}"
);
}
#[test]
fn rank_candidates_lowest_score_first() {
let mut scorer = WafBoosterScorer::no_decay();
let blocked = "' UNION SELECT 1,2--".to_string();
let safe = "hello world".to_string();
scorer.observe_block(&blocked, None);
scorer.observe_pass(&safe);
let ranked = scorer.rank_candidates(&[blocked.clone(), safe.clone()]);
assert_eq!(
ranked.len(),
2,
"rank_candidates must return same count as input"
);
assert!(
ranked[0].1 <= ranked[1].1,
"candidates must be sorted ascending by score: {:?}",
ranked
);
assert_eq!(
ranked[0].0, safe,
"safe payload must rank first (lower score)"
);
}
#[test]
fn rank_empty_input_returns_empty() {
let scorer = WafBoosterScorer::no_decay();
let ranked = scorer.rank_candidates(&[]);
assert!(ranked.is_empty());
}
#[test]
fn rank_single_candidate() {
let mut scorer = WafBoosterScorer::no_decay();
scorer.observe_block("test", None);
let ranked = scorer.rank_candidates(&["test".to_string()]);
assert_eq!(ranked.len(), 1);
}
#[test]
fn rank_stable_for_ties() {
let scorer = WafBoosterScorer::no_decay();
let candidates: Vec<String> = (0..5).map(|i| format!("candidate_{i}")).collect();
let ranked = scorer.rank_candidates(&candidates);
let ranked_names: Vec<_> = ranked.iter().map(|(n, _)| n.clone()).collect();
assert_eq!(
ranked_names, candidates,
"stable sort must preserve order on ties"
);
}
#[test]
fn very_long_payload_does_not_exceed_feature_cap() {
let long_payload = "' OR 1=1-- ".repeat(200);
let mut scorer = WafBoosterScorer::no_decay();
scorer.observe_block(&long_payload, None);
let score = scorer.score_candidate(&long_payload);
assert!(score.is_finite(), "score must be finite for long payloads");
}
#[test]
fn multiple_rule_ids_tracked_separately_via_observe_block() {
let mut scorer = WafBoosterScorer::no_decay();
let payload = "' UNION SELECT--";
scorer.observe_block(payload, Some("942100"));
scorer.observe_block(payload, Some("941100"));
scorer.observe_block(payload, None);
let score = scorer.score_candidate(payload);
assert!(
score > 0.0,
"after 3 block observations score must be positive"
);
}
#[test]
fn weight_of_unseen_feature_is_zero() {
let scorer = WafBoosterScorer::no_decay();
assert_eq!(scorer.weight_of("tok:never_seen"), 0.0);
}
#[test]
fn feature_count_grows_monotonically() {
let mut scorer = WafBoosterScorer::no_decay();
let before = scorer.feature_count();
scorer.observe_block("' UNION SELECT 1,2--", None);
let after = scorer.feature_count();
assert!(after > before, "feature count must grow after observation");
}
#[test]
fn score_is_additive_across_independent_features() {
let mut scorer = WafBoosterScorer::no_decay();
scorer.observe_block("alpha beta", None);
scorer.observe_block("gamma delta", None);
let s_alpha = scorer.score_candidate("alpha beta");
let s_gamma = scorer.score_candidate("gamma delta");
let s_both = scorer.score_candidate("alpha beta gamma delta");
assert!(
s_both >= s_alpha.max(s_gamma),
"combined score {s_both} must be >= max({s_alpha}, {s_gamma})"
);
}
}