use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use typeway_core::*;
use typeway_macros::*;
use typeway_server::router::DEFAULT_MAX_BODY_SIZE;
use typeway_server::*;
typeway_path!(type EchoPath = "echo");
typeway_path!(type HelloPath = "hello");
type EchoAPI = (PostEndpoint<EchoPath, String, String>,);
async fn echo_body(body: String) -> String {
body
}
async fn spawn_router(router: Router) -> u16 {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let router = Arc::new(router);
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
let svc = RouterService::new(router.clone());
let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
tokio::spawn(async move {
let _ = hyper::server::conn::http1::Builder::new()
.serve_connection(io, hyper_svc)
.await;
});
}
});
tokio::time::sleep(Duration::from_millis(50)).await;
port
}
async fn start_default_server() -> u16 {
let server = Server::<EchoAPI>::new((bind::<_, _, _>(echo_body),));
spawn_router(server.into_router()).await
}
async fn start_custom_limit_server(max: usize) -> u16 {
let server = Server::<EchoAPI>::new((bind::<_, _, _>(echo_body),)).max_body_size(max);
spawn_router(server.into_router()).await
}
type GetAndPostAPI = (
GetEndpoint<HelloPath, String>,
PostEndpoint<EchoPath, String, String>,
);
async fn hello() -> &'static str {
"hello"
}
async fn start_get_post_server(max: usize) -> u16 {
let server = Server::<GetAndPostAPI>::new((bind::<_, _, _>(hello), bind::<_, _, _>(echo_body)))
.max_body_size(max);
spawn_router(server.into_router()).await
}
async fn send_raw(port: u16, request: &[u8]) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{port}"))
.await
.unwrap();
stream.write_all(request).await.unwrap();
let mut buf = vec![0u8; 16384];
let mut response = Vec::new();
loop {
match tokio::time::timeout(Duration::from_secs(3), stream.read(&mut buf)).await {
Ok(Ok(0)) => break, Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
Ok(Err(_)) => break, Err(_) => break, }
}
String::from_utf8_lossy(&response).to_string()
}
fn extract_status(raw: &str) -> u16 {
let status_str = raw
.split_whitespace()
.nth(1)
.expect("missing status code in response");
status_str.parse().expect("non-numeric status code")
}
#[tokio::test]
async fn exact_limit_boundary_succeeds() {
let port = start_default_server().await;
let body = "x".repeat(DEFAULT_MAX_BODY_SIZE);
let resp = reqwest::Client::new()
.post(format!("http://127.0.0.1:{port}/echo"))
.body(body.clone())
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200, "body exactly at limit should succeed");
assert_eq!(resp.text().await.unwrap(), body);
}
#[tokio::test]
async fn one_byte_over_limit_returns_413() {
let port = start_default_server().await;
let body = "x".repeat(DEFAULT_MAX_BODY_SIZE + 1);
let resp = reqwest::Client::new()
.post(format!("http://127.0.0.1:{port}/echo"))
.body(body)
.send()
.await
.unwrap();
assert_eq!(
resp.status(),
413,
"body one byte over limit should be rejected"
);
}
#[tokio::test]
async fn custom_limit_enforcement() {
let port = start_custom_limit_server(1024).await;
let body = "x".repeat(1025);
let resp = reqwest::Client::new()
.post(format!("http://127.0.0.1:{port}/echo"))
.body(body)
.send()
.await
.unwrap();
assert_eq!(
resp.status(),
413,
"1025 bytes should exceed custom 1024-byte limit"
);
}
#[tokio::test]
async fn custom_limit_exact_boundary_succeeds() {
let port = start_custom_limit_server(1024).await;
let body = "x".repeat(1024);
let resp = reqwest::Client::new()
.post(format!("http://127.0.0.1:{port}/echo"))
.body(body.clone())
.send()
.await
.unwrap();
assert_eq!(
resp.status(),
200,
"body exactly at custom limit should succeed"
);
assert_eq!(resp.text().await.unwrap(), body);
}
#[tokio::test]
async fn zero_byte_body_on_post_does_not_crash() {
let port = start_default_server().await;
let resp = reqwest::Client::new()
.post(format!("http://127.0.0.1:{port}/echo"))
.header("content-type", "application/json")
.body("")
.send()
.await
.unwrap();
let status = resp.status().as_u16();
assert!(
status < 500,
"empty body should not cause a server error, got {status}"
);
}
#[tokio::test]
async fn large_content_length_header_does_not_allocate() {
let port = start_custom_limit_server(1024).await;
let request = format!(
"POST /echo HTTP/1.1\r\n\
Host: 127.0.0.1:{port}\r\n\
Connection: close\r\n\
Content-Length: 1073741824\r\n\
\r\n\
hello"
);
let raw = send_raw(port, request.as_bytes()).await;
if !raw.is_empty() {
let status = extract_status(&raw);
assert_ne!(
status, 200,
"server should not accept a request claiming 1 GB body"
);
}
}
#[tokio::test]
async fn chunked_encoding_exceeding_limit_returns_413() {
let port = start_custom_limit_server(64).await;
let chunk = "a".repeat(10);
let mut body = String::new();
for _ in 0..10 {
body.push_str(&format!("{:x}\r\n{}\r\n", chunk.len(), chunk));
}
body.push_str("0\r\n\r\n");
let request = format!(
"POST /echo HTTP/1.1\r\n\
Host: 127.0.0.1:{port}\r\n\
Connection: close\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
{body}"
);
let raw = send_raw(port, request.as_bytes()).await;
assert!(
!raw.is_empty(),
"server should respond, not drop connection"
);
let status = extract_status(&raw);
assert_eq!(
status, 413,
"chunked body exceeding limit should be rejected with 413"
);
}
#[tokio::test]
async fn chunked_encoding_within_limit_succeeds() {
let port = start_custom_limit_server(256).await;
let chunk = "b".repeat(10);
let mut body = String::new();
for _ in 0..5 {
body.push_str(&format!("{:x}\r\n{}\r\n", chunk.len(), chunk));
}
body.push_str("0\r\n\r\n");
let request = format!(
"POST /echo HTTP/1.1\r\n\
Host: 127.0.0.1:{port}\r\n\
Connection: close\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
{body}"
);
let raw = send_raw(port, request.as_bytes()).await;
assert!(!raw.is_empty(), "server should respond");
let status = extract_status(&raw);
assert_eq!(
status, 200,
"chunked body within limit should succeed, got {status}"
);
}
#[tokio::test]
async fn get_with_small_body_limit_works() {
let port = start_get_post_server(1).await;
let resp = reqwest::get(format!("http://127.0.0.1:{port}/hello"))
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.text().await.unwrap(), "hello");
}