use std::fmt;
use std::io::Error as IoError;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::time::Duration;
use bytes::Bytes;
use http_extensions::{HttpBody, HttpBodyBuilder, HttpError, HttpRequest, HttpRequestBuilder};
use hyper::rt::{Read, ReadBufCursor, Write};
use hyper_util::client::legacy::connect::{Connected, Connection};
use seatbelt::RecoveryInfo;
use templated_uri::BaseUri;
use tick::Clock;
use crate::error_labels::LABEL_CONNECT;
#[derive(Debug, Default)]
#[non_exhaustive]
pub struct PanickingStream;
impl Read for PanickingStream {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: ReadBufCursor<'_>) -> Poll<std::io::Result<()>> {
panic!("poll_read");
}
}
impl Write for PanickingStream {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll<std::io::Result<usize>> {
panic!("poll_write");
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
panic!("poll_flush");
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
panic!("poll_shutdown");
}
fn is_write_vectored(&self) -> bool {
panic!("is_write_vectored");
}
fn poll_write_vectored(self: Pin<&mut Self>, _cx: &mut Context<'_>, _bufs: &[std::io::IoSlice<'_>]) -> Poll<std::io::Result<usize>> {
panic!("poll_write_vectored");
}
}
impl Connection for PanickingStream {
fn connected(&self) -> Connected {
panic!("connected");
}
}
pub struct FakeStream {
result: Option<std::result::Result<Bytes, TestError>>,
state: Arc<Mutex<FakeStreamState>>,
}
#[derive(Debug)]
struct FakeStreamState {
request_received: bool,
read_waker: Option<Waker>,
}
impl fmt::Debug for FakeStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(std::any::type_name::<Self>()).finish_non_exhaustive()
}
}
impl Read for FakeStream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, mut buf: ReadBufCursor<'_>) -> Poll<std::io::Result<()>> {
let mut state = self.state.lock().unwrap();
if !state.request_received {
state.read_waker = Some(cx.waker().clone());
return Poll::Pending;
}
state.read_waker = None;
drop(state);
self.as_mut().result.take().map_or(Poll::Ready(Ok(())), |res| match res {
Ok(bytes) => {
buf.put_slice(&bytes);
Poll::Ready(Ok(()))
}
Err(error) => Poll::Ready(Err(error.into_io_error())),
})
}
}
impl Write for FakeStream {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
let mut state = self.state.lock().unwrap();
let was_waiting = !state.request_received;
state.request_received = true;
if was_waiting && let Some(waker) = state.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl Connection for FakeStream {
fn connected(&self) -> Connected {
Connected::new()
}
}
#[derive(Debug, Clone)]
pub struct FakeConnector {
response: Option<std::result::Result<Bytes, TestError>>,
connect_error: Option<TestError>,
clock: Clock,
pub delay: Duration,
}
impl FakeConnector {
#[must_use]
pub fn new_success(data: impl Into<Bytes>, clock: Clock) -> Self {
Self {
response: Some(Ok(data.into())),
connect_error: None,
clock,
delay: Duration::ZERO,
}
}
#[must_use]
pub fn new_failure(error: TestError, clock: Clock) -> Self {
Self {
response: Some(Err(error)),
connect_error: None,
clock,
delay: Duration::ZERO,
}
}
#[must_use]
pub fn new_connect_failure(error: TestError, clock: Clock) -> Self {
Self {
response: None,
connect_error: Some(error),
clock,
delay: Duration::ZERO,
}
}
#[must_use]
pub fn with_delay(mut self, delay: Duration) -> Self {
self.delay = delay;
self
}
}
impl layered::Service<BaseUri> for FakeConnector {
type Out = http_extensions::Result<FakeStream>;
fn execute(&self, _input: BaseUri) -> impl Future<Output = Self::Out> + Send {
let response = self.response.clone();
let connect_error = self.connect_error.clone();
let clock = self.clock.clone();
let delay = self.delay;
async move {
clock.delay(delay).await;
if let Some(error) = connect_error {
return Err(HttpError::other(error, RecoveryInfo::retry(), LABEL_CONNECT));
}
Ok(FakeStream {
result: response,
state: Arc::new(Mutex::new(FakeStreamState {
request_received: false,
read_waker: None,
})),
})
}
}
}
#[derive(Debug, Clone)]
pub struct TestError {
message: String,
inner: Option<Arc<dyn std::error::Error + Send + Sync>>,
}
impl TestError {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
inner: None,
}
}
#[must_use]
pub fn with_inner<E: std::error::Error + Send + Sync + 'static>(mut self, inner: E) -> Self {
self.inner = Some(Arc::new(inner));
self
}
#[must_use]
pub fn with_inner_recoverability(self, recoverability: RecoveryInfo) -> Self {
self.with_inner(HttpError::other("inner error", recoverability, "other"))
}
#[must_use]
pub fn into_io_error(self) -> IoError {
IoError::other(self)
}
}
impl fmt::Display for TestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for TestError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.inner.as_ref().map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
}
}
#[must_use]
pub fn create_test_request() -> HttpRequest {
HttpRequestBuilder::new(&HttpBodyBuilder::new_fake())
.uri("http://example.com/some-custom-path")
.build()
.expect("test request should build")
}
#[must_use]
pub fn fake_body_builder() -> HttpBodyBuilder {
HttpBodyBuilder::new_fake()
}
#[must_use]
pub fn sorted_attributes(attrs: &[opentelemetry::KeyValue]) -> Vec<(String, String)> {
let mut pairs: Vec<(String, String)> = attrs.iter().map(|kv| (kv.key.to_string(), kv.value.to_string())).collect();
pairs.sort();
pairs
}
#[must_use]
pub fn create_hyper_error() -> hyper::Error {
use futures::executor::block_on;
let (_, conn) = block_on(hyper::client::conn::http1::Builder::new().handshake::<_, HttpBody>(FailingStream))
.expect("handshake should succeed against in-memory stream");
block_on(conn).expect_err("connection driven against FailingStream must fail")
}
#[derive(Debug)]
struct FailingStream;
impl Read for FailingStream {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: ReadBufCursor<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Err(IoError::other("FailingStream read error")))
}
}
impl Write for FailingStream {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll<std::io::Result<usize>> {
Poll::Ready(Err(IoError::other("FailingStream write error")))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Err(IoError::other("FailingStream flush error")))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Err(IoError::other("FailingStream shutdown error")))
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use std::time::Duration;
use anyspawn::Spawner;
use bytes::Bytes;
use fetch_options::RequestFilter;
use fetch_tls::TlsBackend;
use http_body_util::BodyExt;
use layered::Service as _;
use native_tls::TlsConnector;
use seatbelt::RecoveryInfo;
use crate::HyperTransportBuilder;
use crate::testing::{FakeConnector, TestError, create_test_request, fake_body_builder};
fn build_tls() -> TlsBackend {
TlsConnector::new().unwrap().into()
}
fn http_1_response() -> Bytes {
Bytes::from_static(b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, World!")
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fake_connector_serves_canned_response() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let mut options = fetch_options::TransportOptions::default();
options.request_filter = RequestFilter::HttpAndHttps;
let handler = HyperTransportBuilder::new(
FakeConnector::new_success(http_1_response(), clock.clone()),
Spawner::new_tokio(),
clock,
options,
)
.body_builder(fake_body_builder())
.build(build_tls());
let response = handler.execute(create_test_request()).await.unwrap();
assert_eq!(response.status(), 200);
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&*body, b"Hello, World!");
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fake_connector_propagates_connect_failure() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let mut options = fetch_options::TransportOptions::default();
options.request_filter = RequestFilter::HttpAndHttps;
options.connect_timeout = Duration::from_secs(5);
let handler = HyperTransportBuilder::new(
FakeConnector::new_connect_failure(
TestError::new("forced connect error").with_inner_recoverability(RecoveryInfo::retry()),
clock.clone(),
),
Spawner::new_tokio(),
clock,
options,
)
.body_builder(fake_body_builder())
.build(build_tls());
let error = handler
.execute(create_test_request())
.await
.expect_err("connect failure should propagate");
let rendered = error.to_string();
assert!(
rendered.contains("forced connect error"),
"expected error to mention forced connect error, got: {rendered}"
);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn https_only_filter_rejects_http_request() {
let clock = tick::ClockControl::new().auto_advance_timers(true).to_clock();
let handler = HyperTransportBuilder::new(
FakeConnector::new_success(http_1_response(), clock.clone()),
Spawner::new_tokio(),
clock,
fetch_options::TransportOptions::default(),
)
.body_builder(fake_body_builder())
.build(build_tls());
let error = handler
.execute(create_test_request())
.await
.expect_err("http request should be rejected when only https is allowed");
assert!(
error.to_string().to_lowercase().contains("scheme") || error.to_string().to_lowercase().contains("http"),
"expected scheme/http error, got: {error}"
);
}
}