fetch_hyper 0.1.1

Hyper-based HTTP transport utilities for fetch.
Documentation
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

//! Internal generic [`HyperHandler`] driving hyper-util's `legacy::Client`.
//!
//! Implements [`Service<HttpRequest>`]. Type-erased into
//! [`HyperTransport`](crate::HyperTransport) by
//! [`HyperTransportBuilder::build`](crate::HyperTransportBuilder::build).

use std::error::Error;
use std::fmt::{self, Display};

use bytesbuf::BytesView;
use futures::TryFutureExt;
use http::{Extensions, Version};
use http_body_util::BodyExt;
use http_extensions::timeout::BodyTimeout;
use http_extensions::{HttpBody, HttpBodyOptions, HttpError, HttpRequest, HttpResponse, Result};
use hyper_util::client::legacy::connect::{CaptureConnection, capture_connection};
use hyper_util::client::legacy::{self, Client};
use layered::Service;
use opentelemetry::metrics::Meter;

use crate::builder::HyperTransportBuilder;
use crate::connection::client_connector::ClientConnector;
use crate::connection::connect::Connect;
use crate::connection::hyper_connector_adapter::HyperConnectorAdapter;
use crate::connection::io::HyperIo;
use crate::connection::tracked_stream::TrackedStream;
use crate::error_labels::LABEL_REQUEST_HYPER;
use crate::recoverability::detect_recoverability;
use crate::telemetry::ConnectionInfo;
use crate::tls::TlsConnector;

/// The fully-wrapped connector chain handed to `hyper`'s [`Client`].
type WrappedConnector<C, S> = HyperConnectorAdapter<ClientConnector<TlsConnector<C, S>, Box<dyn HyperIo>>, TrackedStream<Box<dyn HyperIo>>>;

/// A Hyper-backed request handler, parameterized by the user-supplied
/// connector and stream types. Public consumers see only the
/// type-erased [`HyperTransport`](crate::HyperTransport).
pub(crate) struct HyperHandler<C, S>
where
    C: Connect<S>,
    S: HyperIo,
{
    client: Client<WrappedConnector<C, S>, HttpBody>,
    body_builder: http_extensions::HttpBodyBuilder,
}

impl<C, S> fmt::Debug for HyperHandler<C, S>
where
    C: Connect<S>,
    S: HyperIo,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct(std::any::type_name::<Self>()).finish_non_exhaustive()
    }
}

impl<C, S> Service<HttpRequest> for HyperHandler<C, S>
where
    C: Connect<S>,
    S: HyperIo,
{
    type Out = Result<HttpResponse>;

    fn execute(&self, mut input: HttpRequest) -> impl Future<Output = Result<HttpResponse>> + Send {
        let captured = capture_connection::<HttpBody>(&mut input);

        let body_options = input
            .extensions()
            .get::<BodyTimeout>()
            .map(|v| HttpBodyOptions::default().timeout(v.duration()))
            .unwrap_or_default();

        self.client
            .request(input)
            .map_err(create_http_error_from_hyper_util)
            .map_ok(move |res| {
                let (parts, body) = res.into_parts();

                let body = body
                    .map_frame(|f| f.map_data(BytesView::from))
                    .map_err(create_http_error_from_hyper);

                handle_poisoning(&captured, &parts.extensions);

                HttpResponse::from_parts(parts, self.body_builder.body(body, &body_options))
            })
    }
}

/// Assembles a [`HyperHandler`] from a configured [`HyperTransportBuilder`].
pub(crate) fn build_hyper_handler<C, S>(builder: HyperTransportBuilder<C, S>, meter: &Meter) -> HyperHandler<C, S>
where
    C: Connect<S>,
    S: HyperIo,
{
    let HyperTransportBuilder {
        connector,
        clock,
        tls,
        body_builder,
        request_filter,
        supported_http_versions,
        connection_lifetime,
        connect_timeout,
        pool_index,
        mut hyper_builder,
        ..
    } = builder;

    if supported_http_versions.len() == 1 && supported_http_versions[0] == Version::HTTP_2 {
        hyper_builder.http2_only(true);
    }

    let tls_connector = TlsConnector::new(tls, connector, request_filter, &supported_http_versions);

    let inner = ClientConnector::new(
        tls_connector,
        clock,
        connect_timeout,
        supported_http_versions,
        meter,
        pool_index,
        connection_lifetime,
    );

    HyperHandler {
        client: hyper_builder.build(HyperConnectorAdapter::new(inner)),
        body_builder,
    }
}

fn create_http_error_from_hyper_util(error: legacy::Error) -> HttpError {
    let recovery = detect_recoverability(&error);
    HttpError::other(HyperError::Legacy(error), recovery, LABEL_REQUEST_HYPER)
}

fn create_http_error_from_hyper(error: hyper::Error) -> HttpError {
    let recovery = detect_recoverability(&error);
    HttpError::other(HyperError::Hyper(error), recovery, LABEL_REQUEST_HYPER)
}

#[derive(Debug)]
enum HyperError {
    Legacy(legacy::Error),
    Hyper(hyper::Error),
}

impl Error for HyperError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            Self::Legacy(e) => Some(e),
            Self::Hyper(e) => Some(e),
        }
    }
}

impl Display for HyperError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Legacy(error) => write!(f, "{error}")?,
            Self::Hyper(error) => write!(f, "{error}")?,
        }

        let mut current: Option<&(dyn Error + 'static)> = self.source();
        while let Some(source) = current {
            write!(f, "\ncaused by: {source}")?;
            current = source.source();
        }

        Ok(())
    }
}

fn handle_poisoning(capture: &CaptureConnection, extensions: &Extensions) {
    if let Some(info) = extensions.get::<ConnectionInfo>()
        && info.is_expired()
        && let Some(connected) = capture.connection_metadata().as_ref()
    {
        connected.poison();
        info.mark_poisoned();
    }
}

#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
    use std::time::Duration;

    use anyspawn::Spawner;
    use bytes::Bytes;
    use http_body_util::BodyExt as _;
    use http_extensions::{HttpBodyBuilder, HttpRequestBuilder};
    use layered::Service as _;
    use tick::Clock;

    use super::*;
    use crate::HyperTransport;
    use crate::options::{ConnectionLifetime, RequestFilter};
    use crate::testing::{FakeConnector, create_hyper_error, fake_body_builder};
    use crate::tls::TlsBackend;

    fn tls() -> TlsBackend {
        native_tls::TlsConnector::new().unwrap().into()
    }

    fn http_response_bytes() -> Bytes {
        Bytes::from_static(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello")
    }

    fn make_handler(connector: FakeConnector, lifetime: ConnectionLifetime) -> HyperTransport {
        let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
        HyperTransportBuilder::new(connector, Spawner::new_tokio(), clock, tls(), HttpBodyBuilder::new_fake())
            .request_filter(RequestFilter::HttpAndHttps)
            .connection_lifetime(lifetime)
            .build()
    }

    fn test_request() -> HttpRequest {
        HttpRequestBuilder::new(&fake_body_builder())
            .uri("http://example.com/path")
            .build()
            .unwrap()
    }

    #[test]
    #[cfg_attr(miri, ignore)]
    fn debug_renders_handler_type() {
        let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
        let connector = FakeConnector::new_success(http_response_bytes(), clock.clone());
        let handler: HyperHandler<FakeConnector, crate::testing::FakeStream> = build_hyper_handler(
            HyperTransportBuilder::new(connector, Spawner::new_tokio(), clock, tls(), HttpBodyBuilder::new_fake())
                .request_filter(RequestFilter::HttpAndHttps),
            &opentelemetry::global::meter("test"),
        );
        let rendered = format!("{handler:?}");
        assert!(rendered.contains("HyperHandler"), "got: {rendered}");
    }

    #[cfg_attr(miri, ignore)]
    #[tokio::test]
    async fn malformed_response_yields_hyper_util_error() {
        // The byte stream is not a valid HTTP/1 response, so hyper's client
        // request future fails with a `legacy::Error`, exercising
        // `create_http_error_from_hyper_util`.
        let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
        let connector = FakeConnector::new_success(Bytes::from_static(b"NOT A VALID HTTP RESPONSE"), clock.clone());
        let handler = make_handler(connector, ConnectionLifetime::Unlimited);
        let err = handler.execute(test_request()).await.expect_err("expected error");
        assert!(!err.to_string().is_empty());
    }

    #[cfg_attr(miri, ignore)]
    #[tokio::test]
    async fn http2_only_configures_hyper_correctly() {
        // Builder with HTTP/2-only flips `http2_only(true)` on hyper's builder.
        // Using FakeStream over HTTP/1.1-style data will fail, but we want to
        // simply exercise the build path and request execution.
        let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
        let connector = FakeConnector::new_success(http_response_bytes(), clock.clone());
        let handler = HyperTransportBuilder::new(connector, Spawner::new_tokio(), clock, tls(), HttpBodyBuilder::new_fake())
            .request_filter(RequestFilter::HttpAndHttps)
            .supported_http_versions(&[Version::HTTP_2])
            .build();
        // Execute to drive the http2 path; we don't care if it fails or not.
        let _ = handler.execute(test_request()).await;
    }

    #[test]
    fn poison_path_no_op_when_no_connection_info() {
        let extensions = Extensions::new();
        let mut req = test_request();
        let capture = capture_connection::<HttpBody>(&mut req);
        // No ConnectionInfo on extensions → handle_poisoning is a no-op.
        handle_poisoning(&capture, &extensions);
    }

    #[test]
    fn poison_path_no_op_when_connection_not_expired() {
        let mut extensions = Extensions::new();
        let info = ConnectionInfo::new(&Clock::new_frozen(), 0, Some(Duration::from_secs(60)));
        extensions.insert(info.clone());

        let mut req = test_request();
        let capture = capture_connection::<HttpBody>(&mut req);
        handle_poisoning(&capture, &extensions);
        assert!(!info.poisoned(), "should not be poisoned when not expired");
    }

    #[test]
    fn poison_path_no_op_when_no_capture_metadata() {
        let mut extensions = Extensions::new();
        let control = tick::ClockControl::new();
        let clock = control.to_clock();
        let info = ConnectionInfo::new(&clock, 0, Some(Duration::from_secs(1)));
        control.advance(Duration::from_secs(5));
        assert!(info.is_expired());
        extensions.insert(info.clone());

        let mut req = test_request();
        let capture = capture_connection::<HttpBody>(&mut req);
        // capture.connection_metadata() returns None until hyper populates it.
        handle_poisoning(&capture, &extensions);
        // No metadata available → mark_poisoned must NOT be called.
        assert!(!info.poisoned());
    }

    #[cfg_attr(miri, ignore)]
    #[tokio::test]
    async fn end_to_end_response_is_returned_with_body() {
        let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
        let connector = FakeConnector::new_success(http_response_bytes(), clock.clone());
        let handler = make_handler(connector, ConnectionLifetime::Unlimited);
        let resp = handler.execute(test_request()).await.unwrap();
        assert_eq!(resp.status(), 200);
        let body = resp.into_body().collect().await.unwrap().to_bytes();
        assert_eq!(&*body, b"hello");
    }

    #[test]
    fn create_http_error_from_hyper_wraps_with_label() {
        use ohno::Labeled;
        let err = create_http_error_from_hyper(create_hyper_error());
        assert!(!err.to_string().is_empty());
        assert_eq!(err.label().as_str(), "request_hyper");
    }

    #[test]
    fn hyper_error_display_includes_source_chain() {
        let err = create_hyper_error();
        let wrapped = HyperError::Hyper(err);
        let rendered = format!("{wrapped}");
        // HyperError::Hyper always exposes its inner error as a source, and
        // create_hyper_error produces a hyper::Error with at least one source
        // level (an io::Error).
        let src = std::error::Error::source(&wrapped);
        assert!(src.is_some());
        if src.and_then(std::error::Error::source).is_some() {
            assert!(rendered.contains("caused by"), "expected chain in: {rendered}");
        }
    }
}