use std::fmt;
pub const MAX_HEADER_SECTION_LEN: usize = 8192;
const HTTP2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum HttpParseError {
HeaderSectionOverflow,
MalformedRequestLine,
InvalidByte,
Http09NotSupported,
}
impl fmt::Display for HttpParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HttpParseError::HeaderSectionOverflow => write!(
f,
"HTTP header section exceeded {MAX_HEADER_SECTION_LEN} bytes without CRLFCRLF"
),
HttpParseError::MalformedRequestLine => write!(f, "malformed HTTP/1.x request line"),
HttpParseError::InvalidByte => write!(f, "invalid byte in request line"),
HttpParseError::Http09NotSupported => write!(f, "HTTP/0.9 has no Host semantics"),
}
}
}
impl std::error::Error for HttpParseError {}
pub fn extract_http_host(first_bytes: &[u8]) -> Result<Option<String>, HttpParseError> {
if first_bytes.len() >= HTTP2_PREFACE.len()
&& &first_bytes[..HTTP2_PREFACE.len()] == HTTP2_PREFACE
{
return Ok(None);
}
let scan_end = first_bytes.len().min(MAX_HEADER_SECTION_LEN);
let header_end = match find_crlf_crlf(&first_bytes[..scan_end]) {
Some(idx) => idx,
None => {
if first_bytes.len() >= MAX_HEADER_SECTION_LEN {
return Err(HttpParseError::HeaderSectionOverflow);
}
return Err(HttpParseError::MalformedRequestLine);
}
};
let header_section_end = header_end + 2;
let header_section = &first_bytes[..header_section_end];
let line_end = find_crlf(header_section).ok_or(HttpParseError::MalformedRequestLine)?;
let request_line = &header_section[..line_end];
validate_request_line(request_line)?;
let space_count = request_line.iter().filter(|&&b| b == b' ').count();
if space_count < 2 {
return Err(HttpParseError::Http09NotSupported);
}
let mut idx = line_end + 2;
while idx < header_section.len() {
let rel_end = match find_crlf(&header_section[idx..]) {
Some(e) => e,
None => break,
};
let line = &header_section[idx..idx + rel_end];
if line.is_empty() {
break;
}
if let Some(colon) = line.iter().position(|&b| b == b':') {
let name = &line[..colon];
if eq_ignore_ascii(name, b"Host") {
let value = &line[colon + 1..];
return Ok(Some(normalise_host_value(value)));
}
}
idx += rel_end + 2;
}
Ok(None)
}
fn find_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(2).position(|w| w == b"\r\n")
}
fn find_crlf_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|w| w == b"\r\n\r\n")
}
fn validate_request_line(line: &[u8]) -> Result<(), HttpParseError> {
if line.is_empty() {
return Err(HttpParseError::MalformedRequestLine);
}
for &b in line {
if b == 0 || b == b'\n' || b == b'\r' {
return Err(HttpParseError::InvalidByte);
}
if !(0x20..=0x7E).contains(&b) {
return Err(HttpParseError::InvalidByte);
}
}
Ok(())
}
fn eq_ignore_ascii(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.all(|(x, y)| x.eq_ignore_ascii_case(y))
}
fn normalise_host_value(raw: &[u8]) -> String {
let mut start = 0;
while start < raw.len() && (raw[start] == b' ' || raw[start] == b'\t') {
start += 1;
}
let mut end = raw.len();
while end > start && (raw[end - 1] == b' ' || raw[end - 1] == b'\t') {
end -= 1;
}
let v = &raw[start..end];
let host_bytes = if v.first() == Some(&b'[') {
if let Some(close) = v.iter().position(|&b| b == b']') {
&v[..=close]
} else {
v
}
} else if let Some(colon) = v.iter().position(|&b| b == b':') {
&v[..colon]
} else {
v
};
let mut s = String::from_utf8_lossy(host_bytes).to_string();
s.make_ascii_lowercase();
if s.ends_with('.') {
s.pop();
}
s
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_get_host() {
let req = b"GET / HTTP/1.1\r\nHost: api.example.com\r\nUser-Agent: x\r\n\r\n";
assert_eq!(
extract_http_host(req).unwrap().as_deref(),
Some("api.example.com")
);
}
#[test]
fn extracts_post_host() {
let req = b"POST /v1/x HTTP/1.1\r\nHost: api.example.com\r\nContent-Length: 0\r\n\r\n";
assert_eq!(
extract_http_host(req).unwrap().as_deref(),
Some("api.example.com")
);
}
#[test]
fn missing_host_returns_ok_none() {
let req = b"GET / HTTP/1.1\r\nAccept: */*\r\n\r\n";
assert_eq!(extract_http_host(req).unwrap(), None);
}
#[test]
fn case_insensitive_host_header() {
let req = b"GET / HTTP/1.1\r\nhost: api.example.com\r\n\r\n";
assert_eq!(
extract_http_host(req).unwrap().as_deref(),
Some("api.example.com")
);
let req2 = b"GET / HTTP/1.1\r\nHOST: api.example.com\r\n\r\n";
assert_eq!(
extract_http_host(req2).unwrap().as_deref(),
Some("api.example.com")
);
}
#[test]
fn host_with_port_strips_port() {
let req = b"GET / HTTP/1.1\r\nHost: api.example.com:8443\r\n\r\n";
assert_eq!(
extract_http_host(req).unwrap().as_deref(),
Some("api.example.com")
);
}
#[test]
fn ipv6_host_with_port_keeps_brackets() {
let req = b"GET / HTTP/1.1\r\nHost: [::1]:443\r\n\r\n";
assert_eq!(extract_http_host(req).unwrap().as_deref(), Some("[::1]"));
}
#[test]
fn malformed_request_line_no_version() {
let req = b"GET /\r\nHost: x.example.com\r\n\r\n";
assert!(matches!(
extract_http_host(req),
Err(HttpParseError::Http09NotSupported)
));
}
#[test]
fn oversized_header_section_rejected() {
let mut buf: Vec<u8> = Vec::with_capacity(MAX_HEADER_SECTION_LEN + 16);
buf.extend_from_slice(b"GET / HTTP/1.1\r\n");
buf.extend_from_slice(b"X-Pad: ");
while buf.len() < MAX_HEADER_SECTION_LEN + 8 {
buf.push(b'a');
}
assert_eq!(
extract_http_host(&buf),
Err(HttpParseError::HeaderSectionOverflow)
);
}
#[test]
fn http2_preface_returns_ok_none() {
let bytes = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n\x00\x00";
assert_eq!(extract_http_host(bytes).unwrap(), None);
}
#[test]
fn crlf_injection_in_request_line_rejected() {
let req = b"GET /\rinjected HTTP/1.1\r\nHost: api.example.com\r\n\r\n";
assert!(matches!(
extract_http_host(req),
Err(HttpParseError::InvalidByte) | Err(HttpParseError::MalformedRequestLine)
));
}
#[test]
fn nul_byte_in_request_line_rejected() {
let req = b"GET /\x00 HTTP/1.1\r\nHost: api.example.com\r\n\r\n";
assert!(matches!(
extract_http_host(req),
Err(HttpParseError::InvalidByte)
));
}
#[test]
fn trailing_dot_in_host_stripped() {
let req = b"GET / HTTP/1.1\r\nHost: api.example.com.\r\n\r\n";
assert_eq!(
extract_http_host(req).unwrap().as_deref(),
Some("api.example.com")
);
}
#[test]
fn host_lowercased() {
let req = b"GET / HTTP/1.1\r\nHost: API.Example.COM\r\n\r\n";
assert_eq!(
extract_http_host(req).unwrap().as_deref(),
Some("api.example.com")
);
}
#[test]
fn incomplete_request_no_crlfcrlf_is_malformed() {
let req = b"GET / HTTP/1.1\r\nHost: api.example.com\r\n";
assert_eq!(
extract_http_host(req),
Err(HttpParseError::MalformedRequestLine)
);
}
}