use std::fmt;
use http::HeaderValue;
pub const MAX_HEADER_VALUE_BYTES: usize = 8 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EscapeError {
ContainsCrlf,
ContainsNull,
ContainsTab,
ContainsNonPrintable(u8),
OversizeForBoundary(usize),
}
impl fmt::Display for EscapeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ContainsCrlf => {
f.write_str("header value contains CR or LF (would smuggle a second header line)")
}
Self::ContainsNull => f.write_str(
"header value contains NUL (proxies and intermediaries truncate on NUL)",
),
Self::ContainsTab => f.write_str(
"header value contains TAB (downstream log pipelines split on whitespace)",
),
Self::ContainsNonPrintable(b) => {
write!(f, "header value contains non-printable byte 0x{b:02X}")
}
Self::OversizeForBoundary(n) => write!(
f,
"header value length {n} exceeds the {MAX_HEADER_VALUE_BYTES}-byte boundary limit"
),
}
}
}
impl std::error::Error for EscapeError {}
pub struct HeaderEscapeGuard;
impl HeaderEscapeGuard {
pub fn header_value(s: &str) -> Result<HeaderValue, EscapeError> {
let bytes = s.as_bytes();
if bytes.len() > MAX_HEADER_VALUE_BYTES {
return Err(EscapeError::OversizeForBoundary(bytes.len()));
}
for &b in bytes {
match b {
b'\r' | b'\n' => return Err(EscapeError::ContainsCrlf),
0 => return Err(EscapeError::ContainsNull),
b'\t' => return Err(EscapeError::ContainsTab),
0x01..=0x08 | 0x0B | 0x0C | 0x0E..=0x1F | 0x7F => {
return Err(EscapeError::ContainsNonPrintable(b));
}
_ => {}
}
}
HeaderValue::from_bytes(bytes).map_err(|_| EscapeError::ContainsNonPrintable(0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn accepts_simple_ascii() {
let v = HeaderEscapeGuard::header_value("application/json").unwrap();
assert_eq!(v.as_bytes(), b"application/json");
}
#[test]
fn accepts_empty_string() {
let v = HeaderEscapeGuard::header_value("").unwrap();
assert_eq!(v.as_bytes(), b"");
}
#[test]
fn accepts_value_with_spaces_and_punctuation() {
let v = HeaderEscapeGuard::header_value("text/html; charset=utf-8, q=0.9").unwrap();
assert_eq!(v.as_bytes(), b"text/html; charset=utf-8, q=0.9");
}
#[test]
fn accepts_max_length_value() {
let s = "a".repeat(MAX_HEADER_VALUE_BYTES);
let v = HeaderEscapeGuard::header_value(&s).unwrap();
assert_eq!(v.as_bytes().len(), MAX_HEADER_VALUE_BYTES);
}
#[test]
fn accepts_high_bit_bytes() {
let v = HeaderEscapeGuard::header_value("café").unwrap();
assert_eq!(v.as_bytes(), "café".as_bytes());
}
#[test]
fn rejects_carriage_return() {
assert_eq!(
HeaderEscapeGuard::header_value("evil\rinjected"),
Err(EscapeError::ContainsCrlf)
);
}
#[test]
fn rejects_line_feed() {
assert_eq!(
HeaderEscapeGuard::header_value("evil\ninjected"),
Err(EscapeError::ContainsCrlf)
);
}
#[test]
fn rejects_crlf_pair_for_response_splitting() {
let payload = "ok\r\nX-Forged: 1\r\n\r\n<html>pwned</html>";
assert_eq!(
HeaderEscapeGuard::header_value(payload),
Err(EscapeError::ContainsCrlf)
);
}
#[test]
fn rejects_nul() {
assert_eq!(
HeaderEscapeGuard::header_value("trunc\0ate"),
Err(EscapeError::ContainsNull)
);
}
#[test]
fn rejects_tab() {
assert_eq!(
HeaderEscapeGuard::header_value("split\tlog"),
Err(EscapeError::ContainsTab)
);
}
#[test]
fn rejects_backspace() {
assert_eq!(
HeaderEscapeGuard::header_value("over\u{0008}type"),
Err(EscapeError::ContainsNonPrintable(0x08))
);
}
#[test]
fn rejects_bell() {
assert_eq!(
HeaderEscapeGuard::header_value("ding\u{0007}!"),
Err(EscapeError::ContainsNonPrintable(0x07))
);
}
#[test]
fn rejects_form_feed() {
assert_eq!(
HeaderEscapeGuard::header_value("page\u{000C}break"),
Err(EscapeError::ContainsNonPrintable(0x0C))
);
}
#[test]
fn rejects_vertical_tab() {
assert_eq!(
HeaderEscapeGuard::header_value("vert\u{000B}tab"),
Err(EscapeError::ContainsNonPrintable(0x0B))
);
}
#[test]
fn rejects_escape_byte() {
assert_eq!(
HeaderEscapeGuard::header_value("\u{001B}[31mred"),
Err(EscapeError::ContainsNonPrintable(0x1B))
);
}
#[test]
fn rejects_del_byte() {
assert_eq!(
HeaderEscapeGuard::header_value("hello\u{007F}"),
Err(EscapeError::ContainsNonPrintable(0x7F))
);
}
#[test]
fn rejects_oversize() {
let s = "a".repeat(MAX_HEADER_VALUE_BYTES + 1);
assert_eq!(
HeaderEscapeGuard::header_value(&s),
Err(EscapeError::OversizeForBoundary(MAX_HEADER_VALUE_BYTES + 1))
);
}
#[test]
fn oversize_check_runs_before_byte_scan() {
let mut s = String::with_capacity(MAX_HEADER_VALUE_BYTES + 4);
for _ in 0..(MAX_HEADER_VALUE_BYTES / 2 + 1) {
s.push_str("\r\n");
}
let n = s.len();
assert_eq!(
HeaderEscapeGuard::header_value(&s),
Err(EscapeError::OversizeForBoundary(n))
);
}
#[test]
fn error_display_mentions_byte_class() {
assert!(EscapeError::ContainsCrlf.to_string().contains("CR or LF"));
assert!(EscapeError::ContainsNull.to_string().contains("NUL"));
assert!(EscapeError::ContainsTab.to_string().contains("TAB"));
assert!(EscapeError::ContainsNonPrintable(0x07)
.to_string()
.contains("0x07"));
assert!(EscapeError::OversizeForBoundary(99_999)
.to_string()
.contains("99999"));
}
#[test]
fn snapshot_known_fixtures() {
let cases: &[(&str, Result<&[u8], EscapeError>)] = &[
("application/json", Ok(b"application/json")),
(
"max-age=31536000; includeSubDomains",
Ok(b"max-age=31536000; includeSubDomains"),
),
("nosniff", Ok(b"nosniff")),
("DENY", Ok(b"DENY")),
("\"abc-123\"", Ok(b"\"abc-123\"")),
("evil\r\nLocation: /pwned", Err(EscapeError::ContainsCrlf)),
("set-cookie\nset-cookie", Err(EscapeError::ContainsCrlf)),
(
"bell\x07alarm",
Err(EscapeError::ContainsNonPrintable(0x07)),
),
("trunc\0ate", Err(EscapeError::ContainsNull)),
("split\there", Err(EscapeError::ContainsTab)),
];
for (input, expected) in cases {
let got = HeaderEscapeGuard::header_value(input);
match (expected, &got) {
(Ok(bytes), Ok(v)) => {
assert_eq!(v.as_bytes(), *bytes, "input {input:?} produced wrong bytes")
}
(Err(want), Err(got_err)) => {
assert_eq!(want, got_err, "input {input:?} produced wrong error")
}
(Ok(_), Err(e)) => panic!("input {input:?} unexpectedly rejected: {e:?}"),
(Err(want), Ok(v)) => panic!(
"input {input:?} unexpectedly accepted (bytes={:?}); wanted {want:?}",
v.as_bytes()
),
}
}
}
#[test]
fn fuzz_every_single_byte_position() {
for byte in 0u8..=0x1F {
for pos in [0usize, 5, 9] {
let mut bytes = b"abcdefghij".to_vec();
bytes.insert(pos, byte);
let s = String::from_utf8(bytes).unwrap();
let got = HeaderEscapeGuard::header_value(&s);
let want = match byte {
b'\r' | b'\n' => EscapeError::ContainsCrlf,
0 => EscapeError::ContainsNull,
b'\t' => EscapeError::ContainsTab,
_ => EscapeError::ContainsNonPrintable(byte),
};
assert_eq!(got, Err(want), "byte 0x{byte:02X} at pos {pos}");
}
}
assert_eq!(
HeaderEscapeGuard::header_value("a\u{007F}b"),
Err(EscapeError::ContainsNonPrintable(0x7F))
);
}
#[test]
fn fuzz_every_printable_ascii_accepted() {
for byte in 0x20u8..0x7F {
let s = format!("x{}y", byte as char);
assert!(
HeaderEscapeGuard::header_value(&s).is_ok(),
"byte 0x{byte:02X} should be accepted",
);
}
}
#[test]
fn fuzz_every_high_bit_byte_accepted() {
for codepoint in 0x80u32..=0xFF {
let s = char::from_u32(codepoint).unwrap().to_string();
let v = HeaderEscapeGuard::header_value(&s).unwrap();
assert_eq!(v.as_bytes(), s.as_bytes());
}
}
#[test]
fn fuzz_oversize_boundary() {
let exact = "a".repeat(MAX_HEADER_VALUE_BYTES);
assert!(HeaderEscapeGuard::header_value(&exact).is_ok());
let over = "a".repeat(MAX_HEADER_VALUE_BYTES + 1);
assert_eq!(
HeaderEscapeGuard::header_value(&over),
Err(EscapeError::OversizeForBoundary(MAX_HEADER_VALUE_BYTES + 1))
);
}
#[test]
fn fuzz_concatenation_attacks() {
let trailers = [
"\r\n",
"\n",
"\r",
"\r\nX-Forged: 1",
"\r\nLocation: http://attacker/",
"\r\n\r\n<html>",
];
for trailer in trailers {
let payload = format!("application/json{trailer}");
assert_eq!(
HeaderEscapeGuard::header_value(&payload),
Err(EscapeError::ContainsCrlf),
"payload {payload:?} must reject"
);
}
}
}