use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use crate::profiler::entropy::shannon_entropy;
use crate::profiler::header_types::{
HeaderAnomaly, HeaderAnomalyResult, HeaderBaseline, ValueStats,
};
const DEFAULT_MAX_ENDPOINTS: usize = 10_000;
const DEFAULT_MIN_SAMPLES: u64 = 50;
const REQUIRED_HEADER_THRESHOLD: f64 = 0.95;
const ENTROPY_Z_THRESHOLD: f64 = 3.0;
const MAX_HEADERS_PER_ENDPOINT: usize = 100;
const LENGTH_TOLERANCE_FACTOR: f64 = 1.5;
#[derive(Debug)]
pub struct HeaderProfiler {
baselines: Arc<DashMap<String, HeaderBaseline>>,
max_endpoints: usize,
min_samples: u64,
}
impl HeaderProfiler {
pub fn new() -> Self {
Self {
baselines: Arc::new(DashMap::with_capacity(1000)),
max_endpoints: DEFAULT_MAX_ENDPOINTS,
min_samples: DEFAULT_MIN_SAMPLES,
}
}
pub fn with_config(max_endpoints: usize, min_samples: u64) -> Self {
Self {
baselines: Arc::new(DashMap::with_capacity(max_endpoints.min(10000))),
max_endpoints,
min_samples,
}
}
pub fn learn(&self, endpoint: &str, headers: &[(String, String)]) {
if self.baselines.len() >= self.max_endpoints && !self.baselines.contains_key(endpoint) {
self.evict_oldest();
}
let mut baseline = self
.baselines
.entry(endpoint.to_string())
.or_insert_with(|| HeaderBaseline::new(endpoint.to_string()));
let present_headers: HashSet<String> =
headers.iter().map(|(k, _)| k.to_lowercase()).collect();
for (header_name, header_value) in headers {
let header_name = header_name.to_lowercase();
if baseline.header_value_stats.len() >= MAX_HEADERS_PER_ENDPOINT
&& !baseline.header_value_stats.contains_key(&header_name)
{
continue;
}
let entropy = shannon_entropy(header_value);
let length = header_value.len();
baseline
.header_value_stats
.entry(header_name.clone())
.or_insert_with(ValueStats::new)
.update(length, entropy);
}
baseline.sample_count += 1;
baseline.last_updated = Instant::now();
if baseline.sample_count >= self.min_samples && baseline.sample_count % 10 == 0 {
self.recalculate_header_categories(&mut baseline, &present_headers);
}
}
pub fn analyze(&self, endpoint: &str, headers: &[(String, String)]) -> HeaderAnomalyResult {
let baseline = match self.baselines.get(endpoint) {
Some(b) => b,
None => return HeaderAnomalyResult::none(),
};
if !baseline.is_mature(self.min_samples) {
return HeaderAnomalyResult::none();
}
let mut result = HeaderAnomalyResult::new();
let present_headers: HashSet<String> =
headers.iter().map(|(k, _)| k.to_lowercase()).collect();
for required_header in &baseline.required_headers {
if !present_headers.contains(required_header) {
result.add(HeaderAnomaly::MissingRequired {
header: required_header.clone(),
});
}
}
for (header_name, _) in headers {
let header_name = header_name.to_lowercase();
if !baseline.is_known(&header_name) {
result.add(HeaderAnomaly::UnexpectedHeader {
header: header_name.clone(),
});
}
}
for (header_name, header_value) in headers {
let header_name = header_name.to_lowercase();
if let Some(stats) = baseline.get_stats(&header_name) {
if stats.is_mature(self.min_samples / 2) {
let length = header_value.len();
if !stats.is_length_in_range(length, LENGTH_TOLERANCE_FACTOR) {
result.add(HeaderAnomaly::LengthAnomaly {
header: header_name.clone(),
length,
expected_range: (stats.min_length, stats.max_length),
});
}
let entropy = shannon_entropy(header_value);
let z_score = stats.entropy_z_score(entropy);
if z_score.abs() > ENTROPY_Z_THRESHOLD {
result.add(HeaderAnomaly::EntropyAnomaly {
header: header_name.clone(),
entropy,
expected_mean: stats.entropy_mean,
});
}
}
}
}
result
}
pub fn get_baseline(&self, endpoint: &str) -> Option<HeaderBaseline> {
self.baselines.get(endpoint).map(|b| b.clone())
}
#[inline]
pub fn endpoint_count(&self) -> usize {
self.baselines.len()
}
#[inline]
pub fn max_endpoints(&self) -> usize {
self.max_endpoints
}
#[inline]
pub fn min_samples(&self) -> u64 {
self.min_samples
}
pub fn clear(&self) {
self.baselines.clear();
}
pub fn stats(&self) -> HeaderProfilerStats {
let mut total_samples = 0u64;
let mut total_headers = 0usize;
let mut mature_endpoints = 0usize;
for entry in self.baselines.iter() {
total_samples += entry.sample_count;
total_headers += entry.header_value_stats.len();
if entry.is_mature(self.min_samples) {
mature_endpoints += 1;
}
}
HeaderProfilerStats {
endpoint_count: self.baselines.len(),
mature_endpoints,
total_samples,
total_headers,
max_endpoints: self.max_endpoints,
}
}
fn recalculate_header_categories(
&self,
baseline: &mut HeaderBaseline,
current_headers: &HashSet<String>,
) {
let sample_count = baseline.sample_count;
let mut new_required = HashSet::with_capacity(baseline.header_value_stats.len());
let mut new_optional = HashSet::with_capacity(baseline.header_value_stats.len());
for (header_name, stats) in &baseline.header_value_stats {
let frequency = stats.total_samples as f64 / sample_count as f64;
if frequency >= REQUIRED_HEADER_THRESHOLD {
new_required.insert(header_name.clone());
} else {
new_optional.insert(header_name.clone());
}
}
for header in current_headers {
if !new_required.contains(header) && !new_optional.contains(header) {
new_optional.insert(header.to_string());
}
}
baseline.required_headers = new_required;
baseline.optional_headers = new_optional;
}
fn evict_oldest(&self) {
let mut oldest_key: Option<String> = None;
let mut oldest_time = Instant::now();
for entry in self.baselines.iter() {
if entry.last_updated < oldest_time {
oldest_time = entry.last_updated;
oldest_key = Some(entry.key().clone());
}
}
if let Some(key) = oldest_key {
self.baselines.remove(&key);
}
}
}
impl Default for HeaderProfiler {
fn default() -> Self {
Self::new()
}
}
impl Clone for HeaderProfiler {
fn clone(&self) -> Self {
Self {
baselines: Arc::clone(&self.baselines),
max_endpoints: self.max_endpoints,
min_samples: self.min_samples,
}
}
}
#[derive(Debug, Clone)]
pub struct HeaderProfilerStats {
pub endpoint_count: usize,
pub mature_endpoints: usize,
pub total_samples: u64,
pub total_headers: usize,
pub max_endpoints: usize,
}
#[cfg(test)]
mod tests {
use super::*;
fn make_headers(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
#[test]
fn test_profiler_new() {
let profiler = HeaderProfiler::new();
assert_eq!(profiler.endpoint_count(), 0);
assert_eq!(profiler.max_endpoints(), DEFAULT_MAX_ENDPOINTS);
assert_eq!(profiler.min_samples(), DEFAULT_MIN_SAMPLES);
}
#[test]
fn test_profiler_with_config() {
let profiler = HeaderProfiler::with_config(100, 10);
assert_eq!(profiler.max_endpoints(), 100);
assert_eq!(profiler.min_samples(), 10);
}
#[test]
fn test_profiler_learn_creates_baseline() {
let profiler = HeaderProfiler::new();
let headers = make_headers(&[
("Content-Type", "application/json"),
("Authorization", "Bearer token123"),
]);
profiler.learn("/api/users", &headers);
assert_eq!(profiler.endpoint_count(), 1);
let baseline = profiler.get_baseline("/api/users").unwrap();
assert_eq!(baseline.sample_count, 1);
assert_eq!(baseline.header_value_stats.len(), 2);
}
#[test]
fn test_profiler_learn_accumulates() {
let profiler = HeaderProfiler::new();
for i in 0..10 {
let headers = make_headers(&[
("Content-Type", "application/json"),
("X-Request-ID", &format!("req-{}", i)),
]);
profiler.learn("/api/test", &headers);
}
let baseline = profiler.get_baseline("/api/test").unwrap();
assert_eq!(baseline.sample_count, 10);
let ct_stats = baseline.get_stats("content-type").unwrap();
assert_eq!(ct_stats.total_samples, 10);
}
#[test]
fn test_profiler_analyze_no_baseline() {
let profiler = HeaderProfiler::new();
let headers = make_headers(&[("Content-Type", "application/json")]);
let result = profiler.analyze("/unknown", &headers);
assert!(!result.has_anomalies());
}
#[test]
fn test_profiler_analyze_immature_baseline() {
let profiler = HeaderProfiler::with_config(100, 10);
for _ in 0..5 {
let headers = make_headers(&[("Content-Type", "application/json")]);
profiler.learn("/api/test", &headers);
}
let headers = make_headers(&[("Content-Type", "application/json")]);
let result = profiler.analyze("/api/test", &headers);
assert!(!result.has_anomalies());
}
#[test]
fn test_detect_missing_required_header() {
let profiler = HeaderProfiler::with_config(100, 10);
for _ in 0..50 {
let headers = make_headers(&[
("Content-Type", "application/json"),
("Authorization", "Bearer token"),
]);
profiler.learn("/api/secure", &headers);
}
let headers = make_headers(&[("Content-Type", "application/json")]);
let result = profiler.analyze("/api/secure", &headers);
assert!(result.has_anomalies());
let missing = result.anomalies.iter().find(
|a| matches!(a, HeaderAnomaly::MissingRequired { header } if header == "authorization"),
);
assert!(missing.is_some());
}
#[test]
fn test_detect_unexpected_header() {
let profiler = HeaderProfiler::with_config(100, 10);
for _ in 0..50 {
let headers = make_headers(&[("Content-Type", "application/json")]);
profiler.learn("/api/test", &headers);
}
let headers = make_headers(&[
("Content-Type", "application/json"),
("X-Evil-Header", "malicious"),
]);
let result = profiler.analyze("/api/test", &headers);
assert!(result.has_anomalies());
let unexpected = result.anomalies.iter().find(|a| {
matches!(a, HeaderAnomaly::UnexpectedHeader { header } if header == "x-evil-header")
});
assert!(unexpected.is_some());
}
#[test]
fn test_detect_length_anomaly() {
let profiler = HeaderProfiler::with_config(100, 20);
for _ in 0..50 {
let headers = make_headers(&[("Authorization", "Bearer short_token")]);
profiler.learn("/api/auth", &headers);
}
let long_token = "a".repeat(10000);
let headers = make_headers(&[("Authorization", &format!("Bearer {}", long_token))]);
let result = profiler.analyze("/api/auth", &headers);
assert!(result.has_anomalies());
let length_anomaly = result.anomalies.iter().find(|a| {
matches!(a, HeaderAnomaly::LengthAnomaly { header, .. } if header == "authorization")
});
assert!(length_anomaly.is_some());
}
#[test]
fn test_detect_entropy_anomaly() {
let profiler = HeaderProfiler::with_config(100, 30);
for i in 0..60 {
let headers = make_headers(&[("X-Token", &format!("user-token-{:05}", i))]);
profiler.learn("/api/token", &headers);
}
let high_entropy = "xK9mNqR5vL8jYpW2eTfGhIuB7cDaZoS4";
let headers = make_headers(&[("X-Token", high_entropy)]);
let result = profiler.analyze("/api/token", &headers);
if result.has_anomalies() {
let has_entropy_anomaly = result.anomalies.iter().any(|a| {
matches!(a, HeaderAnomaly::EntropyAnomaly { header, .. } if header == "x-token")
});
if has_entropy_anomaly {
}
}
}
#[test]
fn test_risk_contribution_accumulates() {
let profiler = HeaderProfiler::with_config(100, 10);
for _ in 0..50 {
let headers = make_headers(&[
("Content-Type", "application/json"),
("Authorization", "Bearer token"),
]);
profiler.learn("/api/risk", &headers);
}
let headers = make_headers(&[("X-Unexpected-1", "value"), ("X-Unexpected-2", "value")]);
let result = profiler.analyze("/api/risk", &headers);
assert!(result.has_anomalies());
assert!(result.risk_contribution > 0);
assert!(result.risk_contribution <= 50);
}
#[test]
fn test_lru_eviction() {
let profiler = HeaderProfiler::with_config(3, 10);
profiler.learn("/api/1", &make_headers(&[("X", "1")]));
std::thread::sleep(std::time::Duration::from_millis(10));
profiler.learn("/api/2", &make_headers(&[("X", "2")]));
std::thread::sleep(std::time::Duration::from_millis(10));
profiler.learn("/api/3", &make_headers(&[("X", "3")]));
assert_eq!(profiler.endpoint_count(), 3);
profiler.learn("/api/4", &make_headers(&[("X", "4")]));
assert_eq!(profiler.endpoint_count(), 3);
assert!(profiler.get_baseline("/api/1").is_none());
assert!(profiler.get_baseline("/api/4").is_some());
}
#[test]
fn test_concurrent_learn() {
use std::thread;
let profiler = Arc::new(HeaderProfiler::new());
let handles: Vec<_> = (0..4)
.map(|i| {
let p = Arc::clone(&profiler);
thread::spawn(move || {
for j in 0..100 {
let headers = make_headers(&[
("Thread", &format!("{}", i)),
("Request", &format!("{}", j)),
]);
p.learn(&format!("/api/thread-{}", i), &headers);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(profiler.endpoint_count(), 4);
}
#[test]
fn test_concurrent_learn_and_analyze() {
use std::thread;
let profiler = Arc::new(HeaderProfiler::with_config(100, 10));
for _ in 0..20 {
profiler.learn(
"/api/concurrent",
&make_headers(&[("Content-Type", "application/json")]),
);
}
let handles: Vec<_> = (0..4)
.map(|i| {
let p = Arc::clone(&profiler);
thread::spawn(move || {
for _ in 0..50 {
if i % 2 == 0 {
p.learn(
"/api/concurrent",
&make_headers(&[("Content-Type", "application/json")]),
);
} else {
let _ = p.analyze(
"/api/concurrent",
&make_headers(&[("Content-Type", "application/json")]),
);
}
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let baseline = profiler.get_baseline("/api/concurrent").unwrap();
assert!(baseline.sample_count > 20);
}
#[test]
fn test_profiler_stats() {
let profiler = HeaderProfiler::with_config(100, 10);
for _ in 0..50 {
profiler.learn(
"/api/mature",
&make_headers(&[("Content-Type", "application/json")]),
);
}
for _ in 0..5 {
profiler.learn("/api/immature", &make_headers(&[("X-Token", "test")]));
}
let stats = profiler.stats();
assert_eq!(stats.endpoint_count, 2);
assert_eq!(stats.mature_endpoints, 1); assert_eq!(stats.total_samples, 55);
assert_eq!(stats.total_headers, 2); }
#[test]
fn test_profiler_clear() {
let profiler = HeaderProfiler::new();
profiler.learn("/api/1", &make_headers(&[("X", "1")]));
profiler.learn("/api/2", &make_headers(&[("X", "2")]));
assert_eq!(profiler.endpoint_count(), 2);
profiler.clear();
assert_eq!(profiler.endpoint_count(), 0);
}
#[test]
fn test_profiler_clone_shares_state() {
let profiler1 = HeaderProfiler::new();
profiler1.learn("/api/shared", &make_headers(&[("X", "1")]));
let profiler2 = profiler1.clone();
profiler2.learn("/api/shared", &make_headers(&[("X", "2")]));
let baseline = profiler1.get_baseline("/api/shared").unwrap();
assert_eq!(baseline.sample_count, 2);
}
#[test]
fn test_header_ordering_is_ignored() {
let profiler = HeaderProfiler::with_config(100, 10);
for _ in 0..20 {
profiler.learn(
"/api/order",
&make_headers(&[("A", "1"), ("B", "2"), ("C", "3")]),
);
}
let result = profiler.analyze(
"/api/order",
&make_headers(&[("C", "3"), ("A", "1"), ("B", "2")]),
);
assert!(!result.has_anomalies());
}
#[test]
fn test_header_case_sensitivity() {
let profiler = HeaderProfiler::with_config(100, 10);
for _ in 0..20 {
profiler.learn("/api/case", &make_headers(&[("X-Custom", "value")]));
}
let result = profiler.analyze("/api/case", &make_headers(&[("x-custom", "value")]));
assert!(
!result.has_anomalies(),
"Header analysis should be case-insensitive"
);
}
}