use once_cell::sync::Lazy;
use regex::{Regex, RegexSet};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::RwLock;
static RULE_DB: Lazy<RwLock<RuleEngine>> = Lazy::new(|| {
let engine = RuleEngine::load_embedded().unwrap_or_else(|e| {
tracing::warn!("Failed to load embedded WAF rules: {e}");
RuleEngine::default()
});
RwLock::new(engine)
});
#[derive(Debug, Default, Clone)]
pub struct RuleEngine {
pub rules: HashMap<String, CompiledWafRule>,
pub names: Vec<String>,
body_regex_set: Option<RegexSet>,
body_pattern_map: Vec<BodyPatternRef>,
body_regexes: Vec<Regex>,
}
#[derive(Debug, Clone)]
struct BodyPatternRef {
waf_name: String,
#[allow(dead_code)]
sig_index: usize,
weight: f64,
}
#[derive(Debug, Clone)]
pub struct CompiledWafRule {
pub name: String,
#[allow(dead_code)]
pub vendor: String,
pub confidence_threshold: f64,
pub evasions: Vec<String>,
#[allow(dead_code)]
pub source: String,
pub signatures: Vec<CompiledSignature>,
}
#[derive(Debug, Clone)]
pub struct CompiledSignature {
pub header_name: Option<String>,
pub header_regex: Option<Regex>,
pub cookie_regex: Option<Regex>,
pub body_regex: Option<Regex>,
pub status_code: Option<u16>,
pub weight: f64,
}
#[derive(Debug, Clone, Deserialize)]
struct RawRuleDb {
#[serde(default)]
waf: Vec<RawWafRule>,
}
#[derive(Debug, Clone, Deserialize)]
struct RawWafRule {
name: String,
vendor: String,
#[serde(default = "default_threshold")]
confidence_threshold: f64,
#[serde(default)]
evasions: Vec<String>,
#[serde(default)]
source: String,
#[serde(default)]
signature: Vec<RawSignature>,
}
#[derive(Debug, Clone, Deserialize)]
struct RawSignature {
header_name: Option<String>,
header_regex: Option<String>,
cookie_regex: Option<String>,
body_regex: Option<String>,
status_code: Option<u16>,
#[serde(default = "default_weight")]
weight: f64,
}
fn default_threshold() -> f64 {
0.3
}
fn default_weight() -> f64 {
0.4
}
const EMBEDDED_RULES_TOML: &str =
include_str!(concat!(env!("OUT_DIR"), "/embedded_detect_rules.toml"));
impl RuleEngine {
pub fn load_embedded() -> Result<Self, RulesError> {
let mut engine = RuleEngine {
rules: HashMap::new(),
names: Vec::new(),
body_regex_set: None,
body_pattern_map: Vec::new(),
body_regexes: Vec::new(),
};
let embedded_ok =
engine.load_from_str(EMBEDDED_RULES_TOML).is_ok() && !engine.rules.is_empty();
if !embedded_ok {
let candidates = [
std::path::PathBuf::from("rules/detect"),
std::path::PathBuf::from("../rules/detect"),
std::path::PathBuf::from("../../rules/detect"),
];
let mut loaded = false;
for dir in &candidates {
if dir.is_dir() {
engine.load_directory(dir)?;
loaded = true;
break;
}
}
if !loaded {
return Err(RulesError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"rules/detect directory not found and no embedded rules available",
)));
}
}
engine.compile_body_regex_set()?;
Ok(engine)
}
pub fn load_from_str(&mut self, toml_content: &str) -> Result<(), RulesError> {
let raw: RawRuleDb = toml::from_str(toml_content)
.map_err(|e| RulesError::Parse(format!("embedded rules: {e}")))?;
for waf in raw.waf {
let compiled = Self::compile_waf(waf)
.map_err(|e| RulesError::Parse(format!("embedded rules: {e}")))?;
let key = compiled.name.clone();
if !self.rules.contains_key(&key) {
self.names.push(key.clone());
}
self.rules.insert(key, compiled);
}
Ok(())
}
pub fn load_directory(&mut self, path: &std::path::Path) -> Result<(), RulesError> {
let mut entries: Vec<_> = std::fs::read_dir(path)?
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.map(|ext| ext.eq_ignore_ascii_case("toml"))
.unwrap_or(false)
})
.map(|e| e.path())
.collect();
entries.sort();
for entry in entries {
let content = std::fs::read_to_string(&entry)?;
let raw: RawRuleDb = toml::from_str(&content)
.map_err(|e| RulesError::Parse(format!("{}: {e}", entry.display())))?;
for waf in raw.waf {
let compiled = Self::compile_waf(waf)
.map_err(|e| RulesError::Parse(format!("{}: {e}", entry.display())))?;
let key = compiled.name.clone();
if !self.rules.contains_key(&key) {
self.names.push(key.clone());
}
self.rules.insert(key, compiled);
}
}
Ok(())
}
fn compile_waf(raw: RawWafRule) -> Result<CompiledWafRule, String> {
let mut signatures = Vec::with_capacity(raw.signature.len());
for sig in raw.signature {
let header_regex = sig
.header_regex
.as_ref()
.map(|p| Regex::new(p).map_err(|e| format!("bad header regex '{p}': {e}")))
.transpose()?;
let cookie_regex = sig
.cookie_regex
.as_ref()
.map(|p| Regex::new(p).map_err(|e| format!("bad cookie regex '{p}': {e}")))
.transpose()?;
let body_regex = sig
.body_regex
.as_ref()
.map(|p| Regex::new(p).map_err(|e| format!("bad body regex '{p}': {e}")))
.transpose()?;
signatures.push(CompiledSignature {
header_name: sig.header_name.map(|s| s.to_ascii_lowercase()),
header_regex,
cookie_regex,
body_regex,
status_code: sig.status_code,
weight: sig.weight,
});
}
Ok(CompiledWafRule {
name: raw.name,
vendor: raw.vendor,
confidence_threshold: raw.confidence_threshold,
evasions: raw.evasions,
source: raw.source,
signatures,
})
}
fn compile_body_regex_set(&mut self) -> Result<(), RulesError> {
let mut patterns: Vec<String> = Vec::new();
let mut map: Vec<BodyPatternRef> = Vec::new();
let mut regexes: Vec<Regex> = Vec::new();
for name in &self.names {
let rule = &self.rules[name];
for (sig_idx, sig) in rule.signatures.iter().enumerate() {
if let Some(ref re) = sig.body_regex {
patterns.push(re.as_str().to_string());
map.push(BodyPatternRef {
waf_name: name.clone(),
sig_index: sig_idx,
weight: sig.weight,
});
regexes.push(re.clone());
}
}
}
if !patterns.is_empty() {
let set = RegexSet::new(&patterns)
.map_err(|e| RulesError::Parse(format!("failed to compile body RegexSet: {e}")))?;
self.body_regex_set = Some(set);
}
self.body_pattern_map = map;
self.body_regexes = regexes;
Ok(())
}
pub fn detect(
&self,
status: u16,
headers: &[(String, String)],
body: &str,
) -> Vec<DetectedWaf> {
let body_hits: Vec<usize> = self
.body_regex_set
.as_ref()
.map(|set| set.matches(body).into_iter().collect())
.unwrap_or_default();
let mut waf_scores: HashMap<&str, (f64, Vec<String>)> = HashMap::new();
for &pattern_idx in &body_hits {
let pref = &self.body_pattern_map[pattern_idx];
let entry = waf_scores
.entry(&pref.waf_name)
.or_insert_with(|| (0.0, Vec::new()));
entry.0 += pref.weight;
if let Some(m) = self.body_regexes[pattern_idx].find(body) {
let snippet = &body[m.start()..m.end().min(m.start() + 40)];
entry.1.push(format!("body: {snippet}"));
}
}
for name in &self.names {
let rule = &self.rules[name];
for sig in &rule.signatures {
if sig.header_regex.is_none()
&& sig.cookie_regex.is_none()
&& sig.status_code.is_none()
{
continue;
}
let mut matched = false;
let entry = waf_scores.entry(name).or_insert_with(|| (0.0, Vec::new()));
if let Some(expected) = sig.status_code
&& status == expected
{
matched = true;
entry.1.push(format!("status: {status}"));
}
if let Some(ref re) = sig.header_regex {
let hname = sig.header_name.as_deref().unwrap_or("");
for (k, v) in headers {
if (hname.is_empty() || k.eq_ignore_ascii_case(hname))
&& let Some(m) = re.find(v)
{
matched = true;
entry.1.push(format!(
"header {k}: {}",
&v[m.start()..m.end().min(m.start() + 40)]
));
break;
}
}
}
if let Some(ref re) = sig.cookie_regex {
for (k, v) in headers {
if k.eq_ignore_ascii_case("set-cookie") && re.is_match(v) {
matched = true;
entry.1.push(format!("cookie: {k}"));
break;
}
}
}
if matched {
entry.0 += sig.weight;
}
}
}
let mut results: Vec<DetectedWaf> = waf_scores
.into_iter()
.filter_map(|(name, (score, indicators))| {
let rule = &self.rules[name];
if score >= rule.confidence_threshold {
Some(DetectedWaf {
name: name.to_string(),
confidence: score.min(1.0),
indicators,
})
} else {
None
}
})
.collect();
results.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
#[must_use]
#[allow(dead_code)]
pub fn evasions_for(&self, name: &str) -> Vec<&str> {
self.rules
.get(name)
.map(|r| r.evasions.iter().map(String::as_str).collect())
.unwrap_or_default()
}
#[must_use]
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.rules.len()
}
#[must_use]
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.rules.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct DetectedWaf {
pub name: String,
pub confidence: f64,
pub indicators: Vec<String>,
}
#[derive(Debug)]
pub enum RulesError {
Io(std::io::Error),
Parse(String),
}
impl std::fmt::Display for RulesError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RulesError::Io(e) => write!(f, "io error: {e}"),
RulesError::Parse(s) => write!(f, "parse error: {s}"),
}
}
}
impl std::error::Error for RulesError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
RulesError::Io(e) => Some(e),
RulesError::Parse(_) => None,
}
}
}
impl From<std::io::Error> for RulesError {
fn from(e: std::io::Error) -> Self {
RulesError::Io(e)
}
}
pub fn with_engine<F, R>(f: F) -> R
where
F: FnOnce(&RuleEngine) -> R,
{
let guard = RULE_DB.read().expect("RULE_DB poisoned");
f(&guard)
}
pub fn reload() -> Result<(), RulesError> {
let new_engine = RuleEngine::load_embedded()?;
let mut guard = RULE_DB.write().expect("RULE_DB poisoned");
*guard = new_engine;
Ok(())
}
#[must_use]
pub fn detect(status: u16, headers: &[(String, String)], body: &str) -> Vec<DetectedWaf> {
with_engine(|engine| engine.detect(status, headers, body))
}
#[must_use]
pub fn supported_wafs() -> Vec<String> {
with_engine(|engine| engine.names.clone())
}
#[must_use]
pub fn suggest_evasion(waf_name: &str) -> Vec<&'static str> {
let list: Vec<String> = with_engine(|engine| {
engine
.rules
.get(waf_name)
.map(|r| r.evasions.clone())
.unwrap_or_else(|| {
vec![
"CaseAlternation".into(),
"SqlCommentInsertion".into(),
"DoubleUrlEncode".into(),
"ContentTypeSwitch".into(),
]
})
});
let leaked: Vec<&'static str> = list
.into_iter()
.map(|s| -> &'static str { Box::leak(s.into_boxed_str()) })
.collect();
leaked
}
#[derive(Debug, Clone, Copy)]
pub struct DetectConfig {
pub threshold: f64,
pub ambiguity_delta: f64,
}
impl Default for DetectConfig {
fn default() -> Self {
Self {
threshold: 0.3,
ambiguity_delta: 0.15,
}
}
}
#[must_use]
pub fn detect_with_config(
status: u16,
headers: &[(String, String)],
body: &str,
config: DetectConfig,
) -> Vec<DetectedWaf> {
let mut results = detect(status, headers, body);
results.retain(|r| r.confidence >= config.threshold);
if results.len() >= 2 {
let delta = results[0].confidence - results[1].confidence;
if delta < config.ambiguity_delta {
let mut keep = 2;
for window in results.windows(2) {
if window[0].confidence - window[1].confidence < config.ambiguity_delta {
keep += 1;
} else {
break;
}
}
results.truncate(keep);
} else {
results.truncate(1);
}
}
results
}