use std::thread;
use async_net::TcpListener;
use futures_lite::future::block_on;
use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
use ugi::{Client, Version};
fn run<T>(value: T) -> T::Output
where
T: std::future::IntoFuture,
{
block_on(async move { value.await })
}
#[test]
fn http1_204_response_ignores_illegal_body_bytes() {
let base = block_on(spawn_http_sequence_server(vec![
b"HTTP/1.1 204 No Content\r\nContent-Length: 4\r\nConnection: close\r\n\r\noops".to_vec(),
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok".to_vec(),
]))
.unwrap();
let client = Client::builder().build().unwrap();
let first = run(client.get(format!("{base}/no-content"))).unwrap();
assert_eq!(first.version(), Version::Http11);
assert_eq!(block_on(first.text()).unwrap(), "");
let second = run(client.get(format!("{base}/after"))).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
}
#[test]
fn http1_304_response_ignores_illegal_body_bytes() {
let base = block_on(spawn_http_sequence_server(vec![
b"HTTP/1.1 304 Not Modified\r\nContent-Length: 3\r\nConnection: close\r\n\r\nbad".to_vec(),
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok".to_vec(),
]))
.unwrap();
let client = Client::builder().build().unwrap();
let first = run(client.get(format!("{base}/cached"))).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "");
let second = run(client.get(format!("{base}/after"))).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
}
#[test]
fn http1_head_response_ignores_illegal_body_bytes() {
let base = block_on(spawn_http_sequence_server(vec![
b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: close\r\n\r\noops".to_vec(),
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok".to_vec(),
]))
.unwrap();
let client = Client::builder().build().unwrap();
let first = run(client.head(format!("{base}/head"))).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "");
let second = run(client.get(format!("{base}/after"))).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
}
#[test]
fn http1_skips_informational_response_before_final_response() {
let base = block_on(spawn_http_sequence_server(vec![
b"HTTP/1.1 100 Continue\r\nHint: wait\r\n\r\nHTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"
.to_vec(),
]))
.unwrap();
let response = run(ugi::get(format!("{base}/informational"))).unwrap();
assert_eq!(response.status().as_u16(), 200);
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn http1_invalid_chunk_size_returns_error() {
let base = block_on(spawn_http_sequence_server(vec![
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\nxyz\r\nhello\r\n0\r\n\r\n".to_vec(),
]))
.unwrap();
let response = run(ugi::get(format!("{base}/chunked"))).unwrap();
let err = block_on(response.text()).unwrap_err();
assert_eq!(err.kind(), &ugi::ErrorKind::Transport);
}
#[test]
fn h2c_upgrade_missing_upgrade_header_returns_error() {
let base = block_on(spawn_http_server(|request| {
assert!(request.contains("Upgrade: h2c") || request.contains("upgrade: h2c"));
b"HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\n\r\n".to_vec()
}))
.unwrap();
let result = run(ugi::get(format!("{base}/h2c")).http2_only());
assert!(matches!(result, Err(err) if err.kind() == &ugi::ErrorKind::Transport));
}
async fn spawn_http_sequence_server(responses: Vec<Vec<u8>>) -> ugi::Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
ugi::Error::with_source(ugi::ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to inspect test server",
err,
)
})?;
thread::spawn(move || {
block_on(async move {
for response in responses {
let (mut stream, _) = listener.accept().await.unwrap();
let mut scratch = vec![0_u8; 4096];
let _ = stream.read(&mut scratch).await.unwrap();
stream.write_all(&response).await.unwrap();
stream.flush().await.unwrap();
}
});
});
Ok(format!("http://{}", addr))
}
async fn spawn_http_server<F>(handler: F) -> ugi::Result<String>
where
F: FnOnce(String) -> Vec<u8> + Send + 'static,
{
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
ugi::Error::with_source(ugi::ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to inspect test server",
err,
)
})?;
thread::spawn(move || {
block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut scratch = vec![0_u8; 4096];
let read = stream.read(&mut scratch).await.unwrap();
let response = handler(String::from_utf8_lossy(&scratch[..read]).to_string());
stream.write_all(&response).await.unwrap();
stream.flush().await.unwrap();
});
});
Ok(format!("http://{}", addr))
}