#![allow(clippy::unwrap_used, reason = "test code")]
use std::time::Duration;
use anyspawn::Spawner;
use bytes::Bytes;
use fetch_hyper::{HyperTransportBuilder, RequestFilter};
use fetch_tls::TlsBackend;
use http::{Method, StatusCode, Version};
use http_extensions::{HttpBodyBuilder, HttpRequestBuilder, Result};
use hyper_util::rt::TokioIo;
use layered::Service as _;
use ohno::ErrorExt;
use templated_uri::BaseUri;
use tick::{Clock, ClockControl};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[derive(Clone)]
struct TokioConnector;
impl layered::Service<BaseUri> for TokioConnector {
type Out = Result<TokioIo<tokio::net::TcpStream>>;
async fn execute(&self, input: BaseUri) -> Self::Out {
let stream = tokio::net::TcpStream::connect((input.authority().host(), input.effective_port().unwrap())).await?;
Ok(TokioIo::new(stream))
}
}
fn build_tls() -> TlsBackend {
native_tls::TlsConnector::new().unwrap().into()
}
fn test_clock() -> Clock {
ClockControl::new().auto_advance_timers(true).to_clock()
}
async fn serve(body: impl Into<Bytes>) -> MockServer {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hello-world"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(body.into().to_vec()))
.mount(&mock_server)
.await;
mock_server
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn real_http_request_succeeds() {
let handler = HyperTransportBuilder::new(
TokioConnector,
Spawner::new_tokio(),
test_clock(),
build_tls(),
HttpBodyBuilder::new_fake(),
)
.connect_timeout(Duration::from_secs(5))
.request_filter(RequestFilter::HttpAndHttps)
.build();
let server = serve(Bytes::from_static(b"Hello World!")).await;
let body_builder = HttpBodyBuilder::new_fake();
let request = HttpRequestBuilder::new(&body_builder)
.method(Method::GET)
.uri(server.uri() + "/hello-world")
.build()
.unwrap();
let response = handler.execute(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn https_only_filter_rejects_http_request() {
let handler = HyperTransportBuilder::new(
TokioConnector,
Spawner::new_tokio(),
test_clock(),
build_tls(),
HttpBodyBuilder::new_fake(),
)
.connect_timeout(Duration::from_secs(5))
.build();
let server = serve(Bytes::from_static(b"Hello World!")).await;
let body_builder = HttpBodyBuilder::new_fake();
let request = HttpRequestBuilder::new(&body_builder)
.method(Method::GET)
.uri(server.uri() + "/hello-world")
.build()
.unwrap();
let error = handler.execute(request).await.unwrap_err();
assert!(error.message().contains("https required but URI was not https"));
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn http2_only_rejected_when_server_negotiates_http1() {
let handler = HyperTransportBuilder::new(
TokioConnector,
Spawner::new_tokio(),
test_clock(),
build_tls(),
HttpBodyBuilder::new_fake(),
)
.connect_timeout(Duration::from_secs(5))
.request_filter(RequestFilter::HttpAndHttps)
.supported_http_versions(&[Version::HTTP_2, Version::HTTP_3])
.build();
let server = serve(Bytes::from_static(b"Hello World!")).await;
let body_builder = HttpBodyBuilder::new_fake();
let request = HttpRequestBuilder::new(&body_builder)
.method(Method::GET)
.uri(server.uri() + "/hello-world")
.build()
.unwrap();
let error = handler.execute(request).await.unwrap_err();
let message = error.message();
let expected_prefix =
"the connection was established with unsupported HTTP version: HTTP/1.1, supported versions are: [HTTP/2.0, HTTP/3.0]";
assert!(
message.contains(expected_prefix),
"expected error message to contain {expected_prefix:?}, got: {message}"
);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn http2_only_with_single_supported_version_uses_prior_knowledge() {
let handler = HyperTransportBuilder::new(
TokioConnector,
Spawner::new_tokio(),
test_clock(),
build_tls(),
HttpBodyBuilder::new_fake(),
)
.connect_timeout(Duration::from_secs(5))
.request_filter(RequestFilter::HttpAndHttps)
.supported_http_versions(&[Version::HTTP_2])
.build();
let server = serve(Bytes::from_static(b"Hello World!")).await;
let body_builder = HttpBodyBuilder::new_fake();
let request = HttpRequestBuilder::new(&body_builder)
.method(Method::GET)
.uri(server.uri() + "/hello-world")
.build()
.unwrap();
let response = handler.execute(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.version(), Version::HTTP_2);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn single_http1_version_does_not_enable_http2_only() {
let handler = HyperTransportBuilder::new(
TokioConnector,
Spawner::new_tokio(),
test_clock(),
build_tls(),
HttpBodyBuilder::new_fake(),
)
.connect_timeout(Duration::from_secs(5))
.request_filter(RequestFilter::HttpAndHttps)
.supported_http_versions(&[Version::HTTP_11])
.build();
let server = serve(Bytes::from_static(b"Hello World!")).await;
let body_builder = HttpBodyBuilder::new_fake();
let request = HttpRequestBuilder::new(&body_builder)
.method(Method::GET)
.uri(server.uri() + "/hello-world")
.build()
.unwrap();
let response = handler.execute(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.version(), Version::HTTP_11);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn zero_lifetime_poisons_connection_after_request() {
use fetch_hyper::ConnectionInfo;
let handler = HyperTransportBuilder::new(
TokioConnector,
Spawner::new_tokio(),
Clock::new_tokio(),
build_tls(),
HttpBodyBuilder::new_fake(),
)
.connect_timeout(Duration::from_secs(5))
.request_filter(RequestFilter::HttpAndHttps)
.connection_lifetime(fetch_hyper::ConnectionLifetime::Fixed(Duration::ZERO))
.build();
let server = serve(Bytes::from_static(b"Hello World!")).await;
let body_builder = HttpBodyBuilder::new_fake();
let request = HttpRequestBuilder::new(&body_builder)
.method(Method::GET)
.uri(server.uri() + "/hello-world")
.build()
.unwrap();
let response = handler.execute(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let info = response.extensions().get::<ConnectionInfo>().unwrap();
assert!(info.poisoned(), "connection should have been poisoned by zero lifetime");
}