use crate::error::ServerError;
use std::fmt;
use std::io::{self, BufRead, BufReader};
use std::net::TcpStream;
use std::time::Duration;
const MAX_REQUEST_LINE_LENGTH: usize = 8190;
const REQUEST_PARTS: usize = 3;
const TIMEOUT_SECONDS: u64 = 30;
const MAX_HEADER_COUNT: usize = 100;
const MAX_HEADER_LINE_LENGTH: usize = 8192;
const MAX_HEADER_BYTES: usize = 64 * 1024;
fn map_timeout_error(error: io::Error) -> ServerError {
ServerError::invalid_request(format!(
"Failed to set read timeout: {}",
error
))
}
fn map_read_error(error: io::Error) -> ServerError {
ServerError::invalid_request(format!(
"Failed to read request line: {}",
error
))
}
#[doc(alias = "http request")]
#[derive(Debug, Clone, PartialEq)]
pub struct Request {
pub method: String,
pub path: String,
pub version: String,
pub headers: Vec<(String, String)>,
}
impl Request {
#[doc(alias = "parse")]
#[doc(alias = "from tcp")]
pub fn from_stream(
stream: &TcpStream,
) -> Result<Self, ServerError> {
stream
.set_read_timeout(Some(Duration::from_secs(
TIMEOUT_SECONDS,
)))
.map_err(map_timeout_error)?;
let mut buf_reader = BufReader::new(stream);
let mut request_line = String::new();
let _ = buf_reader
.read_line(&mut request_line)
.map_err(map_read_error)?;
let trimmed_request_line = request_line.trim_end();
if request_line.len() > MAX_REQUEST_LINE_LENGTH {
return Err(ServerError::invalid_request(format!(
"Request line too long: {} characters (max {})",
request_line.len(),
MAX_REQUEST_LINE_LENGTH
)));
}
let mut parts = trimmed_request_line.split_whitespace();
let Some(method_part) = parts.next() else {
return Err(ServerError::invalid_request(
"Invalid request line: missing method",
));
};
let Some(path_part) = parts.next() else {
return Err(ServerError::invalid_request(
"Invalid request line: missing path",
));
};
let Some(version_part) = parts.next() else {
return Err(ServerError::invalid_request(
"Invalid request line: missing HTTP version",
));
};
if parts.next().is_some() {
return Err(ServerError::invalid_request(format!(
"Invalid request line: expected {} parts",
REQUEST_PARTS
)));
}
let method = method_part.to_string();
if !Self::is_valid_method(&method) {
return Err(ServerError::invalid_request(format!(
"Invalid HTTP method: {}",
method
)));
}
let path = path_part.to_string();
let is_options_asterisk =
method.eq_ignore_ascii_case("OPTIONS") && path == "*";
if !path.starts_with('/') && !is_options_asterisk {
return Err(ServerError::invalid_request(
"Invalid path: must start with '/' (or be '*' for OPTIONS)",
));
}
let version = version_part.to_string();
if !Self::is_valid_version(&version) {
return Err(ServerError::invalid_request(format!(
"Invalid HTTP version: {}",
version
)));
}
let headers = Self::read_headers(&mut buf_reader)?;
Ok(Request {
method,
path,
version,
headers,
})
}
pub fn method(&self) -> &str {
&self.method
}
pub fn path(&self) -> &str {
&self.path
}
pub fn version(&self) -> &str {
&self.version
}
#[doc(alias = "header lookup")]
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
pub fn headers(&self) -> &[(String, String)] {
&self.headers
}
fn is_valid_method(method: &str) -> bool {
matches!(
method.to_ascii_uppercase().as_str(),
"GET"
| "POST"
| "PUT"
| "DELETE"
| "HEAD"
| "OPTIONS"
| "PATCH"
)
}
fn is_valid_version(version: &str) -> bool {
version.eq_ignore_ascii_case("HTTP/1.0")
|| version.eq_ignore_ascii_case("HTTP/1.1")
}
fn read_headers<R: BufRead>(
reader: &mut R,
) -> Result<Vec<(String, String)>, ServerError> {
let mut headers: Vec<(String, String)> = Vec::with_capacity(16);
let mut total_bytes = 0_usize;
let mut line = String::new();
loop {
line.clear();
let bytes =
reader.read_line(&mut line).map_err(map_read_error)?;
if bytes == 0 {
break;
}
total_bytes = total_bytes.saturating_add(bytes);
if total_bytes > MAX_HEADER_BYTES {
return Err(ServerError::invalid_request(
"Header section too large",
));
}
let trimmed = line.trim_end();
if trimmed.is_empty() {
break;
}
if trimmed.len() > MAX_HEADER_LINE_LENGTH {
return Err(ServerError::invalid_request(
"Header line too long",
));
}
let bytes = trimmed.as_bytes();
let colon =
memchr::memchr(b':', bytes).ok_or_else(|| {
ServerError::invalid_request(
"Malformed header line",
)
})?;
let (name, value) = trimmed.split_at(colon);
let value = &value[1..];
if headers.len() >= MAX_HEADER_COUNT {
return Err(ServerError::invalid_request(
"Too many request headers",
));
}
headers.push((
name.trim().to_ascii_lowercase(),
value.trim().to_string(),
));
}
Ok(headers)
}
}
impl fmt::Display for Request {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} {} {}", self.method, self.path, self.version)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::net::TcpListener;
#[test]
fn test_valid_request() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream.write_all(b"GET /index.html HTTP/1.1\r\n").unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let request = Request::from_stream(&stream).unwrap();
assert_eq!(request.method(), "GET");
assert_eq!(request.path(), "/index.html");
assert_eq!(request.version(), "HTTP/1.1");
assert!(request.headers().is_empty());
}
#[test]
fn test_invalid_method() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream
.write_all(b"INVALID /index.html HTTP/1.1\r\n")
.unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let result = Request::from_stream(&stream);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ServerError::InvalidRequest(_)
));
}
#[test]
fn test_max_length_request() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 16); let request = format!("GET {} HTTP/1.1\r\n", long_path);
stream.write_all(request.as_bytes()).unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let result = Request::from_stream(&stream);
assert!(result.is_ok());
assert_eq!(
result.unwrap().path().len(),
MAX_REQUEST_LINE_LENGTH - 16
);
}
#[test]
fn test_oversized_request() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 13); let request = format!("GET {} HTTP/1.1\r\n", long_path);
stream.write_all(request.as_bytes()).unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let result = Request::from_stream(&stream);
assert!(
result.is_err(),
"Oversized request should be invalid. Request: {:?}",
result
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Request line too long:"),
"Unexpected error message: {}",
msg
);
}
#[test]
fn test_invalid_path() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream.write_all(b"GET index.html HTTP/1.1\r\n").unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let result = Request::from_stream(&stream);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ServerError::InvalidRequest(_)
));
}
#[test]
fn test_invalid_version() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream.write_all(b"GET /index.html HTTP/2.0\r\n").unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let result = Request::from_stream(&stream);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ServerError::InvalidRequest(_)
));
}
#[test]
fn test_head_request() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream.write_all(b"HEAD /index.html HTTP/1.1\r\n").unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let request = Request::from_stream(&stream).unwrap();
assert_eq!(request.method(), "HEAD");
assert_eq!(request.path(), "/index.html");
assert_eq!(request.version(), "HTTP/1.1");
}
#[test]
fn test_options_request() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream.write_all(b"OPTIONS * HTTP/1.1\r\n").unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let request = Request::from_stream(&stream).unwrap();
assert_eq!(request.method(), "OPTIONS");
assert_eq!(request.path(), "*");
assert_eq!(request.version(), "HTTP/1.1");
}
#[test]
fn test_internal_error_mapping_helpers() {
let timeout_err =
io::Error::new(io::ErrorKind::TimedOut, "timeout");
let mapped = map_timeout_error(timeout_err);
assert!(
mapped.to_string().contains("Failed to set read timeout")
);
let read_err =
io::Error::new(io::ErrorKind::UnexpectedEof, "eof");
let mapped = map_read_error(read_err);
assert!(
mapped.to_string().contains("Failed to read request line")
);
}
#[test]
fn test_parses_headers() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream
.write_all(
b"GET /index.html HTTP/1.1\r\nHost: localhost\r\nRange: bytes=0-1\r\n\r\n",
)
.unwrap();
});
let stream = TcpStream::connect(addr).unwrap();
let request = Request::from_stream(&stream).unwrap();
assert_eq!(request.header("host"), Some("localhost"));
assert_eq!(request.header("range"), Some("bytes=0-1"));
}
fn run_request_bytes(
bytes: Vec<u8>,
) -> Result<Request, ServerError> {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _ = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let _ = stream.write_all(&bytes);
});
let stream = TcpStream::connect(addr).unwrap();
Request::from_stream(&stream)
}
#[test]
fn test_missing_method_returns_error() {
let err = run_request_bytes(b"\r\n".to_vec()).unwrap_err();
assert!(
err.to_string().contains("missing method"),
"unexpected error: {err}"
);
}
#[test]
fn test_too_many_parts_returns_error() {
let err =
run_request_bytes(b"GET / HTTP/1.1 extra\r\n".to_vec())
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("expected") && msg.contains("parts"),
"unexpected error: {msg}"
);
}
#[test]
fn test_malformed_header_returns_error() {
let err = run_request_bytes(
b"GET / HTTP/1.1\r\nmissing-colon-line\r\n\r\n".to_vec(),
)
.unwrap_err();
assert!(
err.to_string().contains("Malformed header line"),
"unexpected error: {err}"
);
}
#[test]
fn test_header_line_too_long_returns_error() {
let mut req = Vec::from("GET / HTTP/1.1\r\nX: ");
req.extend(std::iter::repeat_n(b'A', MAX_HEADER_LINE_LENGTH));
req.extend_from_slice(b"\r\n\r\n");
let err = run_request_bytes(req).unwrap_err();
assert!(
err.to_string().contains("Header line too long"),
"unexpected error: {err}"
);
}
#[test]
fn test_header_section_too_large_returns_error() {
let mut req = Vec::from("GET / HTTP/1.1\r\n");
let filler: String = "A".repeat(8000);
for i in 0..10 {
req.extend_from_slice(
format!("H{i}: {filler}\r\n").as_bytes(),
);
}
req.extend_from_slice(b"\r\n");
let err = run_request_bytes(req).unwrap_err();
assert!(
err.to_string().contains("Header section too large"),
"unexpected error: {err}"
);
}
#[test]
fn test_too_many_headers_returns_error() {
let mut req = Vec::from("GET / HTTP/1.1\r\n");
for i in 0..=MAX_HEADER_COUNT {
req.extend_from_slice(format!("H{i}: v\r\n").as_bytes());
}
req.extend_from_slice(b"\r\n");
let err = run_request_bytes(req).unwrap_err();
assert!(
err.to_string().contains("Too many request headers"),
"unexpected error: {err}"
);
}
#[test]
fn test_missing_http_version_returns_error() {
let err = run_request_bytes(b"GET /\r\n".to_vec()).unwrap_err();
assert!(
err.to_string().contains("missing HTTP version"),
"unexpected error: {err}"
);
}
#[test]
fn test_request_display_formats_method_path_version() {
let request = Request {
method: "GET".to_string(),
path: "/index.html".to_string(),
version: "HTTP/1.1".to_string(),
headers: Vec::new(),
};
assert_eq!(format!("{request}"), "GET /index.html HTTP/1.1");
}
}