use std::collections::HashSet;
use wafrift_wafmodel::Alphabet;
use crate::rule_corpus::RuleBucket;
pub const DEFAULT_DISTINGUISHED_COUNT: usize = 8;
const HTTP_FILLER_BYTES: &[u8] = b" =&\r\n\t";
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ByteScore {
pub byte: u8,
pub block_presence: f64,
pub bypass_presence: f64,
pub discrimination: f64,
}
#[must_use]
pub fn score_bytes(bucket: &RuleBucket) -> Vec<ByteScore> {
let n_blocks = bucket.blocked.len();
let n_bypasses = bucket.bypassed.len();
let mut out: Vec<ByteScore> = (0u8..=255u8)
.map(|byte| {
let block_presence = if n_blocks == 0 {
0.0
} else {
let hits = bucket
.blocked
.iter()
.filter(|r| r.payload.as_bytes().contains(&byte))
.count();
hits as f64 / n_blocks as f64
};
let bypass_presence = if n_bypasses == 0 {
0.0
} else {
let hits = bucket
.bypassed
.iter()
.filter(|r| r.payload.as_bytes().contains(&byte))
.count();
hits as f64 / n_bypasses as f64
};
ByteScore {
byte,
block_presence,
bypass_presence,
discrimination: (block_presence - bypass_presence).abs(),
}
})
.collect();
out.sort_by(|a, b| {
b.discrimination
.partial_cmp(&a.discrimination)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.byte.cmp(&b.byte))
});
out
}
#[must_use]
pub fn distinguished_bytes(bucket: &RuleBucket, k: usize) -> Vec<u8> {
let scored = score_bytes(bucket);
let filler: HashSet<u8> = HTTP_FILLER_BYTES.iter().copied().collect();
scored
.into_iter()
.filter(|s| !filler.contains(&s.byte))
.filter(|s| s.discrimination > 0.0)
.take(k)
.map(|s| s.byte)
.collect()
}
#[must_use]
pub fn pick_catch_all(bucket: &RuleBucket, distinguished: &[u8]) -> u8 {
let dist: HashSet<u8> = distinguished.iter().copied().collect();
let appears_anywhere = |b: u8| -> bool {
bucket
.blocked
.iter()
.any(|r| r.payload.as_bytes().contains(&b))
|| bucket
.bypassed
.iter()
.any(|r| r.payload.as_bytes().contains(&b))
};
for candidate in (b'A'..=b'Z').chain(b'a'..=b'z') {
if !dist.contains(&candidate) && !appears_anywhere(candidate) {
return candidate;
}
}
b'Z'
}
#[must_use]
pub fn infer_alphabet(bucket: &RuleBucket, k: usize) -> Option<Alphabet> {
if bucket.blocked.is_empty() && bucket.bypassed.is_empty() {
return None;
}
let dist = distinguished_bytes(bucket, k);
if dist.is_empty() {
return None;
}
let catch_all = pick_catch_all(bucket, &dist);
Some(Alphabet::new(dist, catch_all))
}
#[must_use]
pub fn infer_alphabet_default(bucket: &RuleBucket) -> Option<Alphabet> {
infer_alphabet(bucket, DEFAULT_DISTINGUISHED_COUNT)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::coverage_feedback::PayloadClass;
use crate::rule_corpus::{RecordedAttempt, RecordedBypass, SubmissionStatus};
fn cls() -> PayloadClass {
PayloadClass::new("sql")
}
fn attempt(payload: &str) -> RecordedAttempt {
RecordedAttempt {
payload: payload.to_string(),
payload_class: cls(),
encoding_chain: vec![],
response_hash: 0,
observed_at_secs: 0,
}
}
fn bypass(payload: &str) -> RecordedBypass {
RecordedBypass {
payload: payload.to_string(),
payload_class: cls(),
encoding_chain: vec![],
response_hash: 0,
observed_at_secs: 0,
submission: SubmissionStatus::default(),
delivery: String::new(),
}
}
fn bucket_with(blocked: Vec<&str>, bypassed: Vec<&str>) -> RuleBucket {
RuleBucket {
blocked: blocked.into_iter().map(attempt).collect(),
bypassed: bypassed.into_iter().map(bypass).collect(),
..Default::default()
}
}
#[test]
fn empty_bucket_returns_none() {
let b = RuleBucket::default();
assert!(infer_alphabet_default(&b).is_none());
}
#[test]
fn score_bytes_returns_256_entries() {
let b = bucket_with(vec!["abc"], vec!["xyz"]);
let scored = score_bytes(&b);
assert_eq!(scored.len(), 256);
}
#[test]
fn highly_discriminative_byte_ranks_first() {
let b = bucket_with(
vec!["<script>", "<img>", "<svg>"],
vec!["benign", "safe", "ok"],
);
let scored = score_bytes(&b);
assert_eq!(scored[0].byte, b'<');
assert!((scored[0].discrimination - 1.0).abs() < 1e-9);
assert!((scored[0].block_presence - 1.0).abs() < 1e-9);
assert!(scored[0].bypass_presence.abs() < 1e-9);
}
#[test]
fn distinguished_excludes_filler_bytes() {
let b = bucket_with(vec!["a b", "c d", "e f"], vec!["abcdef", "qwert", "zxcvb"]);
let dist = distinguished_bytes(&b, 5);
assert!(!dist.contains(&b' '), "filler byte ' ' must be excluded");
}
#[test]
fn distinguished_count_capped_at_k() {
let b = bucket_with(vec!["<svg onload=alert('xss')>"], vec!["benign"]);
let dist = distinguished_bytes(&b, 3);
assert!(
dist.len() <= 3,
"should return ≤ k distinguished bytes, got {}",
dist.len()
);
}
#[test]
fn distinguished_is_deterministic() {
let b = bucket_with(
vec!["' OR 1=1", "UNION SELECT", "AND 1=1"],
vec!["normal", "input", "okay"],
);
let a = distinguished_bytes(&b, 5);
let b2 = distinguished_bytes(&b, 5);
assert_eq!(a, b2);
}
#[test]
fn pick_catch_all_returns_unused_letter() {
let b = bucket_with(vec!["abc"], vec!["def"]);
let dist = vec![b'a', b'b', b'c'];
let ca = pick_catch_all(&b, &dist);
assert!(!dist.contains(&ca));
assert_ne!(ca, b'd');
assert_ne!(ca, b'e');
assert_ne!(ca, b'f');
}
#[test]
fn pick_catch_all_falls_back_to_z_when_letters_exhausted() {
let all_letters: String = (b'A'..=b'Z')
.chain(b'a'..=b'z')
.map(|b| b as char)
.collect();
let b = bucket_with(vec![all_letters.as_str()], vec![]);
let dist = vec![];
let ca = pick_catch_all(&b, &dist);
assert_eq!(ca, b'Z');
}
#[test]
fn infer_alphabet_sql_bucket_includes_quote_or_dash() {
let b = bucket_with(
vec![
"' OR 1=1--",
"1' AND 1=1#",
"admin'--",
"UNION SELECT",
"SLEEP(5)",
],
vec!["normal_input", "search_term", "user_query"],
);
let alpha = infer_alphabet_default(&b).expect("alphabet");
let symbols = alpha.raw_symbols();
let has_sqli_byte =
symbols.contains(&b'\'') || symbols.contains(&b'-') || symbols.contains(&b'(');
assert!(
has_sqli_byte,
"inferred alphabet should include an SQLi-keying byte: {symbols:?}"
);
}
#[test]
fn infer_alphabet_xss_bucket_includes_angle_bracket() {
let b = bucket_with(
vec![
"<script>alert(1)</script>",
"<img src=x onerror=alert(1)>",
"<svg/onload=alert(1)>",
"<iframe src=javascript:alert(1)>",
],
vec!["plain text query", "ordinary input"],
);
let alpha = infer_alphabet_default(&b).expect("alphabet");
let symbols = alpha.raw_symbols();
assert!(
symbols.contains(&b'<') || symbols.contains(&b'>'),
"XSS alphabet must include < or >: {symbols:?}"
);
}
#[test]
fn infer_alphabet_zero_discrimination_bucket_returns_none() {
let b = bucket_with(vec!["abc", "abc", "abc"], vec!["abc", "abc", "abc"]);
let result = infer_alphabet_default(&b);
assert!(
result.is_none(),
"zero-discrimination corpus should yield None"
);
}
#[test]
fn infer_alphabet_only_blocks_no_bypasses_still_works() {
let b = bucket_with(
vec!["' OR 1=1", "UNION SELECT", "SLEEP(5)", "1' AND 1=1"],
vec![],
);
let alpha = infer_alphabet_default(&b);
assert!(alpha.is_some(), "blocks-only bucket must yield alphabet");
let alpha = alpha.unwrap();
assert!(alpha.len() >= 2);
}
#[test]
fn infer_alphabet_only_bypasses_no_blocks_still_works() {
let b = bucket_with(vec![], vec!["%27 OR 1%3d1", "<script>", "%3cscript%3e"]);
let alpha = infer_alphabet_default(&b);
assert!(alpha.is_some(), "bypasses-only bucket must yield alphabet");
}
#[test]
fn discrimination_ordering_is_monotone() {
let b = bucket_with(vec!["abc", "abd", "abe"], vec!["xyz"]);
let scored = score_bytes(&b);
for i in 1..scored.len() {
assert!(
scored[i - 1].discrimination >= scored[i].discrimination,
"scored list must be sorted descending by discrimination"
);
}
}
#[test]
fn many_payloads_perf_smoke() {
let payloads: Vec<String> = (0..1000)
.map(|i| format!("' OR {i}=1-- comment{i}"))
.collect();
let bypasses_v: Vec<String> = (0..100).map(|i| format!("normal_input_{i}")).collect();
let bucket = RuleBucket {
blocked: payloads.iter().map(|s| attempt(s)).collect(),
bypassed: bypasses_v.iter().map(|s| bypass(s)).collect(),
..Default::default()
};
let start = std::time::Instant::now();
let alpha = infer_alphabet_default(&bucket).expect("alphabet");
let elapsed = start.elapsed();
assert!(alpha.len() > 1);
assert!(
elapsed.as_millis() < 500,
"1000-payload bucket too slow: {elapsed:?}"
);
}
#[test]
fn k_zero_returns_empty_distinguished_and_no_alphabet() {
let b = bucket_with(vec!["' OR 1=1"], vec!["safe"]);
assert!(distinguished_bytes(&b, 0).is_empty());
assert!(infer_alphabet(&b, 0).is_none());
}
#[test]
fn alphabet_round_trip_through_raw_symbols() {
let b = bucket_with(
vec!["' OR 1=1", "UNION SELECT", "<script>"],
vec!["normal_input"],
);
let alpha = infer_alphabet_default(&b).expect("alphabet");
let raw = alpha.raw_symbols().to_vec();
let restored = wafrift_wafmodel::Alphabet::from_raw_symbols(raw.clone());
assert_eq!(restored.raw_symbols(), raw.as_slice());
}
#[test]
fn distinguished_payload_count_caps_at_k_even_with_many_unique_bytes() {
let b = bucket_with(vec!["<>'\"(){}[]!@#$%^&*+-/=?:;,."], vec!["plain"]);
assert!(distinguished_bytes(&b, 5).len() <= 5);
assert!(distinguished_bytes(&b, 10).len() <= 10);
assert!(distinguished_bytes(&b, 20).len() <= 20);
}
}