use std::collections::HashMap;
use percent_encoding::percent_decode_str;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default)]
pub struct Request<'a> {
pub method: &'a str,
pub path: &'a str,
pub query: Option<&'a str>,
pub headers: Vec<Header<'a>>,
pub body: Option<&'a [u8]>,
pub client_ip: &'a str,
pub is_static: bool,
}
#[derive(Debug, Clone)]
pub struct Header<'a> {
pub name: &'a str,
pub value: &'a str,
}
impl<'a> Header<'a> {
pub fn new(name: &'a str, value: &'a str) -> Self {
Self { name, value }
}
}
#[derive(Debug, Clone)]
pub struct Verdict {
pub action: Action,
pub risk_score: u16,
pub matched_rules: Vec<u32>,
pub entity_risk: f64,
pub entity_blocked: bool,
pub block_reason: Option<String>,
pub risk_contributions: Vec<RiskContribution>,
pub endpoint_template: Option<String>,
pub endpoint_risk: Option<f32>,
pub anomaly_score: Option<f64>,
pub adjusted_threshold: Option<f64>,
pub anomaly_signals: Vec<AnomalySignal>,
pub timed_out: bool,
pub rules_evaluated: Option<u32>,
}
impl Default for Verdict {
fn default() -> Self {
Self {
action: Action::Allow,
risk_score: 0,
matched_rules: Vec::new(),
entity_risk: 0.0,
entity_blocked: false,
block_reason: None,
risk_contributions: Vec::new(),
endpoint_template: None,
endpoint_risk: None,
anomaly_score: None,
adjusted_threshold: None,
anomaly_signals: Vec::new(),
timed_out: false,
rules_evaluated: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Action {
Allow = 0,
Block = 1,
}
#[derive(Debug, Clone)]
pub struct RiskContribution {
pub rule_id: u32,
pub base_risk: f64,
pub multiplier: f64,
pub final_risk: f64,
}
impl RiskContribution {
#[inline]
pub fn new(rule_id: u32, base_risk: f64, multiplier: f64) -> Self {
Self {
rule_id,
base_risk,
multiplier,
final_risk: base_risk * multiplier,
}
}
}
#[derive(Debug, Clone)]
pub struct AnomalySignal {
pub signal_type: AnomalySignalType,
pub severity: f32,
pub detail: String,
}
impl AnomalySignal {
pub fn to_anomaly_type(&self) -> AnomalyType {
match self.signal_type {
AnomalySignalType::PayloadSize => AnomalyType::OversizedRequest,
AnomalySignalType::RequestRate => AnomalyType::VelocitySpike,
AnomalySignalType::ErrorRate => AnomalyType::TimingAnomaly,
AnomalySignalType::ParameterAnomaly => AnomalyType::Custom,
AnomalySignalType::ContentTypeAnomaly => AnomalyType::Custom,
AnomalySignalType::TimingAnomaly => AnomalyType::TimingAnomaly,
AnomalySignalType::SchemaViolation => AnomalyType::Custom,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AnomalySignalType {
PayloadSize,
RequestRate,
ErrorRate,
ParameterAnomaly,
ContentTypeAnomaly,
TimingAnomaly,
SchemaViolation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum BlockingMode {
#[default]
Learning,
Enforcement,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskConfig {
pub max_risk: f64,
pub enable_repeat_multipliers: bool,
pub anomaly_risk_overrides: HashMap<AnomalyType, f64>,
pub anomaly_blocking_threshold: f64,
pub blocking_mode: BlockingMode,
}
impl Default for RiskConfig {
fn default() -> Self {
Self {
max_risk: 100.0,
enable_repeat_multipliers: true,
anomaly_risk_overrides: HashMap::new(),
anomaly_blocking_threshold: 10.0,
blocking_mode: BlockingMode::Learning, }
}
}
impl RiskConfig {
pub fn with_extended_range() -> Self {
Self {
max_risk: 1000.0,
..Default::default()
}
}
#[inline]
pub fn anomaly_risk(&self, anomaly_type: AnomalyType) -> f64 {
self.anomaly_risk_overrides
.get(&anomaly_type)
.copied()
.unwrap_or_else(|| anomaly_type.default_risk())
}
pub fn set_anomaly_risk(&mut self, anomaly_type: AnomalyType, risk: f64) {
self.anomaly_risk_overrides.insert(anomaly_type, risk);
}
pub fn reset_anomaly_risk(&mut self, anomaly_type: AnomalyType) {
self.anomaly_risk_overrides.remove(&anomaly_type);
}
}
#[derive(Debug, Clone)]
pub struct AnomalyContribution {
pub anomaly_type: AnomalyType,
pub risk: f64,
pub reason: Option<String>,
pub applied_at: u64,
}
impl AnomalyContribution {
pub fn new(anomaly_type: AnomalyType, risk: f64, reason: Option<String>, now: u64) -> Self {
Self {
anomaly_type,
risk,
reason,
applied_at: now,
}
}
}
#[inline]
pub fn repeat_multiplier(match_count: u32) -> f64 {
match match_count {
0 | 1 => 1.0,
2..=5 => 1.25,
6..=10 => 1.5,
_ => 2.0,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(u8)]
pub enum AnomalyType {
FingerprintChange = 0,
SessionSharing = 1,
TokenReuse = 2,
VelocitySpike = 3,
RotationPattern = 4,
TimingAnomaly = 5,
ImpossibleTravel = 6,
OversizedRequest = 7,
OversizedResponse = 8,
BandwidthSpike = 9,
ExfiltrationPattern = 10,
UploadPattern = 11,
Custom = 255,
}
impl AnomalyType {
#[inline]
pub const fn default_risk(self) -> f64 {
match self {
AnomalyType::SessionSharing => 50.0,
AnomalyType::ExfiltrationPattern => 40.0,
AnomalyType::TokenReuse => 40.0,
AnomalyType::RotationPattern => 35.0,
AnomalyType::UploadPattern => 35.0,
AnomalyType::FingerprintChange => 30.0,
AnomalyType::BandwidthSpike => 25.0,
AnomalyType::ImpossibleTravel => 25.0,
AnomalyType::OversizedRequest => 20.0,
AnomalyType::OversizedResponse => 15.0,
AnomalyType::VelocitySpike => 15.0,
AnomalyType::TimingAnomaly => 10.0,
AnomalyType::Custom => 0.0,
}
}
pub const fn name(self) -> &'static str {
match self {
AnomalyType::FingerprintChange => "fingerprint_change",
AnomalyType::SessionSharing => "session_sharing",
AnomalyType::TokenReuse => "token_reuse",
AnomalyType::VelocitySpike => "velocity_spike",
AnomalyType::RotationPattern => "rotation_pattern",
AnomalyType::TimingAnomaly => "timing_anomaly",
AnomalyType::ImpossibleTravel => "impossible_travel",
AnomalyType::OversizedRequest => "oversized_request",
AnomalyType::OversizedResponse => "oversized_response",
AnomalyType::BandwidthSpike => "bandwidth_spike",
AnomalyType::ExfiltrationPattern => "exfiltration_pattern",
AnomalyType::UploadPattern => "upload_pattern",
AnomalyType::Custom => "custom",
}
}
}
#[derive(Debug)]
pub struct EvalContext<'a> {
pub ip: &'a str,
pub method: &'a str,
pub url: &'a str,
pub headers: HashMap<String, &'a str>,
pub args: Vec<String>,
pub arg_entries: Vec<ArgEntry>,
pub body_text: Option<&'a str>,
pub raw_body: Option<&'a [u8]>,
pub is_static: bool,
pub json_text: Option<String>,
pub deadline: Option<std::time::Instant>,
}
#[derive(Debug, Clone)]
pub struct ArgEntry {
pub key: String,
pub value: String,
}
impl<'a> EvalContext<'a> {
pub fn from_request(req: &'a Request<'a>) -> Self {
let mut headers = HashMap::new();
for h in &req.headers {
headers.insert(h.name.to_ascii_lowercase(), h.value);
}
let (mut args, mut arg_entries) = parse_query_args(req.path, req.query);
let body_text = req.body.and_then(|b| std::str::from_utf8(b).ok());
if let Some(text) = body_text {
if headers
.get("content-type")
.map(|ct| ct.contains("application/x-www-form-urlencoded"))
.unwrap_or(false)
{
let (body_args, body_entries) = parse_query_args("", Some(text));
args.extend(body_args);
arg_entries.extend(body_entries);
}
}
let json_text = body_text.and_then(|text| {
if headers
.get("content-type")
.map(|ct| ct.contains("application/json"))
.unwrap_or(false)
{
if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
flatten_json(&value, &mut args, &mut arg_entries);
}
Some(text.to_string())
} else {
None
}
});
if let Some(raw_body) = req.body {
if let Some(content_type) = headers.get("content-type") {
if content_type.contains("multipart/form-data") {
if let Some(boundary) = extract_multipart_boundary(content_type) {
let (mp_args, mp_entries) = parse_multipart(raw_body, &boundary);
args.extend(mp_args);
arg_entries.extend(mp_entries);
}
}
}
}
Self {
ip: req.client_ip,
method: req.method,
url: req.path,
headers,
args,
arg_entries,
body_text,
raw_body: req.body,
is_static: req.is_static,
json_text,
deadline: None,
}
}
pub fn from_request_with_deadline(req: &'a Request<'a>, deadline: std::time::Instant) -> Self {
let mut ctx = Self::from_request(req);
ctx.deadline = Some(deadline);
ctx
}
#[inline]
pub fn is_deadline_exceeded(&self) -> bool {
self.deadline
.map(|d| std::time::Instant::now() >= d)
.unwrap_or(false)
}
}
fn extract_multipart_boundary(content_type: &str) -> Option<String> {
content_type
.split(';')
.map(|p| p.trim())
.find_map(|p| {
let (key, value) = p.split_once('=')?;
if key.trim().eq_ignore_ascii_case("boundary") {
Some(value.trim().trim_matches('"').to_string())
} else {
None
}
})
.filter(|b| !b.is_empty())
}
fn parse_multipart(raw_body: &[u8], boundary: &str) -> (Vec<String>, Vec<ArgEntry>) {
let mut args = Vec::new();
let mut entries = Vec::new();
let body_str = String::from_utf8_lossy(raw_body);
let marker = format!("--{}", boundary);
for part in body_str.split(&marker) {
let part = part.trim_matches('\r').trim_matches('\n');
if part.is_empty() || part == "--" {
continue;
}
if let Some((headers, body)) = part.split_once("\r\n\r\n") {
let name = headers
.lines()
.find(|l| l.to_ascii_lowercase().starts_with("content-disposition"))
.and_then(|l| {
l.split(';')
.find(|p| p.trim().starts_with("name="))
.map(|p| {
p.trim()
.trim_start_matches("name=")
.trim_matches('"')
.to_string()
})
});
if let Some(key) = name {
let value = body.trim_end_matches("\r\n").to_string();
args.push(value.clone());
entries.push(ArgEntry { key, value });
}
}
}
(args, entries)
}
const MAX_JSON_DEPTH: usize = 32;
const MAX_JSON_ELEMENTS: usize = 1000;
fn flatten_json(value: &serde_json::Value, args: &mut Vec<String>, entries: &mut Vec<ArgEntry>) {
let mut element_count = 0usize;
flatten_json_recursive(value, args, entries, 0, &mut element_count);
}
fn flatten_json_recursive(
value: &serde_json::Value,
args: &mut Vec<String>,
entries: &mut Vec<ArgEntry>,
depth: usize,
element_count: &mut usize,
) {
if depth > MAX_JSON_DEPTH {
return;
}
if *element_count >= MAX_JSON_ELEMENTS {
return;
}
match value {
serde_json::Value::Object(map) => {
for (k, v) in map {
*element_count += 1;
if *element_count >= MAX_JSON_ELEMENTS {
return;
}
match v {
serde_json::Value::String(s) => {
args.push(s.clone());
entries.push(ArgEntry {
key: k.clone(),
value: s.clone(),
});
}
serde_json::Value::Number(n) => {
let s = n.to_string();
args.push(s.clone());
entries.push(ArgEntry {
key: k.clone(),
value: s,
});
}
serde_json::Value::Bool(b) => {
let s = b.to_string();
args.push(s.clone());
entries.push(ArgEntry {
key: k.clone(),
value: s,
});
}
_ => flatten_json_recursive(v, args, entries, depth + 1, element_count),
}
}
}
serde_json::Value::Array(arr) => {
for v in arr {
*element_count += 1;
if *element_count >= MAX_JSON_ELEMENTS {
return;
}
flatten_json_recursive(v, args, entries, depth + 1, element_count);
}
}
_ => {}
}
}
fn parse_query_args(path: &str, query: Option<&str>) -> (Vec<String>, Vec<ArgEntry>) {
let mut args = Vec::new();
let mut arg_entries = Vec::new();
let query_str = if let Some(q) = query {
q
} else if let Some(idx) = path.find('?') {
&path[idx + 1..]
} else {
return (args, arg_entries);
};
for pair in query_str.split('&') {
if pair.is_empty() {
continue;
}
args.push(pair.to_string());
if let Some((key, value)) = pair.split_once('=') {
let key_fixed = key.replace('+', " ");
let value_fixed = value.replace('+', " ");
let decoded_key = percent_decode_str(&key_fixed)
.decode_utf8_lossy()
.to_string();
let decoded_value = percent_decode_str(&value_fixed)
.decode_utf8_lossy()
.to_string();
arg_entries.push(ArgEntry {
key: decoded_key,
value: decoded_value,
});
} else {
let pair_fixed = pair.replace('+', " ");
let decoded_key = percent_decode_str(&pair_fixed)
.decode_utf8_lossy()
.to_string();
arg_entries.push(ArgEntry {
key: decoded_key,
value: String::new(),
});
}
}
(args, arg_entries)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_query_args() {
let (args, entries) = parse_query_args("/api/users?id=1&name=test", None);
assert_eq!(args.len(), 2);
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].key, "id");
assert_eq!(entries[0].value, "1");
assert_eq!(entries[1].key, "name");
assert_eq!(entries[1].value, "test");
}
#[test]
fn test_eval_context_from_request() {
let req = Request {
method: "POST",
path: "/api/login?username=admin",
headers: vec![Header::new("Content-Type", "application/json")],
body: Some(b"{\"password\": \"test\"}"),
client_ip: "192.168.1.1",
..Default::default()
};
let ctx = EvalContext::from_request(&req);
assert_eq!(ctx.method, "POST");
assert_eq!(ctx.ip, "192.168.1.1");
assert_eq!(ctx.arg_entries.len(), 2);
assert!(ctx.json_text.is_some());
}
#[test]
fn test_anomaly_type_default_risk() {
assert_eq!(AnomalyType::SessionSharing.default_risk(), 50.0);
assert_eq!(AnomalyType::ImpossibleTravel.default_risk(), 25.0);
assert_eq!(AnomalyType::Custom.default_risk(), 0.0);
}
}