use std::net::{SocketAddr, TcpListener};
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use http::header::HeaderValue;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::client::conn::http1;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use serde::Deserialize;
use serde_json::json;
use tokio::net::TcpStream;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tower::{Service, ServiceBuilder};
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use hyperlite::{parse_json_body, serve, success, BoxBody, BoxError, Router};
mod test_helpers;
use test_helpers::*;
fn available_addr() -> SocketAddr {
let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind test listener");
let addr = listener.local_addr().expect("local addr");
drop(listener);
addr
}
async fn spawn_server<S>(addr: SocketAddr, service: S) -> JoinHandle<Result<(), BoxError>>
where
S: Service<
Request<BoxBody>,
Response = Response<Full<Bytes>>,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
tokio::spawn(async move { serve(addr, service).await })
}
async fn wait_for_server(addr: SocketAddr) {
for _ in 0..10 {
if tokio::net::TcpStream::connect(addr).await.is_ok() {
return;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
panic!("server did not start on {addr}");
}
async fn send_request(addr: SocketAddr, request: Request<Full<Bytes>>) -> Response<Incoming> {
let stream = TcpStream::connect(addr)
.await
.expect("failed to connect to server");
let io = TokioIo::new(stream);
let (mut sender, connection) = http1::handshake(io).await.expect("client handshake failed");
tokio::spawn(async move {
if let Err(err) = connection.await {
#[cfg(feature = "tracing")]
tracing::error!(?err, "client connection error");
#[cfg(not(feature = "tracing"))]
let _ = err;
}
});
sender
.send_request(request)
.await
.expect("request should succeed")
}
async fn send_request_with_timeout(
addr: SocketAddr,
request: Request<Full<Bytes>>,
) -> Response<Incoming> {
timeout(Duration::from_secs(2), send_request(addr, request))
.await
.expect("client request timed out")
}
async fn response_bytes(response: Response<Incoming>) -> Bytes {
response
.into_body()
.collect()
.await
.expect("failed to read response body")
.to_bytes()
}
async fn response_json(response: Response<Incoming>) -> serde_json::Value {
let bytes = response_bytes(response).await;
serde_json::from_slice(&bytes).expect("invalid JSON body")
}
#[derive(Deserialize)]
struct EchoPayload {
name: String,
}
#[tokio::test]
async fn test_connection_handler_body_conversion() {
let router = Router::new(()).route(
"/json",
Method::POST,
Arc::new(|req, _| {
Box::pin(async move {
let payload = parse_json_body::<EchoPayload>(req).await?;
Ok(success(StatusCode::OK, payload.name))
})
}),
);
let addr = available_addr();
let handle = spawn_server(addr, router).await;
wait_for_server(addr).await;
let uri = format!("http://{}:{}/json", addr.ip(), addr.port());
let request = Request::builder()
.method(Method::POST)
.uri(uri)
.header("content-type", "application/json")
.body(Full::new(Bytes::from_static(br#"{"name":"Alice"}"#)))
.unwrap();
let response = send_request_with_timeout(addr, request).await;
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["data"], "Alice");
handle.abort();
}
#[tokio::test]
async fn test_connection_handler_service_call() {
let state = TestState::new();
let router = Router::new(state.clone()).route(
"/count",
Method::GET,
Arc::new(|_, state| {
Box::pin(async move {
state.increment();
Ok(success(StatusCode::OK, state.get()))
})
}),
);
let addr = available_addr();
let handle = spawn_server(addr, router).await;
wait_for_server(addr).await;
let uri = format!("http://{}:{}/count", addr.ip(), addr.port());
let request = Request::builder()
.method(Method::GET)
.uri(uri)
.body(Full::new(Bytes::new()))
.unwrap();
let response = send_request_with_timeout(addr, request).await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(state.get(), 1);
handle.abort();
}
#[tokio::test]
async fn test_connection_handler_extensions_preserved() {
let router = Router::new(()).route(
"/headers",
Method::GET,
Arc::new(|req, _| {
Box::pin(async move {
let value = req
.headers()
.get("x-custom")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
Ok(success(StatusCode::OK, value))
})
}),
);
let addr = available_addr();
let handle = spawn_server(addr, router).await;
wait_for_server(addr).await;
let uri = format!("http://{}:{}/headers", addr.ip(), addr.port());
let request = Request::builder()
.method(Method::GET)
.uri(uri)
.header("x-custom", "value")
.body(Full::new(Bytes::new()))
.unwrap();
let response = send_request_with_timeout(addr, request).await;
let json = response_json(response).await;
assert_eq!(json["data"], "value");
handle.abort();
}
#[tokio::test]
async fn test_connection_handler_clone() {
let state = TestState::new();
let router = Router::new(state.clone()).route(
"/clone",
Method::GET,
Arc::new(|_, state| {
Box::pin(async move {
state.increment();
Ok(success(StatusCode::OK, state.get()))
})
}),
);
let addr = available_addr();
let handle = spawn_server(addr, router).await;
wait_for_server(addr).await;
for _ in 0..2 {
let uri = format!("http://{}:{}/clone", addr.ip(), addr.port());
let request = Request::builder()
.method(Method::GET)
.uri(&uri)
.body(Full::new(Bytes::new()))
.unwrap();
let response = send_request_with_timeout(addr, request).await;
assert_eq!(response.status(), StatusCode::OK);
}
assert_eq!(state.get(), 2);
handle.abort();
}
#[tokio::test]
async fn test_serve_binds_to_port() {
let router = Router::new(()).route(
"/health",
Method::GET,
Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "ok")) })),
);
let addr = available_addr();
let handle = spawn_server(addr, router).await;
wait_for_server(addr).await;
let uri = format!("http://{}:{}/health", addr.ip(), addr.port());
let request = Request::builder()
.method(Method::GET)
.uri(uri)
.body(Full::new(Bytes::new()))
.unwrap();
let response = send_request_with_timeout(addr, request).await;
assert_eq!(response.status(), StatusCode::OK);
handle.abort();
}
#[tokio::test]
async fn test_serve_accepts_connections() {
let router = Router::new(()).route(
"/ping",
Method::GET,
Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "pong")) })),
);
let addr = available_addr();
let handle = spawn_server(addr, router).await;
wait_for_server(addr).await;
let uri = format!("http://{}:{}/ping", addr.ip(), addr.port());
let request = Request::builder()
.method(Method::GET)
.uri(uri)
.body(Full::new(Bytes::new()))
.unwrap();
let response = send_request_with_timeout(addr, request).await;
let json = response_json(response).await;
assert_eq!(json["data"], "pong");
handle.abort();
}
#[tokio::test]
#[ignore]
async fn test_serve_with_router() {
let router = Router::new(())
.route(
"/one",
Method::GET,
Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "one")) })),
)
.route(
"/two",
Method::GET,
Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "two")) })),
);
let addr = available_addr();
let handle = spawn_server(addr, router).await;
wait_for_server(addr).await;
for (path, expected) in [("/one", "one"), ("/two", "two")] {
let uri = format!("http://{}:{}{}", addr.ip(), addr.port(), path);
let request = Request::builder()
.method(Method::GET)
.uri(uri)
.body(Full::new(Bytes::new()))
.unwrap();
let response = send_request_with_timeout(addr, request).await;
let json = response_json(response).await;
assert_eq!(json["data"], expected);
}
handle.abort();
}
#[tokio::test]
#[ignore]
async fn test_serve_with_middleware() {
let router = Router::new(()).route(
"/middleware",
Method::GET,
Arc::new(|req, _| {
Box::pin(async move {
let request_id = req
.extensions()
.get::<tower_http::request_id::RequestId>()
.and_then(|id| id.header_value().to_str().ok())
.unwrap_or("")
.to_string();
Ok(success(StatusCode::OK, json!({ "request_id": request_id })))
})
}),
);
let service = ServiceBuilder::new()
.layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
.layer(PropagateRequestIdLayer::x_request_id())
.service(router);
let addr = available_addr();
let handle = spawn_server(addr, service).await;
wait_for_server(addr).await;
let uri = format!("http://{}:{}/middleware", addr.ip(), addr.port());
let request = Request::builder()
.method(Method::GET)
.uri(uri)
.body(Full::new(Bytes::new()))
.unwrap();
let response = send_request_with_timeout(addr, request).await;
let header = response
.headers()
.get("x-request-id")
.cloned()
.unwrap_or_else(|| HeaderValue::from_static(""));
let bytes = response_bytes(response).await;
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(
json["data"]["request_id"].as_str().unwrap(),
header.to_str().unwrap()
);
handle.abort();
}