use crate::canon::{Channel, canonicalize};
use crate::error::{Result, WafModelError};
use crate::normalize::{Transform, apply_chain};
use crate::outcome::Outcome;
use regex::bytes::Regex;
use wafrift_types::Request;
use wafrift_types::hash::{FNV_OFFSET_64, FNV_PRIME_64};
pub trait WafOracle {
fn classify(&mut self, req: &Request) -> Result<Outcome>;
fn queries(&self) -> u64;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
pub struct ChannelSet(u16);
impl ChannelSet {
const fn bit(ch: Channel) -> u16 {
1 << (ch as u16)
}
#[must_use]
pub const fn none() -> Self {
ChannelSet(0)
}
#[must_use]
pub const fn all() -> Self {
ChannelSet(0x00FF)
}
#[must_use]
pub const fn with(self, ch: Channel) -> Self {
ChannelSet(self.0 | Self::bit(ch))
}
#[must_use]
pub const fn contains(self, ch: Channel) -> bool {
self.0 & Self::bit(ch) != 0
}
}
impl FromIterator<Channel> for ChannelSet {
fn from_iter<I: IntoIterator<Item = Channel>>(iter: I) -> Self {
iter.into_iter().fold(ChannelSet::none(), ChannelSet::with)
}
}
#[derive(Debug, Clone)]
pub struct Rule {
pub id: String,
pub channels: ChannelSet,
pub transforms: Vec<Transform>,
pub pattern: Regex,
pub score: u32,
}
#[derive(Debug)]
pub struct SimRegexWaf {
rules: Vec<Rule>,
threshold: u32,
queries: u64,
}
impl SimRegexWaf {
#[must_use]
pub fn new(rules: Vec<Rule>, threshold: u32) -> Self {
SimRegexWaf {
rules,
threshold,
queries: 0,
}
}
#[must_use]
pub fn rule_count(&self) -> usize {
self.rules.len()
}
#[must_use]
pub fn threshold(&self) -> u32 {
self.threshold
}
#[must_use]
pub fn fingerprint(&self) -> String {
let mut lines: Vec<String> = self
.rules
.iter()
.map(|r| format!("{}|{}|{}", r.id, r.pattern.as_str(), r.score))
.collect();
lines.sort();
let mut h: u64 = FNV_OFFSET_64;
for byte in format!("t={};{}", self.threshold, lines.join("\n")).bytes() {
h ^= u64::from(byte);
h = h.wrapping_mul(FNV_PRIME_64);
}
format!("{h:016x}")
}
#[must_use]
pub fn classify_uncounted(&self, req: &Request) -> Outcome {
let view = canonicalize(req);
let mut total = 0u32;
for rule in &self.rules {
let hit = view
.segments
.iter()
.filter(|s| rule.channels.contains(s.channel))
.any(|s| {
let t = apply_chain(&rule.transforms, &s.bytes);
rule.pattern.is_match(&t)
});
if hit {
total = total.saturating_add(rule.score);
if total >= self.threshold {
return Outcome::Block;
}
}
}
Outcome::Pass
}
#[must_use]
pub fn rules(&self) -> &[Rule] {
&self.rules
}
#[must_use]
pub fn with_rules_added(&self, extra: Vec<Rule>) -> SimRegexWaf {
let mut rules = self.rules.clone();
rules.extend(extra);
SimRegexWaf::new(rules, self.threshold)
}
pub fn from_toml(src: &str) -> Result<Self> {
#[derive(serde::Deserialize)]
struct RawRule {
id: String,
channels: Vec<Channel>,
transforms: Vec<Transform>,
pattern: String,
score: u32,
}
#[derive(serde::Deserialize)]
struct Doc {
threshold: u32,
rule: Vec<RawRule>,
}
let doc: Doc = toml::from_str(src)
.map_err(|e| WafModelError::Artifact(format!("ruleset TOML: {e}")))?;
let mut rules = Vec::with_capacity(doc.rule.len());
const MAX_PATTERN_LEN: usize = 16 * 1024;
for r in doc.rule {
if r.pattern.len() > MAX_PATTERN_LEN {
return Err(WafModelError::Artifact(format!(
"rule {} pattern is {} bytes; max {} (defends against \
hostile ruleset compile-time blowup)",
r.id,
r.pattern.len(),
MAX_PATTERN_LEN
)));
}
let pattern = regex::bytes::RegexBuilder::new(&r.pattern)
.size_limit(wafrift_types::REGEX_NFA_SIZE_LIMIT)
.build()
.map_err(|source| WafModelError::BadRule {
rule: r.id.clone(),
source,
})?;
rules.push(Rule {
id: r.id,
channels: r.channels.into_iter().collect(),
transforms: r.transforms,
pattern,
score: r.score,
});
}
Ok(SimRegexWaf::new(rules, doc.threshold))
}
}
impl WafOracle for SimRegexWaf {
fn classify(&mut self, req: &Request) -> Result<Outcome> {
self.queries += 1;
Ok(self.classify_uncounted(req))
}
fn queries(&self) -> u64 {
self.queries
}
}
pub struct FnOracle<F> {
f: F,
queries: u64,
}
impl<F> FnOracle<F>
where
F: FnMut(&Request) -> Result<Outcome>,
{
pub fn new(f: F) -> Self {
FnOracle { f, queries: 0 }
}
}
impl<F> WafOracle for FnOracle<F>
where
F: FnMut(&Request) -> Result<Outcome>,
{
fn classify(&mut self, req: &Request) -> Result<Outcome> {
self.queries += 1;
(self.f)(req)
}
fn queries(&self) -> u64 {
self.queries
}
}