use memchr::memmem;
#[derive(thiserror::Error, PartialEq, Debug)]
pub enum Error {
#[error("the http head was not complete")]
IncompleteHead,
#[error("the http head contained non-ascii characters")]
NonAscii,
#[error("missing http header: {0}")]
MissingHeader(&'static str),
#[error("malformed http header")]
MalformedHeader,
#[error("invalid http header value: {0}")]
InvalidHeaderValue(&'static str),
}
pub trait HeadParser<'de>: Sized {
fn parse_head_section(head: &'de str) -> Result<Self, Error>;
fn parse_headers(head_and_body: &'de [u8]) -> Result<(Self, &'de [u8]), Error> {
let head_end = memmem::find(head_and_body, b"\r\n\r\n").ok_or(Error::IncompleteHead)?;
let head_bytes = &head_and_body[..head_end];
if !head_bytes.is_ascii() {
return Err(Error::NonAscii);
}
let head = unsafe { std::str::from_utf8_unchecked(head_bytes) };
let headers = Self::parse_head_section(head)?;
let body = &head_and_body[head_end + 4..];
Ok((headers, body))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, PartialEq)]
pub struct SimpleHeaders {
pub content_length: usize,
}
impl<'de> HeadParser<'de> for SimpleHeaders {
fn parse_head_section(head: &'de str) -> Result<Self, Error> {
let content_length_str = head
.split("\r\n")
.find(|line| line.starts_with("Content-Length:"))
.ok_or(Error::MissingHeader("Content-Length"))?
.split(':')
.nth(1)
.ok_or(Error::MalformedHeader)?
.trim();
let content_length = content_length_str
.parse::<usize>()
.map_err(|_| Error::InvalidHeaderValue("Content-Length"))?;
Ok(SimpleHeaders {
content_length,
})
}
}
#[test]
fn parse_valid_head() {
let input_head = b"Content-Length: 5\r\nAnother-Header: value\r\n\r\nBodyHere";
let (headers, body) = SimpleHeaders::parse_headers(input_head).unwrap();
assert_eq!(headers, SimpleHeaders { content_length: 5 });
assert_eq!(body, b"BodyHere");
}
#[test]
fn error_on_non_ascii_head() {
let input_head = b"Content-Length: 5\r\nNon-Ascii: \x80\x81\x82\r\n\r\nBodyHere";
let result = SimpleHeaders::parse_headers(input_head);
assert_eq!(result, Err(Error::NonAscii));
}
#[test]
fn error_on_incomplete_head() {
let input_head = b"Content-Length: 5\r\nAnother-Header: value\r\nBodyWithoutHeadDelimiter";
let result = SimpleHeaders::parse_headers(input_head);
assert_eq!(result, Err(Error::IncompleteHead));
}
#[test]
fn error_on_missing_header() {
let input_head = b"Wrong-Header: 5\r\nAnother-Header: value\r\n\r\nBodyHere";
let result = SimpleHeaders::parse_headers(input_head);
assert_eq!(result, Err(Error::MissingHeader("Content-Length")));
}
#[test]
fn error_on_invalid_header_value() {
let input_head = b"Content-Length: invalid_value\r\nAnother-Header: value\r\n\r\nBodyHere";
let result = SimpleHeaders::parse_headers(input_head);
assert_eq!(result, Err(Error::InvalidHeaderValue("Content-Length")));
}
}