use jerrycan_core::{
App, CorsConfig, CorsOrigins, Json, Middleware, MiddlewareFuture, Multipart, Next, NoContent,
RequestCtx, Result, StreamBody, get, post,
};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
async fn endless() -> StreamBody {
let (body, tx) = StreamBody::channel();
tokio::spawn(async move {
let chunk = vec![b'x'; 64 * 1024];
while tx.send(chunk.clone()).await {}
});
body
}
async fn three() -> StreamBody {
let (body, tx) = StreamBody::channel();
tokio::spawn(async move {
for _ in 0..3 {
tx.send(vec![b'y'; 1024]).await;
}
});
body
}
fn decode_chunked(mut body: &str) -> usize {
let mut total = 0;
while let Some((len_line, rest)) = body.split_once("\r\n") {
let len = usize::from_str_radix(len_line.trim(), 16).expect("valid chunk size");
if len == 0 {
break;
}
total += len;
body = &rest[len + 2..];
}
total
}
#[tokio::test]
async fn stalled_reader_is_disconnected_after_write_stall_cap() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap().to_string();
let app = App::new()
.route("/endless", get(endless))
.write_stall_timeout(Duration::from_millis(500));
let server = tokio::spawn(async move { app.serve_with(listener).await });
let observe = async {
let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
s.write_all(b"GET /endless HTTP/1.1\r\nhost: t\r\n\r\n")
.await
.unwrap();
let mut buf = [0u8; 8 * 1024];
let n = s.read(&mut buf).await.unwrap();
assert!(n > 0, "the stream must start before we stall");
tokio::time::sleep(Duration::from_secs(3)).await;
loop {
match s.read(&mut buf).await {
Ok(0) => break,
Ok(_) => continue,
Err(_) => break,
}
}
};
tokio::time::timeout(Duration::from_secs(10), observe)
.await
.expect("server must drop the stalled reader, not hang on the write");
server.abort();
}
#[tokio::test]
async fn prompt_reader_gets_the_full_stream_and_is_not_disconnected() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap().to_string();
let app = App::new()
.route("/three", get(three))
.write_stall_timeout(Duration::from_millis(500));
let server = tokio::spawn(async move { app.serve_with(listener).await });
let read_all = async {
let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
s.write_all(b"GET /three HTTP/1.1\r\nhost: t\r\nconnection: close\r\n\r\n")
.await
.unwrap();
let mut buf = Vec::new();
s.read_to_end(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf).into_owned()
};
let raw = tokio::time::timeout(Duration::from_secs(10), read_all)
.await
.expect("a prompt reader must never be disconnected by the stall cap");
assert!(
raw.starts_with("HTTP/1.1 200"),
"got: {}",
&raw[..raw.len().min(64)]
);
let body = raw.split_once("\r\n\r\n").expect("response has headers").1;
let payload_len = decode_chunked(body);
assert_eq!(
payload_len,
3 * 1024,
"the full streamed payload must arrive"
);
server.abort();
}
async fn echo(Json(v): Json<serde_json::Value>) -> Result<Json<serde_json::Value>> {
Ok(Json(v))
}
#[tokio::test]
async fn streamed_request_body_drains_over_a_real_socket_when_written_in_dribbles() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap().to_string();
let app = App::new().route("/up", post(echo).stream_body().body_limit(1024));
let server = tokio::spawn(async move { app.serve_with(listener).await });
let body = br#"{"hello":"streamed world"}"#;
let (a, rest) = body.split_at(8);
let (b, c) = rest.split_at(9);
let exchange = async {
let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
let head = format!(
"POST /up HTTP/1.1\r\nHost: l\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
s.write_all(head.as_bytes()).await.unwrap();
s.flush().await.unwrap();
for chunk in [a, b, c] {
s.write_all(chunk).await.unwrap();
s.flush().await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
}
let mut buf = Vec::new();
s.read_to_end(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf).into_owned()
};
let raw = tokio::time::timeout(Duration::from_secs(10), exchange)
.await
.expect("the dribbled streamed body must be drained and echoed, not hang");
assert!(
raw.starts_with("HTTP/1.1 200"),
"got: {}",
&raw[..raw.len().min(80)]
);
let resp_body = raw.split_once("\r\n\r\n").expect("response has headers").1;
let echoed: serde_json::Value =
serde_json::from_str(resp_body.trim()).expect("echoed JSON body");
assert_eq!(echoed, serde_json::json!({"hello": "streamed world"}));
server.abort();
}
async fn upload(mut mp: Multipart) -> Result<Json<Vec<(String, String)>>> {
let mut out = Vec::new();
while let Some(part) = mp.next_part().await? {
let name = part.name().to_string();
let text = part.text().await?;
out.push((name, text));
}
Ok(Json(out))
}
#[tokio::test]
async fn chunked_multipart_upload_over_a_live_socket() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap().to_string();
let app = App::new().route("/upload", post(upload).stream_body().body_limit(4096));
let server = tokio::spawn(async move { app.serve_with(listener).await });
let body =
"--B\r\ncontent-disposition: form-data; name=\"a\"\r\n\r\nhello\r\n--B--\r\n".to_string();
let exchange = async {
let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
let head = format!(
"POST /upload HTTP/1.1\r\nHost: l\r\nContent-Type: multipart/form-data; boundary=B\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
s.write_all(head.as_bytes()).await.unwrap();
s.flush().await.unwrap();
for chunk in body.as_bytes().chunks(5) {
s.write_all(chunk).await.unwrap();
s.flush().await.unwrap();
tokio::time::sleep(Duration::from_millis(5)).await;
}
let mut buf = Vec::new();
s.read_to_end(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf).into_owned()
};
let raw = tokio::time::timeout(Duration::from_secs(10), exchange)
.await
.expect("the dribbled multipart body must be parsed and echoed, not hang");
assert!(
raw.starts_with("HTTP/1.1 200"),
"got: {}",
&raw[..raw.len().min(80)]
);
let resp_body = raw.split_once("\r\n\r\n").expect("response has headers").1;
let echoed: Vec<(String, String)> =
serde_json::from_str(resp_body.trim()).expect("echoed pairs");
assert_eq!(echoed, vec![("a".to_string(), "hello".to_string())]);
server.abort();
}
async fn sink() -> NoContent {
NoContent
}
#[tokio::test]
async fn cross_origin_413_over_a_live_socket_carries_allow_origin() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap().to_string();
let app = App::new()
.cors(CorsConfig::new(CorsOrigins::list(["https://app.example"])))
.route("/upload", post(sink).body_limit(8));
let server = tokio::spawn(async move { app.serve_with(listener).await });
let body = vec![b'x'; 64];
let exchange = async {
let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
let head = format!(
"POST /upload HTTP/1.1\r\nHost: l\r\nOrigin: https://app.example\r\nContent-Type: application/octet-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
s.write_all(head.as_bytes()).await.unwrap();
s.write_all(&body).await.unwrap();
s.flush().await.unwrap();
let mut buf = Vec::new();
s.read_to_end(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf).into_owned()
};
let raw = tokio::time::timeout(Duration::from_secs(10), exchange)
.await
.expect("the over-limit exchange must complete, not hang");
assert!(
raw.starts_with("HTTP/1.1 413"),
"got: {}",
&raw[..raw.len().min(80)]
);
let headers = raw.split_once("\r\n\r\n").expect("response has headers").0;
assert!(
headers
.lines()
.any(|l| l.eq_ignore_ascii_case("access-control-allow-origin: https://app.example")),
"the serve-level 413 must carry Allow-Origin so the browser surfaces it, got headers:\n{headers}"
);
server.abort();
}
struct StampPeer;
impl Middleware for StampPeer {
fn handle<'a>(&'a self, ctx: &'a mut RequestCtx, next: Next<'a>) -> MiddlewareFuture<'a> {
let peer = ctx.peer_addr();
Box::pin(async move {
let mut res = next.run(ctx).await;
if let Some(peer) = peer
&& let Ok(value) = http::HeaderValue::from_str(&peer.to_string())
{
res.headers_mut().insert("x-peer", value);
}
res
})
}
}
async fn ok() -> &'static str {
"ok"
}
#[tokio::test]
async fn client_peer_address_flows_into_the_handler_over_a_real_socket() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap().to_string();
let app = App::new().route("/whoami", get(ok)).middleware(StampPeer);
let server = tokio::spawn(async move { app.serve_with(listener).await });
let exchange = async {
let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
s.write_all(b"GET /whoami HTTP/1.1\r\nhost: t\r\nconnection: close\r\n\r\n")
.await
.unwrap();
let mut buf = Vec::new();
s.read_to_end(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf).into_owned()
};
let raw = tokio::time::timeout(Duration::from_secs(10), exchange)
.await
.expect("the peer-stamping exchange must complete, not hang");
assert!(
raw.starts_with("HTTP/1.1 200"),
"got: {}",
&raw[..raw.len().min(64)]
);
let headers = raw.split_once("\r\n\r\n").expect("response has headers").0;
let peer_line = headers
.lines()
.find(|l| l.to_ascii_lowercase().starts_with("x-peer:"))
.expect("x-peer header present");
assert!(
peer_line.contains("127.0.0.1:"),
"x-peer must carry the loopback peer address, got: {peer_line}"
);
server.abort();
}