use wafrift_types::{EvasionConfig, WafClass};
use wafrift_wafmodel::ensemble_dilution::{SubScoreEstimator, dilute};
pub const DEFAULT_DILUTION_THRESHOLD: f64 = 25.0;
pub const DEFAULT_INITIAL_COEFF: f64 = 5.0;
pub const DEFAULT_ALPHA: f64 = 0.1;
#[must_use]
pub fn is_ensemble_waf(waf_name: &str) -> bool {
WafClass::from_waf_name(waf_name).is_ensemble()
}
#[must_use]
pub fn dilution_adjusted_fitness(
oracle_fitness: f64,
payload: &str,
estimator: &SubScoreEstimator,
threshold: f64,
config: &EvasionConfig,
waf_name: &str,
) -> f64 {
let w = config.dilution_weight.clamp(0.0, 1.0);
if w == 0.0 || !is_ensemble_waf(waf_name) {
return oracle_fitness;
}
let dilution_score = compute_dilution_score(payload, estimator, threshold);
(oracle_fitness * (1.0 - w) + dilution_score * w).clamp(0.0, 1.0)
}
#[must_use]
pub fn compute_dilution_score(payload: &str, estimator: &SubScoreEstimator, threshold: f64) -> f64 {
let Some(result) = dilute(payload, estimator, threshold) else {
return 0.5;
};
if result.plausible_bypass {
return 1.0;
}
let predicted = result.strategy.predicted_total;
if threshold <= 0.0 {
return 0.0;
}
((threshold - predicted) / threshold).clamp(0.0, 0.99)
}
#[must_use]
pub fn default_estimator() -> SubScoreEstimator {
SubScoreEstimator::new(DEFAULT_INITIAL_COEFF, DEFAULT_ALPHA)
}
#[cfg(test)]
mod tests {
use super::*;
use wafrift_types::EvasionConfig;
fn estimator() -> SubScoreEstimator {
default_estimator()
}
#[test]
fn ensemble_waf_cloudflare_detected() {
assert!(is_ensemble_waf("Cloudflare WAF"));
assert!(is_ensemble_waf("cloudflare"));
}
#[test]
fn ensemble_waf_aws_detected() {
assert!(is_ensemble_waf("AWS WAF"));
assert!(is_ensemble_waf("amazon"));
}
#[test]
fn ensemble_waf_owasp_crs_detected() {
assert!(is_ensemble_waf("OWASP CRS"));
assert!(is_ensemble_waf("crs"));
}
#[test]
fn ensemble_waf_ml_backed_not_ensemble() {
assert!(!is_ensemble_waf("AWS Bot Control"));
assert!(!is_ensemble_waf("Cloudflare Bot Management"));
assert!(!is_ensemble_waf("Akamai Bot Manager"));
}
#[test]
fn ensemble_waf_unknown_not_ensemble() {
assert!(!is_ensemble_waf("SomeRandomVendor"));
}
#[test]
fn dilution_score_known_attack_payload() {
let score = compute_dilution_score("' UNION SELECT--", &estimator(), 40.0);
assert!(
(0.0..=1.0).contains(&score),
"score must be in [0,1]: {score}"
);
}
#[test]
fn dilution_score_benign_payload_neutral() {
let score = compute_dilution_score("hello world", &estimator(), 40.0);
assert!(
(0.0..=1.0).contains(&score),
"score must be in [0,1]: {score}"
);
}
#[test]
fn dilution_weight_zero_returns_oracle_fitness() {
let config = EvasionConfig {
dilution_weight: 0.0,
..Default::default()
};
let adj = dilution_adjusted_fitness(
0.7,
"' UNION SELECT--",
&estimator(),
40.0,
&config,
"Cloudflare WAF",
);
assert!(
(adj - 0.7).abs() < 1e-9,
"must equal oracle_fitness exactly"
);
}
#[test]
fn dilution_weight_one_returns_pure_dilution() {
let config = EvasionConfig {
dilution_weight: 1.0,
..Default::default()
};
let dilution_only = compute_dilution_score("' UNION SELECT--", &estimator(), 40.0);
let adj = dilution_adjusted_fitness(
0.0, "' UNION SELECT--",
&estimator(),
40.0,
&config,
"Cloudflare WAF",
);
assert!(
(adj - dilution_only).abs() < 1e-9,
"weight=1.0 must equal pure dilution score"
);
}
#[test]
fn dilution_gating_no_effect_on_non_ensemble_waf() {
let config = EvasionConfig {
dilution_weight: 1.0,
..Default::default()
}; let adj = dilution_adjusted_fitness(
0.55,
"' UNION SELECT--",
&estimator(),
40.0,
&config,
"SomeRandomVendor", );
assert!(
(adj - 0.55).abs() < 1e-9,
"non-ensemble WAF must not be affected by dilution_weight"
);
}
#[test]
fn dilution_adjusted_clamps_to_unit_interval() {
let config = EvasionConfig {
dilution_weight: 0.3,
..Default::default()
};
let adj = dilution_adjusted_fitness(
1.0,
"' UNION SELECT<script>alert(1)</script>",
&estimator(),
40.0,
&config,
"cloudflare",
);
assert!((0.0..=1.0).contains(&adj), "clamped to [0,1]: {adj}");
}
#[test]
fn compute_dilution_score_deterministic_same_input() {
let est = estimator();
let s1 = compute_dilution_score("' UNION SELECT--", &est, 40.0);
let s2 = compute_dilution_score("' UNION SELECT--", &est, 40.0);
assert!(
(s1 - s2).abs() < 1e-12,
"dilution scoring must be deterministic"
);
}
#[test]
fn high_coeff_payload_scores_lower_than_low_coeff() {
let mut est_high = default_estimator();
*est_high
.coeffs
.get_mut(&wafrift_wafmodel::ensemble_dilution::RuleGroup::SqlInjection)
.unwrap() = 50.0;
let mut est_low = default_estimator();
*est_low
.coeffs
.get_mut(&wafrift_wafmodel::ensemble_dilution::RuleGroup::SqlInjection)
.unwrap() = 1.0;
let score_high = compute_dilution_score("' UNION SELECT--", &est_high, 40.0);
let score_low = compute_dilution_score("' UNION SELECT--", &est_low, 40.0);
assert!(
score_low >= score_high,
"low-contribution group should score >= high-contribution group"
);
}
}