use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use percent_encoding::{AsciiSet, NON_ALPHANUMERIC, utf8_percent_encode};
use super::types::{HttpMethod, HttpRequest, HttpResponse};
const URL_ENCODE_SET: &AsciiSet = &NON_ALPHANUMERIC
.remove(b'-')
.remove(b'_')
.remove(b'.')
.remove(b'~');
pub trait NetworkProvider: Send + Sync {
type Listener: TcpListenerTrait;
type Stream: TcpStreamTrait;
fn tcp_listen(&self, addr: &str, port: u16) -> Result<Self::Listener, IoError>;
}
pub trait TcpListenerTrait {
type Stream: TcpStreamTrait;
fn accept(&self) -> Result<Self::Stream, IoError>;
fn set_nonblocking(&self, nonblocking: bool) -> Result<(), IoError>;
}
pub trait TcpStreamTrait {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError>;
fn write(&mut self, buf: &[u8]) -> Result<usize, IoError>;
fn write_all(&mut self, buf: &[u8]) -> Result<(), IoError> {
let mut written = 0;
while written < buf.len() {
written += self.write(&buf[written..])?;
}
Ok(())
}
fn flush(&mut self) -> Result<(), IoError>;
fn shutdown(&mut self) -> Result<(), IoError>;
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) -> Result<(), IoError>;
fn set_write_timeout(&mut self, timeout_ms: Option<u64>) -> Result<(), IoError>;
fn peer_addr(&self) -> Result<String, IoError>;
}
#[derive(Debug, Clone)]
pub struct IoError {
pub kind: IoErrorKind,
pub message: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IoErrorKind {
ConnectionRefused,
ConnectionReset,
ConnectionAborted,
NotConnected,
AddrInUse,
AddrNotAvailable,
WouldBlock,
TimedOut,
Interrupted,
InvalidInput,
InvalidData,
UnexpectedEof,
Other,
}
impl IoError {
pub fn new(kind: IoErrorKind, message: impl Into<String>) -> Self {
Self {
kind,
message: message.into(),
}
}
}
#[derive(Debug, Clone)]
pub enum HttpParseError {
InvalidRequestLine,
InvalidMethod,
InvalidHeader,
MissingHost,
ContentTooLarge,
Io(IoError),
}
impl From<IoError> for HttpParseError {
fn from(e: IoError) -> Self {
HttpParseError::Io(e)
}
}
const MAX_REQUEST_SIZE: usize = 16 * 1024 * 1024;
const MAX_HEADER_SIZE: usize = 64 * 1024;
pub fn parse_request<S: TcpStreamTrait>(stream: &mut S) -> Result<HttpRequest, HttpParseError> {
let mut header_buf = vec![0u8; MAX_HEADER_SIZE];
let mut header_len = 0;
loop {
if header_len >= MAX_HEADER_SIZE {
return Err(HttpParseError::ContentTooLarge);
}
let n = stream.read(&mut header_buf[header_len..header_len + 1])?;
if n == 0 {
return Err(HttpParseError::Io(IoError::new(
IoErrorKind::UnexpectedEof,
"connection closed",
)));
}
header_len += n;
if header_len >= 4 && &header_buf[header_len - 4..header_len] == b"\r\n\r\n" {
break;
}
}
let header_str = core::str::from_utf8(&header_buf[..header_len])
.map_err(|_| HttpParseError::InvalidHeader)?;
let mut lines = header_str.lines();
let request_line = lines.next().ok_or(HttpParseError::InvalidRequestLine)?;
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() < 2 {
return Err(HttpParseError::InvalidRequestLine);
}
let method = HttpMethod::from_str(parts[0]).ok_or(HttpParseError::InvalidMethod)?;
let path_and_query = parts[1];
let (path, query) = if let Some(idx) = path_and_query.find('?') {
let (p, q) = path_and_query.split_at(idx);
(p.to_string(), parse_query_string(&q[1..]))
} else {
(path_and_query.to_string(), BTreeMap::new())
};
let mut headers = BTreeMap::new();
for line in lines {
if line.is_empty() {
break;
}
if let Some(idx) = line.find(':') {
let name = line[..idx].trim().to_string();
let value = line[idx + 1..].trim().to_string();
headers.insert(name, value);
}
}
let content_length: usize = headers
.get("Content-Length")
.or_else(|| headers.get("content-length"))
.and_then(|s| s.parse().ok())
.unwrap_or(0);
if content_length > MAX_REQUEST_SIZE {
return Err(HttpParseError::ContentTooLarge);
}
let mut body = vec![0u8; content_length];
if content_length > 0 {
let mut read = 0;
while read < content_length {
let n = stream.read(&mut body[read..])?;
if n == 0 {
break;
}
read += n;
}
body.truncate(read);
}
Ok(HttpRequest {
method,
path,
query,
headers,
body,
})
}
fn parse_query_string(query: &str) -> BTreeMap<String, String> {
let mut params = BTreeMap::new();
for pair in query.split('&') {
if let Some(idx) = pair.find('=') {
let key = url_decode(&pair[..idx]);
let value = url_decode(&pair[idx + 1..]);
params.insert(key, value);
} else if !pair.is_empty() {
params.insert(url_decode(pair), String::new());
}
}
params
}
fn url_decode(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '%' {
let h1 = chars.next();
let h2 = chars.next();
if let (Some(h1), Some(h2)) = (h1, h2) {
if let Ok(byte) = u8::from_str_radix(&format!("{}{}", h1, h2), 16) {
result.push(byte as char);
continue;
}
}
result.push('%');
} else if c == '+' {
result.push(' ');
} else {
result.push(c);
}
}
result
}
pub fn url_encode(s: &str) -> String {
utf8_percent_encode(s, URL_ENCODE_SET).to_string()
}
pub fn write_response<S: TcpStreamTrait>(
stream: &mut S,
response: &HttpResponse,
) -> Result<(), IoError> {
let status_text = status_text(response.status);
let status_line = alloc::format!("HTTP/1.1 {} {}\r\n", response.status, status_text);
stream.write_all(status_line.as_bytes())?;
let content_length = alloc::format!("Content-Length: {}\r\n", response.body.len());
stream.write_all(content_length.as_bytes())?;
for (name, value) in &response.headers {
let header = alloc::format!("{}: {}\r\n", name, value);
stream.write_all(header.as_bytes())?;
}
stream.write_all(b"\r\n")?;
stream.write_all(&response.body)?;
stream.flush()
}
fn status_text(code: u16) -> &'static str {
match code {
100 => "Continue",
200 => "OK",
201 => "Created",
204 => "No Content",
206 => "Partial Content",
301 => "Moved Permanently",
302 => "Found",
304 => "Not Modified",
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
404 => "Not Found",
405 => "Method Not Allowed",
409 => "Conflict",
411 => "Length Required",
412 => "Precondition Failed",
413 => "Payload Too Large",
416 => "Range Not Satisfiable",
500 => "Internal Server Error",
501 => "Not Implemented",
503 => "Service Unavailable",
_ => "Unknown",
}
}
pub struct HttpConnection<S: TcpStreamTrait> {
stream: S,
peer_addr: String,
}
impl<S: TcpStreamTrait> HttpConnection<S> {
pub fn new(mut stream: S) -> Result<Self, IoError> {
let peer_addr = stream.peer_addr().unwrap_or_else(|_| "unknown".into());
stream.set_read_timeout(Some(30000))?;
stream.set_write_timeout(Some(30000))?;
Ok(Self { stream, peer_addr })
}
pub fn peer_addr(&self) -> &str {
&self.peer_addr
}
pub fn read_request(&mut self) -> Result<HttpRequest, HttpParseError> {
parse_request(&mut self.stream)
}
pub fn write_response(&mut self, response: &HttpResponse) -> Result<(), IoError> {
write_response(&mut self.stream, response)
}
pub fn close(mut self) -> Result<(), IoError> {
self.stream.shutdown()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_query_string() {
let params = parse_query_string("foo=bar&baz=qux");
assert_eq!(params.get("foo"), Some(&"bar".to_string()));
assert_eq!(params.get("baz"), Some(&"qux".to_string()));
}
#[test]
fn test_parse_query_string_empty() {
let params = parse_query_string("");
assert!(params.is_empty());
}
#[test]
fn test_parse_query_string_encoded() {
let params = parse_query_string("path=%2Ftest%2Ffile");
assert_eq!(params.get("path"), Some(&"/test/file".to_string()));
}
#[test]
fn test_url_decode() {
assert_eq!(url_decode("hello%20world"), "hello world");
assert_eq!(url_decode("foo+bar"), "foo bar");
assert_eq!(url_decode("test%2Fpath"), "test/path");
}
#[test]
fn test_url_encode() {
assert_eq!(url_encode("hello world"), "hello%20world");
assert_eq!(url_encode("test/path"), "test%2Fpath");
assert_eq!(url_encode("a-b_c.d~e"), "a-b_c.d~e");
}
#[test]
fn test_status_text() {
assert_eq!(status_text(200), "OK");
assert_eq!(status_text(404), "Not Found");
assert_eq!(status_text(500), "Internal Server Error");
}
#[test]
fn test_io_error() {
let err = IoError::new(IoErrorKind::ConnectionRefused, "connection refused");
assert_eq!(err.kind, IoErrorKind::ConnectionRefused);
assert_eq!(err.message, "connection refused");
}
struct MockStream {
read_data: Vec<u8>,
read_pos: usize,
write_data: Vec<u8>,
}
impl MockStream {
fn new(data: &[u8]) -> Self {
Self {
read_data: data.to_vec(),
read_pos: 0,
write_data: Vec::new(),
}
}
}
impl TcpStreamTrait for MockStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
if self.read_pos >= self.read_data.len() {
return Ok(0);
}
let n = core::cmp::min(buf.len(), self.read_data.len() - self.read_pos);
buf[..n].copy_from_slice(&self.read_data[self.read_pos..self.read_pos + n]);
self.read_pos += n;
Ok(n)
}
fn write(&mut self, buf: &[u8]) -> Result<usize, IoError> {
self.write_data.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> Result<(), IoError> {
Ok(())
}
fn shutdown(&mut self) -> Result<(), IoError> {
Ok(())
}
fn set_read_timeout(&mut self, _timeout_ms: Option<u64>) -> Result<(), IoError> {
Ok(())
}
fn set_write_timeout(&mut self, _timeout_ms: Option<u64>) -> Result<(), IoError> {
Ok(())
}
fn peer_addr(&self) -> Result<String, IoError> {
Ok("127.0.0.1:12345".into())
}
}
#[test]
fn test_parse_request() {
let request = b"GET /bucket/key?foo=bar HTTP/1.1\r\n\
Host: localhost\r\n\
Content-Length: 4\r\n\
\r\n\
test";
let mut stream = MockStream::new(request);
let req = parse_request(&mut stream).unwrap();
assert_eq!(req.method, HttpMethod::Get);
assert_eq!(req.path, "/bucket/key");
assert_eq!(req.query.get("foo"), Some(&"bar".to_string()));
assert_eq!(req.body, b"test");
}
#[test]
fn test_parse_request_no_body() {
let request = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n";
let mut stream = MockStream::new(request);
let req = parse_request(&mut stream).unwrap();
assert_eq!(req.method, HttpMethod::Get);
assert_eq!(req.path, "/");
assert!(req.body.is_empty());
}
#[test]
fn test_write_response() {
let response = HttpResponse::ok()
.with_header("X-Test", "value")
.with_body(b"Hello".to_vec());
let mut stream = MockStream::new(b"");
write_response(&mut stream, &response).unwrap();
let output = String::from_utf8(stream.write_data).unwrap();
assert!(output.contains("HTTP/1.1 200 OK"));
assert!(output.contains("Content-Length: 5"));
assert!(output.contains("X-Test: value"));
assert!(output.ends_with("Hello"));
}
}