use crate::App;
use std::{
fmt::{Debug, Formatter},
net::TcpListener,
};
use tokio::sync::oneshot;
#[cfg(feature = "ws")]
use {
super::ws::TestWebSocket,
crate::headers::SEC_WEBSOCKET_PROTOCOL,
crate::http::{HttpBody, Method, Request, Uri},
hyper_util::rt::{TokioExecutor, TokioIo},
tokio::net::TcpStream,
tokio_tungstenite::{
WebSocketStream,
tungstenite::{ClientRequestBuilder, protocol},
},
};
type AppSetupFn = Box<dyn FnOnce(App) -> App + Send>;
type ServerSetupFn = Box<dyn FnOnce(&mut App) + Send>;
pub struct TestServerBuilder {
is_https: bool,
app_config: Option<AppSetupFn>,
routes: Vec<ServerSetupFn>,
}
impl Debug for TestServerBuilder {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestServerBuilder(...)").finish()
}
}
impl Default for TestServerBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct TestServer {
pub port: u16,
is_https: bool,
shutdown_tx: Option<oneshot::Sender<()>>,
server_handle: Option<tokio::task::JoinHandle<()>>,
}
impl TestServerBuilder {
pub fn new() -> Self {
Self {
app_config: None,
routes: Vec::new(),
is_https: false,
}
}
pub fn configure<F>(mut self, config: F) -> Self
where
F: FnOnce(App) -> App + Send + 'static,
{
self.app_config = Some(Box::new(config));
self
}
pub fn setup<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut App) + Send + 'static,
{
self.routes.push(Box::new(f));
self
}
pub fn with_https(mut self) -> Self {
self.is_https = true;
self
}
pub async fn build(self) -> TestServer {
let port = TestServer::get_free_port();
let (tx, rx) = oneshot::channel();
let (ready_tx, ready_rx) = oneshot::channel();
let app_config = self.app_config;
let routes = self.routes;
let server_handle = tokio::spawn(async move {
let mut app = App::new()
.bind(format!("127.0.0.1:{}", port))
.with_no_delay()
.without_greeter();
if let Some(config) = app_config {
app = config(app);
}
for route in routes {
route(&mut app);
}
let _ = ready_tx.send(());
tokio::select! {
_ = app.run() => {},
_ = rx => {}
}
});
let _ = ready_rx.await;
TestServer {
port,
is_https: self.is_https,
shutdown_tx: Some(tx),
server_handle: Some(server_handle),
}
}
}
impl TestServer {
#[inline]
pub fn builder() -> TestServerBuilder {
TestServerBuilder::new()
}
#[inline]
pub async fn spawn<F>(setup: F) -> Self
where
F: FnOnce(&mut App) + Send + 'static,
{
TestServerBuilder::new().setup(setup).build().await
}
pub fn url(&self, path: &str) -> String {
let protocol = if self.is_https { "https" } else { "http" };
format!("{protocol}://127.0.0.1:{}{path}", self.port)
}
pub fn client_builder(&self) -> reqwest::ClientBuilder {
if cfg!(all(feature = "http1", not(feature = "http2"))) {
reqwest::Client::builder().http1_only()
} else {
reqwest::Client::builder().http2_prior_knowledge()
}
}
pub fn client(&self) -> reqwest::Client {
self.client_builder().build().unwrap()
}
#[cfg(feature = "ws")]
pub async fn ws(&self, path: &str) -> TestWebSocket {
self.ws_with_protocols::<0>(path, []).await
}
#[cfg(feature = "ws")]
pub async fn ws_with_protocols<const N: usize>(
&self,
path: &str,
known_protocols: [&'static str; N],
) -> TestWebSocket {
if cfg!(all(feature = "http1", not(feature = "http2"))) {
self.ws_http1(path, known_protocols).await
} else {
self.ws_http2(path, known_protocols).await
}
}
pub async fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(handle) = self.server_handle.take() {
let _ = tokio::time::timeout(tokio::time::Duration::from_secs(5), handle).await;
}
}
#[inline]
pub fn get_free_port() -> u16 {
TcpListener::bind("127.0.0.1:0")
.unwrap()
.local_addr()
.unwrap()
.port()
}
#[cfg(feature = "ws")]
async fn ws_http1<const N: usize>(
&self,
path: &str,
known_protocols: [&'static str; N],
) -> TestWebSocket {
let uri = format!("ws://127.0.0.1:{}{}", self.port, path);
let req = ClientRequestBuilder::new(Uri::try_from(uri).unwrap()).with_header(
SEC_WEBSOCKET_PROTOCOL.to_string(),
known_protocols.join(","),
);
let (ws, _) = tokio_tungstenite::connect_async(req)
.await
.expect("WebSocket handshake failed");
TestWebSocket::from_http1(ws)
}
#[cfg(feature = "ws")]
async fn ws_http2<const N: usize>(
&self,
path: &str,
known_protocols: [&'static str; N],
) -> TestWebSocket {
let io = TokioIo::new(
TcpStream::connect(format!("127.0.0.1:{}", self.port))
.await
.expect("Failed to connect to test server"),
);
let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(io)
.await
.expect("HTTP/2 handshake failed");
tokio::spawn(async move {
let _ = conn.await;
});
let request = Request::builder()
.method(Method::CONNECT)
.extension(hyper::ext::Protocol::from_static("websocket"))
.header(SEC_WEBSOCKET_PROTOCOL, known_protocols.join(","))
.uri(path)
.body(HttpBody::empty())
.unwrap();
let mut response = sender.send_request(request).await.unwrap();
let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
let io = TokioIo::new(upgraded);
let ws = WebSocketStream::from_raw_socket(io, protocol::Role::Client, None).await;
TestWebSocket::from_http2(ws)
}
}
impl Drop for TestServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn it_starts_server_and_shuts_down() {
let server = TestServer::builder().build().await;
server.shutdown().await;
}
#[tokio::test]
async fn it_binds_server_to_free_port() {
let server = TestServer::builder().build().await;
let resp = server.client().get(server.url("/")).send().await.unwrap();
assert_eq!(resp.status(), 404);
}
#[tokio::test]
async fn it_drops_server_gracefully() {
{
let _server = TestServer::builder().build().await;
}
}
}