use std::borrow::Cow;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HeaderError {
EmptyFieldName,
UppercaseFieldName { name: Vec<u8> },
InvalidFieldNameByte { name: Vec<u8>, byte: u8 },
InvalidFieldValueByte { name: Vec<u8>, byte: u8 },
FieldValueLeadingOrTrailingWhitespace { name: Vec<u8> },
UnknownPseudoHeader { name: Vec<u8> },
InvalidPseudoHeaderValue { name: Vec<u8>, value: Vec<u8> },
}
fn write_escaped(f: &mut core::fmt::Formatter<'_>, bytes: &[u8]) -> core::fmt::Result {
use core::fmt::Write;
const MAX_DISPLAY_LEN: usize = 64;
let truncated = bytes.len() > MAX_DISPLAY_LEN;
let view = if truncated {
&bytes[..MAX_DISPLAY_LEN]
} else {
bytes
};
for &b in view {
for ch in std::ascii::escape_default(b) {
f.write_char(ch as char)?;
}
}
if truncated {
f.write_str("...")?;
}
Ok(())
}
impl core::fmt::Display for HeaderError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::EmptyFieldName => write!(f, "field name must not be empty"),
Self::UppercaseFieldName { name } => {
write!(f, "field name must be lowercase: ")?;
write_escaped(f, name)
}
Self::InvalidFieldNameByte { name, byte } => {
write!(f, "field name contains invalid byte 0x{byte:02x}: ")?;
write_escaped(f, name)
}
Self::InvalidFieldValueByte { name, byte } => {
write!(f, "field value contains invalid byte 0x{byte:02x} (name: ")?;
write_escaped(f, name)?;
f.write_str(")")
}
Self::FieldValueLeadingOrTrailingWhitespace { name } => {
write!(
f,
"field value must not start or end with whitespace (name: "
)?;
write_escaped(f, name)?;
f.write_str(")")
}
Self::UnknownPseudoHeader { name } => {
write!(f, "unknown pseudo header: ")?;
write_escaped(f, name)
}
Self::InvalidPseudoHeaderValue { name, value } => {
write!(f, "invalid pseudo header value for ")?;
write_escaped(f, name)?;
write!(f, ": ")?;
write_escaped(f, value)
}
}
}
}
impl std::error::Error for HeaderError {}
const KNOWN_PSEUDO_HEADERS: &[&[u8]] = &[
b":method",
b":scheme",
b":authority",
b":path",
b":status",
b":protocol",
];
const fn is_lowercase_token_byte(b: u8) -> bool {
matches!(b,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
b'^' | b'_' | b'`' | b'|' | b'~' |
b'0'..=b'9' |
b'a'..=b'z'
)
}
const fn is_field_vchar(b: u8) -> bool {
matches!(b, 0x21..=0x7e | 0x80..=0xff)
}
const fn is_tchar(b: u8) -> bool {
matches!(b,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
b'^' | b'_' | b'`' | b'|' | b'~' |
b'0'..=b'9' |
b'A'..=b'Z' |
b'a'..=b'z'
)
}
enum CheckResult {
Ok,
EmptyFieldName,
UppercaseFieldName,
InvalidFieldNameByte(u8),
InvalidFieldValueByte(u8),
FieldValueLeadingOrTrailingWhitespace,
UnknownPseudoHeader,
InvalidPseudoHeaderValue,
}
const fn check_header(name: &[u8], value: &[u8]) -> CheckResult {
if name.is_empty() {
return CheckResult::EmptyFieldName;
}
let is_pseudo = name[0] == b':';
let name_start = if is_pseudo { 1 } else { 0 };
let mut i = name_start;
while i < name.len() {
let b = name[i];
if b.is_ascii_uppercase() {
return CheckResult::UppercaseFieldName;
}
if !is_lowercase_token_byte(b) {
return CheckResult::InvalidFieldNameByte(b);
}
i += 1;
}
if is_pseudo && name.len() == 1 {
return CheckResult::UnknownPseudoHeader;
}
if is_pseudo && !is_known_pseudo_header(name) {
return CheckResult::UnknownPseudoHeader;
}
let mut i = 0;
while i < value.len() {
let b = value[i];
if !(is_field_vchar(b) || b == b' ' || b == b'\t') {
return CheckResult::InvalidFieldValueByte(b);
}
i += 1;
}
if !value.is_empty() {
let first = value[0];
let last = value[value.len() - 1];
if !is_field_vchar(first) || !is_field_vchar(last) {
return CheckResult::FieldValueLeadingOrTrailingWhitespace;
}
}
if is_pseudo {
match check_pseudo_value(name, value) {
CheckResult::Ok => {}
other => return other,
}
}
CheckResult::Ok
}
const fn is_known_pseudo_header(name: &[u8]) -> bool {
let mut i = 0;
while i < KNOWN_PSEUDO_HEADERS.len() {
if bytes_eq(KNOWN_PSEUDO_HEADERS[i], name) {
return true;
}
i += 1;
}
false
}
const fn bytes_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut i = 0;
while i < a.len() {
if a[i] != b[i] {
return false;
}
i += 1;
}
true
}
const fn check_pseudo_value(name: &[u8], value: &[u8]) -> CheckResult {
if bytes_eq(name, b":method") {
if value.is_empty() {
return CheckResult::InvalidPseudoHeaderValue;
}
let mut i = 0;
while i < value.len() {
if !is_tchar(value[i]) {
return CheckResult::InvalidPseudoHeaderValue;
}
i += 1;
}
} else if bytes_eq(name, b":scheme") {
if value.is_empty() {
return CheckResult::InvalidPseudoHeaderValue;
}
let first = value[0];
if !first.is_ascii_alphabetic() {
return CheckResult::InvalidPseudoHeaderValue;
}
let mut i = 1;
while i < value.len() {
let b = value[i];
if !(b.is_ascii_alphanumeric() || b == b'+' || b == b'-' || b == b'.') {
return CheckResult::InvalidPseudoHeaderValue;
}
i += 1;
}
} else if bytes_eq(name, b":status") {
if value.len() != 3 {
return CheckResult::InvalidPseudoHeaderValue;
}
let mut i = 0;
while i < 3 {
if !value[i].is_ascii_digit() {
return CheckResult::InvalidPseudoHeaderValue;
}
i += 1;
}
} else if bytes_eq(name, b":protocol") {
match check_upgrade_token(value) {
CheckResult::Ok => {}
other => return other,
}
}
CheckResult::Ok
}
const fn check_upgrade_token(value: &[u8]) -> CheckResult {
if value.is_empty() {
return CheckResult::InvalidPseudoHeaderValue;
}
let mut slash = usize::MAX;
let mut i = 0;
while i < value.len() {
if value[i] == b'/' {
slash = i;
break;
}
i += 1;
}
if slash == usize::MAX {
let mut j = 0;
while j < value.len() {
if !is_tchar(value[j]) {
return CheckResult::InvalidPseudoHeaderValue;
}
j += 1;
}
return CheckResult::Ok;
}
if slash == 0 {
return CheckResult::InvalidPseudoHeaderValue;
}
let mut j = 0;
while j < slash {
if !is_tchar(value[j]) {
return CheckResult::InvalidPseudoHeaderValue;
}
j += 1;
}
if slash + 1 >= value.len() {
return CheckResult::InvalidPseudoHeaderValue;
}
let mut j = slash + 1;
while j < value.len() {
if !is_tchar(value[j]) {
return CheckResult::InvalidPseudoHeaderValue;
}
j += 1;
}
CheckResult::Ok
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Header {
name: Cow<'static, [u8]>,
value: Cow<'static, [u8]>,
}
impl Header {
pub fn new(name: impl AsRef<[u8]>, value: impl AsRef<[u8]>) -> Result<Self, HeaderError> {
let name = name.as_ref();
let value = value.as_ref();
match check_header(name, value) {
CheckResult::Ok => Ok(Self {
name: Cow::Owned(name.to_vec()),
value: Cow::Owned(value.to_vec()),
}),
CheckResult::EmptyFieldName => Err(HeaderError::EmptyFieldName),
CheckResult::UppercaseFieldName => Err(HeaderError::UppercaseFieldName {
name: name.to_vec(),
}),
CheckResult::InvalidFieldNameByte(byte) => Err(HeaderError::InvalidFieldNameByte {
name: name.to_vec(),
byte,
}),
CheckResult::InvalidFieldValueByte(byte) => Err(HeaderError::InvalidFieldValueByte {
name: name.to_vec(),
byte,
}),
CheckResult::FieldValueLeadingOrTrailingWhitespace => {
Err(HeaderError::FieldValueLeadingOrTrailingWhitespace {
name: name.to_vec(),
})
}
CheckResult::UnknownPseudoHeader => Err(HeaderError::UnknownPseudoHeader {
name: name.to_vec(),
}),
CheckResult::InvalidPseudoHeaderValue => Err(HeaderError::InvalidPseudoHeaderValue {
name: name.to_vec(),
value: value.to_vec(),
}),
}
}
#[track_caller]
pub const fn from_static(name: &'static [u8], value: &'static [u8]) -> Self {
match check_header(name, value) {
CheckResult::Ok => {}
CheckResult::EmptyFieldName => {
panic!("Header::from_static: field name must not be empty");
}
CheckResult::UppercaseFieldName => {
panic!("Header::from_static: field name must be lowercase");
}
CheckResult::InvalidFieldNameByte(_) => {
panic!("Header::from_static: field name contains invalid byte");
}
CheckResult::InvalidFieldValueByte(_) => {
panic!("Header::from_static: field value contains invalid byte");
}
CheckResult::FieldValueLeadingOrTrailingWhitespace => {
panic!("Header::from_static: field value must not start or end with whitespace");
}
CheckResult::UnknownPseudoHeader => {
panic!("Header::from_static: unknown pseudo header");
}
CheckResult::InvalidPseudoHeaderValue => {
panic!("Header::from_static: invalid pseudo header value");
}
}
Self {
name: Cow::Borrowed(name),
value: Cow::Borrowed(value),
}
}
pub(crate) fn from_validated_parts_internal(
name: Cow<'static, [u8]>,
value: Cow<'static, [u8]>,
) -> Self {
Self { name, value }
}
#[inline]
pub fn name(&self) -> &[u8] {
&self.name
}
#[inline]
pub fn value(&self) -> &[u8] {
&self.value
}
#[inline]
pub fn size(&self) -> usize {
self.name.len() + self.value.len() + 32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_simple_header() {
let h = Header::new(b":method", b"GET").unwrap();
assert_eq!(h.name(), b":method");
assert_eq!(h.value(), b"GET");
}
#[test]
fn test_new_empty_value_allowed() {
let h = Header::new(b":authority", b"").unwrap();
assert_eq!(h.value(), b"");
}
#[test]
fn test_new_uppercase_field_name_rejected() {
assert!(matches!(
Header::new(b"Host", b"example.com"),
Err(HeaderError::UppercaseFieldName { .. })
));
}
#[test]
fn test_new_invalid_field_name_byte_rejected() {
assert!(matches!(
Header::new(b"x hdr", b"v"),
Err(HeaderError::InvalidFieldNameByte { .. })
));
}
#[test]
fn test_new_empty_field_name_rejected() {
assert!(matches!(
Header::new(b"", b"v"),
Err(HeaderError::EmptyFieldName)
));
}
#[test]
fn test_new_field_value_with_cr_rejected() {
assert!(matches!(
Header::new(b":path", b"/foo\r\nX-Inject: 1"),
Err(HeaderError::InvalidFieldValueByte { byte: 0x0d, .. })
));
}
#[test]
fn test_new_field_value_with_lf_rejected() {
assert!(matches!(
Header::new(b"x-h", b"v\nv"),
Err(HeaderError::InvalidFieldValueByte { byte: 0x0a, .. })
));
}
#[test]
fn test_new_field_value_with_nul_rejected() {
assert!(matches!(
Header::new(b"x-h", b"v\0v"),
Err(HeaderError::InvalidFieldValueByte { byte: 0x00, .. })
));
}
#[test]
fn test_new_field_value_leading_space_rejected() {
assert!(matches!(
Header::new(b"x-h", b" v"),
Err(HeaderError::FieldValueLeadingOrTrailingWhitespace { .. })
));
}
#[test]
fn test_new_field_value_trailing_tab_rejected() {
assert!(matches!(
Header::new(b"x-h", b"v\t"),
Err(HeaderError::FieldValueLeadingOrTrailingWhitespace { .. })
));
}
#[test]
fn test_new_unknown_pseudo_header_rejected() {
assert!(matches!(
Header::new(b":unknown", b"value"),
Err(HeaderError::UnknownPseudoHeader { .. })
));
}
#[test]
fn test_new_invalid_method_rejected() {
assert!(matches!(
Header::new(b":method", b"GET POST"),
Err(HeaderError::InvalidPseudoHeaderValue { .. })
));
}
#[test]
fn test_new_invalid_scheme_rejected() {
assert!(matches!(
Header::new(b":scheme", b"1http"),
Err(HeaderError::InvalidPseudoHeaderValue { .. })
));
}
#[test]
fn test_new_invalid_status_rejected() {
assert!(matches!(
Header::new(b":status", b"20"),
Err(HeaderError::InvalidPseudoHeaderValue { .. })
));
assert!(matches!(
Header::new(b":status", b"abc"),
Err(HeaderError::InvalidPseudoHeaderValue { .. })
));
}
#[test]
fn test_from_static_ok() {
const H: Header = Header::from_static(b":method", b"GET");
assert_eq!(H.name(), b":method");
assert_eq!(H.value(), b"GET");
}
#[test]
fn test_size() {
let h = Header::new(b":method", b"GET").unwrap();
assert_eq!(h.size(), 7 + 3 + 32);
}
#[test]
fn test_from_validated_parts_internal_skips_validation() {
let h = Header::from_validated_parts_internal(
Cow::Borrowed(b"Host"),
Cow::Borrowed(b"example"),
);
assert_eq!(h.name(), b"Host");
assert!(matches!(
Header::new(h.name(), h.value()),
Err(HeaderError::UppercaseFieldName { .. })
));
}
#[test]
fn test_display_escapes_crlf_in_name() {
let err = Header::new(b"x\r\nzz", b"v").unwrap_err();
let msg = format!("{err}");
assert!(!msg.contains('\r'), "CR must be escaped: {msg}");
assert!(!msg.contains('\n'), "LF must be escaped: {msg}");
assert!(msg.contains("\\r"), "should contain escaped \\r: {msg}");
assert!(msg.contains("\\n"), "should contain escaped \\n: {msg}");
}
#[test]
fn test_display_truncates_long_name() {
let name = vec![b'A'; 65];
let err = Header::new(&name, b"v").unwrap_err();
let msg = format!("{err}");
assert!(msg.ends_with("..."), "long name must be truncated: {msg}");
}
#[test]
fn test_display_does_not_truncate_at_64_bytes() {
let name = vec![b'A'; 64];
let err = Header::new(&name, b"v").unwrap_err();
let msg = format!("{err}");
assert!(
!msg.ends_with("..."),
"64 bytes name must not be truncated: {msg}"
);
}
}