use crate::handler::NewHandler;
use crate::helpers::http::Body;
use bytes::Bytes;
use http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use http::{request, Method, Request, Response, Uri, Version};
use http_body::Body as HttpBody;
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::BodyExt as _;
use hyper::body::Incoming;
use hyper_util::client::legacy::connect::Connect;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use mime::Mime;
use std::any::Any;
use std::convert::TryInto;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
pub(crate) struct AsyncTestServerInner {
addr: SocketAddr,
timeout: Duration,
handle: tokio::task::JoinHandle<()>,
}
impl AsyncTestServerInner {
pub(crate) async fn new<NH, F, Wrapped, Wrap>(
new_handler: NH,
timeout: Duration,
wrap: Wrap,
) -> anyhow::Result<Self>
where
NH: NewHandler + 'static,
F: Future<Output = Result<Wrapped, ()>> + Unpin + Send + 'static,
Wrapped: Unpin + AsyncRead + AsyncWrite + Send + 'static,
Wrap: Fn(TcpStream) -> F + Send + 'static,
{
let listener = TcpListener::bind("127.0.0.1:0".parse::<SocketAddr>()?).await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn(async {
crate::bind_server(listener, new_handler, wrap).await;
});
Ok(AsyncTestServerInner {
addr,
timeout,
handle,
})
}
pub(crate) fn client<TestC>(self: &Arc<Self>) -> AsyncTestClient<TestC>
where
TestC: From<SocketAddr> + Connect + Clone + Send + Sync + 'static,
{
let test_connect = TestC::from(self.addr);
let client = Client::builder(TokioExecutor::new()).build(test_connect);
AsyncTestClient::new(client, self.timeout, self.clone())
}
}
impl Drop for AsyncTestServerInner {
fn drop(&mut self) {
self.handle.abort();
}
}
pub struct AsyncTestClient<C: Connect> {
client: Client<C, Body>,
timeout: Duration,
_test_server: Arc<AsyncTestServerInner>,
}
impl<C: Connect + Clone + Send + Sync + 'static> AsyncTestClient<C> {
pub(crate) fn new(
client: Client<C, Body>,
timeout: Duration,
test_server: Arc<AsyncTestServerInner>,
) -> Self {
Self {
client,
timeout,
_test_server: test_server,
}
}
pub async fn request(&self, request: Request<Body>) -> anyhow::Result<AsyncTestResponse> {
let request_future = self.client.request(request);
Ok(timeout(self.timeout, request_future).await??.into())
}
pub fn head(
&self,
uri: impl TryInto<Uri, Error: Into<http::Error>>,
) -> AsyncTestRequestBuilder<'_, C> {
self.request_builder_with_method_and_uri(Method::HEAD, uri)
}
pub fn get(
&self,
uri: impl TryInto<Uri, Error: Into<http::Error>>,
) -> AsyncTestRequestBuilder<'_, C> {
self.request_builder_with_method_and_uri(Method::GET, uri)
}
pub fn options(
&self,
uri: impl TryInto<Uri, Error: Into<http::Error>>,
) -> AsyncTestRequestBuilder<'_, C> {
self.request_builder_with_method_and_uri(Method::OPTIONS, uri)
}
pub fn post(
&self,
uri: impl TryInto<Uri, Error: Into<http::Error>>,
) -> AsyncTestRequestBuilder<'_, C> {
self.request_builder_with_method_and_uri(Method::POST, uri)
}
pub fn put(
&self,
uri: impl TryInto<Uri, Error: Into<http::Error>>,
) -> AsyncTestRequestBuilder<'_, C> {
self.request_builder_with_method_and_uri(Method::PUT, uri)
}
pub fn patch(
&self,
uri: impl TryInto<Uri, Error: Into<http::Error>>,
) -> AsyncTestRequestBuilder<'_, C> {
self.request_builder_with_method_and_uri(Method::PATCH, uri)
}
pub fn delete(
&self,
uri: impl TryInto<Uri, Error: Into<http::Error>>,
) -> AsyncTestRequestBuilder<'_, C> {
self.request_builder_with_method_and_uri(Method::DELETE, uri)
}
pub fn build_request(&self) -> AsyncTestRequestBuilder<'_, C> {
AsyncTestRequestBuilder {
test_client: self,
request_builder: request::Builder::new(),
body: None,
}
}
fn request_builder_with_method_and_uri(
&self,
method: Method,
uri: impl TryInto<Uri, Error: Into<http::Error>>,
) -> AsyncTestRequestBuilder<'_, C> {
let request_builder = request::Builder::new().uri(uri).method(method);
AsyncTestRequestBuilder {
test_client: self,
request_builder,
body: None,
}
}
}
impl<C: Connect> From<AsyncTestClient<C>> for Client<C, Body> {
fn from(test_client: AsyncTestClient<C>) -> Self {
test_client.client
}
}
pub struct AsyncTestRequestBuilder<'client, C: Connect> {
test_client: &'client AsyncTestClient<C>,
request_builder: request::Builder,
body: Option<Body>,
}
impl<'client, C: Connect + Clone + Send + Sync + 'static> AsyncTestRequestBuilder<'client, C> {
pub async fn perform(self) -> anyhow::Result<AsyncTestResponse> {
let Self {
test_client,
request_builder,
body,
} = self;
let request = request_builder.body(body.unwrap_or_default())?;
test_client.request(request).await
}
pub fn mime(self, mime: Mime) -> Self {
self.header(
CONTENT_TYPE,
mime.to_string().parse::<HeaderValue>().unwrap(),
)
}
pub fn body(
mut self,
body: impl HttpBody<Data = Bytes, Error = io::Error> + Send + 'static,
) -> Self {
self.body.replace(UnsyncBoxBody::new(body));
self
}
pub fn extension(self, extension: impl Clone + Any + Send + Sync + 'static) -> Self {
self.replace_request_builder(|builder| builder.extension(extension))
}
pub fn header(
self,
key: impl TryInto<HeaderName, Error: Into<http::Error>>,
value: impl TryInto<HeaderValue, Error: Into<http::Error>>,
) -> Self {
self.replace_request_builder(|builder| builder.header(key, value))
}
pub fn method(self, method: impl TryInto<Method, Error: Into<http::Error>>) -> Self {
self.replace_request_builder(|builder| builder.method(method))
}
pub fn uri(self, uri: impl TryInto<Uri, Error: Into<http::Error>>) -> Self {
self.replace_request_builder(|builder| builder.uri(uri))
}
pub fn version(self, version: Version) -> Self {
self.replace_request_builder(|builder| builder.version(version))
}
fn replace_request_builder(
mut self,
replacer: impl FnOnce(request::Builder) -> request::Builder,
) -> Self {
self.request_builder = replacer(self.request_builder);
self
}
}
impl<'client, C: Connect> Deref for AsyncTestRequestBuilder<'client, C> {
type Target = request::Builder;
fn deref(&self) -> &Self::Target {
&self.request_builder
}
}
impl<'client, C: Connect> DerefMut for AsyncTestRequestBuilder<'client, C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.request_builder
}
}
pub struct AsyncTestResponse {
response: Response<Incoming>,
}
impl AsyncTestResponse {
pub async fn read_body(self) -> anyhow::Result<Vec<u8>> {
let bytes = self.response.into_body().collect().await?.to_bytes();
Ok(bytes.to_vec())
}
pub async fn read_utf8_body(self) -> anyhow::Result<String> {
let bytes = self.read_body().await?;
Ok(String::from_utf8(bytes)?)
}
}
impl From<Response<Incoming>> for AsyncTestResponse {
fn from(response: Response<Incoming>) -> Self {
Self { response }
}
}
impl From<AsyncTestResponse> for Response<Incoming> {
fn from(test_response: AsyncTestResponse) -> Self {
test_response.response
}
}
impl Deref for AsyncTestResponse {
type Target = Response<Incoming>;
fn deref(&self) -> &Self::Target {
&self.response
}
}
impl DerefMut for AsyncTestResponse {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.response
}
}
impl Debug for AsyncTestResponse {
fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
formatter.write_str("AsyncTestResponse")
}
}
#[cfg(test)]
pub(crate) mod common_tests {
use super::*;
use crate::handler::IntoBody;
use crate::test::helper::TestHandler;
use http::StatusCode;
pub(crate) async fn serves_requests<TS, F, C>(
server_factory: fn(TestHandler) -> F,
client_factory: fn(&TS) -> AsyncTestClient<C>,
) where
F: Future<Output = anyhow::Result<TS>>,
C: Connect + Clone + Send + Sync + 'static,
{
let test_server = server_factory(TestHandler::from("response")).await.unwrap();
let response = client_factory(&test_server)
.get("http://localhost/")
.perform()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.read_utf8_body().await.unwrap(), "response");
}
pub(crate) async fn times_out<TS, F, C>(
server_factory: fn(TestHandler, Duration) -> F,
client_factory: fn(&TS) -> AsyncTestClient<C>,
) where
F: Future<Output = anyhow::Result<TS>>,
C: Connect + Clone + Send + Sync + 'static,
{
let timeout = Duration::from_secs(10);
let test_server = server_factory(TestHandler::default(), timeout)
.await
.unwrap();
let client = client_factory(&test_server);
tokio::time::pause();
let request_handle =
tokio::spawn(async move { client.get("http://localhost/timeout").perform().await });
tokio::time::sleep(timeout).await;
tokio::time::resume();
let request_result = request_handle.await.unwrap();
assert!(request_result
.unwrap_err()
.is::<tokio::time::error::Elapsed>());
}
pub(crate) async fn echo<TS, F, C>(
server_factory: fn(TestHandler) -> F,
client_factory: fn(&TS) -> AsyncTestClient<C>,
) where
F: Future<Output = anyhow::Result<TS>>,
C: Connect + Clone + Send + Sync + 'static,
{
let server = server_factory(TestHandler::default()).await.unwrap();
let data = "This text should get reflected back to us. Even this fancy piece of unicode: \
\u{3044}\u{308d}\u{306f}\u{306b}\u{307b}";
let response = client_factory(&server)
.post("http://localhost/echo")
.body(data.into_body())
.perform()
.await
.unwrap();
let response_text = response.read_utf8_body().await.unwrap();
assert_eq!(response_text, data);
}
pub(crate) async fn supports_multiple_servers<TS, F, C>(
server_factory: fn(TestHandler) -> F,
client_factory: fn(&TS) -> AsyncTestClient<C>,
) where
F: Future<Output = anyhow::Result<TS>>,
C: Connect + Clone + Send + Sync + 'static,
{
let server_a = server_factory(TestHandler::from("A")).await.unwrap();
let server_b = server_factory(TestHandler::from("B")).await.unwrap();
let client_a = client_factory(&server_a);
let client_b = client_factory(&server_b);
let response_a = client_a
.get("http://localhost/")
.perform()
.await
.unwrap()
.read_utf8_body()
.await
.unwrap();
let response_b = client_b
.get("http://localhost/")
.perform()
.await
.unwrap()
.read_utf8_body()
.await
.unwrap();
assert_eq!(response_a, "A");
assert_eq!(response_b, "B");
}
pub(crate) async fn adds_client_address_to_state<TS, F, C>(
server_factory: fn(TestHandler) -> F,
client_factory: fn(&TS) -> AsyncTestClient<C>,
) where
F: Future<Output = anyhow::Result<TS>>,
C: Connect + Clone + Send + Sync + 'static,
{
let server = server_factory(TestHandler::default()).await.unwrap();
let client = client_factory(&server);
let client_address = client
.get("http://localhost/myaddr")
.perform()
.await
.unwrap()
.read_utf8_body()
.await
.unwrap();
assert!(client_address.starts_with("127.0.0.1"));
}
}