use std::fmt;
use std::marker::PhantomData;
use std::time::Duration;
use anyspawn::Spawner;
use http::Version;
use http_extensions::{HttpBodyBuilder, HttpRequest, HttpResponse, Result};
use hyper_util::client::legacy;
use layered::{DynamicService, DynamicServiceExt, Service};
use opentelemetry::metrics::Meter;
use tick::Clock;
use crate::HyperIo;
use crate::connection::Connect;
use crate::connection::hyper_handler::build_hyper_handler;
use crate::options::{ConnectionLifetime, RequestFilter};
use crate::tls::TlsBackend;
#[derive(Clone, Debug)]
pub struct HyperTransport {
service: DynamicService<HttpRequest, Result<HttpResponse>>,
}
impl From<HyperTransport> for DynamicService<HttpRequest, Result<HttpResponse>> {
fn from(transport: HyperTransport) -> Self {
transport.service
}
}
impl HyperTransport {
pub(crate) fn new(service: DynamicService<HttpRequest, Result<HttpResponse>>) -> Self {
Self { service }
}
}
impl Service<HttpRequest> for HyperTransport {
type Out = Result<HttpResponse>;
fn execute(&self, input: HttpRequest) -> impl Future<Output = Self::Out> + Send {
self.service.execute(input)
}
}
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub(crate) struct SpawnerExecutor(pub(crate) Spawner);
impl<F> hyper::rt::Executor<F> for SpawnerExecutor
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
fn execute(&self, fut: F) {
drop(self.0.spawn(fut));
}
}
pub struct HyperTransportBuilder<C, S>
where
C: Connect<S>,
S: HyperIo,
{
pub(crate) connector: C,
pub(crate) clock: Clock,
pub(crate) tls: TlsBackend,
pub(crate) body_builder: HttpBodyBuilder,
pub(crate) request_filter: RequestFilter,
pub(crate) supported_http_versions: Vec<Version>,
pub(crate) connection_lifetime: ConnectionLifetime,
pub(crate) connect_timeout: Duration,
pub(crate) pool_index: usize,
pub(crate) meter: Option<Meter>,
pub(crate) hyper_builder: legacy::Builder,
pub(crate) _marker: PhantomData<fn() -> S>,
}
impl<C, S> fmt::Debug for HyperTransportBuilder<C, S>
where
C: Connect<S>,
S: HyperIo,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(std::any::type_name::<Self>())
.field("request_filter", &self.request_filter)
.field("supported_http_versions", &self.supported_http_versions)
.field("connect_timeout", &self.connect_timeout)
.field("connection_lifetime", &self.connection_lifetime)
.field("pool_index", &self.pool_index)
.finish_non_exhaustive()
}
}
impl<C, S> HyperTransportBuilder<C, S>
where
C: Connect<S>,
S: HyperIo,
{
#[must_use]
pub fn new(connector: C, spawner: Spawner, clock: Clock, tls: impl Into<TlsBackend>, body_builder: HttpBodyBuilder) -> Self {
let timer = crate::timer::ClockTimer::new(clock.clone());
let mut hyper_builder = legacy::Client::builder(SpawnerExecutor(spawner));
hyper_builder.timer(timer.clone()).pool_timer(timer);
Self {
connector,
clock,
body_builder,
request_filter: RequestFilter::default(),
supported_http_versions: vec![Version::HTTP_11, Version::HTTP_2],
connection_lifetime: ConnectionLifetime::default(),
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
pool_index: 0,
meter: None,
hyper_builder,
_marker: PhantomData,
tls: tls.into(),
}
}
#[must_use]
pub fn request_filter(mut self, filter: RequestFilter) -> Self {
self.request_filter = filter;
self
}
#[must_use]
pub fn supported_http_versions(mut self, versions: &[Version]) -> Self {
self.supported_http_versions = versions.to_vec();
self
}
#[must_use]
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
#[must_use]
pub fn connection_lifetime(mut self, lifetime: ConnectionLifetime) -> Self {
self.connection_lifetime = lifetime;
self
}
#[must_use]
pub fn pool_index(mut self, pool_index: usize) -> Self {
self.pool_index = pool_index;
self
}
#[must_use]
pub fn meter(mut self, meter: Meter) -> Self {
self.meter = Some(meter);
self
}
#[must_use]
pub fn configure_hyper<F>(mut self, configure: F) -> Self
where
F: FnOnce(&mut legacy::Builder),
{
configure(&mut self.hyper_builder);
self
}
#[must_use]
pub fn build(self) -> HyperTransport {
let meter = self.meter.clone().unwrap_or_else(|| opentelemetry::global::meter("fetch_hyper"));
HyperTransport::new(build_hyper_handler(self, &meter).into_dynamic())
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use bytes::Bytes;
use opentelemetry::metrics::MeterProvider;
use opentelemetry_sdk::metrics::SdkMeterProvider;
use super::*;
use crate::testing::FakeConnector;
fn tls() -> TlsBackend {
native_tls::TlsConnector::new().unwrap().into()
}
fn make_builder() -> HyperTransportBuilder<FakeConnector, crate::testing::FakeStream> {
HyperTransportBuilder::new(
FakeConnector::new_success(Bytes::new(), tick::ClockControl::new().auto_advance_timers(true).to_clock()),
Spawner::new_tokio(),
tick::ClockControl::new().auto_advance_timers(true).to_clock(),
tls(),
HttpBodyBuilder::new_fake(),
)
}
#[test]
#[cfg_attr(miri, ignore)]
fn builder_defaults_and_setters() {
let defaults = make_builder();
assert!(defaults.meter.is_none(), "meter is not part of Debug output");
insta::assert_debug_snapshot!("defaults", defaults);
let configured = make_builder()
.request_filter(RequestFilter::HttpAndHttps)
.supported_http_versions(&[Version::HTTP_2])
.connect_timeout(Duration::from_secs(7))
.connection_lifetime(ConnectionLifetime::Fixed(Duration::from_secs(60)))
.pool_index(42);
insta::assert_debug_snapshot!("configured", configured);
}
#[test]
#[cfg_attr(miri, ignore)]
fn meter_setter_stores_meter() {
let provider = SdkMeterProvider::builder().build();
let m = provider.meter("test");
let b = make_builder().meter(m);
assert!(b.meter.is_some());
}
#[test]
#[cfg_attr(miri, ignore)]
fn configure_hyper_runs_callback_synchronously() {
let mut called = false;
let _b = make_builder().configure_hyper(|_| {
called = true;
});
assert!(called);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn build_with_explicit_meter_yields_working_transport() {
let provider = SdkMeterProvider::builder().build();
let response_bytes = Bytes::from_static(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n");
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let handler = HyperTransportBuilder::new(
FakeConnector::new_success(response_bytes, clock.clone()),
Spawner::new_tokio(),
clock,
tls(),
HttpBodyBuilder::new_fake(),
)
.request_filter(RequestFilter::HttpAndHttps)
.meter(provider.meter("test"))
.build();
let resp = handler.execute(crate::testing::create_test_request()).await.unwrap();
assert_eq!(resp.status(), 200);
}
#[test]
#[cfg_attr(miri, ignore)]
fn build_with_h2_only_sets_http2_only_flag() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let _handler = HyperTransportBuilder::new(
FakeConnector::new_success(Bytes::new(), clock.clone()),
Spawner::new_tokio(),
clock,
tls(),
HttpBodyBuilder::new_fake(),
)
.supported_http_versions(&[Version::HTTP_2])
.build();
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn hyper_transport_clones_share_underlying_service() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let response_bytes = Bytes::from_static(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n");
let handler = HyperTransportBuilder::new(
FakeConnector::new_success(response_bytes, clock.clone()),
Spawner::new_tokio(),
clock,
tls(),
HttpBodyBuilder::new_fake(),
)
.request_filter(RequestFilter::HttpAndHttps)
.build();
let cloned = handler.clone();
let _ = format!("{cloned:?}");
let resp = cloned.execute(crate::testing::create_test_request()).await.unwrap();
assert_eq!(resp.status(), 200);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn hyper_transport_into_dynamic_service_executes_request() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let response_bytes = Bytes::from_static(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n");
let handler = HyperTransportBuilder::new(
FakeConnector::new_success(response_bytes, clock.clone()),
Spawner::new_tokio(),
clock,
tls(),
HttpBodyBuilder::new_fake(),
)
.request_filter(RequestFilter::HttpAndHttps)
.build();
let service: DynamicService<HttpRequest, Result<HttpResponse>> = handler.into();
let resp = service.execute(crate::testing::create_test_request()).await.unwrap();
assert_eq!(resp.status(), 200);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn spawner_executor_runs_future() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let executor = SpawnerExecutor(Spawner::new_tokio());
let fired = Arc::new(AtomicBool::new(false));
let fired_clone = Arc::clone(&fired);
hyper::rt::Executor::execute(&executor, async move {
fired_clone.store(true, Ordering::SeqCst);
});
for _ in 0..50 {
if fired.load(Ordering::SeqCst) {
break;
}
tokio::task::yield_now().await;
}
assert!(fired.load(Ordering::SeqCst));
}
}