use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use base64::Engine;
use parking_lot::Mutex;
use percent_encoding::percent_decode_str;
use regex::Regex;
use serde_json::Value;
use crate::state::{IpSet, RegexPatternSet, WebAcl};
pub const FAKECLOUD_GEO_COUNTRY_HEADER: &str = "x-fakecloud-geo-country";
pub const DEFAULT_RATE_WINDOW_SECS: u64 = 300;
#[derive(Debug, Clone)]
pub struct WafRequest<'a> {
pub method: &'a str,
pub uri: &'a str,
pub headers: &'a [(String, String)],
pub body: &'a [u8],
pub query: &'a str,
pub source_ip: IpAddr,
pub country: Option<&'a str>,
pub body_size_bytes: u64,
}
impl<'a> WafRequest<'a> {
pub fn from_parts(
method: &'a str,
uri: &'a str,
headers: &'a [(String, String)],
body: &'a [u8],
query: &'a str,
source_ip: IpAddr,
) -> Self {
Self {
method,
uri,
headers,
body,
query,
source_ip,
country: None,
body_size_bytes: body.len() as u64,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WafAction {
Allow,
Block,
Count,
Captcha,
Challenge,
}
impl WafAction {
pub fn as_str(self) -> &'static str {
match self {
WafAction::Allow => "Allow",
WafAction::Block => "Block",
WafAction::Count => "Count",
WafAction::Captcha => "Captcha",
WafAction::Challenge => "Challenge",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WafEvaluation {
pub action: WafAction,
pub count_rules: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WafVerdict {
pub action: WafAction,
pub terminating_rule_id: Option<String>,
pub labels: Vec<String>,
pub blocked: bool,
pub count_rules: Vec<String>,
pub custom_response_body_key: Option<String>,
pub custom_response_status: Option<u16>,
}
#[derive(Debug, Default)]
pub struct RateLimiter {
inner: Mutex<HashMap<(String, String), Vec<i64>>>,
}
impl RateLimiter {
pub fn new() -> Self {
Self::default()
}
pub fn record(&self, rule_id: &str, agg_key: &str, window_secs: u64, now: i64) -> u64 {
let mut guard = self.inner.lock();
let bucket = guard
.entry((rule_id.to_string(), agg_key.to_string()))
.or_default();
let cutoff = now - window_secs as i64;
bucket.retain(|t| *t > cutoff);
bucket.push(now);
bucket.len() as u64
}
pub fn clear(&self) {
self.inner.lock().clear();
}
}
pub fn evaluate(
req: &WafRequest,
web_acl: &WebAcl,
ipsets: &HashMap<String, IpSet>,
regex_sets: &HashMap<String, RegexPatternSet>,
) -> WafAction {
evaluate_detailed(req, web_acl, ipsets, regex_sets).action
}
pub fn evaluate_detailed(
req: &WafRequest,
web_acl: &WebAcl,
ipsets: &HashMap<String, IpSet>,
regex_sets: &HashMap<String, RegexPatternSet>,
) -> WafEvaluation {
let verdict = evaluate_inner(req, web_acl, ipsets, regex_sets, None, current_epoch_secs());
WafEvaluation {
action: verdict.action,
count_rules: verdict.count_rules,
}
}
pub fn evaluate_web_acl(
web_acl: &WebAcl,
request: &WafRequest,
ipsets: &HashMap<String, IpSet>,
regex_sets: &HashMap<String, RegexPatternSet>,
rate_limiter: &RateLimiter,
now_epoch_secs: i64,
) -> WafVerdict {
evaluate_inner(
request,
web_acl,
ipsets,
regex_sets,
Some(rate_limiter),
now_epoch_secs,
)
}
impl WebAcl {
pub fn evaluate(
&self,
req: &WafRequest,
ipsets: &HashMap<String, IpSet>,
regex_sets: &HashMap<String, RegexPatternSet>,
) -> WafAction {
evaluate(req, self, ipsets, regex_sets)
}
pub fn evaluate_detailed(
&self,
req: &WafRequest,
ipsets: &HashMap<String, IpSet>,
regex_sets: &HashMap<String, RegexPatternSet>,
) -> WafEvaluation {
evaluate_detailed(req, self, ipsets, regex_sets)
}
}
fn current_epoch_secs() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn evaluate_inner(
req: &WafRequest,
web_acl: &WebAcl,
ipsets: &HashMap<String, IpSet>,
regex_sets: &HashMap<String, RegexPatternSet>,
rate_limiter: Option<&RateLimiter>,
now_epoch_secs: i64,
) -> WafVerdict {
let mut rules: Vec<&Value> = web_acl.rules.iter().collect();
rules.sort_by_key(|r| r.get("Priority").and_then(Value::as_i64).unwrap_or(0));
let mut labels: HashSet<String> = HashSet::new();
let mut emitted_labels: Vec<String> = Vec::new();
let mut count_rules: Vec<String> = Vec::new();
for rule in rules {
let Some(stmt) = rule.get("Statement") else {
continue;
};
let ctx = StmtCtx {
req,
ipsets,
regex_sets,
labels: &labels,
rule_id: rule
.get("Name")
.and_then(Value::as_str)
.unwrap_or("")
.to_owned(),
rate_limiter,
now_epoch_secs,
};
if !eval_statement(stmt, &ctx) {
continue;
}
if let Some(arr) = rule.get("RuleLabels").and_then(Value::as_array) {
for label in arr {
if let Some(name) = label.get("Name").and_then(Value::as_str) {
if labels.insert(name.to_owned()) {
emitted_labels.push(name.to_owned());
}
}
}
}
if let Some(action) = rule.get("Action").and_then(rule_action) {
if action == WafAction::Count {
if let Some(name) = rule.get("Name").and_then(Value::as_str) {
count_rules.push(name.to_owned());
}
continue;
}
let (body_key, status) = block_custom_response(rule);
let blocked = matches!(
action,
WafAction::Block | WafAction::Captcha | WafAction::Challenge
);
return WafVerdict {
action,
terminating_rule_id: rule.get("Name").and_then(Value::as_str).map(str::to_owned),
labels: emitted_labels,
blocked,
count_rules,
custom_response_body_key: body_key,
custom_response_status: status,
};
}
}
let default = default_action(web_acl);
WafVerdict {
action: default,
terminating_rule_id: None,
labels: emitted_labels,
blocked: default == WafAction::Block,
count_rules,
custom_response_body_key: None,
custom_response_status: None,
}
}
fn block_custom_response(rule: &Value) -> (Option<String>, Option<u16>) {
let action = rule.get("Action").and_then(Value::as_object);
let Some(action) = action else {
return (None, None);
};
for key in ["Block", "Captcha", "Challenge"] {
if let Some(spec) = action.get(key).and_then(|v| v.get("CustomResponse")) {
let body_key = spec
.get("CustomResponseBodyKey")
.and_then(Value::as_str)
.map(str::to_owned);
let status = spec
.get("ResponseCode")
.and_then(Value::as_u64)
.and_then(|v| u16::try_from(v).ok());
return (body_key, status);
}
}
(None, None)
}
fn default_action(web_acl: &WebAcl) -> WafAction {
if web_acl.default_action.get("Block").is_some() {
WafAction::Block
} else {
WafAction::Allow
}
}
fn rule_action(action: &Value) -> Option<WafAction> {
if action.get("Allow").is_some() {
Some(WafAction::Allow)
} else if action.get("Block").is_some() {
Some(WafAction::Block)
} else if action.get("Count").is_some() {
Some(WafAction::Count)
} else if action.get("Captcha").is_some() {
Some(WafAction::Captcha)
} else if action.get("Challenge").is_some() {
Some(WafAction::Challenge)
} else {
None
}
}
struct StmtCtx<'a> {
req: &'a WafRequest<'a>,
ipsets: &'a HashMap<String, IpSet>,
regex_sets: &'a HashMap<String, RegexPatternSet>,
labels: &'a HashSet<String>,
rule_id: String,
rate_limiter: Option<&'a RateLimiter>,
now_epoch_secs: i64,
}
fn eval_statement(stmt: &Value, ctx: &StmtCtx) -> bool {
let Some(obj) = stmt.as_object() else {
return false;
};
if let Some(s) = obj.get("ByteMatchStatement") {
return eval_byte_match(s, ctx.req);
}
if let Some(s) = obj.get("SqliMatchStatement") {
return eval_sqli_match(s, ctx.req);
}
if let Some(s) = obj.get("XssMatchStatement") {
return eval_xss_match(s, ctx.req);
}
if let Some(s) = obj.get("GeoMatchStatement") {
return eval_geo_match(s, ctx.req);
}
if let Some(s) = obj.get("IPSetReferenceStatement") {
return eval_ipset_ref(s, ctx.req, ctx.ipsets);
}
if let Some(s) = obj.get("RegexPatternSetReferenceStatement") {
return eval_regex_set_ref(s, ctx.req, ctx.regex_sets);
}
if let Some(s) = obj.get("RegexMatchStatement") {
return eval_regex_match(s, ctx.req);
}
if let Some(s) = obj.get("AndStatement") {
return eval_and(s, ctx);
}
if let Some(s) = obj.get("OrStatement") {
return eval_or(s, ctx);
}
if let Some(s) = obj.get("NotStatement") {
return eval_not(s, ctx);
}
if let Some(s) = obj.get("LabelMatchStatement") {
return eval_label_match(s, ctx.labels);
}
if let Some(s) = obj.get("SizeConstraintStatement") {
return eval_size_constraint(s, ctx.req);
}
if let Some(s) = obj.get("RateBasedStatement") {
return eval_rate_based(s, ctx);
}
if let Some(s) = obj.get("ManagedRuleGroupStatement") {
return eval_managed_rule_group(s, ctx.req);
}
false
}
fn eval_and(stmt: &Value, ctx: &StmtCtx) -> bool {
let Some(arr) = stmt.get("Statements").and_then(Value::as_array) else {
return false;
};
!arr.is_empty() && arr.iter().all(|s| eval_statement(s, ctx))
}
fn eval_or(stmt: &Value, ctx: &StmtCtx) -> bool {
let Some(arr) = stmt.get("Statements").and_then(Value::as_array) else {
return false;
};
arr.iter().any(|s| eval_statement(s, ctx))
}
fn eval_not(stmt: &Value, ctx: &StmtCtx) -> bool {
let Some(inner) = stmt.get("Statement") else {
return false;
};
!eval_statement(inner, ctx)
}
fn eval_byte_match(stmt: &Value, req: &WafRequest) -> bool {
let Some(needle_b64) = stmt.get("SearchString").and_then(Value::as_str) else {
return false;
};
let needle: Vec<u8> = base64::engine::general_purpose::STANDARD
.decode(needle_b64)
.unwrap_or_else(|_| needle_b64.as_bytes().to_vec());
if needle.is_empty() {
return false;
}
let Some(constraint) = stmt.get("PositionalConstraint").and_then(Value::as_str) else {
return false;
};
let transformations = stmt.get("TextTransformations");
let fields = collect_fields(stmt.get("FieldToMatch"), req);
fields.iter().any(|raw| {
let candidate = apply_transformations(raw, transformations);
positional_match(&candidate, &needle, constraint)
})
}
fn positional_match(haystack: &[u8], needle: &[u8], constraint: &str) -> bool {
match constraint {
"EXACTLY" => haystack == needle,
"STARTS_WITH" => haystack.starts_with(needle),
"ENDS_WITH" => haystack.ends_with(needle),
"CONTAINS" => bytes_contains(haystack, needle),
"CONTAINS_WORD" => contains_word(haystack, needle),
_ => false,
}
}
fn bytes_contains(haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() || needle.len() > haystack.len() {
return false;
}
haystack.windows(needle.len()).any(|w| w == needle)
}
fn contains_word(haystack: &[u8], needle: &[u8]) -> bool {
if !bytes_contains(haystack, needle) {
return false;
}
let n = needle.len();
haystack.windows(n).enumerate().any(|(i, w)| {
if w != needle {
return false;
}
let left_ok = i == 0 || !is_word_byte(haystack[i - 1]);
let right_ok = i + n == haystack.len() || !is_word_byte(haystack[i + n]);
left_ok && right_ok
})
}
fn is_word_byte(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_'
}
fn eval_sqli_match(stmt: &Value, req: &WafRequest) -> bool {
let transformations = stmt.get("TextTransformations");
let fields = collect_fields(stmt.get("FieldToMatch"), req);
let tokens: &[&[u8]] = &[
b"union select",
b"or 1=1",
b"' or '1'='1",
b"'; drop",
b"--",
b"/*",
b"*/",
b"xp_cmdshell",
];
fields.iter().any(|raw| {
let lower = lowercase_bytes(&apply_transformations(raw, transformations));
tokens.iter().any(|t| bytes_contains(&lower, t))
})
}
fn eval_xss_match(stmt: &Value, req: &WafRequest) -> bool {
let transformations = stmt.get("TextTransformations");
let fields = collect_fields(stmt.get("FieldToMatch"), req);
let tokens: &[&[u8]] = &[
b"<script",
b"</script",
b"javascript:",
b"onerror=",
b"onload=",
b"onclick=",
b"<iframe",
];
fields.iter().any(|raw| {
let lower = lowercase_bytes(&apply_transformations(raw, transformations));
tokens.iter().any(|t| bytes_contains(&lower, t))
})
}
fn eval_geo_match(stmt: &Value, req: &WafRequest) -> bool {
let Some(country) = req.country else {
return false;
};
let Some(arr) = stmt.get("CountryCodes").and_then(Value::as_array) else {
return false;
};
arr.iter()
.filter_map(Value::as_str)
.any(|c| c.eq_ignore_ascii_case(country))
}
fn eval_ipset_ref(stmt: &Value, req: &WafRequest, ipsets: &HashMap<String, IpSet>) -> bool {
let Some(arn) = stmt.get("ARN").and_then(Value::as_str) else {
return false;
};
let Some(set) = ipsets.get(arn) else {
return false;
};
set.addresses
.iter()
.any(|cidr| cidr_contains(cidr, &req.source_ip))
}
fn eval_regex_set_ref(
stmt: &Value,
req: &WafRequest,
regex_sets: &HashMap<String, RegexPatternSet>,
) -> bool {
let Some(arn) = stmt.get("ARN").and_then(Value::as_str) else {
return false;
};
let Some(set) = regex_sets.get(arn) else {
return false;
};
let transformations = stmt.get("TextTransformations");
let fields = collect_fields(stmt.get("FieldToMatch"), req);
let patterns: Vec<Regex> = set
.regular_expressions
.iter()
.filter_map(|p| p.get("RegexString").and_then(Value::as_str))
.filter_map(|s| Regex::new(s).ok())
.collect();
if patterns.is_empty() {
return false;
}
fields.iter().any(|raw| {
let candidate = apply_transformations(raw, transformations);
let Ok(text) = std::str::from_utf8(&candidate) else {
return false;
};
patterns.iter().any(|r| r.is_match(text))
})
}
fn eval_regex_match(stmt: &Value, req: &WafRequest) -> bool {
let Some(pattern) = stmt.get("RegexString").and_then(Value::as_str) else {
return false;
};
let Ok(re) = Regex::new(pattern) else {
return false;
};
let transformations = stmt.get("TextTransformations");
let fields = collect_fields(stmt.get("FieldToMatch"), req);
fields.iter().any(|raw| {
let candidate = apply_transformations(raw, transformations);
std::str::from_utf8(&candidate)
.map(|t| re.is_match(t))
.unwrap_or(false)
})
}
fn eval_label_match(stmt: &Value, labels: &HashSet<String>) -> bool {
let Some(key) = stmt.get("Key").and_then(Value::as_str) else {
return false;
};
let scope = stmt.get("Scope").and_then(Value::as_str).unwrap_or("LABEL");
labels.iter().any(|l| match scope {
"NAMESPACE" => l.starts_with(key),
_ => l == key,
})
}
fn eval_size_constraint(stmt: &Value, req: &WafRequest) -> bool {
let Some(op) = stmt.get("ComparisonOperator").and_then(Value::as_str) else {
return false;
};
let Some(size) = stmt.get("Size").and_then(Value::as_i64) else {
return false;
};
if size < 0 {
return false;
}
let target_size = field_size(
stmt.get("FieldToMatch"),
req,
stmt.get("TextTransformations"),
);
let lhs = target_size as i64;
match op {
"EQ" => lhs == size,
"NE" => lhs != size,
"LE" => lhs <= size,
"LT" => lhs < size,
"GE" => lhs >= size,
"GT" => lhs > size,
_ => false,
}
}
fn field_size(field: Option<&Value>, req: &WafRequest, xforms: Option<&Value>) -> u64 {
let Some(field) = field else {
return 0;
};
let Some(obj) = field.as_object() else {
return 0;
};
if obj.contains_key("Body") || obj.contains_key("JsonBody") {
return req.body_size_bytes;
}
collect_fields(Some(field), req)
.iter()
.map(|raw| apply_transformations(raw, xforms).len() as u64)
.max()
.unwrap_or(0)
}
enum AggregateKey {
Key(String),
Fallback(bool),
}
fn eval_rate_based(stmt: &Value, ctx: &StmtCtx) -> bool {
let Some(limiter) = ctx.rate_limiter else {
return false;
};
let Some(limit) = stmt.get("Limit").and_then(Value::as_u64) else {
return false;
};
if limit == 0 {
return false;
}
let window = stmt
.get("EvaluationWindowSec")
.and_then(Value::as_u64)
.unwrap_or(DEFAULT_RATE_WINDOW_SECS);
if let Some(scope_down) = stmt.get("ScopeDownStatement") {
if !eval_statement(scope_down, ctx) {
return false;
}
}
let agg = stmt
.get("AggregateKeyType")
.and_then(Value::as_str)
.unwrap_or("IP");
let agg_key = match rate_aggregate_key(agg, ctx.req, stmt) {
AggregateKey::Key(k) => k,
AggregateKey::Fallback(b) => return b,
};
let count = limiter.record(&ctx.rule_id, &agg_key, window, ctx.now_epoch_secs);
count > limit
}
fn rate_aggregate_key(agg: &str, req: &WafRequest, stmt: &Value) -> AggregateKey {
match agg {
"IP" => AggregateKey::Key(req.source_ip.to_string()),
"FORWARDED_IP" => {
let cfg = stmt.get("ForwardedIPConfig");
let header_name = cfg
.and_then(|c| c.get("HeaderName"))
.and_then(Value::as_str)
.unwrap_or("X-Forwarded-For");
let fallback_match = cfg
.and_then(|c| c.get("FallbackBehavior"))
.and_then(Value::as_str)
.map(|b| b.eq_ignore_ascii_case("MATCH"))
.unwrap_or(false);
let raw = req
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(header_name))
.map(|(_, v)| {
v.split(',').next().unwrap_or(v).trim().to_string()
});
match raw {
None => AggregateKey::Fallback(fallback_match),
Some(v) if v.parse::<std::net::IpAddr>().is_err() => {
AggregateKey::Fallback(fallback_match)
}
Some(v) => AggregateKey::Key(v),
}
}
"CONSTANT" => AggregateKey::Key("__constant__".to_string()),
"CUSTOM_KEYS" => {
let Some(arr) = stmt.get("CustomKeys").and_then(Value::as_array) else {
return AggregateKey::Fallback(false);
};
let parts: Vec<String> = arr
.iter()
.filter_map(|k| custom_key_value(k, req))
.collect();
if parts.is_empty() {
AggregateKey::Fallback(false)
} else {
AggregateKey::Key(parts.join("|"))
}
}
_ => AggregateKey::Fallback(false),
}
}
fn custom_key_value(key: &Value, req: &WafRequest) -> Option<String> {
if let Some(h) = key.get("Header") {
let name = h.get("Name").and_then(Value::as_str)?;
return req
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.clone());
}
if let Some(q) = key.get("QueryArgument") {
let name = q.get("Name").and_then(Value::as_str)?;
let vals = query_arg_values(req.query, name);
if vals.is_empty() {
return None;
}
return String::from_utf8(vals[0].clone()).ok();
}
if key.get("HTTPMethod").is_some() {
return Some(req.method.to_string());
}
if key.get("UriPath").is_some() {
return Some(req.uri.to_string());
}
if key.get("IP").is_some() {
return Some(req.source_ip.to_string());
}
None
}
fn eval_managed_rule_group(stmt: &Value, req: &WafRequest) -> bool {
let vendor = stmt.get("VendorName").and_then(Value::as_str).unwrap_or("");
let name = stmt.get("Name").and_then(Value::as_str).unwrap_or("");
if !vendor.eq_ignore_ascii_case("AWS") {
return false;
}
match name {
"AWSManagedRulesCommonRuleSet" => {
common_ruleset_match(req)
}
"AWSManagedRulesSQLiRuleSet" => {
let fields = collect_fields(Some(&serde_json::json!({"QueryString": {}})), req);
let tokens: &[&[u8]] = &[b"union select", b"or 1=1", b"' or '1'='1", b"--"];
fields.iter().any(|raw| {
let lower = lowercase_bytes(raw);
tokens.iter().any(|t| bytes_contains(&lower, t))
})
}
"AWSManagedRulesKnownBadInputsRuleSet" => req
.headers
.iter()
.any(|(k, v)| k.eq_ignore_ascii_case("user-agent") && v.is_empty()),
_ => false,
}
}
fn common_ruleset_match(req: &WafRequest) -> bool {
const ADMIN_PREFIXES: &[&str] = &[
"/admin",
"/wp-admin",
"/phpmyadmin",
"/.env",
"/.git",
"/cgi-bin/",
];
let uri = req.uri;
if ADMIN_PREFIXES.iter().any(|p| uri.starts_with(p)) {
return true;
}
req.headers
.iter()
.any(|(k, v)| k.eq_ignore_ascii_case("x-crs-test") && !v.is_empty())
}
fn collect_fields(field: Option<&Value>, req: &WafRequest) -> Vec<Vec<u8>> {
let Some(field) = field else {
return Vec::new();
};
let Some(obj) = field.as_object() else {
return Vec::new();
};
if obj.contains_key("Method") {
return vec![req.method.as_bytes().to_vec()];
}
if obj.contains_key("UriPath") {
return vec![req.uri.as_bytes().to_vec()];
}
if obj.contains_key("QueryString") {
return vec![req.query.as_bytes().to_vec()];
}
if obj.contains_key("Body") || obj.contains_key("JsonBody") {
return vec![req.body.to_vec()];
}
if let Some(sh) = obj.get("SingleHeader") {
if let Some(name) = sh.get("Name").and_then(Value::as_str) {
return req
.headers
.iter()
.filter(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_bytes().to_vec())
.collect();
}
}
if obj.contains_key("AllHeaders") || obj.contains_key("Headers") {
return req
.headers
.iter()
.map(|(k, v)| format!("{k}:{v}").into_bytes())
.collect();
}
if let Some(sqa) = obj.get("SingleQueryArgument") {
if let Some(name) = sqa.get("Name").and_then(Value::as_str) {
return query_arg_values(req.query, name);
}
}
Vec::new()
}
fn query_arg_values(query: &str, name: &str) -> Vec<Vec<u8>> {
query
.split('&')
.filter_map(|kv| {
let mut parts = kv.splitn(2, '=');
let k = parts.next()?;
let v = parts.next().unwrap_or("");
if k.eq_ignore_ascii_case(name) {
Some(v.as_bytes().to_vec())
} else {
None
}
})
.collect()
}
fn apply_transformations(raw: &[u8], xforms: Option<&Value>) -> Vec<u8> {
let Some(arr) = xforms.and_then(Value::as_array) else {
return raw.to_vec();
};
let mut ordered: Vec<&Value> = arr.iter().collect();
ordered.sort_by_key(|t| t.get("Priority").and_then(Value::as_i64).unwrap_or(0));
let mut current = raw.to_vec();
for t in ordered {
let Some(kind) = t.get("Type").and_then(Value::as_str) else {
continue;
};
current = match kind {
"NONE" => current,
"LOWERCASE" => lowercase_bytes(¤t),
"URL_DECODE" => url_decode_bytes(¤t),
"COMPRESS_WHITE_SPACE" => compress_whitespace(¤t),
_ => current,
};
}
current
}
fn lowercase_bytes(input: &[u8]) -> Vec<u8> {
input.iter().map(|b| b.to_ascii_lowercase()).collect()
}
fn url_decode_bytes(input: &[u8]) -> Vec<u8> {
let Ok(s) = std::str::from_utf8(input) else {
return input.to_vec();
};
percent_decode_str(s).collect()
}
fn compress_whitespace(input: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len());
let mut last_was_ws = false;
for &b in input {
let is_ws = matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0c | 0x0b);
if is_ws {
if !last_was_ws {
out.push(b' ');
}
last_was_ws = true;
} else {
out.push(b);
last_was_ws = false;
}
}
out
}
fn cidr_contains(cidr: &str, ip: &IpAddr) -> bool {
let Some((net_str, prefix_str)) = cidr.split_once('/') else {
return net_str_eq(cidr, ip);
};
let Ok(prefix) = prefix_str.parse::<u8>() else {
return false;
};
match (net_str.parse::<IpAddr>(), ip) {
(Ok(IpAddr::V4(net)), IpAddr::V4(addr)) if prefix <= 32 => {
mask_match(&net.octets(), &addr.octets(), prefix)
}
(Ok(IpAddr::V6(net)), IpAddr::V6(addr)) if prefix <= 128 => {
mask_match(&net.octets(), &addr.octets(), prefix)
}
_ => false,
}
}
fn net_str_eq(s: &str, ip: &IpAddr) -> bool {
s.parse::<IpAddr>().map(|p| p == *ip).unwrap_or(false)
}
fn mask_match(net: &[u8], addr: &[u8], prefix: u8) -> bool {
let full_bytes = (prefix / 8) as usize;
let extra_bits = prefix % 8;
if net.len() != addr.len() || full_bytes > net.len() {
return false;
}
if net[..full_bytes] != addr[..full_bytes] {
return false;
}
if extra_bits == 0 {
return true;
}
let mask = 0xffu8 << (8 - extra_bits);
(net[full_bytes] & mask) == (addr[full_bytes] & mask)
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use serde_json::json;
use std::collections::BTreeMap;
use std::net::Ipv4Addr;
fn make_acl(default: Value, rules: Vec<Value>) -> WebAcl {
WebAcl {
id: "id".into(),
name: "acl".into(),
arn: "arn:aws:wafv2:us-east-1:000000000000:regional/webacl/acl/id".into(),
scope: "REGIONAL".into(),
default_action: default,
description: None,
rules,
visibility_config: json!({}),
capacity: 0,
lock_token: "lt".into(),
label_namespace: "awswaf:000000000000:webacl:acl:".into(),
custom_response_bodies: BTreeMap::new(),
captcha_config: None,
challenge_config: None,
token_domains: Vec::new(),
association_config: None,
data_protection_config: None,
on_source_d_do_s_protection_config: None,
application_config: None,
retrofitted_by_firewall_manager: false,
pre_process_firewall_manager_rule_groups: Vec::new(),
post_process_firewall_manager_rule_groups: Vec::new(),
managed_by_firewall_manager: false,
created_time: Utc::now(),
}
}
fn req(uri: &'static str) -> WafRequest<'static> {
WafRequest {
method: "GET",
uri,
headers: &[],
body: b"",
query: "",
source_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
country: None,
body_size_bytes: 0,
}
}
fn byte_match_uri_contains(needle: &str, action: Value) -> Value {
json!({
"Name": "r",
"Priority": 0,
"Action": action,
"VisibilityConfig": {},
"Statement": {
"ByteMatchStatement": {
"SearchString": needle,
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
"PositionalConstraint": "CONTAINS",
}
}
})
}
#[test]
fn byte_match_contains_returns_block_when_default_allow_with_match() {
let acl = make_acl(
json!({"Allow": {}}),
vec![byte_match_uri_contains("/admin", json!({"Block": {}}))],
);
let action = evaluate(&req("/admin/users"), &acl, &HashMap::new(), &HashMap::new());
assert_eq!(action, WafAction::Block);
}
#[test]
fn byte_match_no_match_returns_default_allow() {
let acl = make_acl(
json!({"Allow": {}}),
vec![byte_match_uri_contains("/admin", json!({"Block": {}}))],
);
let action = evaluate(&req("/public"), &acl, &HashMap::new(), &HashMap::new());
assert_eq!(action, WafAction::Allow);
}
#[test]
fn byte_match_exactly_post_admin() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "exactly",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"ByteMatchStatement": {
"SearchString": "/admin",
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
"PositionalConstraint": "EXACTLY",
}},
})],
);
let mut r = req("/admin");
r.method = "POST";
assert_eq!(
evaluate(&r, &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block,
"POST /admin EXACTLY should block"
);
assert_eq!(
evaluate(&req("/admin/users"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Allow,
"/admin/users is not EXACTLY /admin"
);
}
#[test]
fn ip_set_match_blocks_listed_ip() {
let arn = "arn:aws:wafv2:us-east-1:000000000000:regional/ipset/blocked/abc".to_string();
let mut sets = HashMap::new();
sets.insert(
arn.clone(),
IpSet {
id: "abc".into(),
name: "blocked".into(),
arn: arn.clone(),
scope: "REGIONAL".into(),
description: None,
ip_address_version: "IPV4".into(),
addresses: vec!["10.0.0.0/8".into()],
lock_token: "lt".into(),
created_time: Utc::now(),
},
);
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "r",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"IPSetReferenceStatement": {"ARN": arn}},
})],
);
let mut r = req("/");
r.source_ip = IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3));
assert_eq!(evaluate(&r, &acl, &sets, &HashMap::new()), WafAction::Block);
}
#[test]
fn geo_match_country_code_blocks() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "r",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"GeoMatchStatement": {"CountryCodes": ["DE"]}},
})],
);
let mut r = req("/");
r.country = Some("DE");
assert_eq!(
evaluate(&r, &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block,
"country=DE matches"
);
let mut r2 = req("/");
r2.country = Some("US");
assert_eq!(
evaluate(&r2, &acl, &HashMap::new(), &HashMap::new()),
WafAction::Allow,
"country=US does not match"
);
assert_eq!(
evaluate(&req("/"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Allow,
"missing country header -> no match"
);
}
#[test]
fn regex_match_uri_path() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "r",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RegexMatchStatement": {
"RegexString": "^/api/v[0-9]+/admin$",
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
}},
})],
);
assert_eq!(
evaluate(
&req("/api/v2/admin"),
&acl,
&HashMap::new(),
&HashMap::new()
),
WafAction::Block
);
assert_eq!(
evaluate(
&req("/api/v2/admin/x"),
&acl,
&HashMap::new(),
&HashMap::new()
),
WafAction::Allow
);
}
#[test]
fn and_statement_requires_all() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "r",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"AndStatement": {"Statements": [
{"ByteMatchStatement": {
"SearchString": "/admin",
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
"PositionalConstraint": "CONTAINS",
}},
{"GeoMatchStatement": {"CountryCodes": ["US"]}},
]}},
})],
);
assert_eq!(
evaluate(&req("/admin"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Allow
);
}
#[test]
fn or_statement_takes_first_match() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "r",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"OrStatement": {"Statements": [
{"ByteMatchStatement": {
"SearchString": "/admin",
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
"PositionalConstraint": "CONTAINS",
}},
{"GeoMatchStatement": {"CountryCodes": ["US"]}},
]}},
})],
);
assert_eq!(
evaluate(&req("/admin/x"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
}
#[test]
fn not_statement_inverts() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "r",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"NotStatement": {"Statement": {
"ByteMatchStatement": {
"SearchString": "/admin",
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
"PositionalConstraint": "CONTAINS",
}
}}},
})],
);
assert_eq!(
evaluate(&req("/public"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
}
#[test]
fn regex_pattern_set_reference() {
let arn =
"arn:aws:wafv2:us-east-1:000000000000:regional/regexpatternset/rps/abc".to_string();
let mut sets = HashMap::new();
sets.insert(
arn.clone(),
RegexPatternSet {
id: "abc".into(),
name: "rps".into(),
arn: arn.clone(),
scope: "REGIONAL".into(),
description: None,
regular_expressions: vec![
json!({"RegexString": "^/admin"}),
json!({"RegexString": "^/internal"}),
],
lock_token: "lt".into(),
created_time: Utc::now(),
},
);
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "r",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RegexPatternSetReferenceStatement": {
"ARN": arn,
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
}},
})],
);
assert_eq!(
evaluate(&req("/internal/x"), &acl, &HashMap::new(), &sets),
WafAction::Block
);
}
#[test]
fn default_action_block_when_no_rules_match_and_default_block() {
let acl = make_acl(json!({"Block": {}}), vec![]);
assert_eq!(
evaluate(&req("/anything"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
}
#[test]
fn count_action_does_not_terminate() {
let acl = make_acl(
json!({"Allow": {}}),
vec![byte_match_uri_contains("/admin", json!({"Count": {}})), {
let mut r = byte_match_uri_contains("/admin", json!({"Block": {}}));
r["Priority"] = json!(1);
r["Name"] = json!("r2");
r
}],
);
assert_eq!(
evaluate(&req("/admin/x"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
}
#[test]
fn label_match_after_earlier_rule_emits_label() {
let acl = make_acl(
json!({"Allow": {}}),
vec![
json!({
"Name": "tag",
"Priority": 0,
"Action": {"Count": {}},
"VisibilityConfig": {},
"RuleLabels": [{"Name": "awswaf:custom:admin"}],
"Statement": {"ByteMatchStatement": {
"SearchString": "/admin",
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
"PositionalConstraint": "CONTAINS",
}},
}),
json!({
"Name": "block-by-label",
"Priority": 1,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"LabelMatchStatement": {
"Scope": "LABEL",
"Key": "awswaf:custom:admin",
}},
}),
],
);
assert_eq!(
evaluate(&req("/admin/x"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
}
#[test]
fn size_constraint_body_too_large_blocks() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "size",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"SizeConstraintStatement": {
"FieldToMatch": {"Body": {}},
"ComparisonOperator": "GT",
"Size": 1024,
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
}},
})],
);
let mut r = req("/upload");
r.body_size_bytes = 2048;
assert_eq!(
evaluate(&r, &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
let mut r_small = req("/upload");
r_small.body_size_bytes = 512;
assert_eq!(
evaluate(&r_small, &acl, &HashMap::new(), &HashMap::new()),
WafAction::Allow
);
}
#[test]
fn size_constraint_uri_path_le() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "small-uri",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"SizeConstraintStatement": {
"FieldToMatch": {"UriPath": {}},
"ComparisonOperator": "LT",
"Size": 5,
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
}},
})],
);
assert_eq!(
evaluate(&req("/api"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
assert_eq!(
evaluate(&req("/admin"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Allow
);
}
#[test]
fn rate_based_blocks_after_limit() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "rate",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RateBasedStatement": {
"Limit": 1000,
"AggregateKeyType": "IP",
"EvaluationWindowSec": 300,
}},
})],
);
let limiter = RateLimiter::new();
let now = 1_700_000_000;
let mut r = req("/api");
r.source_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
for _ in 0..1000 {
let v = evaluate_web_acl(&acl, &r, &HashMap::new(), &HashMap::new(), &limiter, now);
assert_eq!(v.action, WafAction::Allow);
assert!(!v.blocked);
}
let v = evaluate_web_acl(&acl, &r, &HashMap::new(), &HashMap::new(), &limiter, now);
assert_eq!(v.action, WafAction::Block);
assert_eq!(v.terminating_rule_id.as_deref(), Some("rate"));
assert!(v.blocked);
}
#[test]
fn rate_based_window_rolls_over() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "rate",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RateBasedStatement": {
"Limit": 2,
"AggregateKeyType": "IP",
"EvaluationWindowSec": 300,
}},
})],
);
let limiter = RateLimiter::new();
let r = req("/api");
let t0 = 1_700_000_000;
evaluate_web_acl(&acl, &r, &HashMap::new(), &HashMap::new(), &limiter, t0);
evaluate_web_acl(&acl, &r, &HashMap::new(), &HashMap::new(), &limiter, t0);
let v3 = evaluate_web_acl(&acl, &r, &HashMap::new(), &HashMap::new(), &limiter, t0);
assert_eq!(v3.action, WafAction::Block, "3rd in window blocks");
let later = t0 + 301;
let v4 = evaluate_web_acl(&acl, &r, &HashMap::new(), &HashMap::new(), &limiter, later);
assert_eq!(v4.action, WafAction::Allow, "after window rolls, allowed");
}
#[test]
fn rate_based_per_ip_independent_buckets() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "rate",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RateBasedStatement": {
"Limit": 1,
"AggregateKeyType": "IP",
"EvaluationWindowSec": 60,
}},
})],
);
let limiter = RateLimiter::new();
let now = 1_700_000_000;
let mut a = req("/api");
a.source_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
let mut b = req("/api");
b.source_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
assert_eq!(
evaluate_web_acl(&acl, &a, &HashMap::new(), &HashMap::new(), &limiter, now).action,
WafAction::Allow
);
assert_eq!(
evaluate_web_acl(&acl, &b, &HashMap::new(), &HashMap::new(), &limiter, now).action,
WafAction::Allow
);
assert_eq!(
evaluate_web_acl(&acl, &a, &HashMap::new(), &HashMap::new(), &limiter, now).action,
WafAction::Block
);
assert_eq!(
evaluate_web_acl(&acl, &b, &HashMap::new(), &HashMap::new(), &limiter, now).action,
WafAction::Block
);
}
#[test]
fn rate_based_forwarded_ip_missing_header_no_match_default() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "rate",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RateBasedStatement": {
"Limit": 1,
"AggregateKeyType": "FORWARDED_IP",
"EvaluationWindowSec": 60,
"ForwardedIPConfig": {"HeaderName": "X-Forwarded-For"},
}},
})],
);
let limiter = RateLimiter::new();
let v = evaluate_web_acl(
&acl,
&req("/"),
&HashMap::new(),
&HashMap::new(),
&limiter,
0,
);
assert_eq!(
v.action,
WafAction::Allow,
"missing XFF defaults to NO_MATCH"
);
}
#[test]
fn rate_based_forwarded_ip_missing_header_with_match_fallback_blocks() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "rate",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RateBasedStatement": {
"Limit": 1,
"AggregateKeyType": "FORWARDED_IP",
"EvaluationWindowSec": 60,
"ForwardedIPConfig": {
"HeaderName": "X-Forwarded-For",
"FallbackBehavior": "MATCH"
},
}},
})],
);
let limiter = RateLimiter::new();
let v = evaluate_web_acl(
&acl,
&req("/"),
&HashMap::new(),
&HashMap::new(),
&limiter,
0,
);
assert_eq!(
v.action,
WafAction::Block,
"missing XFF with FallbackBehavior=MATCH blocks regardless of count"
);
}
#[test]
fn rate_based_forwarded_ip_malformed_value_uses_fallback_not_counter() {
let mk_acl = |fb: &str| {
make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "rate",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RateBasedStatement": {
"Limit": 1000,
"AggregateKeyType": "FORWARDED_IP",
"EvaluationWindowSec": 60,
"ForwardedIPConfig": {
"HeaderName": "X-Forwarded-For",
"FallbackBehavior": fb,
},
}},
})],
)
};
let limiter = RateLimiter::new();
let headers = vec![("X-Forwarded-For".to_string(), "not-an-ip".to_string())];
let r = WafRequest {
method: "GET",
uri: "/",
headers: &headers,
body: b"",
query: "",
source_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
country: None,
body_size_bytes: 0,
};
let acl_no = mk_acl("NO_MATCH");
assert_eq!(
evaluate_web_acl(&acl_no, &r, &HashMap::new(), &HashMap::new(), &limiter, 0).action,
WafAction::Allow,
"malformed XFF + NO_MATCH falls through"
);
let acl_match = mk_acl("MATCH");
assert_eq!(
evaluate_web_acl(
&acl_match,
&r,
&HashMap::new(),
&HashMap::new(),
&limiter,
0
)
.action,
WafAction::Block,
"malformed XFF + MATCH blocks even though limit is 1000"
);
}
#[test]
fn rate_based_forwarded_ip_aggregate() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "rate",
"Priority": 0,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"RateBasedStatement": {
"Limit": 1,
"AggregateKeyType": "FORWARDED_IP",
"EvaluationWindowSec": 60,
"ForwardedIPConfig": {"HeaderName": "X-Forwarded-For", "FallbackBehavior": "MATCH"},
}},
})],
);
let limiter = RateLimiter::new();
let now = 1_700_000_000;
let headers = vec![("X-Forwarded-For".to_string(), "203.0.113.5".to_string())];
let r = WafRequest {
method: "GET",
uri: "/",
headers: &headers,
body: b"",
query: "",
source_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
country: None,
body_size_bytes: 0,
};
assert_eq!(
evaluate_web_acl(&acl, &r, &HashMap::new(), &HashMap::new(), &limiter, now).action,
WafAction::Allow
);
assert_eq!(
evaluate_web_acl(&acl, &r, &HashMap::new(), &HashMap::new(), &limiter, now).action,
WafAction::Block
);
}
#[test]
fn managed_rule_group_common_set_blocks_admin() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "crs",
"Priority": 0,
"OverrideAction": {"None": {}},
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"ManagedRuleGroupStatement": {
"VendorName": "AWS",
"Name": "AWSManagedRulesCommonRuleSet",
}},
})],
);
assert_eq!(
evaluate(&req("/admin/users"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
assert_eq!(
evaluate(&req("/wp-admin"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Block
);
assert_eq!(
evaluate(&req("/index.html"), &acl, &HashMap::new(), &HashMap::new()),
WafAction::Allow
);
}
#[test]
fn captcha_action_marks_blocked() {
let acl = make_acl(
json!({"Allow": {}}),
vec![byte_match_uri_contains("/admin", json!({"Captcha": {}}))],
);
let limiter = RateLimiter::new();
let v = evaluate_web_acl(
&acl,
&req("/admin/x"),
&HashMap::new(),
&HashMap::new(),
&limiter,
0,
);
assert_eq!(v.action, WafAction::Captcha);
assert!(v.blocked, "Captcha is non-2xx — counts as blocked");
}
#[test]
fn challenge_action_marks_blocked() {
let acl = make_acl(
json!({"Allow": {}}),
vec![byte_match_uri_contains("/admin", json!({"Challenge": {}}))],
);
let limiter = RateLimiter::new();
let v = evaluate_web_acl(
&acl,
&req("/admin/x"),
&HashMap::new(),
&HashMap::new(),
&limiter,
0,
);
assert_eq!(v.action, WafAction::Challenge);
assert!(v.blocked);
}
#[test]
fn verdict_carries_terminating_rule_id_and_labels() {
let acl = make_acl(
json!({"Allow": {}}),
vec![
json!({
"Name": "tag",
"Priority": 0,
"Action": {"Count": {}},
"VisibilityConfig": {},
"RuleLabels": [{"Name": "awswaf:custom:admin"}],
"Statement": {"ByteMatchStatement": {
"SearchString": "/admin",
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
"PositionalConstraint": "CONTAINS",
}},
}),
json!({
"Name": "block-by-label",
"Priority": 1,
"Action": {"Block": {}},
"VisibilityConfig": {},
"Statement": {"LabelMatchStatement": {
"Scope": "LABEL",
"Key": "awswaf:custom:admin",
}},
}),
],
);
let limiter = RateLimiter::new();
let v = evaluate_web_acl(
&acl,
&req("/admin/x"),
&HashMap::new(),
&HashMap::new(),
&limiter,
0,
);
assert_eq!(v.action, WafAction::Block);
assert_eq!(v.terminating_rule_id.as_deref(), Some("block-by-label"));
assert_eq!(v.labels, vec!["awswaf:custom:admin"]);
assert_eq!(v.count_rules, vec!["tag"]);
}
#[test]
fn block_with_custom_response_propagates_to_verdict() {
let acl = make_acl(
json!({"Allow": {}}),
vec![json!({
"Name": "block",
"Priority": 0,
"Action": {"Block": {"CustomResponse": {
"ResponseCode": 429,
"CustomResponseBodyKey": "rate-limited",
}}},
"VisibilityConfig": {},
"Statement": {"ByteMatchStatement": {
"SearchString": "/admin",
"FieldToMatch": {"UriPath": {}},
"TextTransformations": [{"Priority": 0, "Type": "NONE"}],
"PositionalConstraint": "CONTAINS",
}},
})],
);
let limiter = RateLimiter::new();
let v = evaluate_web_acl(
&acl,
&req("/admin"),
&HashMap::new(),
&HashMap::new(),
&limiter,
0,
);
assert_eq!(v.action, WafAction::Block);
assert_eq!(v.custom_response_status, Some(429));
assert_eq!(v.custom_response_body_key.as_deref(), Some("rate-limited"));
}
}