use std::{
collections::HashMap,
io::{self},
str::FromStr,
};
use http::{Method, Version};
use url::form_urlencoded;
#[derive(PartialEq)]
enum ParseState {
RequestLine,
Headers,
Body,
Complete,
Error(String),
}
impl Default for ParseState {
fn default() -> Self {
Self::RequestLine
}
}
#[derive(Default)]
pub struct HttpRequest {
method: Method,
path: String,
version: Version,
headers: HashMap<String, String>,
body: Vec<u8>,
parse_state: ParseState,
buffer: Vec<u8>,
header_index: usize,
body_bytes_read: usize,
}
impl HttpRequest {
pub fn new() -> Self {
Self::default()
}
pub fn parse(&mut self, input: &[u8]) -> io::Result<bool> {
if input.is_empty() {
return Ok(true);
}
self.buffer.extend_from_slice(input);
loop {
match self.parse_state {
ParseState::RequestLine => {
if !self.parse_request_line()? {
return Ok(true);
}
}
ParseState::Headers => {
if !self.parse_headers()? {
return Ok(true);
}
}
ParseState::Body => {
if !self.parse_body()? {
return Ok(true);
}
}
ParseState::Complete => {
return Ok(false);
}
ParseState::Error(ref err) => {
return Err(io::Error::new(io::ErrorKind::InvalidData, err.clone()));
}
}
}
}
fn parse_request_line(&mut self) -> io::Result<bool> {
if let Some(line_end) = find_line_end(&self.buffer) {
let line = &self.buffer[..line_end];
let line_str = std::str::from_utf8(line)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let parts: Vec<&str> = line_str.split_whitespace().collect();
if parts.len() != 3 {
self.parse_state = ParseState::Error("Invalid request line".into());
return Ok(false);
}
self.method = Method::from_str(parts[0])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
self.path = parts[1].to_string();
self.version = parse_http_version(parts[2]).unwrap_or_else(|| {
self.parse_state = ParseState::Error("Invalid HTTP version".into());
Version::HTTP_09
});
self.buffer.drain(..line_end + 2);
self.parse_state = ParseState::Headers;
Ok(true)
} else {
Ok(false)
}
}
fn parse_headers(&mut self) -> io::Result<bool> {
while let Some(line_end) = find_line_end(&self.buffer[self.header_index..]) {
let absolute_end = self.header_index + line_end;
if line_end == 0 {
self.buffer.drain(..self.header_index + 2);
self.header_index = 0;
self.parse_state = if self.headers.contains_key("Content-Length") {
ParseState::Body
} else {
ParseState::Complete
};
return Ok(true);
}
let line = &self.buffer[self.header_index..absolute_end];
let line_str = std::str::from_utf8(line)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
if let Some((key, value)) = parse_header_line(line_str) {
self.headers.insert(key.to_string(), value.to_string());
}
self.header_index = absolute_end + 2;
}
if self.buffer.len() - self.header_index > 8192 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Headers too large",
));
}
Ok(false)
}
fn parse_body(&mut self) -> io::Result<bool> {
let content_length: usize = match self.headers.get("Content-Length") {
Some(len) => len.parse().map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "Invalid Content-Length")
})?,
None => {
self.parse_state = ParseState::Error("Missing Content-Length".into());
return Ok(false);
}
};
let bytes_remaining = content_length.saturating_sub(self.body_bytes_read);
let bytes_available = self.buffer.len();
if bytes_remaining > 0 && bytes_available > 0 {
let bytes_to_read = bytes_remaining.min(bytes_available);
self.body.extend_from_slice(&self.buffer[..bytes_to_read]);
self.buffer.drain(..bytes_to_read);
self.body_bytes_read += bytes_to_read;
if self.body_bytes_read >= content_length {
self.parse_state = ParseState::Complete;
}
}
Ok(self.parse_state == ParseState::Complete)
}
pub fn method(&self) -> &Method {
&self.method
}
pub fn path(&self) -> &str {
&self.path
}
pub fn version(&self) -> &Version {
&self.version
}
pub fn headers(&self) -> &HashMap<String, String> {
&self.headers
}
pub fn body(&self) -> &[u8] {
&self.body
}
pub fn is_complete(&self) -> bool {
matches!(self.parse_state, ParseState::Complete)
}
pub fn query_params(&self) -> HashMap<String, Vec<String>> {
let mut params: HashMap<String, Vec<String>> = HashMap::new();
if let Some(query_start) = self.path.find('?') {
let query_string = &self.path[query_start + 1..];
for (key, value) in form_urlencoded::parse(query_string.as_bytes()) {
params
.entry(key.into_owned())
.or_default()
.push(value.into_owned());
}
}
params
}
}
pub fn http_version_to_string(version: &Version) -> &'static str {
match *version {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2",
Version::HTTP_3 => "HTTP/3",
_ => "Unknown",
}
}
fn parse_http_version(version_str: &str) -> Option<Version> {
let version_str = version_str.to_uppercase();
match version_str.as_str() {
"HTTP/0.9" => Some(Version::HTTP_09),
"HTTP/1.0" => Some(Version::HTTP_10),
"HTTP/1.1" => Some(Version::HTTP_11),
"HTTP/2" | "HTTP/2.0" => Some(Version::HTTP_2),
"HTTP/3" | "HTTP/3.0" => Some(Version::HTTP_3),
_ => None, }
}
fn find_line_end(buf: &[u8]) -> Option<usize> {
buf.windows(2).position(|window| window == b"\r\n")
}
fn parse_header_line(line: &str) -> Option<(&str, &str)> {
let mut parts = line.splitn(2, ':');
Some((parts.next()?.trim(), parts.next()?.trim()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_request_line() {
let mut request = HttpRequest::new();
let input = b"GET /index.html HTTP/1.1\r\n";
assert!(request.parse(input).unwrap());
assert_eq!(*request.method(), Method::GET);
assert_eq!(request.path(), "/index.html");
assert_eq!(*request.version(), Version::HTTP_11);
}
#[test]
fn test_parse_headers() {
let mut request = HttpRequest::new();
let input = b"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 5\r\n\r\n";
assert!(request.parse(input).unwrap());
assert_eq!(request.headers().get("Host").unwrap(), "example.com");
assert_eq!(request.headers().get("Content-Length").unwrap(), "5");
}
#[test]
fn test_parse_body() {
let mut request = HttpRequest::new();
let part1 = b"POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\nHell";
let part2 = b"o";
assert!(request.parse(part1).unwrap());
assert!(!request.is_complete());
assert!(!request.parse(part2).unwrap());
assert!(request.is_complete());
assert_eq!(request.body(), b"Hello");
}
}