use std::sync::atomic::{AtomicU16, Ordering};
use std::time::Duration;
use std::{env, fs};
use tokio::time::sleep;
static PORT_COUNTER: AtomicU16 = AtomicU16::new(26000);
fn get_test_port() -> u16 {
PORT_COUNTER.fetch_add(1, Ordering::Relaxed)
}
#[cfg(test)]
mod tests {
use super::*;
use potato::server::PipeContext;
use potato::utils::enums::HttpConnection;
use potato::{HttpMethod, HttpRequest, HttpResponse, HttpServer};
use std::borrow::Cow;
use std::collections::HashMap;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
async fn connect_with_retry(addr: &str) -> anyhow::Result<TcpStream> {
let mut last_err = None;
for _ in 0..10 {
match TcpStream::connect(addr).await {
Ok(stream) => return Ok(stream),
Err(err) => {
last_err = Some(err);
sleep(Duration::from_millis(50)).await;
}
}
}
Err(last_err.expect("retry loop must capture error").into())
}
fn response_body_bytes(response: &[u8]) -> anyhow::Result<&[u8]> {
let header_end = response
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|p| p + 4)
.ok_or_else(|| anyhow::anyhow!("response missing header terminator"))?;
Ok(&response[header_end..])
}
fn response_body_data(res: potato::HttpResponse) -> Vec<u8> {
match res.body {
potato::HttpResponseBody::Data(data) => data,
potato::HttpResponseBody::Stream(_) => vec![],
}
}
fn static_get_request(path: &str) -> HttpRequest {
let mut req = HttpRequest::new();
req.method = HttpMethod::GET;
req.url_path = path.into();
req
}
#[test]
fn test_connection_header_token_list_parsing_prefers_close() -> anyhow::Result<()> {
let raw_request = concat!(
"GET / HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: keep-alive, close\r\n",
"\r\n"
);
let (req, _) = HttpRequest::from_headers_part(raw_request.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("request headers should parse completely"))?;
assert_eq!(req.get_header_connection(), HttpConnection::Close);
Ok(())
}
#[test]
fn test_connection_header_token_list_detects_upgrade_case_insensitive() -> anyhow::Result<()> {
let raw_request = concat!(
"GET / HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: keep-alive, UpGrAdE\r\n",
"\r\n"
);
let (req, _) = HttpRequest::from_headers_part(raw_request.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("request headers should parse completely"))?;
assert_eq!(req.get_header_connection(), HttpConnection::Upgrade);
Ok(())
}
#[test]
fn test_websocket_detection_accepts_mixed_connection_tokens() -> anyhow::Result<()> {
let raw_request = concat!(
"GET /ws HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: keep-alive, Upgrade\r\n",
"Upgrade: websocket\r\n",
"Sec-WebSocket-Version: 13\r\n",
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n",
"\r\n"
);
let (req, _) = HttpRequest::from_headers_part(raw_request.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("request headers should parse completely"))?;
assert!(req.is_websocket());
Ok(())
}
#[test]
fn test_request_as_bytes_uses_chunked_when_transfer_encoding_present() -> anyhow::Result<()> {
let mut req = HttpRequest::new();
req.method = potato::HttpMethod::POST;
req.url_path = "/chunked".into();
req.set_header("Host", "127.0.0.1");
req.set_header("Transfer-Encoding", "chunked");
req.body = b"Hello".to_vec().into();
let request_text = String::from_utf8(req.as_bytes())?;
assert!(request_text.contains("Transfer-Encoding: chunked\r\n"));
assert!(!request_text.contains("Content-Length:"));
assert!(request_text.ends_with("\r\n\r\n5\r\nHello\r\n0\r\n\r\n"));
Ok(())
}
#[test]
fn test_request_as_bytes_keeps_content_length_without_chunked_header() -> anyhow::Result<()> {
let mut req = HttpRequest::new();
req.method = potato::HttpMethod::POST;
req.url_path = "/plain".into();
req.set_header("Host", "127.0.0.1");
req.body = b"Hello".to_vec().into();
let request_text = String::from_utf8(req.as_bytes())?;
assert!(!request_text.contains("Transfer-Encoding: chunked\r\n"));
assert!(request_text.contains("Content-Length: 5\r\n"));
assert!(request_text.ends_with("\r\n\r\nHello"));
Ok(())
}
#[test]
fn test_request_as_bytes_writes_declared_chunked_trailers() -> anyhow::Result<()> {
let mut req = HttpRequest::new();
req.method = potato::HttpMethod::POST;
req.url_path = "/chunked-trailer".into();
req.set_header("Host", "127.0.0.1");
req.set_header("Transfer-Encoding", "chunked");
req.set_header("Trailer", "X-Trace");
req.body = b"Hello".to_vec().into();
req.set_trailer("X-Trace", "trace-1");
req.set_trailer("X-Ignored", "should-not-send");
let request_text = String::from_utf8(req.as_bytes())?;
assert!(request_text.contains("Trailer: X-Trace\r\n"));
assert!(!request_text.contains("X-Ignored: should-not-send\r\n"));
assert!(request_text.ends_with("\r\n\r\n5\r\nHello\r\n0\r\nX-Trace: trace-1\r\n\r\n"));
Ok(())
}
#[tokio::test]
async fn test_client_decodes_all_chunked_response_chunks() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let listener = TcpListener::bind(&server_addr).await?;
let server_task = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await?;
let mut req_buf = [0_u8; 1024];
let _ = socket.read(&mut req_buf).await?;
let response = concat!(
"HTTP/1.1 200 OK\r\n",
"Transfer-Encoding: chunked\r\n",
"Connection: close\r\n",
"\r\n",
"5\r\nHello\r\n",
"1\r\n \r\n",
"5\r\nWorld\r\n",
"1\r\n!\r\n",
"0\r\n\r\n"
);
socket.write_all(response.as_bytes()).await?;
socket.shutdown().await?;
Ok::<(), anyhow::Error>(())
});
let url = format!("http://{}/chunked-resp", server_addr);
let res = potato::get(&url, vec![]).await?;
server_task.await??;
let body = match res.body {
potato::HttpResponseBody::Data(data) => data,
potato::HttpResponseBody::Stream(_) => vec![],
};
assert_eq!(body, b"Hello World!".to_vec());
Ok(())
}
#[tokio::test]
async fn test_client_decodes_chunked_response_trailers() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let listener = TcpListener::bind(&server_addr).await?;
let server_task = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await?;
let mut req_buf = [0_u8; 1024];
let _ = socket.read(&mut req_buf).await?;
let response = concat!(
"HTTP/1.1 200 OK\r\n",
"Transfer-Encoding: chunked\r\n",
"Trailer: X-Trace\r\n",
"Connection: close\r\n",
"\r\n",
"5\r\nHello\r\n",
"0\r\n",
"X-Trace: trace-xyz\r\n",
"\r\n"
);
socket.write_all(response.as_bytes()).await?;
socket.shutdown().await?;
Ok::<(), anyhow::Error>(())
});
let url = format!("http://{}/chunked-trailer-resp", server_addr);
let res = potato::get(&url, vec![]).await?;
server_task.await??;
let body = match &res.body {
potato::HttpResponseBody::Data(data) => data.clone(),
potato::HttpResponseBody::Stream(_) => vec![],
};
assert_eq!(body, b"Hello".to_vec());
assert_eq!(res.get_trailer("X-Trace"), Some("trace-xyz"));
Ok(())
}
#[tokio::test]
async fn test_server_creation_and_configuration() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
server.configure(|ctx| {
ctx.use_handlers();
});
println!("✅ Server created and configured for: {}", server_addr);
Ok(())
}
#[tokio::test]
async fn test_multiple_servers_different_ports() -> anyhow::Result<()> {
let port1 = get_test_port();
let port2 = get_test_port();
let port3 = get_test_port();
let server_addr1 = format!("127.0.0.1:{}", port1);
let server_addr2 = format!("127.0.0.1:{}", port2);
let server_addr3 = format!("127.0.0.1:{}", port3);
let _server1 = HttpServer::new(&server_addr1);
let _server2 = HttpServer::new(&server_addr2);
let _server3 = HttpServer::new(&server_addr3);
assert_ne!(port1, port2);
assert_ne!(port2, port3);
assert_ne!(port1, port3);
println!("✅ Created 3 servers on different ports");
Ok(())
}
#[tokio::test]
async fn test_server_shutdown_signal() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let shutdown_tx = server
.shutdown_signal()
.expect("Failed to get shutdown signal");
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(200)).await;
let _ = shutdown_tx.send(());
sleep(Duration::from_millis(200)).await;
server_handle.abort();
println!("✅ Server shutdown signal works");
Ok(())
}
#[tokio::test]
async fn test_server_configuration_options() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
server.configure(|ctx| {
ctx.use_handlers();
});
println!("✅ Server configuration: handlers disabled");
let mut server2 = HttpServer::new(&format!("127.0.0.1:{}", get_test_port()));
server2.configure(|ctx| {
ctx.use_handlers();
});
println!("✅ Server configuration options applied");
Ok(())
}
#[tokio::test]
async fn test_server_address_parsing() -> anyhow::Result<()> {
let addresses = vec!["127.0.0.1:25000", "localhost:25001", "0.0.0.0:25002"];
for addr in addresses {
let _server = HttpServer::new(addr);
println!("✅ Server created with address: {}", addr);
}
Ok(())
}
#[tokio::test]
async fn test_server_protocols_availability() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let _server = HttpServer::new(&server_addr);
println!("✅ Server protocol methods are available");
Ok(())
}
#[tokio::test]
async fn test_server_from_string() -> anyhow::Result<()> {
let port = get_test_port();
let addr_string = format!("127.0.0.1:{port}");
let _server1 = HttpServer::new(addr_string.clone());
println!("✅ Server created from String: {}", addr_string);
let _server2 = HttpServer::new("127.0.0.1:25010");
println!("✅ Server created from &str");
Ok(())
}
#[tokio::test]
async fn test_shutdown_signal_once() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let _shutdown_tx = server
.shutdown_signal()
.expect("First shutdown signal should succeed");
let second_signal = server.shutdown_signal();
assert!(
second_signal.is_none(),
"Second shutdown signal should return None"
);
println!("✅ Shutdown signal correctly allows only one acquisition");
Ok(())
}
#[tokio::test]
async fn test_server_accepts_chunked_request_body() -> anyhow::Result<()> {
#[potato::http_post("/chunked_req_echo")]
async fn chunked_req_echo(req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text(String::from_utf8_lossy(&req.body).to_string())
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /chunked_req_echo HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Transfer-Encoding: chunked\r\n",
"Connection: close\r\n",
"\r\n",
"5\r\n",
"Hello\r\n",
"6\r\n",
" World\r\n",
"0\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200"));
assert!(response_text.contains("Hello World"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_accepts_declared_chunked_request_trailers() -> anyhow::Result<()> {
#[potato::http_post("/chunked_req_trailer_echo")]
async fn chunked_req_trailer_echo(req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text(req.get_trailer("X-Trace").unwrap_or(""))
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /chunked_req_trailer_echo HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"Transfer-Encoding: chunked\r\n",
"Trailer: X-Trace\r\n",
"\r\n",
"5\r\nHello\r\n",
"0\r\n",
"X-Trace: trace-req\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
assert!(response_text.contains("trace-req"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_undeclared_chunked_request_trailers() -> anyhow::Result<()> {
#[potato::http_post("/chunked_req_trailer_reject")]
async fn chunked_req_trailer_reject(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("unreachable")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /chunked_req_trailer_reject HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"Transfer-Encoding: chunked\r\n",
"\r\n",
"5\r\nHello\r\n",
"0\r\n",
"X-Trace: undeclared\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("unexpected trailer field"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_expect_100_continue_sends_interim_and_accepts_body() -> anyhow::Result<()>
{
#[potato::http_post("/expect_continue")]
async fn expect_continue(req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text(String::from_utf8_lossy(&req.body).to_string())
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request_headers = concat!(
"POST /expect_continue HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Expect: 100-continue\r\n",
"Content-Length: 11\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request_headers.as_bytes()).await?;
let mut interim = [0u8; 25];
stream.read_exact(&mut interim).await?;
assert_eq!(
std::str::from_utf8(&interim)?,
"HTTP/1.1 100 Continue\r\n\r\n"
);
stream.write_all(b"Hello World").await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
assert!(response_text.contains("Hello World"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_unsupported_expect_header_with_417() -> anyhow::Result<()> {
#[potato::http_post("/expect_reject")]
async fn expect_reject(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /expect_reject HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Expect: fancy-feature\r\n",
"Content-Length: 5\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 417 Expectation Failed"));
assert!(response_text.contains("unsupported Expect header"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_duplicate_expect_when_any_token_unsupported() -> anyhow::Result<()>
{
#[potato::http_post("/expect_duplicate")]
async fn expect_duplicate(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /expect_duplicate HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Expect: 100-continue\r\n",
"Expect: unknown-token\r\n",
"Content-Length: 5\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 417 Expectation Failed"));
assert!(response_text.contains("unsupported Expect header"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_expect_on_http10_request() -> anyhow::Result<()> {
#[potato::http_post("/expect_http10")]
async fn expect_http10(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /expect_http10 HTTP/1.0\r\n",
"Host: 127.0.0.1\r\n",
"Expect: 100-continue\r\n",
"Content-Length: 5\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 417 Expectation Failed"));
assert!(response_text.contains("HTTP versions below 1.1"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_transfer_encoding_content_length_conflict() -> anyhow::Result<()> {
#[potato::http_post("/chunked_req_conflict")]
async fn chunked_req_conflict(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /chunked_req_conflict HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Transfer-Encoding: chunked\r\n",
"Content-Length: 100\r\n",
"Connection: close\r\n",
"\r\n",
"5\r\n",
"Hello\r\n",
"0\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("conflicting headers: Transfer-Encoding and Content-Length"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_http11_request_without_host() -> anyhow::Result<()> {
#[potato::http_get("/host_required_missing")]
async fn host_required_missing(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"GET /host_required_missing HTTP/1.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("missing required Host header"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_http11_request_with_empty_host() -> anyhow::Result<()> {
#[potato::http_get("/host_required_empty")]
async fn host_required_empty(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"GET /host_required_empty HTTP/1.1\r\n",
"Host: \r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("empty Host header"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_http11_request_with_duplicate_host() -> anyhow::Result<()> {
#[potato::http_get("/host_required_duplicate")]
async fn host_required_duplicate(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"GET /host_required_duplicate HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Host: example.com\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("multiple Host headers are not allowed"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_http11_request_with_invalid_host() -> anyhow::Result<()> {
#[potato::http_get("/host_required_invalid")]
async fn host_required_invalid(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"GET /host_required_invalid HTTP/1.1\r\n",
"Host: bad host\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("invalid Host header"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_chunked_form_body_keeps_existing_body_pair_parsing() -> anyhow::Result<()> {
#[potato::http_post("/chunked_form_parse")]
async fn chunked_form_parse(req: &mut HttpRequest) -> HttpResponse {
let name = req.body_pairs.get("name").map_or("", |v| v.as_ref());
HttpResponse::text(name.to_string())
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /chunked_form_parse HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Content-Type: application/x-www-form-urlencoded\r\n",
"Transfer-Encoding: chunked\r\n",
"Connection: close\r\n",
"\r\n",
"a\r\n",
"name=alice\r\n",
"0\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200"));
assert!(response_text.contains("alice"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_unsupported_transfer_encoding_chain() -> anyhow::Result<()> {
#[potato::http_post("/chunked_te_chain")]
async fn chunked_te_chain(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /chunked_te_chain HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Transfer-Encoding: gzip, chunked\r\n",
"Connection: close\r\n",
"\r\n",
"5\r\n",
"Hello\r\n",
"0\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 501 Not Implemented"));
assert!(response_text.contains("unsupported Transfer-Encoding"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_connect_method_with_status() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"CONNECT example.com:443 HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 501 Not Implemented"));
assert!(response_text.contains("CONNECT method is not implemented"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_accepts_absolute_form_request_target() -> anyhow::Result<()> {
#[potato::http_get("/abs_form_target")]
async fn abs_form_target(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("absolute form ok")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = format!(
"GET http://{}/abs_form_target?ok=1 HTTP/1.1\r\nHost: wrong.example\r\nConnection: close\r\n\r\n",
server_addr
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
assert!(response_text.contains("absolute form ok"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_reverse_proxy_strips_hop_by_hop_request_headers() -> anyhow::Result<()> {
#[potato::http_get("/proxy_hop_headers_echo")]
async fn proxy_hop_headers_echo(req: &mut HttpRequest) -> HttpResponse {
let mut leaked = Vec::new();
for header in [
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"TE",
"Trailer",
"Transfer-Encoding",
"Upgrade",
"Proxy-Connection",
"X-Remove-Me",
] {
if req.get_header(header).is_some() {
leaked.push(header);
}
}
HttpResponse::text(leaked.join(","))
}
let backend_port = get_test_port();
let backend_addr = format!("127.0.0.1:{}", backend_port);
let mut backend_server = HttpServer::new(&backend_addr);
let backend_handle = tokio::spawn(async move {
let _ = backend_server.serve_http().await;
});
let proxy_port = get_test_port();
let proxy_addr = format!("127.0.0.1:{}", proxy_port);
let mut proxy_server = HttpServer::new(&proxy_addr);
proxy_server.configure(|ctx| {
ctx.use_reverse_proxy("/proxy", format!("http://{}", backend_addr), false);
});
let proxy_handle = tokio::spawn(async move {
let _ = proxy_server.serve_http().await;
});
sleep(Duration::from_millis(350)).await;
let mut stream = connect_with_retry(&proxy_addr).await?;
let request = format!(
concat!(
"GET /proxy/proxy_hop_headers_echo HTTP/1.1\r\n",
"Host: {}\r\n",
"Connection: close, X-Remove-Me\r\n",
"Keep-Alive: timeout=5\r\n",
"Proxy-Authenticate: Basic realm=\"proxy\"\r\n",
"Proxy-Authorization: Basic dGVzdDp0ZXN0\r\n",
"TE: trailers\r\n",
"Trailer: X-Trace\r\n",
"Upgrade: websocket\r\n",
"Proxy-Connection: keep-alive\r\n",
"X-Remove-Me: leaked\r\n",
"\r\n"
),
proxy_addr
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
let body = response_body_bytes(&response)?;
assert!(body.is_empty());
proxy_handle.abort();
backend_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_reverse_proxy_strips_hop_by_hop_response_headers() -> anyhow::Result<()> {
#[potato::http_get("/proxy_hop_response_headers")]
async fn proxy_hop_response_headers(_req: &mut HttpRequest) -> HttpResponse {
let mut res = HttpResponse::text("proxy hop headers stripped");
res.add_header("Connection".into(), "keep-alive, X-Upstream-Hop".into());
res.add_header("Keep-Alive".into(), "timeout=5".into());
res.add_header("Proxy-Authenticate".into(), "Basic realm=\"proxy\"".into());
res.add_header("Proxy-Authorization".into(), "Basic dGVzdDp0ZXN0".into());
res.add_header("TE".into(), "trailers".into());
res.add_header("Trailer".into(), "X-Trace".into());
res.add_header("Transfer-Encoding".into(), "chunked".into());
res.add_header("Upgrade".into(), "h2c".into());
res.add_header("Proxy-Connection".into(), "keep-alive".into());
res.add_header("X-Upstream-Hop".into(), "1".into());
res
}
let backend_port = get_test_port();
let backend_addr = format!("127.0.0.1:{}", backend_port);
let mut backend_server = HttpServer::new(&backend_addr);
let backend_handle = tokio::spawn(async move {
let _ = backend_server.serve_http().await;
});
let proxy_port = get_test_port();
let proxy_addr = format!("127.0.0.1:{}", proxy_port);
let mut proxy_server = HttpServer::new(&proxy_addr);
proxy_server.configure(|ctx| {
ctx.use_reverse_proxy("/proxy", format!("http://{}", backend_addr), false);
});
let proxy_handle = tokio::spawn(async move {
let _ = proxy_server.serve_http().await;
});
sleep(Duration::from_millis(350)).await;
let mut stream = connect_with_retry(&proxy_addr).await?;
let request = format!(
"GET /proxy/proxy_hop_response_headers HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
proxy_addr
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
assert!(!response_text.contains("Connection: keep-alive"));
assert!(!response_text.contains("Keep-Alive:"));
assert!(!response_text.contains("Proxy-Authenticate:"));
assert!(!response_text.contains("Proxy-Authorization:"));
assert!(!response_text.contains("TE:"));
assert!(!response_text.contains("Trailer:"));
assert!(!response_text.contains("Transfer-Encoding:"));
assert!(!response_text.contains("Upgrade:"));
assert!(!response_text.contains("Proxy-Connection:"));
assert!(!response_text.contains("X-Upstream-Hop:"));
assert!(response_text.contains("proxy hop headers stripped"));
proxy_handle.abort();
backend_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_reverse_proxy_sets_host_with_non_default_port() -> anyhow::Result<()> {
#[potato::http_get("/proxy_host_echo")]
async fn proxy_host_echo(req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text(req.get_header("Host").unwrap_or(""))
}
let backend_port = get_test_port();
let backend_addr = format!("127.0.0.1:{}", backend_port);
let mut backend_server = HttpServer::new(&backend_addr);
let backend_handle = tokio::spawn(async move {
let _ = backend_server.serve_http().await;
});
let proxy_port = get_test_port();
let proxy_addr = format!("127.0.0.1:{}", proxy_port);
let mut proxy_server = HttpServer::new(&proxy_addr);
proxy_server.configure(|ctx| {
ctx.use_reverse_proxy("/proxy", format!("http://{}", backend_addr), false);
});
let proxy_handle = tokio::spawn(async move {
let _ = proxy_server.serve_http().await;
});
sleep(Duration::from_millis(350)).await;
let mut stream = connect_with_retry(&proxy_addr).await?;
let request = format!(
"GET /proxy/proxy_host_echo HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
proxy_addr
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
let body = response_body_bytes(&response)?;
assert_eq!(String::from_utf8_lossy(body), backend_addr);
proxy_handle.abort();
backend_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_authority_form_for_non_connect() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"GET example.com:443 HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("authority-form request-target is only valid for CONNECT"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_origin_form_for_connect() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"CONNECT /tunnel HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("CONNECT requires authority-form request-target"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_options_asterisk_uses_server_wide_allow() -> anyhow::Result<()> {
#[potato::http_get("/asterisk_get")]
async fn asterisk_get(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("ok")
}
#[potato::http_post("/asterisk_post")]
async fn asterisk_post(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("ok")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"OPTIONS * HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
assert!(response_text.contains("Allow:"));
assert!(response_text.contains("GET"));
assert!(response_text.contains("POST"));
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"GET * HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response_text.contains("asterisk-form request-target requires OPTIONS"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_head_fallback_uses_get_status_without_body() -> anyhow::Result<()> {
#[potato::http_get("/head_fallback_get")]
async fn head_fallback_get(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("head fallback payload")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"HEAD /head_fallback_get HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
assert!(!response_text.contains("head fallback payload"));
let header_end = response
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|p| p + 4)
.ok_or_else(|| anyhow::anyhow!("response missing header terminator"))?;
assert_eq!(response.len(), header_end);
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"HEAD /head_fallback_missing HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 404 Not Found"));
assert!(!response_text.contains("404 not found"));
let header_end = response
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|p| p + 4)
.ok_or_else(|| anyhow::anyhow!("response missing header terminator"))?;
assert_eq!(response.len(), header_end);
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_head_route_takes_precedence_over_get_fallback() -> anyhow::Result<()> {
#[potato::http_get("/head_precedence")]
async fn head_precedence_get(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("from get")
}
#[potato::http_head("/head_precedence")]
async fn head_precedence_head(_req: &mut HttpRequest) -> HttpResponse {
let mut res = HttpResponse::empty();
res.http_code = 204;
res
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"HEAD /head_precedence HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 204 No Content"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_head_handler_response_never_writes_body() -> anyhow::Result<()> {
#[potato::http_head("/head_no_body")]
async fn head_no_body(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("must not be sent")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"HEAD /head_no_body HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 200 OK"));
assert!(response_text.contains("Content-Length: 16\r\n"));
assert!(response_body_bytes(&response)?.is_empty());
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_204_response_never_writes_body() -> anyhow::Result<()> {
#[potato::http_get("/status_204_no_body")]
async fn status_204_no_body(_req: &mut HttpRequest) -> HttpResponse {
let mut res = HttpResponse::text("must not be sent");
res.http_code = 204;
res
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"GET /status_204_no_body HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 204 No Content"));
assert!(!response_text.contains("Content-Length:"));
assert!(response_body_bytes(&response)?.is_empty());
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_304_response_never_writes_body() -> anyhow::Result<()> {
#[potato::http_get("/status_304_no_body")]
async fn status_304_no_body(_req: &mut HttpRequest) -> HttpResponse {
let mut res = HttpResponse::text("must not be sent");
res.http_code = 304;
res
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"GET /status_304_no_body HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Connection: close\r\n",
"\r\n"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
let response_text = String::from_utf8_lossy(&response);
assert!(response_text.starts_with("HTTP/1.1 304 Not Modified"));
assert!(!response_text.contains("Content-Length:"));
assert!(response_body_bytes(&response)?.is_empty());
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_server_rejects_conflicting_duplicate_content_length() -> anyhow::Result<()> {
#[potato::http_post("/duplicate_content_length")]
async fn duplicate_content_length(_req: &mut HttpRequest) -> HttpResponse {
HttpResponse::text("should not reach")
}
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let mut stream = connect_with_retry(&server_addr).await?;
let request = concat!(
"POST /duplicate_content_length HTTP/1.1\r\n",
"Host: 127.0.0.1\r\n",
"Content-Length: 5\r\n",
"Content-Length: 6\r\n",
"Connection: close\r\n",
"\r\n",
"Hello!"
);
stream.write_all(request.as_bytes()).await?;
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
assert!(response.is_empty());
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_static_file_route_supports_range_partial_content() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let temp_dir = env::temp_dir().join(format!("potato-range-{}", port));
fs::create_dir_all(&temp_dir)?;
let file_path = temp_dir.join("sample.txt");
let file_content = b"HelloRangeWorld";
fs::write(&file_path, file_content)?;
let mut server = HttpServer::new(&server_addr);
let static_root = temp_dir.canonicalize()?.to_string_lossy().to_string();
server.configure(|ctx| {
ctx.use_location_route("/static/", static_root.clone(), false);
});
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let headers = vec![potato::Headers::Custom((
"Range".to_string(),
"bytes=5-9".to_string(),
))];
let url = format!("http://{}/static/sample.txt", server_addr);
let res = potato::get(&url, headers).await?;
assert_eq!(res.http_code, 206);
assert_eq!(res.get_header("Accept-Ranges"), Some("bytes"));
assert_eq!(res.get_header("Content-Range"), Some("bytes 5-9/15"));
let body = match res.body {
potato::HttpResponseBody::Data(data) => data,
potato::HttpResponseBody::Stream(_) => vec![],
};
assert_eq!(body, b"Range".to_vec());
server_handle.abort();
_ = fs::remove_file(file_path);
_ = fs::remove_dir(temp_dir);
Ok(())
}
#[tokio::test]
async fn test_static_file_route_returns_416_for_unsatisfiable_range() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let temp_dir = env::temp_dir().join(format!("potato-range-{}", port));
fs::create_dir_all(&temp_dir)?;
let file_path = temp_dir.join("sample.txt");
fs::write(&file_path, b"short")?;
let mut server = HttpServer::new(&server_addr);
let static_root = temp_dir.canonicalize()?.to_string_lossy().to_string();
server.configure(|ctx| {
ctx.use_location_route("/static/", static_root.clone(), false);
});
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let headers = vec![potato::Headers::Custom((
"Range".to_string(),
"bytes=99-100".to_string(),
))];
let url = format!("http://{}/static/sample.txt", server_addr);
let res = potato::get(&url, headers).await?;
assert_eq!(res.http_code, 416);
assert_eq!(res.get_header("Accept-Ranges"), Some("bytes"));
assert_eq!(res.get_header("Content-Range"), Some("bytes */5"));
server_handle.abort();
_ = fs::remove_file(file_path);
_ = fs::remove_dir(temp_dir);
Ok(())
}
#[tokio::test]
async fn test_location_route_blocks_parent_path_segments() -> anyhow::Result<()> {
let unique = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_nanos();
let temp_dir = env::temp_dir().join(format!("potato-location-route-block-{unique}"));
let static_root = temp_dir.join("wwwroot");
let outside_file = temp_dir.join("outside.txt");
_ = fs::remove_dir_all(&temp_dir);
fs::create_dir_all(&static_root)?;
fs::write(&outside_file, b"outside")?;
let mut ctx = PipeContext::new();
ctx.use_location_route("/static/", static_root.to_string_lossy().to_string(), false);
let mut req = static_get_request("/static/../outside.txt");
let res = PipeContext::handle_request(&ctx, &mut req, 0).await;
assert_eq!(res.http_code, 500);
assert_eq!(
String::from_utf8(response_body_data(res))?,
"url path over directory"
);
_ = fs::remove_dir_all(temp_dir);
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn test_location_route_allows_symlink_root_and_symlink_children() -> anyhow::Result<()> {
let unique = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_nanos();
let temp_dir = env::temp_dir().join(format!("potato-location-route-symlink-{unique}"));
let real_root = temp_dir.join("real-root");
let outside_assets = temp_dir.join("outside-assets");
let symlink_root = temp_dir.join("wwwroot");
let symlink_child = real_root.join("assets");
_ = fs::remove_dir_all(&temp_dir);
fs::create_dir_all(&real_root)?;
fs::create_dir_all(&outside_assets)?;
fs::write(real_root.join("direct.txt"), b"direct")?;
fs::write(outside_assets.join("linked.txt"), b"linked")?;
std::os::unix::fs::symlink(&real_root, &symlink_root)?;
std::os::unix::fs::symlink(&outside_assets, &symlink_child)?;
let mut ctx = PipeContext::new();
ctx.use_location_route("/static/", symlink_root.to_string_lossy().to_string(), true);
let mut direct_req = static_get_request("/static/direct.txt");
let direct_res = PipeContext::handle_request(&ctx, &mut direct_req, 0).await;
assert_eq!(direct_res.http_code, 200);
assert_eq!(response_body_data(direct_res), b"direct".to_vec());
let mut linked_req = static_get_request("/static/assets/linked.txt");
let linked_res = PipeContext::handle_request(&ctx, &mut linked_req, 0).await;
assert_eq!(linked_res.http_code, 200);
assert_eq!(response_body_data(linked_res), b"linked".to_vec());
let mut traversal_req = static_get_request("/static/../outside-assets/linked.txt");
let traversal_res = PipeContext::handle_request(&ctx, &mut traversal_req, 0).await;
assert_eq!(traversal_res.http_code, 500);
assert_eq!(
String::from_utf8(response_body_data(traversal_res))?,
"url path over directory"
);
_ = fs::remove_file(symlink_child);
_ = fs::remove_file(symlink_root);
_ = fs::remove_dir_all(temp_dir);
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn test_location_route_blocks_symlink_escape_when_disabled() -> anyhow::Result<()> {
let unique = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_nanos();
let temp_dir = env::temp_dir().join(format!("potato-location-route-symlink-deny-{unique}"));
let real_root = temp_dir.join("real-root");
let outside_assets = temp_dir.join("outside-assets");
let symlink_child = real_root.join("assets");
_ = fs::remove_dir_all(&temp_dir);
fs::create_dir_all(&real_root)?;
fs::create_dir_all(&outside_assets)?;
fs::write(real_root.join("direct.txt"), b"direct")?;
fs::write(outside_assets.join("linked.txt"), b"linked")?;
std::os::unix::fs::symlink(&outside_assets, &symlink_child)?;
let mut ctx = PipeContext::new();
ctx.use_location_route("/static/", real_root.to_string_lossy().to_string(), false);
let mut direct_req = static_get_request("/static/direct.txt");
let direct_res = PipeContext::handle_request(&ctx, &mut direct_req, 0).await;
assert_eq!(direct_res.http_code, 200);
assert_eq!(response_body_data(direct_res), b"direct".to_vec());
let mut linked_req = static_get_request("/static/assets/linked.txt");
let linked_res = PipeContext::handle_request(&ctx, &mut linked_req, 0).await;
assert_eq!(linked_res.http_code, 500);
assert_eq!(
String::from_utf8(response_body_data(linked_res))?,
"url path over directory"
);
_ = fs::remove_file(symlink_child);
_ = fs::remove_dir_all(temp_dir);
Ok(())
}
#[tokio::test]
async fn test_embedded_route_supports_range_partial_content() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let mut assets: HashMap<String, Cow<'static, [u8]>> = HashMap::new();
assets.insert("sample.txt".to_string(), Cow::Borrowed(b"HelloRangeWorld"));
server.configure(|ctx| {
ctx.use_embedded_route("/embed", assets.clone());
});
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let headers = vec![potato::Headers::Custom((
"Range".to_string(),
"bytes=5-9".to_string(),
))];
let url = format!("http://{}/embed/sample.txt", server_addr);
let res = potato::get(&url, headers).await?;
assert_eq!(res.http_code, 206);
assert_eq!(res.get_header("Accept-Ranges"), Some("bytes"));
assert_eq!(res.get_header("Content-Range"), Some("bytes 5-9/15"));
let body = match res.body {
potato::HttpResponseBody::Data(data) => data,
potato::HttpResponseBody::Stream(_) => vec![],
};
assert_eq!(body, b"Range".to_vec());
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_embedded_route_etag_roundtrip_returns_304() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let mut assets: HashMap<String, Cow<'static, [u8]>> = HashMap::new();
assets.insert("etag.txt".to_string(), Cow::Borrowed(b"etag-body"));
server.configure(|ctx| {
ctx.use_embedded_route("/embed", assets.clone());
});
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let url = format!("http://{}/embed/etag.txt", server_addr);
let first = potato::get(&url, vec![]).await?;
assert_eq!(first.http_code, 200);
let etag = first
.get_header("ETag")
.ok_or_else(|| anyhow::anyhow!("embedded response missing ETag"))?
.to_string();
let second = potato::get(
&url,
vec![potato::Headers::Custom((
"If-None-Match".to_string(),
etag.clone(),
))],
)
.await?;
assert_eq!(second.http_code, 304);
assert_eq!(second.get_header("ETag"), Some(etag.as_str()));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_embedded_route_returns_416_for_unsatisfiable_range() -> anyhow::Result<()> {
let port = get_test_port();
let server_addr = format!("127.0.0.1:{port}");
let mut server = HttpServer::new(&server_addr);
let mut assets: HashMap<String, Cow<'static, [u8]>> = HashMap::new();
assets.insert("sample.txt".to_string(), Cow::Borrowed(b"short"));
server.configure(|ctx| {
ctx.use_embedded_route("/embed", assets.clone());
});
let server_handle = tokio::spawn(async move {
let _ = server.serve_http().await;
});
sleep(Duration::from_millis(300)).await;
let headers = vec![potato::Headers::Custom((
"Range".to_string(),
"bytes=99-100".to_string(),
))];
let url = format!("http://{}/embed/sample.txt", server_addr);
let res = potato::get(&url, headers).await?;
assert_eq!(res.http_code, 416);
assert_eq!(res.get_header("Accept-Ranges"), Some("bytes"));
assert_eq!(res.get_header("Content-Range"), Some("bytes */5"));
server_handle.abort();
Ok(())
}
}