use std::error::Error;
use std::fmt::{self, Display};
use bytesbuf::BytesView;
use futures::TryFutureExt;
use http::{Extensions, Version};
use http_body_util::BodyExt;
use http_extensions::timeout::BodyTimeout;
use http_extensions::{HttpBody, HttpBodyOptions, HttpError, HttpRequest, HttpResponse, Result};
use hyper_util::client::legacy::connect::{CaptureConnection, capture_connection};
use hyper_util::client::legacy::{self, Client};
use layered::Service;
use opentelemetry::metrics::Meter;
use crate::builder::HyperTransportBuilder;
use crate::connection::client_connector::ClientConnector;
use crate::connection::connect::Connect;
use crate::connection::hyper_connector_adapter::HyperConnectorAdapter;
use crate::connection::io::HyperIo;
use crate::connection::tracked_stream::TrackedStream;
use crate::error_labels::LABEL_REQUEST_HYPER;
use crate::recoverability::detect_recoverability;
use crate::telemetry::ConnectionInfo;
use crate::tls::TlsConnector;
type WrappedConnector<C, S> = HyperConnectorAdapter<ClientConnector<TlsConnector<C, S>, Box<dyn HyperIo>>, TrackedStream<Box<dyn HyperIo>>>;
pub(crate) struct HyperHandler<C, S>
where
C: Connect<S>,
S: HyperIo,
{
client: Client<WrappedConnector<C, S>, HttpBody>,
body_builder: http_extensions::HttpBodyBuilder,
}
impl<C, S> fmt::Debug for HyperHandler<C, S>
where
C: Connect<S>,
S: HyperIo,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(std::any::type_name::<Self>()).finish_non_exhaustive()
}
}
impl<C, S> Service<HttpRequest> for HyperHandler<C, S>
where
C: Connect<S>,
S: HyperIo,
{
type Out = Result<HttpResponse>;
fn execute(&self, mut input: HttpRequest) -> impl Future<Output = Result<HttpResponse>> + Send {
let captured = capture_connection::<HttpBody>(&mut input);
let body_options = input
.extensions()
.get::<BodyTimeout>()
.map(|v| HttpBodyOptions::default().timeout(v.duration()))
.unwrap_or_default();
self.client
.request(input)
.map_err(create_http_error_from_hyper_util)
.map_ok(move |res| {
let (parts, body) = res.into_parts();
let body = body
.map_frame(|f| f.map_data(BytesView::from))
.map_err(create_http_error_from_hyper);
handle_poisoning(&captured, &parts.extensions);
HttpResponse::from_parts(parts, self.body_builder.body(body, &body_options))
})
}
}
pub(crate) fn build_hyper_handler<C, S>(builder: HyperTransportBuilder<C, S>, meter: &Meter) -> HyperHandler<C, S>
where
C: Connect<S>,
S: HyperIo,
{
let HyperTransportBuilder {
connector,
clock,
tls,
body_builder,
request_filter,
supported_http_versions,
connection_lifetime,
connect_timeout,
pool_index,
mut hyper_builder,
..
} = builder;
if supported_http_versions.len() == 1 && supported_http_versions[0] == Version::HTTP_2 {
hyper_builder.http2_only(true);
}
let tls_connector = TlsConnector::new(tls, connector, request_filter, &supported_http_versions);
let inner = ClientConnector::new(
tls_connector,
clock,
connect_timeout,
supported_http_versions,
meter,
pool_index,
connection_lifetime,
);
HyperHandler {
client: hyper_builder.build(HyperConnectorAdapter::new(inner)),
body_builder,
}
}
fn create_http_error_from_hyper_util(error: legacy::Error) -> HttpError {
let recovery = detect_recoverability(&error);
HttpError::other(HyperError::Legacy(error), recovery, LABEL_REQUEST_HYPER)
}
fn create_http_error_from_hyper(error: hyper::Error) -> HttpError {
let recovery = detect_recoverability(&error);
HttpError::other(HyperError::Hyper(error), recovery, LABEL_REQUEST_HYPER)
}
#[derive(Debug)]
enum HyperError {
Legacy(legacy::Error),
Hyper(hyper::Error),
}
impl Error for HyperError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Legacy(e) => Some(e),
Self::Hyper(e) => Some(e),
}
}
}
impl Display for HyperError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Legacy(error) => write!(f, "{error}")?,
Self::Hyper(error) => write!(f, "{error}")?,
}
let mut current: Option<&(dyn Error + 'static)> = self.source();
while let Some(source) = current {
write!(f, "\ncaused by: {source}")?;
current = source.source();
}
Ok(())
}
}
fn handle_poisoning(capture: &CaptureConnection, extensions: &Extensions) {
if let Some(info) = extensions.get::<ConnectionInfo>()
&& info.is_expired()
&& let Some(connected) = capture.connection_metadata().as_ref()
{
connected.poison();
info.mark_poisoned();
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use std::time::Duration;
use anyspawn::Spawner;
use bytes::Bytes;
use http_body_util::BodyExt as _;
use http_extensions::{HttpBodyBuilder, HttpRequestBuilder};
use layered::Service as _;
use tick::Clock;
use super::*;
use crate::HyperTransport;
use crate::options::{ConnectionLifetime, RequestFilter};
use crate::testing::{FakeConnector, create_hyper_error, fake_body_builder};
use crate::tls::TlsBackend;
fn tls() -> TlsBackend {
native_tls::TlsConnector::new().unwrap().into()
}
fn http_response_bytes() -> Bytes {
Bytes::from_static(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello")
}
fn make_handler(connector: FakeConnector, lifetime: ConnectionLifetime) -> HyperTransport {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
HyperTransportBuilder::new(connector, Spawner::new_tokio(), clock, tls(), HttpBodyBuilder::new_fake())
.request_filter(RequestFilter::HttpAndHttps)
.connection_lifetime(lifetime)
.build()
}
fn test_request() -> HttpRequest {
HttpRequestBuilder::new(&fake_body_builder())
.uri("http://example.com/path")
.build()
.unwrap()
}
#[test]
#[cfg_attr(miri, ignore)]
fn debug_renders_handler_type() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let connector = FakeConnector::new_success(http_response_bytes(), clock.clone());
let handler: HyperHandler<FakeConnector, crate::testing::FakeStream> = build_hyper_handler(
HyperTransportBuilder::new(connector, Spawner::new_tokio(), clock, tls(), HttpBodyBuilder::new_fake())
.request_filter(RequestFilter::HttpAndHttps),
&opentelemetry::global::meter("test"),
);
let rendered = format!("{handler:?}");
assert!(rendered.contains("HyperHandler"), "got: {rendered}");
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn malformed_response_yields_hyper_util_error() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let connector = FakeConnector::new_success(Bytes::from_static(b"NOT A VALID HTTP RESPONSE"), clock.clone());
let handler = make_handler(connector, ConnectionLifetime::Unlimited);
let err = handler.execute(test_request()).await.expect_err("expected error");
assert!(!err.to_string().is_empty());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn http2_only_configures_hyper_correctly() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let connector = FakeConnector::new_success(http_response_bytes(), clock.clone());
let handler = HyperTransportBuilder::new(connector, Spawner::new_tokio(), clock, tls(), HttpBodyBuilder::new_fake())
.request_filter(RequestFilter::HttpAndHttps)
.supported_http_versions(&[Version::HTTP_2])
.build();
let _ = handler.execute(test_request()).await;
}
#[test]
fn poison_path_no_op_when_no_connection_info() {
let extensions = Extensions::new();
let mut req = test_request();
let capture = capture_connection::<HttpBody>(&mut req);
handle_poisoning(&capture, &extensions);
}
#[test]
fn poison_path_no_op_when_connection_not_expired() {
let mut extensions = Extensions::new();
let info = ConnectionInfo::new(&Clock::new_frozen(), 0, Some(Duration::from_secs(60)));
extensions.insert(info.clone());
let mut req = test_request();
let capture = capture_connection::<HttpBody>(&mut req);
handle_poisoning(&capture, &extensions);
assert!(!info.poisoned(), "should not be poisoned when not expired");
}
#[test]
fn poison_path_no_op_when_no_capture_metadata() {
let mut extensions = Extensions::new();
let control = tick::ClockControl::new();
let clock = control.to_clock();
let info = ConnectionInfo::new(&clock, 0, Some(Duration::from_secs(1)));
control.advance(Duration::from_secs(5));
assert!(info.is_expired());
extensions.insert(info.clone());
let mut req = test_request();
let capture = capture_connection::<HttpBody>(&mut req);
handle_poisoning(&capture, &extensions);
assert!(!info.poisoned());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn end_to_end_response_is_returned_with_body() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let connector = FakeConnector::new_success(http_response_bytes(), clock.clone());
let handler = make_handler(connector, ConnectionLifetime::Unlimited);
let resp = handler.execute(test_request()).await.unwrap();
assert_eq!(resp.status(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&*body, b"hello");
}
#[test]
fn create_http_error_from_hyper_wraps_with_label() {
use ohno::Labeled;
let err = create_http_error_from_hyper(create_hyper_error());
assert!(!err.to_string().is_empty());
assert_eq!(err.label().as_str(), "request_hyper");
}
#[test]
fn hyper_error_display_includes_source_chain() {
let err = create_hyper_error();
let wrapped = HyperError::Hyper(err);
let rendered = format!("{wrapped}");
let src = std::error::Error::source(&wrapped);
assert!(src.is_some());
if src.and_then(std::error::Error::source).is_some() {
assert!(rendered.contains("caused by"), "expected chain in: {rendered}");
}
}
}