use crate::{
error::Error,
header::{
HttpHeader,
headers::{CONTENT_LENGTH, CONTENT_TYPE},
},
method::HttpMethod,
protocol::{self, DOUBLE_CRLF_LEN, MAX_HEADERS},
};
use heapless::Vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct QueryPair<'a> {
pub name: &'a str,
pub value: &'a str,
}
#[derive(Debug, Clone)]
pub struct QueryPairs<'a> {
remaining: &'a str,
}
impl<'a> Iterator for QueryPairs<'a> {
type Item = QueryPair<'a>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let (segment, remaining) = self.remaining.split_once('&').map_or_else(
|| (self.remaining, ""),
|(segment, remaining)| (segment, remaining),
);
self.remaining = remaining;
if segment.is_empty() {
if self.remaining.is_empty() {
return None;
}
continue;
}
let (name, value) = segment
.split_once('=')
.map_or((segment, ""), |(name, value)| (name, value));
return Some(QueryPair { name, value });
}
}
}
#[derive(Debug, Clone)]
pub struct QueryValues<'a, 'n> {
pairs: QueryPairs<'a>,
name: &'n str,
}
impl<'a> Iterator for QueryValues<'a, '_> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
self.pairs
.find(|pair| pair.name == self.name)
.map(|pair| pair.value)
}
}
#[derive(Debug)]
pub struct HttpRequest<'a> {
pub method: HttpMethod,
pub path: &'a str,
pub version: &'a str,
pub headers: Vec<HttpHeader<'a>, MAX_HEADERS>,
pub body: &'a [u8],
}
impl<'a> HttpRequest<'a> {
#[must_use]
pub const fn target(&self) -> &'a str {
self.path
}
#[must_use]
pub fn route_path(&self) -> &'a str {
self.path
.split_once('?')
.map_or(self.path, |(path, _query)| path)
}
#[must_use]
pub fn query_string(&self) -> Option<&'a str> {
self.path.split_once('?').map(|(_path, query)| query)
}
#[must_use]
pub fn query_pairs(&self) -> QueryPairs<'a> {
QueryPairs {
remaining: self.query_string().unwrap_or(""),
}
}
#[must_use]
pub fn query(&self, name: &str) -> Option<&'a str> {
self.query_first(name)
}
#[must_use]
pub fn query_first(&self, name: &str) -> Option<&'a str> {
self.query_pairs()
.find(|pair| pair.name == name)
.map(|pair| pair.value)
}
#[must_use]
pub fn query_last(&self, name: &str) -> Option<&'a str> {
self.query_pairs()
.filter(|pair| pair.name == name)
.map(|pair| pair.value)
.last()
}
#[must_use]
pub fn query_all<'n>(&self, name: &'n str) -> QueryValues<'a, 'n> {
QueryValues {
pairs: self.query_pairs(),
name,
}
}
#[must_use]
pub fn query_indexed(&self, name: &str, index: usize) -> Option<&'a str> {
self.query_pairs()
.find(|pair| query_name_matches_index(pair.name, name, index))
.map(|pair| pair.value)
}
#[must_use]
pub fn query_param(&self, name: &str) -> Option<&'a str> {
self.query(name)
}
#[must_use]
pub fn header(&self, name: &str) -> Option<&'a str> {
self.headers
.iter()
.find(|header| header.name.eq_ignore_ascii_case(name))
.map(|header| header.value)
}
pub fn body_str(&self) -> Result<&'a str, Error> {
core::str::from_utf8(self.body)
.map_err(|_| Error::InvalidResponse("Invalid UTF-8 in request body"))
}
#[must_use]
pub fn content_type(&self) -> Option<&'a str> {
self.header(CONTENT_TYPE)
}
#[must_use]
pub fn content_length(&self) -> Option<usize> {
self.header(CONTENT_LENGTH)?.parse().ok()
}
pub fn parse_from(headers_str: &'a str, body: &'a [u8]) -> Result<Self, Error> {
let mut lines = headers_str.lines();
let request_line = lines
.next()
.ok_or(Error::InvalidResponse("Missing request line"))?;
let mut parts = request_line.split_whitespace();
let method_str = parts
.next()
.ok_or(Error::InvalidResponse("Missing method"))?;
let path = parts.next().ok_or(Error::InvalidResponse("Missing path"))?;
let version = parts
.next()
.ok_or(Error::InvalidResponse("Missing version"))?;
let method = HttpMethod::try_from(method_str)
.map_err(|_| Error::InvalidResponse("Unknown HTTP method"))?;
let mut headers = Vec::new();
for line in lines {
if line.is_empty() {
break;
}
if let Some(colon_pos) = line.find(':') {
let name = line[..colon_pos].trim();
let value = line[colon_pos + 1..].trim();
let header = HttpHeader::new(name, value);
headers
.push(header)
.map_err(|_| Error::InvalidResponse("Too many headers"))?;
}
}
Ok(HttpRequest {
method,
path,
version,
headers,
body,
})
}
}
impl<'a> TryFrom<&'a [u8]> for HttpRequest<'a> {
type Error = Error;
fn try_from(buffer: &'a [u8]) -> Result<Self, Self::Error> {
let end_of_headers = protocol::find_double_crlf(buffer)
.ok_or(Error::InvalidResponse("Incomplete request headers"))?;
let headers_str = core::str::from_utf8(&buffer[..end_of_headers])
.map_err(|_| Error::InvalidResponse("Invalid UTF-8 in request"))?;
let body = &buffer[end_of_headers + DOUBLE_CRLF_LEN..];
Self::parse_from(headers_str, body)
}
}
pub fn percent_decode<'a>(input: &str, out: &'a mut [u8]) -> Result<&'a str, Error> {
let mut written = 0;
let mut bytes = input.as_bytes().iter().copied();
while let Some(byte) = bytes.next() {
let decoded = match byte {
b'+' => b' ',
b'%' => {
let hi = bytes
.next()
.ok_or(Error::InvalidResponse("Incomplete percent escape"))?;
let lo = bytes
.next()
.ok_or(Error::InvalidResponse("Incomplete percent escape"))?;
(hex_value(hi).ok_or(Error::InvalidResponse("Invalid percent escape"))? << 4)
| hex_value(lo).ok_or(Error::InvalidResponse("Invalid percent escape"))?
}
byte => byte,
};
let slot = out.get_mut(written).ok_or(Error::BufferOverflow)?;
*slot = decoded;
written += 1;
}
core::str::from_utf8(&out[..written])
.map_err(|_| Error::InvalidResponse("Invalid UTF-8 in percent-decoded value"))
}
fn query_name_matches_index(query_name: &str, name: &str, index: usize) -> bool {
let Some(rest) = query_name.strip_prefix(name) else {
return false;
};
let Some(index_str) = rest
.strip_prefix('[')
.and_then(|rest| rest.strip_suffix(']'))
else {
return false;
};
index_str.parse::<usize>().ok() == Some(index)
}
const fn hex_value(byte: u8) -> Option<u8> {
match byte {
b'0'..=b'9' => Some(byte - b'0'),
b'a'..=b'f' => Some(byte - b'a' + 10),
b'A'..=b'F' => Some(byte - b'A' + 10),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HttpMethod;
#[test]
fn test_parse_request_get() {
let request_str =
"GET /index.html HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\n";
let body = b"";
let request = HttpRequest::parse_from(request_str, body).unwrap();
assert_eq!(request.method, HttpMethod::GET);
assert_eq!(request.path, "/index.html");
assert_eq!(request.version, "HTTP/1.1");
assert_eq!(request.headers.len(), 2);
assert_eq!(request.body, b"");
}
#[test]
fn test_request_query_helpers() {
let request_str =
"GET /search?q=rust&page=1&flag&a=&a=2 HTTP/1.1\r\nHost: example.com\r\n\r\n";
let request = HttpRequest::parse_from(request_str, b"").unwrap();
assert_eq!(request.target(), "/search?q=rust&page=1&flag&a=&a=2");
assert_eq!(request.route_path(), "/search");
assert_eq!(request.query_string(), Some("q=rust&page=1&flag&a=&a=2"));
assert_eq!(request.query("q"), Some("rust"));
assert_eq!(request.query_param("page"), Some("1"));
assert_eq!(request.query("flag"), Some(""));
assert_eq!(request.query_first("a"), Some(""));
assert_eq!(request.query_last("a"), Some("2"));
assert_eq!(request.query("missing"), None);
let mut values = request.query_all("a");
assert_eq!(values.next(), Some(""));
assert_eq!(values.next(), Some("2"));
assert_eq!(values.next(), None);
}
#[test]
fn test_request_query_duplicate_and_bracket_keys() {
let request_str = "GET /items?a=1&a=2&f[0]=1&f[1]=2 HTTP/1.1\r\n\r\n";
let request = HttpRequest::parse_from(request_str, b"").unwrap();
assert_eq!(request.query_first("a"), Some("1"));
assert_eq!(request.query_last("a"), Some("2"));
assert_eq!(request.query("f[0]"), Some("1"));
assert_eq!(request.query("f[1]"), Some("2"));
assert_eq!(request.query("f"), None);
assert_eq!(request.query_indexed("f", 0), Some("1"));
assert_eq!(request.query_indexed("f", 1), Some("2"));
assert_eq!(request.query_indexed("f", 2), None);
let mut pairs = request.query_pairs();
assert_eq!(
pairs.next(),
Some(QueryPair {
name: "a",
value: "1"
})
);
assert_eq!(
pairs.next(),
Some(QueryPair {
name: "a",
value: "2"
})
);
}
#[test]
fn test_request_header_and_body_helpers() {
let request_str =
"POST /submit HTTP/1.1\r\ncontent-type: text/plain\r\nContent-Length: 5\r\n\r\n";
let request = HttpRequest::parse_from(request_str, b"hello").unwrap();
assert_eq!(request.header("Content-Type"), Some("text/plain"));
assert_eq!(request.content_type(), Some("text/plain"));
assert_eq!(request.content_length(), Some(5));
assert_eq!(request.body_str().unwrap(), "hello");
}
#[test]
fn test_percent_decode_query_component() {
let mut out = [0u8; 32];
assert_eq!(
percent_decode("hello+world%21%2F", &mut out).unwrap(),
"hello world!/"
);
assert!(percent_decode("bad%", &mut out).is_err());
assert!(percent_decode("bad%xx", &mut out).is_err());
assert!(percent_decode("toolong", &mut [0u8; 3]).is_err());
}
#[test]
fn test_parse_request_post_with_body() {
let request_str = "POST /api/data HTTP/1.1\r\nContent-Type: application/json\r\nContent-Length: 13\r\n\r\n";
let body = b"{\"key\":\"value\"}";
let request = HttpRequest::parse_from(request_str, body).unwrap();
assert_eq!(request.method, HttpMethod::POST);
assert_eq!(request.path, "/api/data");
assert_eq!(request.version, "HTTP/1.1");
assert_eq!(request.headers.len(), 2);
assert_eq!(request.body, b"{\"key\":\"value\"}");
let content_type_header = request
.headers
.iter()
.find(|h| h.name == "Content-Type")
.unwrap();
assert_eq!(content_type_header.value, "application/json");
}
#[test]
fn test_parse_request_invalid_method() {
let request_str = "INVALID /path HTTP/1.1\r\n\r\n";
let body = b"";
let result = HttpRequest::parse_from(request_str, body);
assert!(result.is_err());
}
#[test]
fn test_parse_request_missing_parts() {
let request_str = "GET HTTP/1.1\r\n\r\n";
let body = b"";
let result = HttpRequest::parse_from(request_str, body);
assert!(result.is_err());
let request_str = "GET /path\r\n\r\n";
let result = HttpRequest::parse_from(request_str, body);
assert!(result.is_err());
let request_str = "";
let result = HttpRequest::parse_from(request_str, body);
assert!(result.is_err());
}
#[test]
fn test_parse_request_all_http_methods() {
let methods = [
("GET", HttpMethod::GET),
("POST", HttpMethod::POST),
("PUT", HttpMethod::PUT),
("DELETE", HttpMethod::DELETE),
("PATCH", HttpMethod::PATCH),
("HEAD", HttpMethod::HEAD),
("OPTIONS", HttpMethod::OPTIONS),
("TRACE", HttpMethod::TRACE),
("CONNECT", HttpMethod::CONNECT),
];
for (method_str, expected_method) in &methods {
let request_str = format!("{method_str} /path HTTP/1.1\r\n\r\n");
let request = HttpRequest::parse_from(&request_str, b"").unwrap();
assert_eq!(request.method, *expected_method);
}
}
#[test]
fn test_find_double_crlf() {
use crate::protocol::find_double_crlf;
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\nBody";
assert_eq!(find_double_crlf(data), Some(33));
let data = b"\r\n\r\nBody";
assert_eq!(find_double_crlf(data), Some(0));
let data = b"Headers\r\n\r\n";
assert_eq!(find_double_crlf(data), Some(7));
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\n";
assert_eq!(find_double_crlf(data), None);
let data = b"\r\n\r";
assert_eq!(find_double_crlf(data), None);
let data = b"";
assert_eq!(find_double_crlf(data), None);
}
#[test]
fn test_try_from_complete_request() {
let buffer = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\n";
let request = HttpRequest::try_from(buffer.as_slice()).unwrap();
assert_eq!(request.method, HttpMethod::GET);
assert_eq!(request.path, "/index.html");
assert_eq!(request.version, "HTTP/1.1");
assert_eq!(request.headers.len(), 2);
assert_eq!(request.body, b"");
}
#[test]
fn test_try_from_request_with_body() {
let buffer =
b"POST /api/data HTTP/1.1\r\nContent-Type: application/json\r\n\r\n{\"key\":\"value\"}";
let request = HttpRequest::try_from(buffer.as_slice()).unwrap();
assert_eq!(request.method, HttpMethod::POST);
assert_eq!(request.path, "/api/data");
assert_eq!(request.version, "HTTP/1.1");
assert_eq!(request.headers.len(), 1);
assert_eq!(request.body, b"{\"key\":\"value\"}");
}
#[test]
fn test_try_from_incomplete_headers() {
let buffer = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n";
let result = HttpRequest::try_from(buffer.as_slice());
assert!(result.is_err());
}
#[test]
fn test_try_from_invalid_utf8() {
let mut buffer: Vec<u8, 128> = Vec::new();
let _ = buffer.extend_from_slice(b"GET /index.html HTTP/1.1\r\nHost: ");
let _ = buffer.push(0xFF); let _ = buffer.extend_from_slice(b"\r\n\r\n");
let result = HttpRequest::try_from(buffer.as_slice());
assert!(result.is_err());
}
}