use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use bytes::Bytes;
use http::header::HOST;
use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
use http_body_util::{BodyExt, Full};
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use super::cookie::CookieJar;
use super::recorder::LogRecorder;
use super::request::{PendingBody, TestRequestBuilder};
use super::response::TestResponse;
use super::websocket::TestWebSocketBuilder;
use super::TestOverrides;
use crate::app::{App, AppInner, TestApp};
use crate::body::{box_body, BoxError, ReqBody};
use crate::error::{Error, Result};
use crate::state::StateMap;
pub(crate) type StreamingBody =
Pin<Box<dyn http_body::Body<Data = Bytes, Error = BoxError> + Send>>;
type ResourceRegister = Box<dyn FnOnce(&mut StateMap)>;
#[derive(Clone)]
pub(crate) struct TestHeader {
pub(crate) name: HeaderName,
pub(crate) value: HeaderValue,
pub(crate) unsafe_allowed: bool,
}
impl TestHeader {
pub(crate) fn safe(name: HeaderName, value: HeaderValue) -> Self {
Self {
name,
value,
unsafe_allowed: false,
}
}
pub(crate) fn unsafe_allowed(name: HeaderName, value: HeaderValue) -> Self {
Self {
name,
value,
unsafe_allowed: true,
}
}
}
const SENSITIVE_TEST_HEADERS: [&str; 5] = [
"host",
"forwarded",
"x-forwarded-for",
"x-forwarded-host",
"x-forwarded-proto",
];
pub(crate) enum Transport {
InProcess(Arc<AppInner>),
RealPort(SocketAddr),
}
impl Transport {
pub(crate) fn address(&self) -> Option<SocketAddr> {
match self {
Transport::InProcess(_) => None,
Transport::RealPort(addr) => Some(*addr),
}
}
pub(crate) async fn execute(
&self,
request: http::Request<ReqBody>,
) -> Result<(StatusCode, HeaderMap, Bytes)> {
match self {
Transport::InProcess(app) => {
let response = app.clone().handle(request).await;
let (parts, body) = response.into_parts();
let bytes = collect_body(body).await?;
Ok((parts.status, parts.headers, bytes))
}
Transport::RealPort(addr) => {
let response = send_over_socket(*addr, request).await?;
let (parts, body) = response.into_parts();
let bytes = collect_body(body).await?;
Ok((parts.status, parts.headers, bytes))
}
}
}
pub(crate) async fn execute_streaming(
&self,
request: http::Request<ReqBody>,
) -> Result<(StatusCode, HeaderMap, StreamingBody)> {
match self {
Transport::InProcess(app) => {
let response = app.clone().handle(request).await;
let (parts, body) = response.into_parts();
Ok((parts.status, parts.headers, Box::pin(body)))
}
Transport::RealPort(addr) => {
let response = send_over_socket(*addr, request).await?;
let (parts, body) = response.into_parts();
let boxed: StreamingBody =
Box::pin(body.map_err(|error| Box::new(error) as BoxError));
Ok((parts.status, parts.headers, boxed))
}
}
}
}
async fn collect_body<B>(body: B) -> Result<Bytes>
where
B: http_body::Body<Data = Bytes>,
B::Error: std::fmt::Display,
{
let collected = body
.collect()
.await
.map_err(|error| Error::internal(format!("failed to read response body: {error}")))?;
Ok(collected.to_bytes())
}
async fn send_over_socket(
addr: SocketAddr,
mut request: http::Request<ReqBody>,
) -> Result<http::Response<hyper::body::Incoming>> {
if !request.headers().contains_key(HOST) {
if let Ok(value) = HeaderValue::from_str(&addr.to_string()) {
request.headers_mut().insert(HOST, value);
}
}
let stream = TcpStream::connect(addr)
.await
.map_err(|error| Error::internal(format!("failed to connect to {addr}: {error}")))?;
let io = TokioIo::new(stream);
let (mut sender, connection) = hyper::client::conn::http1::handshake(io)
.await
.map_err(|error| Error::internal(format!("client handshake failed: {error}")))?;
tokio::spawn(async move {
let _ = connection.await;
});
sender
.send_request(request)
.await
.map_err(|error| Error::internal(format!("request failed: {error}")))
}
pub(crate) struct Shared {
pub(crate) transport: Transport,
pub(crate) default_headers: HeaderMap,
pub(crate) unsafe_default_headers: HeaderMap,
pub(crate) cookies: Mutex<CookieJar>,
}
impl Shared {
pub(crate) async fn send(
&self,
method: Method,
path: String,
query: Vec<(String, String)>,
headers: Vec<TestHeader>,
body: PendingBody,
) -> Result<TestResponse> {
let request = self.build_request(method, &path, &query, headers, body)?;
let (status, headers, bytes) = self.transport.execute(request).await?;
self.cookies
.lock()
.expect("cookie jar mutex poisoned")
.store(&headers);
Ok(TestResponse {
status,
headers,
body: bytes,
})
}
pub(crate) async fn open_sse(
&self,
method: Method,
path: String,
query: Vec<(String, String)>,
headers: Vec<TestHeader>,
) -> Result<super::sse::TestSseStream> {
let request = self.build_request(method, &path, &query, headers, PendingBody::default())?;
let (_status, headers, body) = self.transport.execute_streaming(request).await?;
self.cookies
.lock()
.expect("cookie jar mutex poisoned")
.store(&headers);
Ok(super::sse::TestSseStream::new(body))
}
pub(crate) fn build_request(
&self,
method: Method,
path: &str,
query: &[(String, String)],
headers: Vec<TestHeader>,
body: PendingBody,
) -> Result<http::Request<ReqBody>> {
let uri = if query.is_empty() {
path.to_owned()
} else {
let encoded = serde_urlencoded::to_string(query)
.map_err(|_| Error::internal("failed to encode query parameters"))?;
format!("{path}?{encoded}")
};
let mut request = http::Request::new(box_body(Full::new(body.bytes)));
*request.method_mut() = method;
*request.uri_mut() = uri
.parse()
.map_err(|_| Error::bad_request(format!("invalid request URI: {uri}")))?;
self.reject_in_process_sensitive_headers(&headers)?;
let map = request.headers_mut();
for (name, value) in self.default_headers.iter() {
map.insert(name, value.clone());
}
for (name, value) in self.unsafe_default_headers.iter() {
map.insert(name, value.clone());
}
for header in headers {
map.insert(header.name, header.value);
}
self.cookies
.lock()
.expect("cookie jar mutex poisoned")
.apply(map);
if let Some(content_type) = body.content_type {
map.insert(super::request::CONTENT_TYPE_HEADER, content_type);
}
Ok(request)
}
pub(crate) fn reject_in_process_sensitive_headers(&self, headers: &[TestHeader]) -> Result<()> {
if !matches!(self.transport, Transport::InProcess(_)) {
return Ok(());
}
let mut blocked = Vec::new();
for header in headers {
if !header.unsafe_allowed && is_sensitive_test_header(&header.name) {
blocked.push(header.name.as_str().to_owned());
}
}
if blocked.is_empty() {
Ok(())
} else {
Err(sensitive_header_error(&blocked))
}
}
}
fn is_sensitive_test_header(name: &HeaderName) -> bool {
SENSITIVE_TEST_HEADERS
.iter()
.any(|candidate| *candidate == name.as_str())
}
fn sensitive_header_error(headers: &[String]) -> Error {
Error::bad_request(format!(
"in-process test clients reject security-sensitive header(s): {}; use unsafe_header/unsafe_default_header or TestClient::serve(...).bind_random_port()",
headers.join(", ")
))
.with_code("TEST_UNSAFE_HEADER_REQUIRES_OPT_IN")
}
pub struct TestClient {
shared: Arc<Shared>,
teardown: Teardown,
_log_guard: Option<tracing::subscriber::DefaultGuard>,
}
enum Teardown {
InProcess(Box<TestApp>),
RealPort {
shutdown: Option<oneshot::Sender<()>>,
handle: JoinHandle<()>,
},
}
impl TestClient {
pub async fn new(app: TestApp) -> Result<Self> {
Ok(Self {
shared: Arc::new(Shared {
transport: Transport::InProcess(app.inner.clone()),
default_headers: HeaderMap::new(),
unsafe_default_headers: HeaderMap::new(),
cookies: Mutex::new(CookieJar::default()),
}),
teardown: Teardown::InProcess(Box::new(app)),
_log_guard: None,
})
}
pub fn builder(app: App) -> TestClientBuilder {
TestClientBuilder::new(app)
}
pub fn serve(app: App) -> ServeBuilder {
ServeBuilder { app }
}
pub fn local_addr(&self) -> Option<SocketAddr> {
self.shared.transport.address()
}
pub fn websocket(&self, path: &str) -> TestWebSocketBuilder {
TestWebSocketBuilder::new(self.shared.clone(), path)
}
pub fn get(&self, path: &str) -> TestRequestBuilder {
TestRequestBuilder::new(self.shared.clone(), Method::GET, path)
}
pub fn post(&self, path: &str) -> TestRequestBuilder {
TestRequestBuilder::new(self.shared.clone(), Method::POST, path)
}
pub fn put(&self, path: &str) -> TestRequestBuilder {
TestRequestBuilder::new(self.shared.clone(), Method::PUT, path)
}
pub fn patch(&self, path: &str) -> TestRequestBuilder {
TestRequestBuilder::new(self.shared.clone(), Method::PATCH, path)
}
pub fn delete(&self, path: &str) -> TestRequestBuilder {
TestRequestBuilder::new(self.shared.clone(), Method::DELETE, path)
}
pub async fn shutdown(self) -> Result<()> {
match self.teardown {
Teardown::InProcess(app) => app.shutdown().await,
Teardown::RealPort { shutdown, handle } => {
if let Some(sender) = shutdown {
let _ = sender.send(());
}
let _ = handle.await;
Ok(())
}
}
}
}
pub struct ServeBuilder {
app: App,
}
impl ServeBuilder {
pub async fn bind_random_port(self) -> Result<TestClient> {
let (addr_tx, addr_rx) = oneshot::channel::<Result<SocketAddr>>();
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let sender = Arc::new(Mutex::new(Some(addr_tx)));
let ready_sender = sender.clone();
let app = self.app.on_ready(move |ctx| {
let sender = ready_sender.clone();
async move {
if let Some(tx) = sender.lock().expect("address sender mutex poisoned").take() {
let _ = tx.send(Ok(ctx.addr()));
}
Ok(())
}
});
let sender = sender.clone();
let handle = tokio::spawn(async move {
let result = app
.serve_with_shutdown("127.0.0.1:0", async move {
let _ = shutdown_rx.await;
})
.await;
if let (Err(error), Some(tx)) = (
result,
sender.lock().expect("address sender mutex poisoned").take(),
) {
let _ = tx.send(Err(error));
}
});
let addr = addr_rx
.await
.map_err(|_| Error::internal("the test server failed to start"))??;
Ok(TestClient {
shared: Arc::new(Shared {
transport: Transport::RealPort(addr),
default_headers: HeaderMap::new(),
unsafe_default_headers: HeaderMap::new(),
cookies: Mutex::new(CookieJar::default()),
}),
teardown: Teardown::RealPort {
shutdown: Some(shutdown_tx),
handle,
},
_log_guard: None,
})
}
}
pub struct TestClientBuilder {
app: App,
resources: Vec<ResourceRegister>,
overrides: TestOverrides,
default_headers: HeaderMap,
unsafe_default_headers: HeaderMap,
blocked_sensitive_headers: Vec<String>,
cookies: CookieJar,
recorder: Option<LogRecorder>,
}
impl TestClientBuilder {
fn new(app: App) -> Self {
Self {
app,
resources: Vec::new(),
overrides: TestOverrides::default(),
default_headers: HeaderMap::new(),
unsafe_default_headers: HeaderMap::new(),
blocked_sensitive_headers: Vec::new(),
cookies: CookieJar::default(),
recorder: None,
}
}
pub fn logger(mut self, recorder: LogRecorder) -> Self {
self.recorder = Some(recorder);
self
}
pub fn resource<S: Send + Sync + 'static>(mut self, value: S) -> Self {
self.resources
.push(Box::new(move |state| state.insert(value)));
self
}
pub fn override_dependency<T: Clone + Send + Sync + 'static>(mut self, value: T) -> Self {
self.overrides.insert::<T, _>(move || value.clone());
self
}
pub fn override_dependency_with<T, F>(mut self, factory: F) -> Self
where
T: Send + 'static,
F: Fn() -> T + Send + Sync + 'static,
{
self.overrides.insert::<T, F>(factory);
self
}
pub fn default_header(mut self, name: &str, value: &str) -> Self {
if let (Ok(name), Ok(value)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
if is_sensitive_test_header(&name) {
self.blocked_sensitive_headers
.push(name.as_str().to_owned());
} else {
self.default_headers.insert(name, value);
}
}
self
}
pub fn unsafe_default_header(mut self, name: &str, value: &str) -> Self {
if let (Ok(name), Ok(value)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
self.unsafe_default_headers.insert(name, value);
}
self
}
pub fn cookie(mut self, name: &str, value: &str) -> Self {
self.cookies.set(name, value);
self
}
pub async fn build(self) -> Result<TestClient> {
let resources = self.resources;
let overrides = self.overrides;
let default_headers = self.default_headers;
let unsafe_default_headers = self.unsafe_default_headers;
let blocked_sensitive_headers = self.blocked_sensitive_headers;
let cookies = self.cookies;
let recorder = self.recorder;
if !blocked_sensitive_headers.is_empty() {
return Err(sensitive_header_error(&blocked_sensitive_headers));
}
let app = self
.app
.build_test_with(move |state| {
for register in resources {
register(state);
}
if !overrides.is_empty() {
state.insert(overrides);
}
})
.await?;
let log_guard = recorder.map(|recorder| {
use tracing_subscriber::layer::SubscriberExt;
let subscriber = tracing_subscriber::registry().with(recorder);
tracing::subscriber::set_default(subscriber)
});
Ok(TestClient {
shared: Arc::new(Shared {
transport: Transport::InProcess(app.inner.clone()),
default_headers,
unsafe_default_headers,
cookies: Mutex::new(cookies),
}),
teardown: Teardown::InProcess(Box::new(app)),
_log_guard: log_guard,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::app::App;
use crate::body::{BoxError, RespBody};
use crate::response::Response as TorkResponse;
use crate::router::{BoxFuture, HandlerFn, Route, Router};
use bytes::Bytes;
use futures_util::stream;
use http::header::{CONTENT_TYPE, COOKIE};
use http_body::Frame;
use http_body_util::{BodyExt, StreamBody};
use std::sync::Arc;
fn json_handler() -> HandlerFn {
Arc::new(|_ctx: crate::extract::RequestContext| -> BoxFuture<'static, crate::Result<TorkResponse>> {
Box::pin(async { Ok(crate::json_response(crate::StatusCode::OK, &serde_json::json!({ "ok": true }))) })
})
}
fn stream_handler() -> HandlerFn {
Arc::new(|_ctx: crate::extract::RequestContext| -> BoxFuture<'static, crate::Result<TorkResponse>> {
Box::pin(async {
let frames = stream::iter(vec![
Ok::<_, BoxError>(Frame::data(Bytes::from_static(b"one"))),
Ok(Frame::data(Bytes::from_static(b"two"))),
]);
let body = RespBody::stream(StreamBody::new(frames));
let mut response = TorkResponse::new(body);
*response.status_mut() = crate::StatusCode::OK;
response.headers_mut().insert(
CONTENT_TYPE,
http::HeaderValue::from_static("text/event-stream"),
);
Ok(response)
})
})
}
fn shared() -> Shared {
let mut default_headers = HeaderMap::new();
default_headers.insert("x-default", HeaderValue::from_static("on"));
let mut cookies = CookieJar::default();
cookies.set("sid", "abc");
Shared {
transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
default_headers,
unsafe_default_headers: HeaderMap::new(),
cookies: Mutex::new(cookies),
}
}
#[test]
fn build_request_merges_defaults_headers_cookies_and_content_type() {
let request = shared()
.build_request(
Method::POST,
"/items",
&[("q".to_owned(), "hello world".to_owned())],
vec![TestHeader::safe(
HeaderName::from_static("x-custom"),
HeaderValue::from_static("yes"),
)],
PendingBody {
content_type: Some(HeaderValue::from_static("application/json")),
bytes: Bytes::from_static(b"{}"),
},
)
.unwrap();
assert_eq!(request.uri(), "/items?q=hello+world");
assert_eq!(request.headers()["x-default"], "on");
assert_eq!(request.headers()["x-custom"], "yes");
assert_eq!(request.headers()[COOKIE], "sid=abc");
assert_eq!(request.headers()[CONTENT_TYPE], "application/json");
}
#[test]
fn build_request_rejects_invalid_uri() {
let error = shared()
.build_request(
Method::GET,
"http://[",
&[],
Vec::new(),
PendingBody::default(),
)
.unwrap_err();
assert_eq!(error.kind(), crate::error::ErrorKind::BadRequest);
assert!(error.message().starts_with("invalid request URI:"));
}
#[test]
fn build_request_rejects_sensitive_headers_in_process_without_opt_in() {
let error = shared()
.build_request(
Method::GET,
"/items",
&[],
vec![TestHeader::safe(
HeaderName::from_static("host"),
HeaderValue::from_static("example.com"),
)],
PendingBody::default(),
)
.unwrap_err();
assert_eq!(error.code(), "TEST_UNSAFE_HEADER_REQUIRES_OPT_IN");
assert!(error.message().contains("host"));
}
#[test]
fn build_request_allows_sensitive_headers_with_opt_in() {
let request = shared()
.build_request(
Method::GET,
"/items",
&[],
vec![TestHeader::unsafe_allowed(
HeaderName::from_static("host"),
HeaderValue::from_static("example.com"),
)],
PendingBody::default(),
)
.unwrap();
assert_eq!(request.headers()["host"], "example.com");
}
#[tokio::test]
async fn real_port_transport_exercises_execute_and_execute_streaming() {
let app = App::new().include_router(
Router::new()
.route(Route::new(Method::GET, "/json", json_handler()))
.route(Route::new(Method::GET, "/stream", stream_handler())),
);
let client = TestClient::serve(app).bind_random_port().await.unwrap();
assert!(client.local_addr().is_some());
assert!(client.shared.transport.address().is_some());
let request = client
.shared
.build_request(
Method::GET,
"/json",
&[],
Vec::new(),
PendingBody::default(),
)
.unwrap();
let (status, headers, bytes) = client.shared.transport.execute(request).await.unwrap();
assert_eq!(status, StatusCode::OK);
assert_eq!(headers[CONTENT_TYPE], "application/json");
assert!(bytes.contains(&b'o'));
let request = client
.shared
.build_request(
Method::GET,
"/stream",
&[],
Vec::new(),
PendingBody::default(),
)
.unwrap();
let (status, headers, mut body) = client
.shared
.transport
.execute_streaming(request)
.await
.unwrap();
assert_eq!(status, StatusCode::OK);
assert_eq!(headers[CONTENT_TYPE], "text/event-stream");
let mut saw_data = false;
while let Some(frame) = body.frame().await {
let frame = frame.unwrap();
if frame.into_data().is_ok() {
saw_data = true;
}
}
assert!(saw_data);
client.shutdown().await.unwrap();
}
}