use std::collections::BTreeMap;
use http::HeaderMap;
pub const BODY_SAMPLE_CAP: usize = 256;
const SECURITY_ALLOWLIST: &[&str] = &[
"authorization",
"if-match",
"if-none-match",
"if-modified-since",
"if-unmodified-since",
"if-range",
"range",
"accept",
"accept-language",
"content-type",
"www-authenticate",
"cookie",
"set-cookie",
];
#[must_use]
pub fn filter_security_headers(headers: &HeaderMap) -> BTreeMap<String, String> {
let mut out = BTreeMap::new();
for (name, value) in headers {
let lower = name.as_str().to_ascii_lowercase();
if !is_relevant(&lower) {
continue;
}
if out.contains_key(&lower) {
continue;
}
if let Ok(s) = value.to_str() {
out.insert(lower, s.to_owned());
}
}
out
}
fn is_relevant(lower: &str) -> bool {
SECURITY_ALLOWLIST.contains(&lower) || lower.starts_with("x-")
}
#[must_use]
pub fn truncate_body_sample(body: &[u8]) -> String {
match std::str::from_utf8(body) {
Ok(s) if s.len() <= BODY_SAMPLE_CAP => s.to_owned(),
Ok(s) => {
let cut = utf8_safe_truncate(s, BODY_SAMPLE_CAP);
format!("{cut}\u{2026} (truncated, total {}b)", body.len())
}
Err(_) => format!("<{} bytes, non-text>", body.len()),
}
}
fn utf8_safe_truncate(s: &str, max_bytes: usize) -> &str {
let mut end = max_bytes.min(s.len());
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
#[cfg(test)]
mod tests {
use super::*;
use http::{HeaderMap, HeaderName, HeaderValue};
fn hm(pairs: &[(&str, &str)]) -> HeaderMap {
let mut h = HeaderMap::new();
for &(k, v) in pairs {
h.insert(
HeaderName::from_bytes(k.as_bytes()).expect("name"),
HeaderValue::from_str(v).expect("value"),
);
}
h
}
#[test]
fn filter_security_headers_includes_authorization() {
let h = hm(&[("authorization", "Bearer abc")]);
let out = filter_security_headers(&h);
assert_eq!(out.get("authorization"), Some(&"Bearer abc".to_owned()));
}
#[test]
fn filter_security_headers_includes_if_match() {
let h = hm(&[("if-match", "W/\"v1\"")]);
let out = filter_security_headers(&h);
assert_eq!(out.get("if-match"), Some(&"W/\"v1\"".to_owned()));
}
#[test]
fn filter_security_headers_includes_x_custom_header() {
let h = hm(&[("x-custom-header", "1")]);
let out = filter_security_headers(&h);
assert_eq!(out.get("x-custom-header"), Some(&"1".to_owned()));
}
#[test]
fn filter_security_headers_excludes_user_agent() {
let h = hm(&[("user-agent", "parlov/test")]);
let out = filter_security_headers(&h);
assert!(!out.contains_key("user-agent"));
}
#[test]
fn filter_security_headers_excludes_host() {
let h = hm(&[("host", "example.com")]);
let out = filter_security_headers(&h);
assert!(!out.contains_key("host"));
}
#[test]
fn filter_security_headers_normalizes_case() {
let h = hm(&[("If-Match", "x")]);
let out = filter_security_headers(&h);
assert!(out.contains_key("if-match"));
}
#[test]
fn truncate_body_sample_under_cap_full_content() {
let body = b"hello world";
let s = truncate_body_sample(body);
assert_eq!(s, "hello world");
}
#[test]
fn truncate_body_sample_over_cap_truncated_with_marker() {
let body = vec![b'a'; BODY_SAMPLE_CAP + 50];
let s = truncate_body_sample(&body);
assert!(s.contains("(truncated, total"), "got: {s}");
assert!(s.starts_with("aaaa"));
}
#[test]
fn truncate_body_sample_utf8_boundary_safe() {
let single = '\u{1F600}';
let n = (BODY_SAMPLE_CAP / 4) + 5;
let body: String = std::iter::repeat_n(single, n).collect();
let s = truncate_body_sample(body.as_bytes());
let prefix = s.split('\u{2026}').next().unwrap_or(&s);
assert!(
prefix.chars().all(|c| c == single),
"non-emoji char leaked: {prefix:?}"
);
}
#[test]
fn truncate_body_sample_non_utf8_byte_count_fallback() {
let body = vec![0xff, 0xff, 0xff];
let s = truncate_body_sample(&body);
assert_eq!(s, "<3 bytes, non-text>");
}
}