use std::collections::{BTreeMap, HashSet};
const MAX_HEADERS: usize = 8;
const MAX_HEADER_NAME_LEN: usize = 256;
const MAX_HEADER_VALUE_LEN: usize = 512;
const FORBIDDEN_HEADERS: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
"host",
"content-length",
];
pub fn validate_headers(
headers: &BTreeMap<String, String>,
) -> Result<BTreeMap<String, String>, ValidationError> {
let forbidden_set: HashSet<String> =
FORBIDDEN_HEADERS.iter().map(|h| h.to_lowercase()).collect();
let mut validated = BTreeMap::new();
for (name, value) in headers {
let name_lower = name.to_lowercase();
if forbidden_set.contains(&name_lower) {
return Err(ValidationError::ForbiddenHeader {
header: name.clone(),
});
}
if name.len() > MAX_HEADER_NAME_LEN {
return Err(ValidationError::HeaderNameTooLong {
header: name.clone(),
max: MAX_HEADER_NAME_LEN,
});
}
if value.len() > MAX_HEADER_VALUE_LEN {
return Err(ValidationError::HeaderValueTooLong {
header: name.clone(),
max: MAX_HEADER_VALUE_LEN,
});
}
if !is_valid_header_name(name) {
return Err(ValidationError::InvalidHeaderName {
header: name.clone(),
});
}
if !is_valid_header_value(value) {
return Err(ValidationError::InvalidHeaderValue {
header: name.clone(),
value: value.clone(),
});
}
validated.insert(name_lower, value.clone());
}
if validated.len() > MAX_HEADERS {
return Err(ValidationError::TooManyHeaders {
max: MAX_HEADERS,
provided: validated.len(),
});
}
Ok(validated)
}
fn is_valid_header_name(name: &str) -> bool {
!name.is_empty()
&& name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
}
fn is_valid_header_value(value: &str) -> bool {
value.chars().all(|c| {
(c.is_ascii() && !c.is_control()) || c == '\t'
})
}
pub fn hash_user_id(user_id: &str) -> String {
use tiny_keccak::{Hasher, Keccak};
let mut hasher = Keccak::v256();
hasher.update(user_id.as_bytes());
let mut output = [0u8; 32];
hasher.finalize(&mut output);
hex::encode(&output[..16])
}
pub fn process_headers_with_pii_protection(
headers: &BTreeMap<String, String>,
) -> BTreeMap<String, String> {
let mut processed = BTreeMap::new();
for (name, value) in headers {
let processed_value = match name.as_str() {
"x-user-id" | "x-user-email" | "x-customer-email" => hash_user_id(value),
"x-tenant-id" => {
if value.contains('@') {
hash_user_id(value)
} else if value.len() == 32 && value.chars().all(|c| c.is_ascii_hexdigit()) {
value.clone()
} else {
hash_user_id(value)
}
}
_ => value.clone(),
};
processed.insert(name.clone(), processed_value);
}
processed
}
#[derive(Debug, thiserror::Error)]
pub enum ValidationError {
#[error("Too many headers provided: {provided} (max: {max})")]
TooManyHeaders { max: usize, provided: usize },
#[error("Forbidden header: {header}")]
ForbiddenHeader { header: String },
#[error("Header name too long: {header} (max: {max} bytes)")]
HeaderNameTooLong { header: String, max: usize },
#[error("Header value too long for {header} (max: {max} bytes)")]
HeaderValueTooLong { header: String, max: usize },
#[error("Invalid header name: {header}")]
InvalidHeaderName { header: String },
#[error("Invalid header value for {header}: {value}")]
InvalidHeaderValue { header: String, value: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_headers_valid() {
let mut headers = BTreeMap::new();
headers.insert("X-Tenant-Id".to_string(), "abc123".to_string());
headers.insert("X-User-Type".to_string(), "premium".to_string());
let result = validate_headers(&headers);
assert!(result.is_ok());
let validated = result.unwrap();
assert_eq!(validated.len(), 2);
}
#[test]
fn test_validate_headers_too_many() {
let mut headers = BTreeMap::new();
for i in 0..10 {
headers.insert(format!("X-Header-{i}"), "value".to_string());
}
let result = validate_headers(&headers);
assert!(matches!(
result,
Err(ValidationError::TooManyHeaders { .. })
));
}
#[test]
fn test_validate_headers_forbidden() {
let mut headers = BTreeMap::new();
headers.insert("Connection".to_string(), "close".to_string());
let result = validate_headers(&headers);
assert!(matches!(
result,
Err(ValidationError::ForbiddenHeader { .. })
));
}
#[test]
fn test_validate_headers_invalid_name() {
let mut headers = BTreeMap::new();
headers.insert("X-Invalid Header".to_string(), "value".to_string());
let result = validate_headers(&headers);
assert!(matches!(
result,
Err(ValidationError::InvalidHeaderName { .. })
));
}
#[test]
fn test_validate_headers_name_too_long() {
let mut headers = BTreeMap::new();
let long_name = "X-".to_string() + &"a".repeat(300);
headers.insert(long_name, "value".to_string());
let result = validate_headers(&headers);
assert!(matches!(
result,
Err(ValidationError::HeaderNameTooLong { .. })
));
}
#[test]
fn test_validate_headers_value_too_long() {
let mut headers = BTreeMap::new();
let long_value = "a".repeat(600);
headers.insert("X-Test".to_string(), long_value);
let result = validate_headers(&headers);
assert!(matches!(
result,
Err(ValidationError::HeaderValueTooLong { .. })
));
}
#[test]
fn test_hash_user_id() {
let user_id = "user123@example.com";
let hash1 = hash_user_id(user_id);
let hash2 = hash_user_id(user_id);
assert_eq!(hash1, hash2);
assert_eq!(hash1.len(), 32);
let hash3 = hash_user_id("different@example.com");
assert_ne!(hash1, hash3);
}
#[test]
fn test_valid_header_name() {
assert!(is_valid_header_name("X-Tenant-Id"));
assert!(is_valid_header_name("X_User_Type"));
assert!(is_valid_header_name("Authorization"));
assert!(!is_valid_header_name(""));
assert!(!is_valid_header_name("X Tenant Id"));
assert!(!is_valid_header_name("X-Tenant:Id"));
}
#[test]
fn test_valid_header_value() {
assert!(is_valid_header_value("abc123"));
assert!(is_valid_header_value("Bearer token123"));
assert!(is_valid_header_value("value with spaces"));
assert!(!is_valid_header_value("value\nwith\nnewlines"));
assert!(!is_valid_header_value("value\0with\0nulls"));
}
#[test]
fn test_header_case_insensitive_override() {
let mut headers = BTreeMap::new();
headers.insert("X-Tenant-Id".to_string(), "first_value".to_string());
headers.insert("x-tenant-id".to_string(), "second_value".to_string());
headers.insert("X-TENANT-ID".to_string(), "third_value".to_string());
let result = validate_headers(&headers).unwrap();
assert_eq!(result.len(), 1);
assert!(result.contains_key("x-tenant-id"));
let value = result.get("x-tenant-id").unwrap();
assert!(
value == "first_value" || value == "second_value" || value == "third_value",
"Value should be one of the provided values, got: {value}"
);
}
#[test]
fn test_multiple_headers_case_insensitive() {
let mut headers = BTreeMap::new();
headers.insert("X-User-Id".to_string(), "user123".to_string());
headers.insert("x-tenant-id".to_string(), "tenant456".to_string());
headers.insert("X-TENANT-ID".to_string(), "tenant789".to_string()); headers.insert("Content-Type".to_string(), "application/json".to_string());
let result = validate_headers(&headers).unwrap();
assert_eq!(result.len(), 3);
assert!(result.contains_key("x-user-id"));
assert!(result.contains_key("x-tenant-id"));
assert!(result.contains_key("content-type"));
let tenant_value = result.get("x-tenant-id").unwrap();
assert!(
tenant_value == "tenant456" || tenant_value == "tenant789",
"Tenant ID should be one of the override values"
);
}
}